mkql_chain1_map.cpp 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. #include "mkql_chain1_map.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  4. #include <yql/essentials/minikql/computation/mkql_custom_list.h>
  5. #include <yql/essentials/minikql/mkql_node_cast.h>
  6. namespace NKikimr {
  7. namespace NMiniKQL {
  8. namespace {
  9. struct TComputationNodes {
  10. IComputationExternalNode* const ItemArg;
  11. IComputationExternalNode* const StateArg;
  12. IComputationNode* const InitItem;
  13. IComputationNode* const InitState;
  14. IComputationNode* const UpdateItem;
  15. IComputationNode* const UpdateState;
  16. };
  17. class TFold1MapFlowWrapper : public TStatefulFlowCodegeneratorNode<TFold1MapFlowWrapper> {
  18. typedef TStatefulFlowCodegeneratorNode<TFold1MapFlowWrapper> TBaseComputation;
  19. public:
  20. TFold1MapFlowWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* flow,
  21. IComputationExternalNode* itemArg, IComputationExternalNode* stateArg,
  22. IComputationNode* initItem, IComputationNode* initState,
  23. IComputationNode* updateItem, IComputationNode* updateState)
  24. : TBaseComputation(mutables, flow, kind, EValueRepresentation::Embedded),
  25. Flow(flow), ComputationNodes({itemArg, stateArg, initItem, initState, updateItem, updateState})
  26. {}
  27. NUdf::TUnboxedValue DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  28. auto item = Flow->GetValue(ctx);
  29. if (item.IsSpecial()) {
  30. return item;
  31. }
  32. ComputationNodes.ItemArg->SetValue(ctx, std::move(item));
  33. const bool init = state.IsInvalid();
  34. const auto value = (init ? ComputationNodes.InitItem : ComputationNodes.UpdateItem)->GetValue(ctx);
  35. ComputationNodes.StateArg->SetValue(ctx, (init ? ComputationNodes.InitState : ComputationNodes.UpdateState)->GetValue(ctx));
  36. if (init) {
  37. state = NUdf::TUnboxedValuePod(true);
  38. }
  39. return value;
  40. }
  41. #ifndef MKQL_DISABLE_CODEGEN
  42. Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  43. auto& context = ctx.Codegen.GetContext();
  44. const auto codegenItemArg = dynamic_cast<ICodegeneratorExternalNode*>(ComputationNodes.ItemArg);
  45. const auto codegenStateArg = dynamic_cast<ICodegeneratorExternalNode*>(ComputationNodes.StateArg);
  46. MKQL_ENSURE(codegenItemArg, "Item arg must be codegenerator node.");
  47. MKQL_ENSURE(codegenStateArg, "State arg must be codegenerator node.");
  48. const auto valueType = Type::getInt128Ty(context);
  49. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  50. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  51. const auto result = PHINode::Create(valueType, 3U, "result", done);
  52. const auto item = GetNodeValue(Flow, ctx, block);
  53. result->addIncoming(item, block);
  54. BranchInst::Create(done, good, IsSpecial(item, block, context), block);
  55. block = good;
  56. codegenItemArg->CreateSetValue(ctx, block, item);
  57. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  58. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  59. const auto state = new LoadInst(valueType, statePtr, "load", block);
  60. BranchInst::Create(init, next, IsInvalid(state, block, context), block);
  61. block = init;
  62. const auto one = GetNodeValue(ComputationNodes.InitItem, ctx, block);
  63. codegenStateArg->CreateSetValue(ctx, block, GetNodeValue(ComputationNodes.InitState, ctx, block));
  64. result->addIncoming(one, block);
  65. new StoreInst(GetTrue(context), statePtr, block);
  66. BranchInst::Create(done, block);
  67. block = next;
  68. const auto two = GetNodeValue(ComputationNodes.UpdateItem, ctx, block);
  69. codegenStateArg->CreateSetValue(ctx, block, GetNodeValue(ComputationNodes.UpdateState, ctx, block));
  70. result->addIncoming(two, block);
  71. BranchInst::Create(done, block);
  72. block = done;
  73. return result;
  74. }
  75. #endif
  76. private:
  77. void RegisterDependencies() const final {
  78. if (const auto flow = FlowDependsOn(Flow)) {
  79. DependsOn(flow, ComputationNodes.InitItem);
  80. DependsOn(flow, ComputationNodes.InitState);
  81. DependsOn(flow, ComputationNodes.UpdateItem);
  82. DependsOn(flow, ComputationNodes.UpdateState);
  83. Own(flow, ComputationNodes.ItemArg);
  84. Own(flow, ComputationNodes.StateArg);
  85. }
  86. }
  87. IComputationNode* const Flow;
  88. const TComputationNodes ComputationNodes;
  89. };
  90. template <bool IsStream>
  91. class TBaseChain1MapWrapper {
  92. public:
  93. class TListValue : public TCustomListValue {
  94. public:
  95. class TIterator : public TComputationValue<TIterator> {
  96. public:
  97. TIterator(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, NUdf::TUnboxedValue&& iter, const TComputationNodes& computationNodes)
  98. : TComputationValue<TIterator>(memInfo)
  99. , CompCtx(compCtx)
  100. , Iter(std::move(iter))
  101. , ComputationNodes(computationNodes)
  102. {}
  103. private:
  104. bool Next(NUdf::TUnboxedValue& value) final {
  105. if (!Iter.Next(ComputationNodes.ItemArg->RefValue(CompCtx))) {
  106. return false;
  107. }
  108. ++Length;
  109. auto itemNode = Length == 1 ? ComputationNodes.InitItem : ComputationNodes.UpdateItem;
  110. auto stateNode = Length == 1 ? ComputationNodes.InitState : ComputationNodes.UpdateState;
  111. value = itemNode->GetValue(CompCtx);
  112. ComputationNodes.StateArg->SetValue(CompCtx, stateNode->GetValue(CompCtx));
  113. return true;
  114. }
  115. TComputationContext& CompCtx;
  116. const NUdf::TUnboxedValue Iter;
  117. const TComputationNodes& ComputationNodes;
  118. ui64 Length = 0;
  119. };
  120. TListValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, NUdf::TUnboxedValue&& list, const TComputationNodes& computationNodes)
  121. : TCustomListValue(memInfo)
  122. , CompCtx(compCtx)
  123. , List(std::move(list))
  124. , ComputationNodes(computationNodes)
  125. {}
  126. private:
  127. NUdf::TUnboxedValue GetListIterator() const final {
  128. return CompCtx.HolderFactory.Create<TIterator>(CompCtx, List.GetListIterator(), ComputationNodes);
  129. }
  130. ui64 GetListLength() const final {
  131. if (!Length) {
  132. Length = List.GetListLength();
  133. }
  134. return *Length;
  135. }
  136. bool HasListItems() const final {
  137. if (!HasItems) {
  138. HasItems = List.HasListItems();
  139. }
  140. return *HasItems;
  141. }
  142. TComputationContext& CompCtx;
  143. const NUdf::TUnboxedValue List;
  144. const TComputationNodes& ComputationNodes;
  145. };
  146. class TStreamValue : public TComputationValue<TStreamValue> {
  147. public:
  148. using TBase = TComputationValue<TStreamValue>;
  149. TStreamValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, NUdf::TUnboxedValue&& list, const TComputationNodes& computationNodes)
  150. : TBase(memInfo)
  151. , CompCtx(compCtx)
  152. , List(std::move(list))
  153. , ComputationNodes(computationNodes)
  154. {}
  155. private:
  156. NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& value) final {
  157. const auto status = List.Fetch(ComputationNodes.ItemArg->RefValue(CompCtx));
  158. if (status != NUdf::EFetchStatus::Ok) {
  159. return status;
  160. }
  161. ++Length;
  162. auto itemNode = Length == 1 ? ComputationNodes.InitItem : ComputationNodes.UpdateItem;
  163. auto stateNode = Length == 1 ? ComputationNodes.InitState : ComputationNodes.UpdateState;
  164. value = itemNode->GetValue(CompCtx);
  165. ComputationNodes.StateArg->SetValue(CompCtx, stateNode->GetValue(CompCtx));
  166. return NUdf::EFetchStatus::Ok;
  167. }
  168. TComputationContext& CompCtx;
  169. const NUdf::TUnboxedValue List;
  170. const TComputationNodes& ComputationNodes;
  171. ui64 Length = 0;
  172. };
  173. TBaseChain1MapWrapper(IComputationNode* list, IComputationExternalNode* itemArg, IComputationExternalNode* stateArg,
  174. IComputationNode* initItem, IComputationNode* initState,
  175. IComputationNode* updateItem, IComputationNode* updateState)
  176. : List(list), ComputationNodes({itemArg, stateArg, initItem, initState, updateItem, updateState})
  177. {}
  178. #ifndef MKQL_DISABLE_CODEGEN
  179. template<bool IsFirst>
  180. Function* GenerateMapper(NYql::NCodegen::ICodegen& codegen, const TString& name) const {
  181. auto& module = codegen.GetModule();
  182. auto& context = codegen.GetContext();
  183. const auto newItem = IsFirst ? ComputationNodes.InitItem : ComputationNodes.UpdateItem;
  184. const auto newState = IsFirst ? ComputationNodes.InitState : ComputationNodes.UpdateState;
  185. const auto codegenItemArg = dynamic_cast<ICodegeneratorExternalNode*>(ComputationNodes.ItemArg);
  186. const auto codegenStateArg = dynamic_cast<ICodegeneratorExternalNode*>(ComputationNodes.StateArg);
  187. MKQL_ENSURE(codegenItemArg, "Item arg must be codegenerator node.");
  188. MKQL_ENSURE(codegenStateArg, "State arg must be codegenerator node.");
  189. if (const auto f = module.getFunction(name.c_str()))
  190. return f;
  191. const auto valueType = Type::getInt128Ty(context);
  192. const auto containerType = static_cast<Type*>(valueType);
  193. const auto contextType = GetCompContextType(context);
  194. const auto statusType = IsStream ? Type::getInt32Ty(context) : Type::getInt1Ty(context);
  195. const auto funcType = FunctionType::get(statusType, {PointerType::getUnqual(contextType), containerType, PointerType::getUnqual(valueType)}, false);
  196. TCodegenContext ctx(codegen);
  197. ctx.Func = cast<Function>(module.getOrInsertFunction(name.c_str(), funcType).getCallee());
  198. DISubprogramAnnotator annotator(ctx, ctx.Func);
  199. auto args = ctx.Func->arg_begin();
  200. ctx.Ctx = &*args;
  201. const auto containerArg = &*++args;
  202. const auto valuePtr = &*++args;
  203. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  204. auto block = main;
  205. const auto container = static_cast<Value*>(containerArg);
  206. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  207. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  208. const auto itemPtr = codegenItemArg->CreateRefValue(ctx, block);
  209. const auto status = IsStream ?
  210. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Fetch>(statusType, container, codegen, block, itemPtr):
  211. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Next>(statusType, container, codegen, block, itemPtr);
  212. const auto icmp = IsStream ?
  213. CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, status, ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Ok)), "cond", block):
  214. CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, status, ConstantInt::getFalse(context), "cond", block);
  215. BranchInst::Create(done, good, icmp, block);
  216. block = good;
  217. SafeUnRefUnboxedOne(valuePtr, ctx, block);
  218. GetNodeValue(valuePtr, newItem, ctx, block);
  219. const auto nextState = GetNodeValue(newState, ctx, block);
  220. codegenStateArg->CreateSetValue(ctx, block, nextState);
  221. BranchInst::Create(done, block);
  222. block = done;
  223. ReturnInst::Create(context, status, block);
  224. return ctx.Func;
  225. }
  226. using TChainMapPtr = std::conditional_t<IsStream, TStreamCodegenValueOne::TFetchPtr, TListCodegenValueOne::TNextPtr>;
  227. Function* MapFuncOne = nullptr;
  228. Function* MapFuncTwo = nullptr;
  229. TChainMapPtr MapOne = nullptr;
  230. TChainMapPtr MapTwo = nullptr;
  231. #endif
  232. IComputationNode* const List;
  233. const TComputationNodes ComputationNodes;
  234. };
  235. class TStreamChain1MapWrapper : public TCustomValueCodegeneratorNode<TStreamChain1MapWrapper>, private TBaseChain1MapWrapper<true> {
  236. typedef TCustomValueCodegeneratorNode<TStreamChain1MapWrapper> TBaseComputation;
  237. typedef TBaseChain1MapWrapper<true> TBaseWrapper;
  238. public:
  239. TStreamChain1MapWrapper(TComputationMutables& mutables, IComputationNode* list,
  240. IComputationExternalNode* itemArg, IComputationExternalNode* stateArg,
  241. IComputationNode* initItem, IComputationNode* initState,
  242. IComputationNode* updateItem, IComputationNode* updateState
  243. ) : TBaseComputation(mutables), TBaseWrapper(list, itemArg, stateArg, initItem, initState, updateItem, updateState)
  244. {}
  245. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  246. #ifndef MKQL_DISABLE_CODEGEN
  247. if (ctx.ExecuteLLVM && MapOne && MapTwo)
  248. return ctx.HolderFactory.Create<TStreamCodegenValueOne>(MapOne, MapTwo, &ctx, List->GetValue(ctx));
  249. #endif
  250. return ctx.HolderFactory.Create<TStreamValue>(ctx, List->GetValue(ctx), ComputationNodes);
  251. }
  252. private:
  253. void RegisterDependencies() const final {
  254. DependsOn(List);
  255. DependsOn(ComputationNodes.InitItem);
  256. DependsOn(ComputationNodes.InitState);
  257. DependsOn(ComputationNodes.UpdateItem);
  258. DependsOn(ComputationNodes.UpdateState);
  259. Own(ComputationNodes.ItemArg);
  260. Own(ComputationNodes.StateArg);
  261. }
  262. #ifndef MKQL_DISABLE_CODEGEN
  263. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  264. MapFuncOne = GenerateMapper<true>(codegen, TBaseComputation::MakeName("Fetch_One"));
  265. MapFuncTwo = GenerateMapper<false>(codegen, TBaseComputation::MakeName("Fetch_Two"));
  266. codegen.ExportSymbol(MapFuncOne);
  267. codegen.ExportSymbol(MapFuncTwo);
  268. }
  269. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  270. if (MapFuncOne)
  271. MapOne = reinterpret_cast<TChainMapPtr>(codegen.GetPointerToFunction(MapFuncOne));
  272. if (MapFuncTwo)
  273. MapTwo = reinterpret_cast<TChainMapPtr>(codegen.GetPointerToFunction(MapFuncTwo));
  274. }
  275. #endif
  276. };
  277. class TListChain1MapWrapper : public TBothWaysCodegeneratorNode<TListChain1MapWrapper>, private TBaseChain1MapWrapper<false> {
  278. typedef TBothWaysCodegeneratorNode<TListChain1MapWrapper> TBaseComputation;
  279. typedef TBaseChain1MapWrapper<false> TBaseWrapper;
  280. public:
  281. TListChain1MapWrapper(TComputationMutables& mutables, IComputationNode* list,
  282. IComputationExternalNode* itemArg, IComputationExternalNode* stateArg,
  283. IComputationNode* initItem, IComputationNode* initState,
  284. IComputationNode* updateItem, IComputationNode* updateState
  285. ) : TBaseComputation(mutables), TBaseWrapper(list, itemArg, stateArg, initItem, initState, updateItem, updateState)
  286. {}
  287. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  288. auto list = List->GetValue(ctx);
  289. if (auto elements = list.GetElements()) {
  290. auto size = list.GetListLength();
  291. NUdf::TUnboxedValue* items = nullptr;
  292. const auto result = ctx.HolderFactory.CreateDirectArrayHolder(size, items);
  293. if (size) {
  294. ComputationNodes.ItemArg->SetValue(ctx, NUdf::TUnboxedValue(*elements++));
  295. *items++ = ComputationNodes.InitItem->GetValue(ctx);
  296. ComputationNodes.StateArg->SetValue(ctx, ComputationNodes.InitState->GetValue(ctx));
  297. while (--size) {
  298. ComputationNodes.ItemArg->SetValue(ctx, NUdf::TUnboxedValue(*elements++));
  299. *items++ = ComputationNodes.UpdateItem->GetValue(ctx);
  300. ComputationNodes.StateArg->SetValue(ctx, ComputationNodes.UpdateState->GetValue(ctx));
  301. }
  302. }
  303. return result;
  304. }
  305. return ctx.HolderFactory.Create<TListValue>(ctx, std::move(list), ComputationNodes);
  306. }
  307. #ifndef MKQL_DISABLE_CODEGEN
  308. NUdf::TUnboxedValuePod MakeLazyList(TComputationContext& ctx, const NUdf::TUnboxedValuePod value) const {
  309. return ctx.HolderFactory.Create<TListCodegenValueOne>(MapOne, MapTwo, &ctx, value);
  310. }
  311. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  312. auto& context = ctx.Codegen.GetContext();
  313. const auto codegenItemArg = dynamic_cast<ICodegeneratorExternalNode*>(ComputationNodes.ItemArg);
  314. const auto codegenStateArg = dynamic_cast<ICodegeneratorExternalNode*>(ComputationNodes.StateArg);
  315. MKQL_ENSURE(codegenItemArg, "Item arg must be codegenerator node.");
  316. MKQL_ENSURE(codegenStateArg, "State arg must be codegenerator node.");
  317. const auto list = GetNodeValue(List, ctx, block);
  318. const auto lazy = BasicBlock::Create(context, "lazy", ctx.Func);
  319. const auto hard = BasicBlock::Create(context, "hard", ctx.Func);
  320. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  321. const auto map = PHINode::Create(list->getType(), 3U, "map", done);
  322. const auto elementsType = PointerType::getUnqual(list->getType());
  323. const auto elements = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElements>(elementsType, list, ctx.Codegen, block);
  324. const auto fill = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, elements, ConstantPointerNull::get(elementsType), "fill", block);
  325. BranchInst::Create(hard, lazy, fill, block);
  326. {
  327. block = hard;
  328. const auto size = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetListLength>(Type::getInt64Ty(context), list, ctx.Codegen, block);
  329. const auto itemsPtr = *Stateless || ctx.AlwaysInline ?
  330. new AllocaInst(elementsType, 0U, "items_ptr", &ctx.Func->getEntryBlock().back()):
  331. new AllocaInst(elementsType, 0U, "items_ptr", block);
  332. const auto array = GenNewArray(ctx, size, itemsPtr, block);
  333. const auto items = new LoadInst(elementsType, itemsPtr, "items", block);
  334. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  335. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  336. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  337. const auto stop = BasicBlock::Create(context, "stop", ctx.Func);
  338. const auto good = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, size, ConstantInt::get(size->getType(), 0), "good", block);
  339. map->addIncoming(array, block);
  340. BranchInst::Create(init, done, good, block);
  341. block = init;
  342. const auto head = new LoadInst(list->getType(), elements, "head", block);
  343. codegenItemArg->CreateSetValue(ctx, block, head);
  344. GetNodeValue(items, ComputationNodes.InitItem, ctx, block);
  345. const auto state = GetNodeValue(ComputationNodes.InitState, ctx, block);
  346. codegenStateArg->CreateSetValue(ctx, block, state);
  347. const auto index = PHINode::Create(size->getType(), 2U, "index", loop);
  348. index->addIncoming(ConstantInt::get(size->getType(), 1), block);
  349. BranchInst::Create(loop, block);
  350. block = loop;
  351. const auto more = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, size, index, "more", block);
  352. BranchInst::Create(next, stop, more, block);
  353. block = next;
  354. const auto src = GetElementPtrInst::CreateInBounds(list->getType(), elements, {index}, "src", block);
  355. const auto item = new LoadInst(list->getType(), src, "item", block);
  356. codegenItemArg->CreateSetValue(ctx, block, item);
  357. const auto dst = GetElementPtrInst::CreateInBounds(list->getType(), items, {index}, "dst", block);
  358. GetNodeValue(dst, ComputationNodes.UpdateItem, ctx, block);
  359. const auto newState = GetNodeValue(ComputationNodes.UpdateState, ctx, block);
  360. codegenStateArg->CreateSetValue(ctx, block, newState);
  361. const auto plus = BinaryOperator::CreateAdd(index, ConstantInt::get(size->getType(), 1), "plus", block);
  362. index->addIncoming(plus, block);
  363. BranchInst::Create(loop, block);
  364. block = stop;
  365. if (List->IsTemporaryValue()) {
  366. CleanupBoxed(list, ctx, block);
  367. }
  368. map->addIncoming(array, block);
  369. BranchInst::Create(done, block);
  370. }
  371. {
  372. block = lazy;
  373. const auto doFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TListChain1MapWrapper::MakeLazyList));
  374. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  375. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  376. const auto funType = FunctionType::get(list->getType() , {self->getType(), ctx.Ctx->getType(), list->getType()}, false);
  377. const auto doFuncPtr = CastInst::Create(Instruction::IntToPtr, doFunc, PointerType::getUnqual(funType), "function", block);
  378. const auto value = CallInst::Create(funType, doFuncPtr, {self, ctx.Ctx, list}, "value", block);
  379. map->addIncoming(value, block);
  380. BranchInst::Create(done, block);
  381. }
  382. block = done;
  383. return map;
  384. }
  385. #endif
  386. private:
  387. void RegisterDependencies() const final {
  388. DependsOn(List);
  389. DependsOn(ComputationNodes.InitItem);
  390. DependsOn(ComputationNodes.InitState);
  391. DependsOn(ComputationNodes.UpdateItem);
  392. DependsOn(ComputationNodes.UpdateState);
  393. Own(ComputationNodes.ItemArg);
  394. Own(ComputationNodes.StateArg);
  395. }
  396. #ifndef MKQL_DISABLE_CODEGEN
  397. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  398. TMutableCodegeneratorRootNode<TListChain1MapWrapper>::GenerateFunctions(codegen);
  399. MapFuncOne = GenerateMapper<true>(codegen, TBaseComputation::MakeName("Next_One"));
  400. MapFuncTwo = GenerateMapper<false>(codegen, TBaseComputation::MakeName("Next_Two"));
  401. codegen.ExportSymbol(MapFuncOne);
  402. codegen.ExportSymbol(MapFuncTwo);
  403. }
  404. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  405. TMutableCodegeneratorRootNode<TListChain1MapWrapper>::FinalizeFunctions(codegen);
  406. if (MapFuncOne)
  407. MapOne = reinterpret_cast<TChainMapPtr>(codegen.GetPointerToFunction(MapFuncOne));
  408. if (MapFuncTwo)
  409. MapTwo = reinterpret_cast<TChainMapPtr>(codegen.GetPointerToFunction(MapFuncTwo));
  410. }
  411. #endif
  412. };
  413. }
  414. IComputationNode* WrapChain1Map(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  415. MKQL_ENSURE(callable.GetInputsCount() == 7, "Expected 7 args");
  416. const auto type = callable.GetType()->GetReturnType();
  417. const auto flow = LocateNode(ctx.NodeLocator, callable, 0);
  418. const auto initItem = LocateNode(ctx.NodeLocator, callable, 2);
  419. const auto initState = LocateNode(ctx.NodeLocator, callable, 3);
  420. const auto updateItem = LocateNode(ctx.NodeLocator, callable, 5);
  421. const auto updateState = LocateNode(ctx.NodeLocator, callable, 6);
  422. const auto itemArg = LocateExternalNode(ctx.NodeLocator, callable, 1);
  423. const auto stateArg = LocateExternalNode(ctx.NodeLocator, callable, 4);
  424. if (type->IsFlow()) {
  425. return new TFold1MapFlowWrapper(ctx.Mutables, GetValueRepresentation(type), flow, itemArg, stateArg, initItem, initState, updateItem, updateState);
  426. } else if (type->IsStream()) {
  427. return new TStreamChain1MapWrapper(ctx.Mutables, flow, itemArg, stateArg, initItem, initState, updateItem, updateState);
  428. } else if (type->IsList()) {
  429. return new TListChain1MapWrapper(ctx.Mutables, flow, itemArg, stateArg, initItem, initState, updateItem, updateState);
  430. }
  431. THROW yexception() << "Expected flow, list or stream.";
  432. }
  433. }
  434. }