mkql_computation_node_graph.cpp 36 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048
  1. #include "mkql_computation_node_holders.h"
  2. #include "mkql_computation_node_holders_codegen.h"
  3. #include "mkql_value_builder.h"
  4. #include "mkql_computation_node_codegen.h" // Y_IGNORE
  5. #include <yql/essentials/public/udf/arrow/memory_pool.h>
  6. #include <yql/essentials/minikql/computation/mkql_computation_pattern_cache.h>
  7. #include <yql/essentials/minikql/comp_nodes/mkql_saveload.h>
  8. #include <yql/essentials/minikql/mkql_type_builder.h>
  9. #include <yql/essentials/minikql/mkql_terminator.h>
  10. #include <yql/essentials/minikql/mkql_string_util.h>
  11. #include <util/system/env.h>
  12. #include <util/system/mutex.h>
  13. #include <util/digest/city.h>
  14. #ifndef MKQL_DISABLE_CODEGEN
  15. #include <llvm/Support/raw_ostream.h> // Y_IGNORE
  16. #endif
  17. namespace NKikimr {
  18. namespace NMiniKQL {
  19. using namespace NDetail;
  20. namespace {
  21. #ifndef MKQL_DISABLE_CODEGEN
  22. constexpr ui64 TotalFunctionsLimit = 1000;
  23. constexpr ui64 TotalInstructionsLimit = 100000;
  24. constexpr ui64 MaxFunctionInstructionsLimit = 50000;
  25. #endif
  26. const ui64 IS_NODE_REACHABLE = 1;
  27. const static TStatKey PagePool_PeakAllocated("PagePool_PeakAllocated", false);
  28. const static TStatKey PagePool_PeakUsed("PagePool_PeakUsed", false);
  29. const static TStatKey PagePool_AllocCount("PagePool_AllocCount", true);
  30. const static TStatKey PagePool_PageAllocCount("PagePool_PageAllocCount", true);
  31. const static TStatKey PagePool_PageHitCount("PagePool_PageHitCount", true);
  32. const static TStatKey PagePool_PageMissCount("PagePool_PageMissCount", true);
  33. const static TStatKey PagePool_OffloadedAllocCount("PagePool_OffloadedAllocCount", true);
  34. const static TStatKey PagePool_OffloadedBytes("PagePool_OffloadedBytes", true);
  35. const static TStatKey CodeGen_FullTime("CodeGen_FullTime", true);
  36. const static TStatKey CodeGen_GenerateTime("CodeGen_GenerateTime", true);
  37. const static TStatKey CodeGen_CompileTime("CodeGen_CompileTime", true);
  38. const static TStatKey CodeGen_TotalFunctions("CodeGen_TotalFunctions", true);
  39. const static TStatKey CodeGen_TotalInstructions("CodeGen_TotalInstructions", true);
  40. const static TStatKey CodeGen_MaxFunctionInstructions("CodeGen_MaxFunctionInstructions", false);
  41. const static TStatKey CodeGen_FunctionPassTime("CodeGen_FunctionPassTime", true);
  42. const static TStatKey CodeGen_ModulePassTime("CodeGen_ModulePassTime", true);
  43. const static TStatKey CodeGen_FinalizeTime("CodeGen_FinalizeTime", true);
  44. const static TStatKey Mkql_TotalNodes("Mkql_TotalNodes", true);
  45. const static TStatKey Mkql_CodegenFunctions("Mkql_CodegenFunctions", true);
  46. class TDependencyScanVisitor : public TEmptyNodeVisitor {
  47. public:
  48. void Walk(TNode* root, const TTypeEnvironment& env) {
  49. Stack = &env.GetNodeStack();
  50. Stack->clear();
  51. Stack->push_back(root);
  52. while (!Stack->empty()) {
  53. auto top = Stack->back();
  54. Stack->pop_back();
  55. if (top->GetCookie() != IS_NODE_REACHABLE) {
  56. top->SetCookie(IS_NODE_REACHABLE);
  57. top->Accept(*this);
  58. }
  59. }
  60. Stack = nullptr;
  61. }
  62. private:
  63. using TEmptyNodeVisitor::Visit;
  64. void Visit(TStructLiteral& node) override {
  65. for (ui32 i = 0; i < node.GetValuesCount(); ++i) {
  66. AddNode(node.GetValue(i).GetNode());
  67. }
  68. }
  69. void Visit(TListLiteral& node) override {
  70. for (ui32 i = 0; i < node.GetItemsCount(); ++i) {
  71. AddNode(node.GetItems()[i].GetNode());
  72. }
  73. }
  74. void Visit(TOptionalLiteral& node) override {
  75. if (node.HasItem()) {
  76. AddNode(node.GetItem().GetNode());
  77. }
  78. }
  79. void Visit(TDictLiteral& node) override {
  80. for (ui32 i = 0; i < node.GetItemsCount(); ++i) {
  81. AddNode(node.GetItem(i).first.GetNode());
  82. AddNode(node.GetItem(i).second.GetNode());
  83. }
  84. }
  85. void Visit(TCallable& node) override {
  86. if (node.HasResult()) {
  87. AddNode(node.GetResult().GetNode());
  88. } else {
  89. for (ui32 i = 0; i < node.GetInputsCount(); ++i) {
  90. AddNode(node.GetInput(i).GetNode());
  91. }
  92. }
  93. }
  94. void Visit(TAny& node) override {
  95. if (node.HasItem()) {
  96. AddNode(node.GetItem().GetNode());
  97. }
  98. }
  99. void Visit(TTupleLiteral& node) override {
  100. for (ui32 i = 0; i < node.GetValuesCount(); ++i) {
  101. AddNode(node.GetValue(i).GetNode());
  102. }
  103. }
  104. void Visit(TVariantLiteral& node) override {
  105. AddNode(node.GetItem().GetNode());
  106. }
  107. void AddNode(TNode* node) {
  108. if (node->GetCookie() != IS_NODE_REACHABLE) {
  109. Stack->push_back(node);
  110. }
  111. }
  112. std::vector<TNode*>* Stack = nullptr;
  113. };
  114. class TPatternNodes: public TAtomicRefCount<TPatternNodes> {
  115. public:
  116. typedef TIntrusivePtr<TPatternNodes> TPtr;
  117. TPatternNodes(TAllocState& allocState)
  118. : AllocState(allocState)
  119. , MemInfo(MakeIntrusive<TMemoryUsageInfo>("ComputationPatternNodes"))
  120. {
  121. #ifndef NDEBUG
  122. AllocState.ActiveMemInfo.emplace(MemInfo.Get(), MemInfo);
  123. #else
  124. Y_UNUSED(AllocState);
  125. #endif
  126. }
  127. ~TPatternNodes()
  128. {
  129. for (auto it = ComputationNodesList.rbegin(); it != ComputationNodesList.rend(); ++it) {
  130. *it = nullptr;
  131. }
  132. ComputationNodesList.clear();
  133. if (!UncaughtException()) {
  134. #ifndef NDEBUG
  135. AllocState.ActiveMemInfo.erase(MemInfo.Get());
  136. #endif
  137. }
  138. }
  139. ITerminator& GetTerminator() {
  140. return *ValueBuilder;
  141. }
  142. const TComputationMutables& GetMutables() const {
  143. return Mutables;
  144. }
  145. const TComputationNodePtrDeque& GetNodes() const {
  146. return ComputationNodesList;
  147. }
  148. IComputationNode* GetComputationNode(TNode* node, bool pop = false, bool require = true) {
  149. const auto cookie = node->GetCookie();
  150. const auto result = reinterpret_cast<IComputationNode*>(cookie);
  151. if (cookie <= IS_NODE_REACHABLE) {
  152. MKQL_ENSURE(!require, "Computation graph builder, node not found, type:"
  153. << node->GetType()->GetKindAsStr());
  154. return result;
  155. }
  156. if (pop) {
  157. node->SetCookie(0);
  158. }
  159. return result;
  160. }
  161. IComputationExternalNode* GetEntryPoint(size_t index, bool require) {
  162. MKQL_ENSURE(index < Runtime2ComputationEntryPoints.size() && (!require || Runtime2ComputationEntryPoints[index]),
  163. "Pattern nodes can not get computation node by index: " << index << ", require: " << require
  164. << ", Runtime2ComputationEntryPoints size: " << Runtime2ComputationEntryPoints.size());
  165. return Runtime2ComputationEntryPoints[index];
  166. }
  167. IComputationNode* GetRoot() {
  168. return RootNode;
  169. }
  170. bool GetSuitableForCache() const {
  171. return SuitableForCache;
  172. }
  173. size_t GetEntryPointsCount() const {
  174. return Runtime2ComputationEntryPoints.size();
  175. }
  176. private:
  177. friend class TComputationGraphBuildingVisitor;
  178. friend class TComputationGraph;
  179. TAllocState& AllocState;
  180. TIntrusivePtr<TMemoryUsageInfo> MemInfo;
  181. THolder<THolderFactory> HolderFactory;
  182. THolder<TDefaultValueBuilder> ValueBuilder;
  183. TComputationMutables Mutables;
  184. TComputationNodePtrDeque ComputationNodesList;
  185. IComputationNode* RootNode = nullptr;
  186. TComputationExternalNodePtrVector Runtime2ComputationEntryPoints;
  187. TComputationNodeOnNodeMap ElementsCache;
  188. bool SuitableForCache = true;
  189. };
  190. class TComputationGraphBuildingVisitor:
  191. public INodeVisitor,
  192. private TNonCopyable
  193. {
  194. public:
  195. TComputationGraphBuildingVisitor(const TComputationPatternOpts& opts)
  196. : Env(opts.Env)
  197. , TypeInfoHelper(new TTypeInfoHelper())
  198. , CountersProvider(opts.CountersProvider)
  199. , SecureParamsProvider(opts.SecureParamsProvider)
  200. , Factory(opts.Factory)
  201. , FunctionRegistry(*opts.FunctionRegistry)
  202. , ValidateMode(opts.ValidateMode)
  203. , ValidatePolicy(opts.ValidatePolicy)
  204. , GraphPerProcess(opts.GraphPerProcess)
  205. , PatternNodes(MakeIntrusive<TPatternNodes>(opts.AllocState))
  206. , ExternalAlloc(opts.PatternEnv)
  207. {
  208. PatternNodes->HolderFactory = MakeHolder<THolderFactory>(opts.AllocState, *PatternNodes->MemInfo, &FunctionRegistry);
  209. PatternNodes->ValueBuilder = MakeHolder<TDefaultValueBuilder>(*PatternNodes->HolderFactory, ValidatePolicy);
  210. PatternNodes->ValueBuilder->SetSecureParamsProvider(opts.SecureParamsProvider);
  211. NodeFactory = MakeHolder<TNodeFactory>(*PatternNodes->MemInfo, PatternNodes->Mutables);
  212. }
  213. ~TComputationGraphBuildingVisitor() {
  214. auto g = Env.BindAllocator();
  215. NodeFactory.Reset();
  216. PatternNodes.Reset();
  217. }
  218. const TTypeEnvironment& GetTypeEnvironment() const {
  219. return Env;
  220. }
  221. const IFunctionRegistry& GetFunctionRegistry() const {
  222. return FunctionRegistry;
  223. }
  224. private:
  225. template <typename T>
  226. void VisitType(T& node) {
  227. AddNode(node, NodeFactory->CreateTypeNode(&node));
  228. }
  229. void Visit(TTypeType& node) override {
  230. VisitType<TTypeType>(node);
  231. }
  232. void Visit(TVoidType& node) override {
  233. VisitType<TVoidType>(node);
  234. }
  235. void Visit(TNullType& node) override {
  236. VisitType<TNullType>(node);
  237. }
  238. void Visit(TEmptyListType& node) override {
  239. VisitType<TEmptyListType>(node);
  240. }
  241. void Visit(TEmptyDictType& node) override {
  242. VisitType<TEmptyDictType>(node);
  243. }
  244. void Visit(TDataType& node) override {
  245. VisitType<TDataType>(node);
  246. }
  247. void Visit(TPgType& node) override {
  248. VisitType<TPgType>(node);
  249. }
  250. void Visit(TStructType& node) override {
  251. VisitType<TStructType>(node);
  252. }
  253. void Visit(TListType& node) override {
  254. VisitType<TListType>(node);
  255. }
  256. void Visit(TStreamType& node) override {
  257. VisitType<TStreamType>(node);
  258. }
  259. void Visit(TFlowType& node) override {
  260. VisitType<TFlowType>(node);
  261. }
  262. void Visit(TBlockType& node) override {
  263. VisitType<TBlockType>(node);
  264. }
  265. void Visit(TMultiType& node) override {
  266. VisitType<TMultiType>(node);
  267. }
  268. void Visit(TTaggedType& node) override {
  269. VisitType<TTaggedType>(node);
  270. }
  271. void Visit(TOptionalType& node) override {
  272. VisitType<TOptionalType>(node);
  273. }
  274. void Visit(TDictType& node) override {
  275. VisitType<TDictType>(node);
  276. }
  277. void Visit(TCallableType& node) override {
  278. VisitType<TCallableType>(node);
  279. }
  280. void Visit(TAnyType& node) override {
  281. VisitType<TAnyType>(node);
  282. }
  283. void Visit(TTupleType& node) override {
  284. VisitType<TTupleType>(node);
  285. }
  286. void Visit(TResourceType& node) override {
  287. VisitType<TResourceType>(node);
  288. }
  289. void Visit(TVariantType& node) override {
  290. VisitType<TVariantType>(node);
  291. }
  292. void Visit(TVoid& node) override {
  293. AddNode(node, NodeFactory->CreateImmutableNode(NUdf::TUnboxedValue::Void()));
  294. }
  295. void Visit(TNull& node) override {
  296. AddNode(node, NodeFactory->CreateImmutableNode(NUdf::TUnboxedValue()));
  297. }
  298. void Visit(TEmptyList& node) override {
  299. AddNode(node, NodeFactory->CreateImmutableNode(PatternNodes->HolderFactory->GetEmptyContainerLazy()));
  300. }
  301. void Visit(TEmptyDict& node) override {
  302. AddNode(node, NodeFactory->CreateImmutableNode(PatternNodes->HolderFactory->GetEmptyContainerLazy()));
  303. }
  304. void Visit(TDataLiteral& node) override {
  305. auto value = node.AsValue();
  306. NUdf::TDataTypeId typeId = node.GetType()->GetSchemeType();
  307. if (typeId != 0x101) { // TODO remove
  308. const auto slot = NUdf::GetDataSlot(typeId);
  309. MKQL_ENSURE(IsValidValue(slot, value),
  310. "Bad data literal for type: " << NUdf::GetDataTypeInfo(slot).Name << ", " << value);
  311. }
  312. NUdf::TUnboxedValue externalValue;
  313. if (ExternalAlloc) {
  314. if (value.IsString()) {
  315. externalValue = MakeString(value.AsStringRef());
  316. }
  317. }
  318. if (!externalValue) {
  319. externalValue = std::move(value);
  320. }
  321. AddNode(node, NodeFactory->CreateImmutableNode(std::move(externalValue)));
  322. }
  323. void Visit(TStructLiteral& node) override {
  324. TComputationNodePtrVector values;
  325. values.reserve(node.GetValuesCount());
  326. for (ui32 i = 0, e = node.GetValuesCount(); i < e; ++i) {
  327. values.push_back(GetComputationNode(node.GetValue(i).GetNode()));
  328. }
  329. AddNode(node, NodeFactory->CreateArrayNode(std::move(values)));
  330. }
  331. void Visit(TListLiteral& node) override {
  332. TComputationNodePtrVector items;
  333. items.reserve(node.GetItemsCount());
  334. for (ui32 i = 0; i < node.GetItemsCount(); ++i) {
  335. items.push_back(GetComputationNode(node.GetItems()[i].GetNode()));
  336. }
  337. AddNode(node, NodeFactory->CreateArrayNode(std::move(items)));
  338. }
  339. void Visit(TOptionalLiteral& node) override {
  340. auto item = node.HasItem() ? GetComputationNode(node.GetItem().GetNode()) : nullptr;
  341. AddNode(node, NodeFactory->CreateOptionalNode(item));
  342. }
  343. void Visit(TDictLiteral& node) override {
  344. auto keyType = node.GetType()->GetKeyType();
  345. TKeyTypes types;
  346. bool isTuple;
  347. bool encoded;
  348. bool useIHash;
  349. GetDictionaryKeyTypes(keyType, types, isTuple, encoded, useIHash);
  350. std::vector<std::pair<IComputationNode*, IComputationNode*>> items;
  351. items.reserve(node.GetItemsCount());
  352. for (ui32 i = 0, e = node.GetItemsCount(); i < e; ++i) {
  353. auto item = node.GetItem(i);
  354. items.push_back(std::make_pair(GetComputationNode(item.first.GetNode()), GetComputationNode(item.second.GetNode())));
  355. }
  356. bool isSorted = !CanHash(keyType);
  357. AddNode(node, NodeFactory->CreateDictNode(std::move(items), types, isTuple, encoded ? keyType : nullptr,
  358. useIHash && !isSorted ? MakeHashImpl(keyType) : nullptr,
  359. useIHash ? MakeEquateImpl(keyType) : nullptr,
  360. useIHash && isSorted ? MakeCompareImpl(keyType) : nullptr, isSorted));
  361. }
  362. void Visit(TCallable& node) override {
  363. if (node.HasResult()) {
  364. node.GetResult().GetNode()->Accept(*this);
  365. auto computationNode = PatternNodes->ComputationNodesList.back().Get();
  366. node.SetCookie((ui64)computationNode);
  367. return;
  368. }
  369. if (node.GetType()->GetName() == "Steal") {
  370. return;
  371. }
  372. TNodeLocator nodeLocator = [this](TNode* dependentNode, bool pop) {
  373. return GetComputationNode(dependentNode, pop);
  374. };
  375. TComputationNodeFactoryContext ctx(
  376. nodeLocator,
  377. FunctionRegistry,
  378. Env,
  379. TypeInfoHelper,
  380. CountersProvider,
  381. SecureParamsProvider,
  382. *NodeFactory,
  383. *PatternNodes->HolderFactory,
  384. PatternNodes->ValueBuilder.Get(),
  385. ValidateMode,
  386. ValidatePolicy,
  387. GraphPerProcess,
  388. PatternNodes->Mutables,
  389. PatternNodes->ElementsCache,
  390. std::bind(&TComputationGraphBuildingVisitor::PushBackNode, this, std::placeholders::_1));
  391. const auto computationNode = Factory(node, ctx);
  392. const auto& name = node.GetType()->GetName();
  393. if (name == "KqpWideReadTable" ||
  394. name == "KqpWideReadTableRanges" ||
  395. name == "KqpBlockReadTableRanges" ||
  396. name == "KqpLookupTable" ||
  397. name == "KqpReadTable"
  398. ) {
  399. PatternNodes->SuitableForCache = false;
  400. }
  401. if (!computationNode) {
  402. THROW yexception()
  403. << "Computation graph builder, unsupported function: " << name << " type: " << Factory.target_type().name() ;
  404. }
  405. AddNode(node, computationNode);
  406. }
  407. void Visit(TAny& node) override {
  408. if (!node.HasItem()) {
  409. AddNode(node, NodeFactory->CreateImmutableNode(NUdf::TUnboxedValue::Void()));
  410. } else {
  411. AddNode(node, GetComputationNode(node.GetItem().GetNode()));
  412. }
  413. }
  414. void Visit(TTupleLiteral& node) override {
  415. TComputationNodePtrVector values;
  416. values.reserve(node.GetValuesCount());
  417. for (ui32 i = 0, e = node.GetValuesCount(); i < e; ++i) {
  418. values.push_back(GetComputationNode(node.GetValue(i).GetNode()));
  419. }
  420. AddNode(node, NodeFactory->CreateArrayNode(std::move(values)));
  421. }
  422. void Visit(TVariantLiteral& node) override {
  423. auto item = GetComputationNode(node.GetItem().GetNode());
  424. AddNode(node, NodeFactory->CreateVariantNode(item, node.GetIndex()));
  425. }
  426. public:
  427. IComputationNode* GetComputationNode(TNode* node, bool pop = false, bool require = true) {
  428. return PatternNodes->GetComputationNode(node, pop, require);
  429. }
  430. TMemoryUsageInfo& GetMemInfo() {
  431. return *PatternNodes->MemInfo;
  432. }
  433. const THolderFactory& GetHolderFactory() const {
  434. return *PatternNodes->HolderFactory;
  435. }
  436. TPatternNodes::TPtr GetPatternNodes() {
  437. return PatternNodes;
  438. }
  439. const TComputationNodePtrDeque& GetNodes() const {
  440. return PatternNodes->GetNodes();
  441. }
  442. void PreserveRoot(IComputationNode* rootNode) {
  443. PatternNodes->RootNode = rootNode;
  444. }
  445. void PreserveEntryPoints(TComputationExternalNodePtrVector&& runtime2ComputationEntryPoints) {
  446. PatternNodes->Runtime2ComputationEntryPoints = std::move(runtime2ComputationEntryPoints);
  447. }
  448. private:
  449. void PushBackNode(const IComputationNode::TPtr& computationNode) {
  450. computationNode->RegisterDependencies();
  451. PatternNodes->ComputationNodesList.push_back(computationNode);
  452. }
  453. void AddNode(TNode& node, const IComputationNode::TPtr& computationNode) {
  454. PushBackNode(computationNode);
  455. node.SetCookie((ui64)computationNode.Get());
  456. }
  457. private:
  458. const TTypeEnvironment& Env;
  459. NUdf::ITypeInfoHelper::TPtr TypeInfoHelper;
  460. NUdf::ICountersProvider* CountersProvider;
  461. const NUdf::ISecureParamsProvider* SecureParamsProvider;
  462. const TComputationNodeFactory Factory;
  463. const IFunctionRegistry& FunctionRegistry;
  464. TIntrusivePtr<TMemoryUsageInfo> MemInfo;
  465. THolder<TNodeFactory> NodeFactory;
  466. NUdf::EValidateMode ValidateMode;
  467. NUdf::EValidatePolicy ValidatePolicy;
  468. EGraphPerProcess GraphPerProcess;
  469. TPatternNodes::TPtr PatternNodes;
  470. const bool ExternalAlloc; // obsolete, will be removed after YQL-13977
  471. };
  472. class TComputationGraph final : public IComputationGraph {
  473. public:
  474. TComputationGraph(TPatternNodes::TPtr& patternNodes, const TComputationOptsFull& compOpts, NYql::NCodegen::ICodegen::TSharedPtr codegen)
  475. : PatternNodes(patternNodes)
  476. , MemInfo(MakeIntrusive<TMemoryUsageInfo>("ComputationGraph"))
  477. , CompOpts(compOpts)
  478. , Codegen(std::move(codegen))
  479. {
  480. #ifndef NDEBUG
  481. CompOpts.AllocState.ActiveMemInfo.emplace(MemInfo.Get(), MemInfo);
  482. #endif
  483. HolderFactory = MakeHolder<THolderFactory>(CompOpts.AllocState, *MemInfo, patternNodes->HolderFactory->GetFunctionRegistry());
  484. ValueBuilder = MakeHolder<TDefaultValueBuilder>(*HolderFactory.Get(), compOpts.ValidatePolicy);
  485. ValueBuilder->SetSecureParamsProvider(CompOpts.SecureParamsProvider);
  486. }
  487. ~TComputationGraph() {
  488. auto stats = CompOpts.Stats;
  489. auto& pagePool = HolderFactory->GetPagePool();
  490. MKQL_SET_MAX_STAT(stats, PagePool_PeakAllocated, pagePool.GetPeakAllocated());
  491. MKQL_SET_MAX_STAT(stats, PagePool_PeakUsed, pagePool.GetPeakUsed());
  492. MKQL_ADD_STAT(stats, PagePool_AllocCount, pagePool.GetAllocCount());
  493. MKQL_ADD_STAT(stats, PagePool_PageAllocCount, pagePool.GetPageAllocCount());
  494. MKQL_ADD_STAT(stats, PagePool_PageHitCount, pagePool.GetPageHitCount());
  495. MKQL_ADD_STAT(stats, PagePool_PageMissCount, pagePool.GetPageMissCount());
  496. MKQL_ADD_STAT(stats, PagePool_OffloadedAllocCount, pagePool.GetOffloadedAllocCount());
  497. MKQL_ADD_STAT(stats, PagePool_OffloadedBytes, pagePool.GetOffloadedBytes());
  498. }
  499. void Prepare() override {
  500. if (!IsPrepared) {
  501. Ctx.Reset(new TComputationContext(*HolderFactory,
  502. ValueBuilder.Get(),
  503. CompOpts,
  504. PatternNodes->GetMutables(),
  505. *NYql::NUdf::GetYqlMemoryPool()));
  506. Ctx->ExecuteLLVM = Codegen.get() != nullptr;
  507. ValueBuilder->SetCalleePositionHolder(Ctx->CalleePosition);
  508. for (auto& node : PatternNodes->GetNodes()) {
  509. node->InitNode(*Ctx);
  510. }
  511. IsPrepared = true;
  512. }
  513. }
  514. TComputationContext& GetContext() override {
  515. Prepare();
  516. return *Ctx;
  517. }
  518. NUdf::TUnboxedValue GetValue() override {
  519. Prepare();
  520. return PatternNodes->GetRoot()->GetValue(*Ctx);
  521. }
  522. IComputationExternalNode* GetEntryPoint(size_t index, bool require) override {
  523. Prepare();
  524. return PatternNodes->GetEntryPoint(index, require);
  525. }
  526. const TArrowKernelsTopology* GetKernelsTopology() override {
  527. Prepare();
  528. if (!KernelsTopology.has_value()) {
  529. CalculateKernelTopology(*Ctx);
  530. }
  531. return &KernelsTopology.value();
  532. }
  533. void CalculateKernelTopology(TComputationContext& ctx) {
  534. KernelsTopology.emplace();
  535. KernelsTopology->InputArgsCount = PatternNodes->GetEntryPointsCount();
  536. std::stack<const IComputationNode*> stack;
  537. struct TNodeState {
  538. bool Visited;
  539. ui32 Index;
  540. };
  541. std::unordered_map<const IComputationNode*, TNodeState> deps;
  542. for (ui32 i = 0; i < KernelsTopology->InputArgsCount; ++i) {
  543. auto entryPoint = PatternNodes->GetEntryPoint(i, false);
  544. if (!entryPoint) {
  545. continue;
  546. }
  547. deps.emplace(entryPoint, TNodeState{ true, i});
  548. }
  549. stack.push(PatternNodes->GetRoot());
  550. while (!stack.empty()) {
  551. auto node = stack.top();
  552. auto [iter, inserted] = deps.emplace(node, TNodeState{ false, 0 });
  553. auto extNode = dynamic_cast<const IComputationExternalNode*>(node);
  554. if (extNode) {
  555. MKQL_ENSURE(!inserted, "Unexpected external node");
  556. stack.pop();
  557. continue;
  558. }
  559. auto kernelNode = node->PrepareArrowKernelComputationNode(ctx);
  560. MKQL_ENSURE(kernelNode, "No kernel for node: " << node->DebugString());
  561. auto argsCount = kernelNode->GetArgsDesc().size();
  562. if (!iter->second.Visited) {
  563. for (ui32 j = 0; j < argsCount; ++j) {
  564. stack.push(kernelNode->GetArgument(j));
  565. }
  566. iter->second.Visited = true;
  567. } else {
  568. iter->second.Index = KernelsTopology->InputArgsCount + KernelsTopology->Items.size();
  569. KernelsTopology->Items.emplace_back();
  570. auto& i = KernelsTopology->Items.back();
  571. i.Inputs.reserve(argsCount);
  572. for (ui32 j = 0; j < argsCount; ++j) {
  573. auto it = deps.find(kernelNode->GetArgument(j));
  574. MKQL_ENSURE(it != deps.end(), "Missing argument");
  575. i.Inputs.emplace_back(it->second.Index);
  576. }
  577. i.Node = std::move(kernelNode);
  578. stack.pop();
  579. }
  580. }
  581. }
  582. void Invalidate() override {
  583. std::fill_n(Ctx->MutableValues.get(), PatternNodes->GetMutables().CurValueIndex, NUdf::TUnboxedValue(NUdf::TUnboxedValuePod::Invalid()));
  584. }
  585. void InvalidateCaches() override {
  586. for (const auto cachedIndex : Ctx->Mutables.CachedValues) {
  587. Ctx->MutableValues[cachedIndex] = NUdf::TUnboxedValuePod::Invalid();
  588. }
  589. }
  590. const TComputationNodePtrDeque& GetNodes() const override {
  591. return PatternNodes->GetNodes();
  592. }
  593. TMemoryUsageInfo& GetMemInfo() const override {
  594. return *MemInfo;
  595. }
  596. const THolderFactory& GetHolderFactory() const override {
  597. return *HolderFactory;
  598. }
  599. ITerminator* GetTerminator() const override {
  600. return ValueBuilder.Get();
  601. }
  602. bool SetExecuteLLVM(bool value) override {
  603. const bool old = Ctx->ExecuteLLVM;
  604. Ctx->ExecuteLLVM = value;
  605. return old;
  606. }
  607. TString SaveGraphState() override {
  608. Prepare();
  609. TString result;
  610. for (ui32 i : PatternNodes->GetMutables().SerializableValues) {
  611. const NUdf::TUnboxedValuePod& mutableValue = Ctx->MutableValues[i];
  612. if (mutableValue.IsInvalid()) {
  613. WriteUi64(result, std::numeric_limits<ui64>::max()); // -1.
  614. } else if (mutableValue.IsBoxed()) {
  615. TList<TString> taskState;
  616. size_t taskStateSize = 0;
  617. auto saveList = [&](auto& list) {
  618. auto listIt = list.GetListIterator();
  619. NUdf::TUnboxedValue str;
  620. while (listIt.Next(str)) {
  621. const TStringBuf strRef = str.AsStringRef();
  622. taskStateSize += strRef.Size();
  623. taskState.push_back({});
  624. taskState.back().AppendNoAlias(strRef.Data(), strRef.Size());
  625. }
  626. };
  627. bool isList = mutableValue.HasListItems();
  628. NUdf::TUnboxedValue list;
  629. if (isList) { // No load was done during previous runs.
  630. saveList(mutableValue);
  631. } else {
  632. NUdf::TUnboxedValue saved = mutableValue.Save();
  633. if (saved.IsString() || saved.IsEmbedded()) { // Old version.
  634. const TStringBuf savedBuf = saved.AsStringRef();
  635. taskState.push_back({});
  636. taskState.back().AppendNoAlias(savedBuf.Data(), savedBuf.Size());
  637. taskStateSize = savedBuf.Size();
  638. } else {
  639. saveList(saved);
  640. }
  641. }
  642. WriteUi64(result, taskStateSize);
  643. for (auto it = taskState.begin(); it != taskState.end();) {
  644. result.AppendNoAlias(it->data(), it->size());
  645. it = taskState.erase(it);
  646. }
  647. } else { // No load was done during previous runs (if any).
  648. MKQL_ENSURE(mutableValue.HasValue() && (mutableValue.IsString() || mutableValue.IsEmbedded()), "State is expected to have data or invalid value");
  649. const NUdf::TStringRef savedRef = mutableValue.AsStringRef();
  650. WriteUi64(result, savedRef.Size());
  651. result.AppendNoAlias(savedRef.Data(), savedRef.Size());
  652. }
  653. }
  654. return result;
  655. }
  656. void LoadGraphState(TStringBuf state) override {
  657. Prepare();
  658. for (ui32 i : PatternNodes->GetMutables().SerializableValues) {
  659. if (const ui64 size = ReadUi64(state); size != std::numeric_limits<ui64>::max()) {
  660. MKQL_ENSURE(state.Size() >= size, "Serialized state is corrupted - buffer is too short (" << state.Size() << ") for specified size: " << size);
  661. TStringBuf savedRef(state.Data(), size);
  662. Ctx->MutableValues[i] = NKikimr::NMiniKQL::TOutputSerializer::MakeArray(*Ctx, savedRef);
  663. state.Skip(size);
  664. } // else leave it Invalid()
  665. }
  666. MKQL_ENSURE(state.Empty(), "Serialized state is corrupted - extra bytes left: " << state.Size());
  667. }
  668. private:
  669. const TPatternNodes::TPtr PatternNodes;
  670. const TIntrusivePtr<TMemoryUsageInfo> MemInfo;
  671. THolder<THolderFactory> HolderFactory;
  672. THolder<TDefaultValueBuilder> ValueBuilder;
  673. THolder<TComputationContext> Ctx;
  674. TComputationOptsFull CompOpts;
  675. NYql::NCodegen::ICodegen::TSharedPtr Codegen;
  676. bool IsPrepared = false;
  677. std::optional<TArrowKernelsTopology> KernelsTopology;
  678. };
  679. class TComputationPatternImpl final : public IComputationPattern {
  680. public:
  681. TComputationPatternImpl(THolder<TComputationGraphBuildingVisitor>&& builder, const TComputationPatternOpts& opts)
  682. #if defined(MKQL_DISABLE_CODEGEN)
  683. : Codegen()
  684. #elif defined(MKQL_FORCE_USE_CODEGEN)
  685. : Codegen(NYql::NCodegen::ICodegen::MakeShared(NYql::NCodegen::ETarget::Native))
  686. #else
  687. : Codegen((NYql::NCodegen::ICodegen::IsCodegenAvailable() && opts.OptLLVM != "OFF") || GetEnv(TString("MKQL_FORCE_USE_LLVM")) ? NYql::NCodegen::ICodegen::MakeShared(NYql::NCodegen::ETarget::Native) : NYql::NCodegen::ICodegen::TPtr())
  688. #endif
  689. {
  690. /// TODO: Enable JIT for AARCH64/Win
  691. #if defined(__aarch64__) || defined(_win_)
  692. Codegen = {};
  693. #endif
  694. const auto& nodes = builder->GetNodes();
  695. for (const auto& node : nodes)
  696. node->PrepareStageOne();
  697. for (const auto& node : nodes)
  698. node->PrepareStageTwo();
  699. MKQL_ADD_STAT(opts.Stats, Mkql_TotalNodes, nodes.size());
  700. PatternNodes = builder->GetPatternNodes();
  701. if (Codegen) {
  702. Compile(opts.OptLLVM, opts.Stats);
  703. }
  704. }
  705. ~TComputationPatternImpl() {
  706. if (TypeEnv) {
  707. auto guard = TypeEnv->BindAllocator();
  708. PatternNodes.Reset();
  709. }
  710. }
  711. void Compile(TString optLLVM, IStatsRegistry* stats) {
  712. if (IsPatternCompiled.load())
  713. return;
  714. #ifndef MKQL_DISABLE_CODEGEN
  715. if (!Codegen)
  716. Codegen = NYql::NCodegen::ICodegen::Make(NYql::NCodegen::ETarget::Native);
  717. const auto& nodes = PatternNodes->GetNodes();
  718. TStatTimer timerFull(CodeGen_FullTime);
  719. timerFull.Acquire();
  720. {
  721. TStatTimer timerGen(CodeGen_GenerateTime);
  722. timerGen.Acquire();
  723. for (auto it = nodes.crbegin(); nodes.crend() != it; ++it) {
  724. if (const auto codegen = dynamic_cast<ICodegeneratorRootNode*>(it->Get())) {
  725. codegen->GenerateFunctions(*Codegen);
  726. }
  727. }
  728. timerGen.Release();
  729. timerGen.Report(stats);
  730. }
  731. if (optLLVM.Contains("--dump-generated")) {
  732. Cerr << "############### Begin generated module ###############" << Endl;
  733. Codegen->GetModule().print(llvm::errs(), nullptr);
  734. Cerr << "################ End generated module ################" << Endl;
  735. }
  736. TStatTimer timerComp(CodeGen_CompileTime);
  737. timerComp.Acquire();
  738. NYql::NCodegen::TCodegenStats codegenStats;
  739. Codegen->GetStats(codegenStats);
  740. MKQL_ADD_STAT(stats, CodeGen_TotalFunctions, codegenStats.TotalFunctions);
  741. MKQL_ADD_STAT(stats, CodeGen_TotalInstructions, codegenStats.TotalInstructions);
  742. MKQL_SET_MAX_STAT(stats, CodeGen_MaxFunctionInstructions, codegenStats.MaxFunctionInstructions);
  743. if (optLLVM.Contains("--dump-stats")) {
  744. Cerr << "TotalFunctions: " << codegenStats.TotalFunctions << Endl;
  745. Cerr << "TotalInstructions: " << codegenStats.TotalInstructions << Endl;
  746. Cerr << "MaxFunctionInstructions: " << codegenStats.MaxFunctionInstructions << Endl;
  747. }
  748. if (optLLVM.Contains("--dump-perf-map")) {
  749. Codegen->TogglePerfJITEventListener();
  750. }
  751. if (codegenStats.TotalFunctions >= TotalFunctionsLimit ||
  752. codegenStats.TotalInstructions >= TotalInstructionsLimit ||
  753. codegenStats.MaxFunctionInstructions >= MaxFunctionInstructionsLimit) {
  754. Codegen.reset();
  755. } else {
  756. Codegen->Verify();
  757. Codegen->Compile(GetCompileOptions(optLLVM), &CompileStats);
  758. MKQL_ADD_STAT(stats, CodeGen_FunctionPassTime, CompileStats.FunctionPassTime);
  759. MKQL_ADD_STAT(stats, CodeGen_ModulePassTime, CompileStats.ModulePassTime);
  760. MKQL_ADD_STAT(stats, CodeGen_FinalizeTime, CompileStats.FinalizeTime);
  761. }
  762. timerComp.Release();
  763. timerComp.Report(stats);
  764. if (Codegen) {
  765. if (optLLVM.Contains("--dump-compiled")) {
  766. Cerr << "############### Begin compiled module ###############" << Endl;
  767. Codegen->GetModule().print(llvm::errs(), nullptr);
  768. Cerr << "################ End compiled module ################" << Endl;
  769. }
  770. if (optLLVM.Contains("--asm-compiled")) {
  771. Cerr << "############### Begin compiled asm ###############" << Endl;
  772. Codegen->ShowGeneratedFunctions(&Cerr);
  773. Cerr << "################ End compiled asm ################" << Endl;
  774. }
  775. ui64 count = 0U;
  776. for (const auto& node : nodes) {
  777. if (const auto codegen = dynamic_cast<ICodegeneratorRootNode*>(node.Get())) {
  778. codegen->FinalizeFunctions(*Codegen);
  779. ++count;
  780. }
  781. }
  782. if (count) {
  783. MKQL_ADD_STAT(stats, Mkql_CodegenFunctions, count);
  784. }
  785. }
  786. timerFull.Release();
  787. timerFull.Report(stats);
  788. #else
  789. Y_UNUSED(optLLVM);
  790. Y_UNUSED(stats);
  791. #endif
  792. IsPatternCompiled.store(true);
  793. }
  794. bool IsCompiled() const {
  795. return IsPatternCompiled.load();
  796. }
  797. size_t CompiledCodeSize() const {
  798. return CompileStats.TotalObjectSize;
  799. }
  800. void RemoveCompiledCode() {
  801. IsPatternCompiled.store(false);
  802. CompileStats = {};
  803. Codegen.reset();
  804. }
  805. THolder<IComputationGraph> Clone(const TComputationOptsFull& compOpts) {
  806. if (IsPatternCompiled.load()) {
  807. return MakeHolder<TComputationGraph>(PatternNodes, compOpts, Codegen);
  808. }
  809. return MakeHolder<TComputationGraph>(PatternNodes, compOpts, nullptr);
  810. }
  811. bool GetSuitableForCache() const {
  812. return PatternNodes->GetSuitableForCache();
  813. }
  814. private:
  815. TStringBuf GetCompileOptions(const TString& s) {
  816. const TString flag = "--compile-options";
  817. auto lpos = s.rfind(flag);
  818. if (lpos == TString::npos)
  819. return TStringBuf();
  820. lpos += flag.size();
  821. auto rpos = s.find(" --", lpos);
  822. if (rpos == TString::npos)
  823. return TStringBuf(s, lpos);
  824. else
  825. return TStringBuf(s, lpos, rpos - lpos);
  826. };
  827. TTypeEnvironment* TypeEnv = nullptr;
  828. TPatternNodes::TPtr PatternNodes;
  829. NYql::NCodegen::ICodegen::TSharedPtr Codegen;
  830. std::atomic<bool> IsPatternCompiled = false;
  831. NYql::NCodegen::TCompileStats CompileStats;
  832. };
  833. TIntrusivePtr<TComputationPatternImpl> MakeComputationPatternImpl(TExploringNodeVisitor& explorer, const TRuntimeNode& root,
  834. const std::vector<TNode*>& entryPoints, const TComputationPatternOpts& opts) {
  835. TDependencyScanVisitor depScanner;
  836. depScanner.Walk(root.GetNode(), opts.Env);
  837. auto builder = MakeHolder<TComputationGraphBuildingVisitor>(opts);
  838. const TBindTerminator bind(&builder->GetPatternNodes()->GetTerminator());
  839. for (const auto& node : explorer.GetNodes()) {
  840. Y_ABORT_UNLESS(node->GetCookie() <= IS_NODE_REACHABLE, "TNode graph should not be reused");
  841. if (node->GetCookie() == IS_NODE_REACHABLE) {
  842. node->Accept(*builder);
  843. }
  844. }
  845. const auto rootNode = builder->GetComputationNode(root.GetNode());
  846. TComputationExternalNodePtrVector runtime2ComputationEntryPoints;
  847. runtime2ComputationEntryPoints.resize(entryPoints.size(), nullptr);
  848. std::unordered_map<TNode*, std::vector<ui32>> entryPointIndex;
  849. for (ui32 i = 0; i < entryPoints.size(); ++i) {
  850. entryPointIndex[entryPoints[i]].emplace_back(i);
  851. }
  852. for (const auto& node : explorer.GetNodes()) {
  853. auto it = entryPointIndex.find(node);
  854. if (it == entryPointIndex.cend()) {
  855. continue;
  856. }
  857. auto compNode = dynamic_cast<IComputationExternalNode*>(builder->GetComputationNode(node));
  858. for (auto index : it->second) {
  859. runtime2ComputationEntryPoints[index] = compNode;
  860. }
  861. }
  862. for (const auto& node : explorer.GetNodes()) {
  863. node->SetCookie(0);
  864. }
  865. builder->PreserveRoot(rootNode);
  866. builder->PreserveEntryPoints(std::move(runtime2ComputationEntryPoints));
  867. return MakeIntrusive<TComputationPatternImpl>(std::move(builder), opts);
  868. }
  869. } // namespace
  870. IComputationPattern::TPtr MakeComputationPattern(TExploringNodeVisitor& explorer, const TRuntimeNode& root,
  871. const std::vector<TNode*>& entryPoints, const TComputationPatternOpts& opts) {
  872. return MakeComputationPatternImpl(explorer, root, entryPoints, opts);
  873. }
  874. } // namespace NMiniKQL
  875. } // namespace NKikimr