mkql_computation_node_graph.cpp 36 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043
  1. #include "mkql_computation_node_holders.h"
  2. #include "mkql_computation_node_holders_codegen.h"
  3. #include "mkql_value_builder.h"
  4. #include "mkql_computation_node_codegen.h" // Y_IGNORE
  5. #include <yql/essentials/public/udf/arrow/memory_pool.h>
  6. #include <yql/essentials/minikql/computation/mkql_computation_pattern_cache.h>
  7. #include <yql/essentials/minikql/comp_nodes/mkql_saveload.h>
  8. #include <yql/essentials/minikql/mkql_type_builder.h>
  9. #include <yql/essentials/minikql/mkql_terminator.h>
  10. #include <yql/essentials/minikql/mkql_string_util.h>
  11. #include <util/system/env.h>
  12. #include <util/system/mutex.h>
  13. #include <util/digest/city.h>
  14. #ifndef MKQL_DISABLE_CODEGEN
  15. #include <llvm/Support/raw_ostream.h> // Y_IGNORE
  16. #endif
  17. namespace NKikimr {
  18. namespace NMiniKQL {
  19. using namespace NDetail;
  20. namespace {
  21. #ifndef MKQL_DISABLE_CODEGEN
  22. constexpr ui64 TotalFunctionsLimit = 1000;
  23. constexpr ui64 TotalInstructionsLimit = 100000;
  24. constexpr ui64 MaxFunctionInstructionsLimit = 50000;
  25. #endif
  26. const ui64 IS_NODE_REACHABLE = 1;
  27. const static TStatKey PagePool_PeakAllocated("PagePool_PeakAllocated", false);
  28. const static TStatKey PagePool_PeakUsed("PagePool_PeakUsed", false);
  29. const static TStatKey PagePool_AllocCount("PagePool_AllocCount", true);
  30. const static TStatKey PagePool_PageAllocCount("PagePool_PageAllocCount", true);
  31. const static TStatKey PagePool_PageHitCount("PagePool_PageHitCount", true);
  32. const static TStatKey PagePool_PageMissCount("PagePool_PageMissCount", true);
  33. const static TStatKey PagePool_OffloadedAllocCount("PagePool_OffloadedAllocCount", true);
  34. const static TStatKey PagePool_OffloadedBytes("PagePool_OffloadedBytes", true);
  35. const static TStatKey CodeGen_FullTime("CodeGen_FullTime", true);
  36. const static TStatKey CodeGen_GenerateTime("CodeGen_GenerateTime", true);
  37. const static TStatKey CodeGen_CompileTime("CodeGen_CompileTime", true);
  38. const static TStatKey CodeGen_TotalFunctions("CodeGen_TotalFunctions", true);
  39. const static TStatKey CodeGen_TotalInstructions("CodeGen_TotalInstructions", true);
  40. const static TStatKey CodeGen_MaxFunctionInstructions("CodeGen_MaxFunctionInstructions", false);
  41. const static TStatKey CodeGen_FunctionPassTime("CodeGen_FunctionPassTime", true);
  42. const static TStatKey CodeGen_ModulePassTime("CodeGen_ModulePassTime", true);
  43. const static TStatKey CodeGen_FinalizeTime("CodeGen_FinalizeTime", true);
  44. const static TStatKey Mkql_TotalNodes("Mkql_TotalNodes", true);
  45. const static TStatKey Mkql_CodegenFunctions("Mkql_CodegenFunctions", true);
  46. class TDependencyScanVisitor : public TEmptyNodeVisitor {
  47. public:
  48. void Walk(TNode* root, const TTypeEnvironment& env) {
  49. Stack = &env.GetNodeStack();
  50. Stack->clear();
  51. Stack->push_back(root);
  52. while (!Stack->empty()) {
  53. auto top = Stack->back();
  54. Stack->pop_back();
  55. if (top->GetCookie() != IS_NODE_REACHABLE) {
  56. top->SetCookie(IS_NODE_REACHABLE);
  57. top->Accept(*this);
  58. }
  59. }
  60. Stack = nullptr;
  61. }
  62. private:
  63. using TEmptyNodeVisitor::Visit;
  64. void Visit(TStructLiteral& node) override {
  65. for (ui32 i = 0; i < node.GetValuesCount(); ++i) {
  66. AddNode(node.GetValue(i).GetNode());
  67. }
  68. }
  69. void Visit(TListLiteral& node) override {
  70. for (ui32 i = 0; i < node.GetItemsCount(); ++i) {
  71. AddNode(node.GetItems()[i].GetNode());
  72. }
  73. }
  74. void Visit(TOptionalLiteral& node) override {
  75. if (node.HasItem()) {
  76. AddNode(node.GetItem().GetNode());
  77. }
  78. }
  79. void Visit(TDictLiteral& node) override {
  80. for (ui32 i = 0; i < node.GetItemsCount(); ++i) {
  81. AddNode(node.GetItem(i).first.GetNode());
  82. AddNode(node.GetItem(i).second.GetNode());
  83. }
  84. }
  85. void Visit(TCallable& node) override {
  86. if (node.HasResult()) {
  87. AddNode(node.GetResult().GetNode());
  88. } else {
  89. for (ui32 i = 0; i < node.GetInputsCount(); ++i) {
  90. AddNode(node.GetInput(i).GetNode());
  91. }
  92. }
  93. }
  94. void Visit(TAny& node) override {
  95. if (node.HasItem()) {
  96. AddNode(node.GetItem().GetNode());
  97. }
  98. }
  99. void Visit(TTupleLiteral& node) override {
  100. for (ui32 i = 0; i < node.GetValuesCount(); ++i) {
  101. AddNode(node.GetValue(i).GetNode());
  102. }
  103. }
  104. void Visit(TVariantLiteral& node) override {
  105. AddNode(node.GetItem().GetNode());
  106. }
  107. void AddNode(TNode* node) {
  108. if (node->GetCookie() != IS_NODE_REACHABLE) {
  109. Stack->push_back(node);
  110. }
  111. }
  112. std::vector<TNode*>* Stack = nullptr;
  113. };
  114. class TPatternNodes: public TAtomicRefCount<TPatternNodes> {
  115. public:
  116. typedef TIntrusivePtr<TPatternNodes> TPtr;
  117. TPatternNodes(TAllocState& allocState)
  118. : AllocState(allocState)
  119. , MemInfo(MakeIntrusive<TMemoryUsageInfo>("ComputationPatternNodes"))
  120. {
  121. #ifndef NDEBUG
  122. AllocState.ActiveMemInfo.emplace(MemInfo.Get(), MemInfo);
  123. #else
  124. Y_UNUSED(AllocState);
  125. #endif
  126. }
  127. ~TPatternNodes()
  128. {
  129. for (auto it = ComputationNodesList.rbegin(); it != ComputationNodesList.rend(); ++it) {
  130. *it = nullptr;
  131. }
  132. ComputationNodesList.clear();
  133. if (!UncaughtException()) {
  134. #ifndef NDEBUG
  135. AllocState.ActiveMemInfo.erase(MemInfo.Get());
  136. #endif
  137. }
  138. }
  139. const TComputationMutables& GetMutables() const {
  140. return Mutables;
  141. }
  142. const TComputationNodePtrDeque& GetNodes() const {
  143. return ComputationNodesList;
  144. }
  145. IComputationNode* GetComputationNode(TNode* node, bool pop = false, bool require = true) {
  146. const auto cookie = node->GetCookie();
  147. const auto result = reinterpret_cast<IComputationNode*>(cookie);
  148. if (cookie <= IS_NODE_REACHABLE) {
  149. MKQL_ENSURE(!require, "Computation graph builder, node not found, type:"
  150. << node->GetType()->GetKindAsStr());
  151. return result;
  152. }
  153. if (pop) {
  154. node->SetCookie(0);
  155. }
  156. return result;
  157. }
  158. IComputationExternalNode* GetEntryPoint(size_t index, bool require) {
  159. MKQL_ENSURE(index < Runtime2ComputationEntryPoints.size() && (!require || Runtime2ComputationEntryPoints[index]),
  160. "Pattern nodes can not get computation node by index: " << index << ", require: " << require
  161. << ", Runtime2ComputationEntryPoints size: " << Runtime2ComputationEntryPoints.size());
  162. return Runtime2ComputationEntryPoints[index];
  163. }
  164. IComputationNode* GetRoot() {
  165. return RootNode;
  166. }
  167. bool GetSuitableForCache() const {
  168. return SuitableForCache;
  169. }
  170. size_t GetEntryPointsCount() const {
  171. return Runtime2ComputationEntryPoints.size();
  172. }
  173. private:
  174. friend class TComputationGraphBuildingVisitor;
  175. friend class TComputationGraph;
  176. TAllocState& AllocState;
  177. TIntrusivePtr<TMemoryUsageInfo> MemInfo;
  178. THolder<THolderFactory> HolderFactory;
  179. THolder<TDefaultValueBuilder> ValueBuilder;
  180. TComputationMutables Mutables;
  181. TComputationNodePtrDeque ComputationNodesList;
  182. IComputationNode* RootNode = nullptr;
  183. TComputationExternalNodePtrVector Runtime2ComputationEntryPoints;
  184. TComputationNodeOnNodeMap ElementsCache;
  185. bool SuitableForCache = true;
  186. };
  187. class TComputationGraphBuildingVisitor:
  188. public INodeVisitor,
  189. private TNonCopyable
  190. {
  191. public:
  192. TComputationGraphBuildingVisitor(const TComputationPatternOpts& opts)
  193. : Env(opts.Env)
  194. , TypeInfoHelper(new TTypeInfoHelper())
  195. , CountersProvider(opts.CountersProvider)
  196. , SecureParamsProvider(opts.SecureParamsProvider)
  197. , Factory(opts.Factory)
  198. , FunctionRegistry(*opts.FunctionRegistry)
  199. , ValidateMode(opts.ValidateMode)
  200. , ValidatePolicy(opts.ValidatePolicy)
  201. , GraphPerProcess(opts.GraphPerProcess)
  202. , PatternNodes(MakeIntrusive<TPatternNodes>(opts.AllocState))
  203. , ExternalAlloc(opts.PatternEnv)
  204. {
  205. PatternNodes->HolderFactory = MakeHolder<THolderFactory>(opts.AllocState, *PatternNodes->MemInfo, &FunctionRegistry);
  206. PatternNodes->ValueBuilder = MakeHolder<TDefaultValueBuilder>(*PatternNodes->HolderFactory, ValidatePolicy);
  207. PatternNodes->ValueBuilder->SetSecureParamsProvider(opts.SecureParamsProvider);
  208. NodeFactory = MakeHolder<TNodeFactory>(*PatternNodes->MemInfo, PatternNodes->Mutables);
  209. }
  210. ~TComputationGraphBuildingVisitor() {
  211. auto g = Env.BindAllocator();
  212. NodeFactory.Reset();
  213. PatternNodes.Reset();
  214. }
  215. const TTypeEnvironment& GetTypeEnvironment() const {
  216. return Env;
  217. }
  218. const IFunctionRegistry& GetFunctionRegistry() const {
  219. return FunctionRegistry;
  220. }
  221. private:
  222. template <typename T>
  223. void VisitType(T& node) {
  224. AddNode(node, NodeFactory->CreateTypeNode(&node));
  225. }
  226. void Visit(TTypeType& node) override {
  227. VisitType<TTypeType>(node);
  228. }
  229. void Visit(TVoidType& node) override {
  230. VisitType<TVoidType>(node);
  231. }
  232. void Visit(TNullType& node) override {
  233. VisitType<TNullType>(node);
  234. }
  235. void Visit(TEmptyListType& node) override {
  236. VisitType<TEmptyListType>(node);
  237. }
  238. void Visit(TEmptyDictType& node) override {
  239. VisitType<TEmptyDictType>(node);
  240. }
  241. void Visit(TDataType& node) override {
  242. VisitType<TDataType>(node);
  243. }
  244. void Visit(TPgType& node) override {
  245. VisitType<TPgType>(node);
  246. }
  247. void Visit(TStructType& node) override {
  248. VisitType<TStructType>(node);
  249. }
  250. void Visit(TListType& node) override {
  251. VisitType<TListType>(node);
  252. }
  253. void Visit(TStreamType& node) override {
  254. VisitType<TStreamType>(node);
  255. }
  256. void Visit(TFlowType& node) override {
  257. VisitType<TFlowType>(node);
  258. }
  259. void Visit(TBlockType& node) override {
  260. VisitType<TBlockType>(node);
  261. }
  262. void Visit(TMultiType& node) override {
  263. VisitType<TMultiType>(node);
  264. }
  265. void Visit(TTaggedType& node) override {
  266. VisitType<TTaggedType>(node);
  267. }
  268. void Visit(TOptionalType& node) override {
  269. VisitType<TOptionalType>(node);
  270. }
  271. void Visit(TDictType& node) override {
  272. VisitType<TDictType>(node);
  273. }
  274. void Visit(TCallableType& node) override {
  275. VisitType<TCallableType>(node);
  276. }
  277. void Visit(TAnyType& node) override {
  278. VisitType<TAnyType>(node);
  279. }
  280. void Visit(TTupleType& node) override {
  281. VisitType<TTupleType>(node);
  282. }
  283. void Visit(TResourceType& node) override {
  284. VisitType<TResourceType>(node);
  285. }
  286. void Visit(TVariantType& node) override {
  287. VisitType<TVariantType>(node);
  288. }
  289. void Visit(TVoid& node) override {
  290. AddNode(node, NodeFactory->CreateImmutableNode(NUdf::TUnboxedValue::Void()));
  291. }
  292. void Visit(TNull& node) override {
  293. AddNode(node, NodeFactory->CreateImmutableNode(NUdf::TUnboxedValue()));
  294. }
  295. void Visit(TEmptyList& node) override {
  296. AddNode(node, NodeFactory->CreateImmutableNode(PatternNodes->HolderFactory->GetEmptyContainerLazy()));
  297. }
  298. void Visit(TEmptyDict& node) override {
  299. AddNode(node, NodeFactory->CreateImmutableNode(PatternNodes->HolderFactory->GetEmptyContainerLazy()));
  300. }
  301. void Visit(TDataLiteral& node) override {
  302. auto value = node.AsValue();
  303. NUdf::TDataTypeId typeId = node.GetType()->GetSchemeType();
  304. if (typeId != 0x101) { // TODO remove
  305. const auto slot = NUdf::GetDataSlot(typeId);
  306. MKQL_ENSURE(IsValidValue(slot, value),
  307. "Bad data literal for type: " << NUdf::GetDataTypeInfo(slot).Name << ", " << value);
  308. }
  309. NUdf::TUnboxedValue externalValue;
  310. if (ExternalAlloc) {
  311. if (value.IsString()) {
  312. externalValue = MakeString(value.AsStringRef());
  313. }
  314. }
  315. if (!externalValue) {
  316. externalValue = std::move(value);
  317. }
  318. AddNode(node, NodeFactory->CreateImmutableNode(std::move(externalValue)));
  319. }
  320. void Visit(TStructLiteral& node) override {
  321. TComputationNodePtrVector values;
  322. values.reserve(node.GetValuesCount());
  323. for (ui32 i = 0, e = node.GetValuesCount(); i < e; ++i) {
  324. values.push_back(GetComputationNode(node.GetValue(i).GetNode()));
  325. }
  326. AddNode(node, NodeFactory->CreateArrayNode(std::move(values)));
  327. }
  328. void Visit(TListLiteral& node) override {
  329. TComputationNodePtrVector items;
  330. items.reserve(node.GetItemsCount());
  331. for (ui32 i = 0; i < node.GetItemsCount(); ++i) {
  332. items.push_back(GetComputationNode(node.GetItems()[i].GetNode()));
  333. }
  334. AddNode(node, NodeFactory->CreateArrayNode(std::move(items)));
  335. }
  336. void Visit(TOptionalLiteral& node) override {
  337. auto item = node.HasItem() ? GetComputationNode(node.GetItem().GetNode()) : nullptr;
  338. AddNode(node, NodeFactory->CreateOptionalNode(item));
  339. }
  340. void Visit(TDictLiteral& node) override {
  341. auto keyType = node.GetType()->GetKeyType();
  342. TKeyTypes types;
  343. bool isTuple;
  344. bool encoded;
  345. bool useIHash;
  346. GetDictionaryKeyTypes(keyType, types, isTuple, encoded, useIHash);
  347. std::vector<std::pair<IComputationNode*, IComputationNode*>> items;
  348. items.reserve(node.GetItemsCount());
  349. for (ui32 i = 0, e = node.GetItemsCount(); i < e; ++i) {
  350. auto item = node.GetItem(i);
  351. items.push_back(std::make_pair(GetComputationNode(item.first.GetNode()), GetComputationNode(item.second.GetNode())));
  352. }
  353. bool isSorted = !CanHash(keyType);
  354. AddNode(node, NodeFactory->CreateDictNode(std::move(items), types, isTuple, encoded ? keyType : nullptr,
  355. useIHash && !isSorted ? MakeHashImpl(keyType) : nullptr,
  356. useIHash ? MakeEquateImpl(keyType) : nullptr,
  357. useIHash && isSorted ? MakeCompareImpl(keyType) : nullptr, isSorted));
  358. }
  359. void Visit(TCallable& node) override {
  360. if (node.HasResult()) {
  361. node.GetResult().GetNode()->Accept(*this);
  362. auto computationNode = PatternNodes->ComputationNodesList.back().Get();
  363. node.SetCookie((ui64)computationNode);
  364. return;
  365. }
  366. if (node.GetType()->GetName() == "Steal") {
  367. return;
  368. }
  369. TNodeLocator nodeLocator = [this](TNode* dependentNode, bool pop) {
  370. return GetComputationNode(dependentNode, pop);
  371. };
  372. TComputationNodeFactoryContext ctx(
  373. nodeLocator,
  374. FunctionRegistry,
  375. Env,
  376. TypeInfoHelper,
  377. CountersProvider,
  378. SecureParamsProvider,
  379. *NodeFactory,
  380. *PatternNodes->HolderFactory,
  381. PatternNodes->ValueBuilder.Get(),
  382. ValidateMode,
  383. ValidatePolicy,
  384. GraphPerProcess,
  385. PatternNodes->Mutables,
  386. PatternNodes->ElementsCache,
  387. std::bind(&TComputationGraphBuildingVisitor::PushBackNode, this, std::placeholders::_1));
  388. const auto computationNode = Factory(node, ctx);
  389. const auto& name = node.GetType()->GetName();
  390. if (name == "KqpWideReadTable" ||
  391. name == "KqpWideReadTableRanges" ||
  392. name == "KqpBlockReadTableRanges" ||
  393. name == "KqpLookupTable" ||
  394. name == "KqpReadTable"
  395. ) {
  396. PatternNodes->SuitableForCache = false;
  397. }
  398. if (!computationNode) {
  399. THROW yexception()
  400. << "Computation graph builder, unsupported function: " << name << " type: " << Factory.target_type().name() ;
  401. }
  402. AddNode(node, computationNode);
  403. }
  404. void Visit(TAny& node) override {
  405. if (!node.HasItem()) {
  406. AddNode(node, NodeFactory->CreateImmutableNode(NUdf::TUnboxedValue::Void()));
  407. } else {
  408. AddNode(node, GetComputationNode(node.GetItem().GetNode()));
  409. }
  410. }
  411. void Visit(TTupleLiteral& node) override {
  412. TComputationNodePtrVector values;
  413. values.reserve(node.GetValuesCount());
  414. for (ui32 i = 0, e = node.GetValuesCount(); i < e; ++i) {
  415. values.push_back(GetComputationNode(node.GetValue(i).GetNode()));
  416. }
  417. AddNode(node, NodeFactory->CreateArrayNode(std::move(values)));
  418. }
  419. void Visit(TVariantLiteral& node) override {
  420. auto item = GetComputationNode(node.GetItem().GetNode());
  421. AddNode(node, NodeFactory->CreateVariantNode(item, node.GetIndex()));
  422. }
  423. public:
  424. IComputationNode* GetComputationNode(TNode* node, bool pop = false, bool require = true) {
  425. return PatternNodes->GetComputationNode(node, pop, require);
  426. }
  427. TMemoryUsageInfo& GetMemInfo() {
  428. return *PatternNodes->MemInfo;
  429. }
  430. const THolderFactory& GetHolderFactory() const {
  431. return *PatternNodes->HolderFactory;
  432. }
  433. TPatternNodes::TPtr GetPatternNodes() {
  434. return PatternNodes;
  435. }
  436. const TComputationNodePtrDeque& GetNodes() const {
  437. return PatternNodes->GetNodes();
  438. }
  439. void PreserveRoot(IComputationNode* rootNode) {
  440. PatternNodes->RootNode = rootNode;
  441. }
  442. void PreserveEntryPoints(TComputationExternalNodePtrVector&& runtime2ComputationEntryPoints) {
  443. PatternNodes->Runtime2ComputationEntryPoints = std::move(runtime2ComputationEntryPoints);
  444. }
  445. private:
  446. void PushBackNode(const IComputationNode::TPtr& computationNode) {
  447. computationNode->RegisterDependencies();
  448. PatternNodes->ComputationNodesList.push_back(computationNode);
  449. }
  450. void AddNode(TNode& node, const IComputationNode::TPtr& computationNode) {
  451. PushBackNode(computationNode);
  452. node.SetCookie((ui64)computationNode.Get());
  453. }
  454. private:
  455. const TTypeEnvironment& Env;
  456. NUdf::ITypeInfoHelper::TPtr TypeInfoHelper;
  457. NUdf::ICountersProvider* CountersProvider;
  458. const NUdf::ISecureParamsProvider* SecureParamsProvider;
  459. const TComputationNodeFactory Factory;
  460. const IFunctionRegistry& FunctionRegistry;
  461. TIntrusivePtr<TMemoryUsageInfo> MemInfo;
  462. THolder<TNodeFactory> NodeFactory;
  463. NUdf::EValidateMode ValidateMode;
  464. NUdf::EValidatePolicy ValidatePolicy;
  465. EGraphPerProcess GraphPerProcess;
  466. TPatternNodes::TPtr PatternNodes;
  467. const bool ExternalAlloc; // obsolete, will be removed after YQL-13977
  468. };
  469. class TComputationGraph final : public IComputationGraph {
  470. public:
  471. TComputationGraph(TPatternNodes::TPtr& patternNodes, const TComputationOptsFull& compOpts, NYql::NCodegen::ICodegen::TSharedPtr codegen)
  472. : PatternNodes(patternNodes)
  473. , MemInfo(MakeIntrusive<TMemoryUsageInfo>("ComputationGraph"))
  474. , CompOpts(compOpts)
  475. , Codegen(std::move(codegen))
  476. {
  477. #ifndef NDEBUG
  478. CompOpts.AllocState.ActiveMemInfo.emplace(MemInfo.Get(), MemInfo);
  479. #endif
  480. HolderFactory = MakeHolder<THolderFactory>(CompOpts.AllocState, *MemInfo, patternNodes->HolderFactory->GetFunctionRegistry());
  481. ValueBuilder = MakeHolder<TDefaultValueBuilder>(*HolderFactory.Get(), compOpts.ValidatePolicy);
  482. ValueBuilder->SetSecureParamsProvider(CompOpts.SecureParamsProvider);
  483. }
  484. ~TComputationGraph() {
  485. auto stats = CompOpts.Stats;
  486. auto& pagePool = HolderFactory->GetPagePool();
  487. MKQL_SET_MAX_STAT(stats, PagePool_PeakAllocated, pagePool.GetPeakAllocated());
  488. MKQL_SET_MAX_STAT(stats, PagePool_PeakUsed, pagePool.GetPeakUsed());
  489. MKQL_ADD_STAT(stats, PagePool_AllocCount, pagePool.GetAllocCount());
  490. MKQL_ADD_STAT(stats, PagePool_PageAllocCount, pagePool.GetPageAllocCount());
  491. MKQL_ADD_STAT(stats, PagePool_PageHitCount, pagePool.GetPageHitCount());
  492. MKQL_ADD_STAT(stats, PagePool_PageMissCount, pagePool.GetPageMissCount());
  493. MKQL_ADD_STAT(stats, PagePool_OffloadedAllocCount, pagePool.GetOffloadedAllocCount());
  494. MKQL_ADD_STAT(stats, PagePool_OffloadedBytes, pagePool.GetOffloadedBytes());
  495. }
  496. void Prepare() override {
  497. if (!IsPrepared) {
  498. Ctx.Reset(new TComputationContext(*HolderFactory,
  499. ValueBuilder.Get(),
  500. CompOpts,
  501. PatternNodes->GetMutables(),
  502. *NYql::NUdf::GetYqlMemoryPool()));
  503. Ctx->ExecuteLLVM = Codegen.get() != nullptr;
  504. ValueBuilder->SetCalleePositionHolder(Ctx->CalleePosition);
  505. for (auto& node : PatternNodes->GetNodes()) {
  506. node->InitNode(*Ctx);
  507. }
  508. IsPrepared = true;
  509. }
  510. }
  511. TComputationContext& GetContext() override {
  512. Prepare();
  513. return *Ctx;
  514. }
  515. NUdf::TUnboxedValue GetValue() override {
  516. Prepare();
  517. return PatternNodes->GetRoot()->GetValue(*Ctx);
  518. }
  519. IComputationExternalNode* GetEntryPoint(size_t index, bool require) override {
  520. Prepare();
  521. return PatternNodes->GetEntryPoint(index, require);
  522. }
  523. const TArrowKernelsTopology* GetKernelsTopology() override {
  524. Prepare();
  525. if (!KernelsTopology.has_value()) {
  526. CalculateKernelTopology(*Ctx);
  527. }
  528. return &KernelsTopology.value();
  529. }
  530. void CalculateKernelTopology(TComputationContext& ctx) {
  531. KernelsTopology.emplace();
  532. KernelsTopology->InputArgsCount = PatternNodes->GetEntryPointsCount();
  533. std::stack<const IComputationNode*> stack;
  534. struct TNodeState {
  535. bool Visited;
  536. ui32 Index;
  537. };
  538. std::unordered_map<const IComputationNode*, TNodeState> deps;
  539. for (ui32 i = 0; i < KernelsTopology->InputArgsCount; ++i) {
  540. auto entryPoint = PatternNodes->GetEntryPoint(i, false);
  541. if (!entryPoint) {
  542. continue;
  543. }
  544. deps.emplace(entryPoint, TNodeState{ true, i});
  545. }
  546. stack.push(PatternNodes->GetRoot());
  547. while (!stack.empty()) {
  548. auto node = stack.top();
  549. auto [iter, inserted] = deps.emplace(node, TNodeState{ false, 0 });
  550. auto extNode = dynamic_cast<const IComputationExternalNode*>(node);
  551. if (extNode) {
  552. MKQL_ENSURE(!inserted, "Unexpected external node");
  553. stack.pop();
  554. continue;
  555. }
  556. auto kernelNode = node->PrepareArrowKernelComputationNode(ctx);
  557. MKQL_ENSURE(kernelNode, "No kernel for node: " << node->DebugString());
  558. auto argsCount = kernelNode->GetArgsDesc().size();
  559. if (!iter->second.Visited) {
  560. for (ui32 j = 0; j < argsCount; ++j) {
  561. stack.push(kernelNode->GetArgument(j));
  562. }
  563. iter->second.Visited = true;
  564. } else {
  565. iter->second.Index = KernelsTopology->InputArgsCount + KernelsTopology->Items.size();
  566. KernelsTopology->Items.emplace_back();
  567. auto& i = KernelsTopology->Items.back();
  568. i.Inputs.reserve(argsCount);
  569. for (ui32 j = 0; j < argsCount; ++j) {
  570. auto it = deps.find(kernelNode->GetArgument(j));
  571. MKQL_ENSURE(it != deps.end(), "Missing argument");
  572. i.Inputs.emplace_back(it->second.Index);
  573. }
  574. i.Node = std::move(kernelNode);
  575. stack.pop();
  576. }
  577. }
  578. }
  579. void Invalidate() override {
  580. std::fill_n(Ctx->MutableValues.get(), PatternNodes->GetMutables().CurValueIndex, NUdf::TUnboxedValue(NUdf::TUnboxedValuePod::Invalid()));
  581. }
  582. void InvalidateCaches() override {
  583. for (const auto cachedIndex : Ctx->Mutables.CachedValues) {
  584. Ctx->MutableValues[cachedIndex] = NUdf::TUnboxedValuePod::Invalid();
  585. }
  586. }
  587. const TComputationNodePtrDeque& GetNodes() const override {
  588. return PatternNodes->GetNodes();
  589. }
  590. TMemoryUsageInfo& GetMemInfo() const override {
  591. return *MemInfo;
  592. }
  593. const THolderFactory& GetHolderFactory() const override {
  594. return *HolderFactory;
  595. }
  596. ITerminator* GetTerminator() const override {
  597. return ValueBuilder.Get();
  598. }
  599. bool SetExecuteLLVM(bool value) override {
  600. const bool old = Ctx->ExecuteLLVM;
  601. Ctx->ExecuteLLVM = value;
  602. return old;
  603. }
  604. TString SaveGraphState() override {
  605. Prepare();
  606. TString result;
  607. for (ui32 i : PatternNodes->GetMutables().SerializableValues) {
  608. const NUdf::TUnboxedValuePod& mutableValue = Ctx->MutableValues[i];
  609. if (mutableValue.IsInvalid()) {
  610. WriteUi64(result, std::numeric_limits<ui64>::max()); // -1.
  611. } else if (mutableValue.IsBoxed()) {
  612. TList<TString> taskState;
  613. size_t taskStateSize = 0;
  614. auto saveList = [&](auto& list) {
  615. auto listIt = list.GetListIterator();
  616. NUdf::TUnboxedValue str;
  617. while (listIt.Next(str)) {
  618. const TStringBuf strRef = str.AsStringRef();
  619. taskStateSize += strRef.Size();
  620. taskState.push_back({});
  621. taskState.back().AppendNoAlias(strRef.Data(), strRef.Size());
  622. }
  623. };
  624. bool isList = mutableValue.HasListItems();
  625. NUdf::TUnboxedValue list;
  626. if (isList) { // No load was done during previous runs.
  627. saveList(mutableValue);
  628. } else {
  629. NUdf::TUnboxedValue saved = mutableValue.Save();
  630. if (saved.IsString() || saved.IsEmbedded()) { // Old version.
  631. const TStringBuf savedBuf = saved.AsStringRef();
  632. taskState.push_back({});
  633. taskState.back().AppendNoAlias(savedBuf.Data(), savedBuf.Size());
  634. taskStateSize = savedBuf.Size();
  635. } else {
  636. saveList(saved);
  637. }
  638. }
  639. WriteUi64(result, taskStateSize);
  640. for (auto it = taskState.begin(); it != taskState.end();) {
  641. result.AppendNoAlias(it->data(), it->size());
  642. it = taskState.erase(it);
  643. }
  644. } else { // No load was done during previous runs (if any).
  645. MKQL_ENSURE(mutableValue.HasValue() && (mutableValue.IsString() || mutableValue.IsEmbedded()), "State is expected to have data or invalid value");
  646. const NUdf::TStringRef savedRef = mutableValue.AsStringRef();
  647. WriteUi64(result, savedRef.Size());
  648. result.AppendNoAlias(savedRef.Data(), savedRef.Size());
  649. }
  650. }
  651. return result;
  652. }
  653. void LoadGraphState(TStringBuf state) override {
  654. Prepare();
  655. for (ui32 i : PatternNodes->GetMutables().SerializableValues) {
  656. if (const ui64 size = ReadUi64(state); size != std::numeric_limits<ui64>::max()) {
  657. MKQL_ENSURE(state.Size() >= size, "Serialized state is corrupted - buffer is too short (" << state.Size() << ") for specified size: " << size);
  658. TStringBuf savedRef(state.Data(), size);
  659. Ctx->MutableValues[i] = NKikimr::NMiniKQL::TOutputSerializer::MakeArray(*Ctx, savedRef);
  660. state.Skip(size);
  661. } // else leave it Invalid()
  662. }
  663. MKQL_ENSURE(state.Empty(), "Serialized state is corrupted - extra bytes left: " << state.Size());
  664. }
  665. private:
  666. const TPatternNodes::TPtr PatternNodes;
  667. const TIntrusivePtr<TMemoryUsageInfo> MemInfo;
  668. THolder<THolderFactory> HolderFactory;
  669. THolder<TDefaultValueBuilder> ValueBuilder;
  670. THolder<TComputationContext> Ctx;
  671. TComputationOptsFull CompOpts;
  672. NYql::NCodegen::ICodegen::TSharedPtr Codegen;
  673. bool IsPrepared = false;
  674. std::optional<TArrowKernelsTopology> KernelsTopology;
  675. };
  676. class TComputationPatternImpl final : public IComputationPattern {
  677. public:
  678. TComputationPatternImpl(THolder<TComputationGraphBuildingVisitor>&& builder, const TComputationPatternOpts& opts)
  679. #if defined(MKQL_DISABLE_CODEGEN)
  680. : Codegen()
  681. #elif defined(MKQL_FORCE_USE_CODEGEN)
  682. : Codegen(NYql::NCodegen::ICodegen::MakeShared(NYql::NCodegen::ETarget::Native))
  683. #else
  684. : Codegen((NYql::NCodegen::ICodegen::IsCodegenAvailable() && opts.OptLLVM != "OFF") || GetEnv(TString("MKQL_FORCE_USE_LLVM")) ? NYql::NCodegen::ICodegen::MakeShared(NYql::NCodegen::ETarget::Native) : NYql::NCodegen::ICodegen::TPtr())
  685. #endif
  686. {
  687. /// TODO: Enable JIT for AARCH64/Win
  688. #if defined(__aarch64__) || defined(_win_)
  689. Codegen = {};
  690. #endif
  691. const auto& nodes = builder->GetNodes();
  692. for (const auto& node : nodes)
  693. node->PrepareStageOne();
  694. for (const auto& node : nodes)
  695. node->PrepareStageTwo();
  696. MKQL_ADD_STAT(opts.Stats, Mkql_TotalNodes, nodes.size());
  697. PatternNodes = builder->GetPatternNodes();
  698. if (Codegen) {
  699. Compile(opts.OptLLVM, opts.Stats);
  700. }
  701. }
  702. ~TComputationPatternImpl() {
  703. if (TypeEnv) {
  704. auto guard = TypeEnv->BindAllocator();
  705. PatternNodes.Reset();
  706. }
  707. }
  708. void Compile(TString optLLVM, IStatsRegistry* stats) {
  709. if (IsPatternCompiled.load())
  710. return;
  711. #ifndef MKQL_DISABLE_CODEGEN
  712. if (!Codegen)
  713. Codegen = NYql::NCodegen::ICodegen::Make(NYql::NCodegen::ETarget::Native);
  714. const auto& nodes = PatternNodes->GetNodes();
  715. TStatTimer timerFull(CodeGen_FullTime);
  716. timerFull.Acquire();
  717. {
  718. TStatTimer timerGen(CodeGen_GenerateTime);
  719. timerGen.Acquire();
  720. for (auto it = nodes.crbegin(); nodes.crend() != it; ++it) {
  721. if (const auto codegen = dynamic_cast<ICodegeneratorRootNode*>(it->Get())) {
  722. codegen->GenerateFunctions(*Codegen);
  723. }
  724. }
  725. timerGen.Release();
  726. timerGen.Report(stats);
  727. }
  728. if (optLLVM.Contains("--dump-generated")) {
  729. Cerr << "############### Begin generated module ###############" << Endl;
  730. Codegen->GetModule().print(llvm::errs(), nullptr);
  731. Cerr << "################ End generated module ################" << Endl;
  732. }
  733. TStatTimer timerComp(CodeGen_CompileTime);
  734. timerComp.Acquire();
  735. NYql::NCodegen::TCodegenStats codegenStats;
  736. Codegen->GetStats(codegenStats);
  737. MKQL_ADD_STAT(stats, CodeGen_TotalFunctions, codegenStats.TotalFunctions);
  738. MKQL_ADD_STAT(stats, CodeGen_TotalInstructions, codegenStats.TotalInstructions);
  739. MKQL_SET_MAX_STAT(stats, CodeGen_MaxFunctionInstructions, codegenStats.MaxFunctionInstructions);
  740. if (optLLVM.Contains("--dump-stats")) {
  741. Cerr << "TotalFunctions: " << codegenStats.TotalFunctions << Endl;
  742. Cerr << "TotalInstructions: " << codegenStats.TotalInstructions << Endl;
  743. Cerr << "MaxFunctionInstructions: " << codegenStats.MaxFunctionInstructions << Endl;
  744. }
  745. if (optLLVM.Contains("--dump-perf-map")) {
  746. Codegen->TogglePerfJITEventListener();
  747. }
  748. if (codegenStats.TotalFunctions >= TotalFunctionsLimit ||
  749. codegenStats.TotalInstructions >= TotalInstructionsLimit ||
  750. codegenStats.MaxFunctionInstructions >= MaxFunctionInstructionsLimit) {
  751. Codegen.reset();
  752. } else {
  753. Codegen->Verify();
  754. Codegen->Compile(GetCompileOptions(optLLVM), &CompileStats);
  755. MKQL_ADD_STAT(stats, CodeGen_FunctionPassTime, CompileStats.FunctionPassTime);
  756. MKQL_ADD_STAT(stats, CodeGen_ModulePassTime, CompileStats.ModulePassTime);
  757. MKQL_ADD_STAT(stats, CodeGen_FinalizeTime, CompileStats.FinalizeTime);
  758. }
  759. timerComp.Release();
  760. timerComp.Report(stats);
  761. if (Codegen) {
  762. if (optLLVM.Contains("--dump-compiled")) {
  763. Cerr << "############### Begin compiled module ###############" << Endl;
  764. Codegen->GetModule().print(llvm::errs(), nullptr);
  765. Cerr << "################ End compiled module ################" << Endl;
  766. }
  767. if (optLLVM.Contains("--asm-compiled")) {
  768. Cerr << "############### Begin compiled asm ###############" << Endl;
  769. Codegen->ShowGeneratedFunctions(&Cerr);
  770. Cerr << "################ End compiled asm ################" << Endl;
  771. }
  772. ui64 count = 0U;
  773. for (const auto& node : nodes) {
  774. if (const auto codegen = dynamic_cast<ICodegeneratorRootNode*>(node.Get())) {
  775. codegen->FinalizeFunctions(*Codegen);
  776. ++count;
  777. }
  778. }
  779. if (count) {
  780. MKQL_ADD_STAT(stats, Mkql_CodegenFunctions, count);
  781. }
  782. }
  783. timerFull.Release();
  784. timerFull.Report(stats);
  785. #else
  786. Y_UNUSED(optLLVM);
  787. Y_UNUSED(stats);
  788. #endif
  789. IsPatternCompiled.store(true);
  790. }
  791. bool IsCompiled() const {
  792. return IsPatternCompiled.load();
  793. }
  794. size_t CompiledCodeSize() const {
  795. return CompileStats.TotalObjectSize;
  796. }
  797. void RemoveCompiledCode() {
  798. IsPatternCompiled.store(false);
  799. CompileStats = {};
  800. Codegen.reset();
  801. }
  802. THolder<IComputationGraph> Clone(const TComputationOptsFull& compOpts) {
  803. if (IsPatternCompiled.load()) {
  804. return MakeHolder<TComputationGraph>(PatternNodes, compOpts, Codegen);
  805. }
  806. return MakeHolder<TComputationGraph>(PatternNodes, compOpts, nullptr);
  807. }
  808. bool GetSuitableForCache() const {
  809. return PatternNodes->GetSuitableForCache();
  810. }
  811. private:
  812. TStringBuf GetCompileOptions(const TString& s) {
  813. const TString flag = "--compile-options";
  814. auto lpos = s.rfind(flag);
  815. if (lpos == TString::npos)
  816. return TStringBuf();
  817. lpos += flag.size();
  818. auto rpos = s.find(" --", lpos);
  819. if (rpos == TString::npos)
  820. return TStringBuf(s, lpos);
  821. else
  822. return TStringBuf(s, lpos, rpos - lpos);
  823. };
  824. TTypeEnvironment* TypeEnv = nullptr;
  825. TPatternNodes::TPtr PatternNodes;
  826. NYql::NCodegen::ICodegen::TSharedPtr Codegen;
  827. std::atomic<bool> IsPatternCompiled = false;
  828. NYql::NCodegen::TCompileStats CompileStats;
  829. };
  830. TIntrusivePtr<TComputationPatternImpl> MakeComputationPatternImpl(TExploringNodeVisitor& explorer, const TRuntimeNode& root,
  831. const std::vector<TNode*>& entryPoints, const TComputationPatternOpts& opts) {
  832. TDependencyScanVisitor depScanner;
  833. depScanner.Walk(root.GetNode(), opts.Env);
  834. auto builder = MakeHolder<TComputationGraphBuildingVisitor>(opts);
  835. for (const auto& node : explorer.GetNodes()) {
  836. Y_ABORT_UNLESS(node->GetCookie() <= IS_NODE_REACHABLE, "TNode graph should not be reused");
  837. if (node->GetCookie() == IS_NODE_REACHABLE) {
  838. node->Accept(*builder);
  839. }
  840. }
  841. const auto rootNode = builder->GetComputationNode(root.GetNode());
  842. TComputationExternalNodePtrVector runtime2ComputationEntryPoints;
  843. runtime2ComputationEntryPoints.resize(entryPoints.size(), nullptr);
  844. std::unordered_map<TNode*, std::vector<ui32>> entryPointIndex;
  845. for (ui32 i = 0; i < entryPoints.size(); ++i) {
  846. entryPointIndex[entryPoints[i]].emplace_back(i);
  847. }
  848. for (const auto& node : explorer.GetNodes()) {
  849. auto it = entryPointIndex.find(node);
  850. if (it == entryPointIndex.cend()) {
  851. continue;
  852. }
  853. auto compNode = dynamic_cast<IComputationExternalNode*>(builder->GetComputationNode(node));
  854. for (auto index : it->second) {
  855. runtime2ComputationEntryPoints[index] = compNode;
  856. }
  857. }
  858. for (const auto& node : explorer.GetNodes()) {
  859. node->SetCookie(0);
  860. }
  861. builder->PreserveRoot(rootNode);
  862. builder->PreserveEntryPoints(std::move(runtime2ComputationEntryPoints));
  863. return MakeIntrusive<TComputationPatternImpl>(std::move(builder), opts);
  864. }
  865. } // namespace
  866. IComputationPattern::TPtr MakeComputationPattern(TExploringNodeVisitor& explorer, const TRuntimeNode& root,
  867. const std::vector<TNode*>& entryPoints, const TComputationPatternOpts& opts) {
  868. return MakeComputationPatternImpl(explorer, root, entryPoints, opts);
  869. }
  870. } // namespace NMiniKQL
  871. } // namespace NKikimr