mkql_chain_map.cpp 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. #include "mkql_chain_map.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  4. #include <yql/essentials/minikql/computation/mkql_custom_list.h>
  5. #include <yql/essentials/minikql/mkql_node_cast.h>
  6. namespace NKikimr {
  7. namespace NMiniKQL {
  8. namespace {
  9. struct TComputationNodes {
  10. IComputationExternalNode* const ItemArg;
  11. IComputationExternalNode* const StateArg;
  12. IComputationNode* const NewItem;
  13. IComputationNode* const NewState;
  14. };
  15. class TFoldMapFlowWrapper : public TStatefulFlowCodegeneratorNode<TFoldMapFlowWrapper> {
  16. typedef TStatefulFlowCodegeneratorNode<TFoldMapFlowWrapper> TBaseComputation;
  17. public:
  18. TFoldMapFlowWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* flow, IComputationNode* initialState, IComputationExternalNode* itemArg, IComputationExternalNode* stateArg, IComputationNode* newItem, IComputationNode* newState)
  19. : TBaseComputation(mutables, flow, kind, EValueRepresentation::Embedded),
  20. Flow(flow), Init(initialState), ComputationNodes({itemArg, stateArg, newItem, newState})
  21. {}
  22. NUdf::TUnboxedValue DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  23. if (state.IsInvalid()) {
  24. ComputationNodes.StateArg->SetValue(ctx, Init->GetValue(ctx));
  25. state = NUdf::TUnboxedValuePod(true);
  26. }
  27. auto item = Flow->GetValue(ctx);
  28. if (item.IsSpecial()) {
  29. return item;
  30. }
  31. ComputationNodes.ItemArg->SetValue(ctx, std::move(item));
  32. const auto value = ComputationNodes.NewItem->GetValue(ctx);
  33. ComputationNodes.StateArg->SetValue(ctx, ComputationNodes.NewState->GetValue(ctx));
  34. return value;
  35. }
  36. #ifndef MKQL_DISABLE_CODEGEN
  37. Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  38. auto& context = ctx.Codegen.GetContext();
  39. const auto codegenItemArg = dynamic_cast<ICodegeneratorExternalNode*>(ComputationNodes.ItemArg);
  40. const auto codegenStateArg = dynamic_cast<ICodegeneratorExternalNode*>(ComputationNodes.StateArg);
  41. MKQL_ENSURE(codegenItemArg, "Item arg must be codegenerator node.");
  42. MKQL_ENSURE(codegenStateArg, "State arg must be codegenerator node.");
  43. const auto valueType = Type::getInt128Ty(context);
  44. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  45. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  46. const auto state = new LoadInst(valueType, statePtr, "load", block);
  47. BranchInst::Create(init, work, IsInvalid(state, block, context), block);
  48. block = init;
  49. new StoreInst(GetTrue(context), statePtr, block);
  50. codegenStateArg->CreateSetValue(ctx, block, GetNodeValue(Init, ctx, block));
  51. BranchInst::Create(work, block);
  52. block = work;
  53. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  54. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  55. const auto result = PHINode::Create(valueType, 2U, "result", done);
  56. const auto item = GetNodeValue(Flow, ctx, block);
  57. result->addIncoming(item, block);
  58. BranchInst::Create(done, good, IsSpecial(item, block, context), block);
  59. block = good;
  60. codegenItemArg->CreateSetValue(ctx, block, item);
  61. const auto value = GetNodeValue(ComputationNodes.NewItem, ctx, block);
  62. codegenStateArg->CreateSetValue(ctx, block, GetNodeValue(ComputationNodes.NewState, ctx, block));
  63. result->addIncoming(value, block);
  64. BranchInst::Create(done, block);
  65. block = done;
  66. return result;
  67. }
  68. #endif
  69. private:
  70. void RegisterDependencies() const final {
  71. if (const auto flow = FlowDependsOn(Flow)) {
  72. DependsOn(flow, Init);
  73. DependsOn(flow, ComputationNodes.NewItem);
  74. DependsOn(flow, ComputationNodes.NewState);
  75. Own(flow, ComputationNodes.ItemArg);
  76. Own(flow, ComputationNodes.StateArg);
  77. }
  78. }
  79. IComputationNode* const Flow;
  80. IComputationNode* const Init;
  81. const TComputationNodes ComputationNodes;
  82. };
  83. template <bool IsStream>
  84. class TBaseChainMapWrapper {
  85. public:
  86. class TListValue : public TCustomListValue {
  87. public:
  88. class TIterator : public TComputationValue<TIterator> {
  89. public:
  90. TIterator(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, NUdf::TUnboxedValue&& iter, const NUdf::TUnboxedValue& init, const TComputationNodes& computationNodes)
  91. : TComputationValue<TIterator>(memInfo)
  92. , CompCtx(compCtx)
  93. , ComputationNodes(computationNodes)
  94. , Iter(std::move(iter))
  95. , Init(init)
  96. {
  97. }
  98. private:
  99. bool Next(NUdf::TUnboxedValue& value) final {
  100. if (!Init.IsInvalid()) {
  101. ComputationNodes.StateArg->SetValue(CompCtx, std::move(Init));
  102. Init = NUdf::TUnboxedValue::Invalid();
  103. }
  104. if (!Iter.Next(ComputationNodes.ItemArg->RefValue(CompCtx))) {
  105. return false;
  106. }
  107. value = ComputationNodes.NewItem->GetValue(CompCtx);
  108. ComputationNodes.StateArg->SetValue(CompCtx, ComputationNodes.NewState->GetValue(CompCtx));
  109. return true;
  110. }
  111. TComputationContext& CompCtx;
  112. const TComputationNodes& ComputationNodes;
  113. const NUdf::TUnboxedValue Iter;
  114. NUdf::TUnboxedValue Init;
  115. };
  116. TListValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, NUdf::TUnboxedValue&& list, NUdf::TUnboxedValue&& init, const TComputationNodes& computationNodes)
  117. : TCustomListValue(memInfo)
  118. , CompCtx(compCtx)
  119. , List(std::move(list))
  120. , Init(std::move(init))
  121. , ComputationNodes(computationNodes)
  122. {}
  123. private:
  124. NUdf::TUnboxedValue GetListIterator() const final {
  125. return CompCtx.HolderFactory.Create<TIterator>(CompCtx, List.GetListIterator(), Init, ComputationNodes);
  126. }
  127. ui64 GetListLength() const final {
  128. if (!Length) {
  129. Length = List.GetListLength();
  130. }
  131. return *Length;
  132. }
  133. bool HasListItems() const final {
  134. if (!HasItems) {
  135. HasItems = List.HasListItems();
  136. }
  137. return *HasItems;
  138. }
  139. TComputationContext& CompCtx;
  140. const NUdf::TUnboxedValue List;
  141. const NUdf::TUnboxedValue Init;
  142. const TComputationNodes& ComputationNodes;
  143. };
  144. class TStreamValue : public TComputationValue<TStreamValue> {
  145. public:
  146. using TBase = TComputationValue<TStreamValue>;
  147. TStreamValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, NUdf::TUnboxedValue&& list, NUdf::TUnboxedValue&& init, const TComputationNodes& computationNodes)
  148. : TBase(memInfo)
  149. , CompCtx(compCtx)
  150. , ComputationNodes(computationNodes)
  151. , List(std::move(list))
  152. , Init(std::move(init))
  153. {}
  154. private:
  155. NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& value) final {
  156. if (!Init.IsInvalid()) {
  157. ComputationNodes.StateArg->SetValue(CompCtx, std::move(Init));
  158. Init = NUdf::TUnboxedValuePod::Invalid();
  159. }
  160. const auto status = List.Fetch(ComputationNodes.ItemArg->RefValue(CompCtx));
  161. if (status != NUdf::EFetchStatus::Ok) {
  162. return status;
  163. }
  164. value = ComputationNodes.NewItem->GetValue(CompCtx);
  165. ComputationNodes.StateArg->SetValue(CompCtx, ComputationNodes.NewState->GetValue(CompCtx));
  166. return NUdf::EFetchStatus::Ok;
  167. }
  168. TComputationContext& CompCtx;
  169. const TComputationNodes& ComputationNodes;
  170. const NUdf::TUnboxedValue List;
  171. NUdf::TUnboxedValue Init;
  172. };
  173. TBaseChainMapWrapper(IComputationNode* list, IComputationNode* init, IComputationExternalNode* itemArg, IComputationExternalNode* stateArg, IComputationNode* newItem, IComputationNode* newState)
  174. : List(list), Init(init), ComputationNodes({itemArg, stateArg, newItem, newState})
  175. {}
  176. #ifndef MKQL_DISABLE_CODEGEN
  177. Function* GenerateMapper(NYql::NCodegen::ICodegen& codegen, const TString& name) const {
  178. auto& module = codegen.GetModule();
  179. auto& context = codegen.GetContext();
  180. const auto codegenItemArg = dynamic_cast<ICodegeneratorExternalNode*>(ComputationNodes.ItemArg);
  181. const auto codegenStateArg = dynamic_cast<ICodegeneratorExternalNode*>(ComputationNodes.StateArg);
  182. MKQL_ENSURE(codegenItemArg, "Item arg must be codegenerator node.");
  183. MKQL_ENSURE(codegenStateArg, "State arg must be codegenerator node.");
  184. if (const auto f = module.getFunction(name.c_str()))
  185. return f;
  186. const auto valueType = Type::getInt128Ty(context);
  187. const auto containerType = static_cast<Type*>(valueType);
  188. const auto contextType = GetCompContextType(context);
  189. const auto statusType = IsStream ? Type::getInt32Ty(context) : Type::getInt1Ty(context);
  190. const auto funcType = FunctionType::get(statusType, {PointerType::getUnqual(contextType), containerType, PointerType::getUnqual(valueType), PointerType::getUnqual(valueType)}, false);
  191. TCodegenContext ctx(codegen);
  192. ctx.Func = cast<Function>(module.getOrInsertFunction(name.c_str(), funcType).getCallee());
  193. DISubprogramAnnotator annotator(ctx, ctx.Func);
  194. auto args = ctx.Func->arg_begin();
  195. ctx.Ctx = &*args;
  196. const auto containerArg = &*++args;
  197. const auto initArg = &*++args;
  198. const auto valuePtr = &*++args;
  199. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  200. auto block = main;
  201. const auto container = static_cast<Value*>(containerArg);
  202. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  203. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  204. const auto load = new LoadInst(valueType, initArg, "load", block);
  205. BranchInst::Create(work, init, IsInvalid(load, block, context), block);
  206. block = init;
  207. codegenStateArg->CreateSetValue(ctx, block, initArg);
  208. new StoreInst(GetInvalid(context), initArg, block);
  209. BranchInst::Create(work, block);
  210. block = work;
  211. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  212. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  213. const auto itemPtr = codegenItemArg->CreateRefValue(ctx, block);
  214. const auto status = IsStream ?
  215. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Fetch>(statusType, container, codegen, block, itemPtr):
  216. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Next>(statusType, container, codegen, block, itemPtr);
  217. const auto icmp = IsStream ?
  218. CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, status, ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Ok)), "cond", block):
  219. CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, status, ConstantInt::getFalse(context), "cond", block);
  220. BranchInst::Create(done, good, icmp, block);
  221. block = good;
  222. SafeUnRefUnboxedOne(valuePtr, ctx, block);
  223. GetNodeValue(valuePtr, ComputationNodes.NewItem, ctx, block);
  224. const auto newState = GetNodeValue(ComputationNodes.NewState, ctx, block);
  225. codegenStateArg->CreateSetValue(ctx, block, newState);
  226. BranchInst::Create(done, block);
  227. block = done;
  228. ReturnInst::Create(context, status, block);
  229. return ctx.Func;
  230. }
  231. using TChainMapPtr = std::conditional_t<IsStream, TStreamCodegenStatefulValue::TFetchPtr, TCustomListCodegenStatefulValue::TNextPtr>;
  232. Function* ChainMapFunc = nullptr;
  233. TChainMapPtr ChainMap = nullptr;
  234. #endif
  235. IComputationNode* const List;
  236. IComputationNode* const Init;
  237. const TComputationNodes ComputationNodes;
  238. };
  239. class TStreamChainMapWrapper : public TCustomValueCodegeneratorNode<TStreamChainMapWrapper>, private TBaseChainMapWrapper<true> {
  240. typedef TCustomValueCodegeneratorNode<TStreamChainMapWrapper> TBaseComputation;
  241. typedef TBaseChainMapWrapper<true> TBaseWrapper;
  242. public:
  243. TStreamChainMapWrapper(TComputationMutables& mutables, IComputationNode* list, IComputationNode* init, IComputationExternalNode* itemArg, IComputationExternalNode* stateArg, IComputationNode* newItem, IComputationNode* newState)
  244. : TBaseComputation(mutables), TBaseWrapper(list, init, itemArg, stateArg, newItem, newState)
  245. {}
  246. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  247. #ifndef MKQL_DISABLE_CODEGEN
  248. if (ctx.ExecuteLLVM && ChainMap)
  249. return ctx.HolderFactory.Create<TStreamCodegenStatefulValue>(ChainMap, &ctx, List->GetValue(ctx), Init->GetValue(ctx));
  250. #endif
  251. return ctx.HolderFactory.Create<TStreamValue>(ctx, List->GetValue(ctx), Init->GetValue(ctx), ComputationNodes);
  252. }
  253. private:
  254. void RegisterDependencies() const final {
  255. DependsOn(List);
  256. DependsOn(Init);
  257. DependsOn(ComputationNodes.NewItem);
  258. DependsOn(ComputationNodes.NewState);
  259. Own(ComputationNodes.ItemArg);
  260. Own(ComputationNodes.StateArg);
  261. }
  262. #ifndef MKQL_DISABLE_CODEGEN
  263. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  264. ChainMapFunc = GenerateMapper(codegen, TBaseComputation::MakeName("Fetch"));
  265. codegen.ExportSymbol(ChainMapFunc);
  266. }
  267. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  268. if (ChainMapFunc)
  269. ChainMap = reinterpret_cast<TChainMapPtr>(codegen.GetPointerToFunction(ChainMapFunc));
  270. }
  271. #endif
  272. };
  273. class TListChainMapWrapper : public TBothWaysCodegeneratorNode<TListChainMapWrapper>, private TBaseChainMapWrapper<false> {
  274. typedef TBothWaysCodegeneratorNode<TListChainMapWrapper> TBaseComputation;
  275. typedef TBaseChainMapWrapper<false> TBaseWrapper;
  276. public:
  277. TListChainMapWrapper(TComputationMutables& mutables, IComputationNode* list, IComputationNode* init, IComputationExternalNode* itemArg, IComputationExternalNode* stateArg, IComputationNode* newItem, IComputationNode* newState)
  278. : TBaseComputation(mutables), TBaseWrapper(list, init, itemArg, stateArg, newItem, newState)
  279. {}
  280. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  281. auto init = Init->GetValue(ctx);
  282. auto list = List->GetValue(ctx);
  283. if (auto elements = list.GetElements()) {
  284. auto size = list.GetListLength();
  285. ComputationNodes.StateArg->SetValue(ctx, std::move(init));
  286. NUdf::TUnboxedValue* items = nullptr;
  287. const auto result = ctx.HolderFactory.CreateDirectArrayHolder(size, items);
  288. while (size--) {
  289. ComputationNodes.ItemArg->SetValue(ctx, NUdf::TUnboxedValue(*elements++));
  290. *items++ = ComputationNodes.NewItem->GetValue(ctx);
  291. ComputationNodes.StateArg->SetValue(ctx, ComputationNodes.NewState->GetValue(ctx));
  292. }
  293. return result;
  294. }
  295. return ctx.HolderFactory.Create<TListValue>(ctx, std::move(list), std::move(init), ComputationNodes);
  296. }
  297. #ifndef MKQL_DISABLE_CODEGEN
  298. NUdf::TUnboxedValuePod MakeLazyList(TComputationContext& ctx, const NUdf::TUnboxedValuePod value, const NUdf::TUnboxedValuePod init) const {
  299. return ctx.HolderFactory.Create<TCustomListCodegenStatefulValue>(ChainMap, &ctx, value, init);
  300. }
  301. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  302. auto& context = ctx.Codegen.GetContext();
  303. const auto codegenItemArg = dynamic_cast<ICodegeneratorExternalNode*>(ComputationNodes.ItemArg);
  304. const auto codegenStateArg = dynamic_cast<ICodegeneratorExternalNode*>(ComputationNodes.StateArg);
  305. MKQL_ENSURE(codegenItemArg, "Item arg must be codegenerator node.");
  306. MKQL_ENSURE(codegenStateArg, "State arg must be codegenerator node.");
  307. const auto init = GetNodeValue(Init, ctx, block);
  308. const auto list = GetNodeValue(List, ctx, block);
  309. const auto lazy = BasicBlock::Create(context, "lazy", ctx.Func);
  310. const auto hard = BasicBlock::Create(context, "hard", ctx.Func);
  311. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  312. const auto map = PHINode::Create(list->getType(), 2U, "map", done);
  313. const auto elementsType = PointerType::getUnqual(list->getType());
  314. const auto elements = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElements>(elementsType, list, ctx.Codegen, block);
  315. const auto fill = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, elements, ConstantPointerNull::get(elementsType), "fill", block);
  316. BranchInst::Create(hard, lazy, fill, block);
  317. {
  318. block = hard;
  319. const auto size = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetListLength>(Type::getInt64Ty(context), list, ctx.Codegen, block);
  320. codegenStateArg->CreateSetValue(ctx, block, init);
  321. const auto itemsPtr = *Stateless || ctx.AlwaysInline ?
  322. new AllocaInst(elementsType, 0U, "items_ptr", &ctx.Func->getEntryBlock().back()):
  323. new AllocaInst(elementsType, 0U, "items_ptr", block);
  324. const auto array = GenNewArray(ctx, size, itemsPtr, block);
  325. const auto items = new LoadInst(elementsType, itemsPtr, "items", block);
  326. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  327. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  328. const auto stop = BasicBlock::Create(context, "stop", ctx.Func);
  329. const auto index = PHINode::Create(size->getType(), 2U, "index", loop);
  330. index->addIncoming(ConstantInt::get(size->getType(), 0), block);
  331. BranchInst::Create(loop, block);
  332. block = loop;
  333. const auto more = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, size, index, "more", block);
  334. BranchInst::Create(next, stop, more, block);
  335. block = next;
  336. const auto src = GetElementPtrInst::CreateInBounds(list->getType(), elements, {index}, "src", block);
  337. const auto item = new LoadInst(list->getType(), src, "item", block);
  338. codegenItemArg->CreateSetValue(ctx, block, item);
  339. const auto dst = GetElementPtrInst::CreateInBounds(list->getType(), items, {index}, "dst", block);
  340. GetNodeValue(dst, ComputationNodes.NewItem, ctx, block);
  341. const auto newState = GetNodeValue(ComputationNodes.NewState, ctx, block);
  342. codegenStateArg->CreateSetValue(ctx, block, newState);
  343. const auto plus = BinaryOperator::CreateAdd(index, ConstantInt::get(size->getType(), 1), "plus", block);
  344. index->addIncoming(plus, block);
  345. BranchInst::Create(loop, block);
  346. block = stop;
  347. if (List->IsTemporaryValue()) {
  348. CleanupBoxed(list, ctx, block);
  349. }
  350. map->addIncoming(array, block);
  351. BranchInst::Create(done, block);
  352. }
  353. {
  354. block = lazy;
  355. const auto doFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TListChainMapWrapper::MakeLazyList));
  356. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  357. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  358. const auto funType = FunctionType::get(list->getType() , {self->getType(), ctx.Ctx->getType(), list->getType(), init->getType()}, false);
  359. const auto doFuncPtr = CastInst::Create(Instruction::IntToPtr, doFunc, PointerType::getUnqual(funType), "function", block);
  360. const auto value = CallInst::Create(funType, doFuncPtr, {self, ctx.Ctx, list, init}, "value", block);
  361. map->addIncoming(value, block);
  362. BranchInst::Create(done, block);
  363. }
  364. block = done;
  365. return map;
  366. }
  367. #endif
  368. private:
  369. void RegisterDependencies() const final {
  370. DependsOn(List);
  371. DependsOn(Init);
  372. DependsOn(ComputationNodes.NewItem);
  373. DependsOn(ComputationNodes.NewState);
  374. Own(ComputationNodes.ItemArg);
  375. Own(ComputationNodes.StateArg);
  376. }
  377. #ifndef MKQL_DISABLE_CODEGEN
  378. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  379. TMutableCodegeneratorRootNode<TListChainMapWrapper>::GenerateFunctions(codegen);
  380. ChainMapFunc = GenerateMapper(codegen, TBaseComputation::MakeName("Next"));
  381. codegen.ExportSymbol(ChainMapFunc);
  382. }
  383. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  384. TMutableCodegeneratorRootNode<TListChainMapWrapper>::FinalizeFunctions(codegen);
  385. if (ChainMapFunc)
  386. ChainMap = reinterpret_cast<TChainMapPtr>(codegen.GetPointerToFunction(ChainMapFunc));
  387. }
  388. #endif
  389. };
  390. }
  391. IComputationNode* WrapChainMap(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  392. MKQL_ENSURE(callable.GetInputsCount() == 6, "Expected 6 args");
  393. const auto type = callable.GetType()->GetReturnType();
  394. const auto flow = LocateNode(ctx.NodeLocator, callable, 0);
  395. const auto initialState = LocateNode(ctx.NodeLocator, callable, 1);
  396. const auto newItem = LocateNode(ctx.NodeLocator, callable, 4);
  397. const auto newState = LocateNode(ctx.NodeLocator, callable, 5);
  398. const auto itemArg = LocateExternalNode(ctx.NodeLocator, callable, 2);
  399. const auto stateArg = LocateExternalNode(ctx.NodeLocator, callable, 3);
  400. if (type->IsFlow()) {
  401. return new TFoldMapFlowWrapper(ctx.Mutables, GetValueRepresentation(type), flow, initialState, itemArg, stateArg, newItem, newState);
  402. } else if (type->IsStream()) {
  403. return new TStreamChainMapWrapper(ctx.Mutables, flow, initialState, itemArg, stateArg, newItem, newState);
  404. } else if (type->IsList()) {
  405. return new TListChainMapWrapper(ctx.Mutables, flow, initialState, itemArg, stateArg, newItem, newState);
  406. }
  407. THROW yexception() << "Expected flow, list or stream.";
  408. }
  409. }
  410. }