mkql_flatmap.cpp 73 KB


  1. #include "mkql_flatmap.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  4. #include <yql/essentials/minikql/mkql_node_cast.h>
  5. #include <yql/essentials/utils/cast.h>
  6. namespace NKikimr {
  7. namespace NMiniKQL {
  8. using NYql::EnsureDynamicCast;
  9. namespace {
  10. class TFlowFlatMapFlowWrapper : public TStatefulFlowCodegeneratorNode<TFlowFlatMapFlowWrapper> {
  11. using TBaseComputation = TStatefulFlowCodegeneratorNode<TFlowFlatMapFlowWrapper>;
  12. public:
  13. TFlowFlatMapFlowWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* flow, IComputationExternalNode* input, IComputationNode* output)
  14. : TBaseComputation(mutables, flow, kind, EValueRepresentation::Embedded), Flow(flow), Input(input), Output(output)
  15. {}
  16. NUdf::TUnboxedValuePod DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  17. if (state.IsInvalid()) {
  18. if (auto item = Flow->GetValue(ctx); item.IsSpecial()) {
  19. return item.Release();
  20. } else {
  21. state = NUdf::TUnboxedValuePod();
  22. Input->SetValue(ctx, std::move(item));
  23. }
  24. }
  25. while (true) {
  26. if (auto output = Output->GetValue(ctx); output.IsFinish()) {
  27. if (auto item = Flow->GetValue(ctx); item.IsSpecial()) {
  28. state = NUdf::TUnboxedValuePod::Invalid();
  29. return item.Release();
  30. } else {
  31. state = NUdf::TUnboxedValuePod();
  32. Input->SetValue(ctx, std::move(item));
  33. }
  34. } else {
  35. return output.Release();
  36. }
  37. }
  38. }
  39. #ifndef MKQL_DISABLE_CODEGEN
  40. Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  41. const auto codegenInput = dynamic_cast<ICodegeneratorExternalNode*>(Input);
  42. MKQL_ENSURE(codegenInput, "Input must be codegenerator node.");
  43. auto& context = ctx.Codegen.GetContext();
  44. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  45. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  46. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  47. const auto step = BasicBlock::Create(context, "step", ctx.Func);
  48. const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
  49. const auto valueType = Type::getInt128Ty(context);
  50. const auto result = PHINode::Create(valueType, 2U, "result", exit);
  51. const auto state = new LoadInst(valueType, statePtr, "state", block);
  52. const auto reset = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, state, GetInvalid(context), "reset", block);
  53. BranchInst::Create(init, work, reset, block);
  54. block = init;
  55. const auto item = GetNodeValue(Flow, ctx, block);
  56. result->addIncoming(item, block);
  57. BranchInst::Create(exit, next, IsSpecial(item, block, context), block);
  58. block = next;
  59. new StoreInst(GetEmpty(context), statePtr, block);
  60. codegenInput->CreateSetValue(ctx, block, item);
  61. BranchInst::Create(work, block);
  62. block = work;
  63. const auto output = GetNodeValue(Output, ctx, block);
  64. const auto finish = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, output, GetFinish(context), "finish", block);
  65. result->addIncoming(output, block);
  66. BranchInst::Create(step, exit, finish, block);
  67. block = step;
  68. new StoreInst(GetInvalid(context), statePtr, block);
  69. BranchInst::Create(init, block);
  70. block = exit;
  71. return result;
  72. }
  73. #endif
  74. private:
  75. void RegisterDependencies() const final {
  76. if (const auto flow = FlowDependsOn(Flow)) {
  77. Own(flow, Input);
  78. DependsOn(flow, Output);
  79. }
  80. Input->AddDependence(Output->GetSource());
  81. }
  82. IComputationNode* const Flow;
  83. IComputationExternalNode* const Input;
  84. IComputationNode* const Output;
  85. };
  86. class TFlowFlatMapWideWrapper : public TStatefulWideFlowCodegeneratorNode<TFlowFlatMapWideWrapper> {
  87. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TFlowFlatMapWideWrapper>;
  88. public:
  89. TFlowFlatMapWideWrapper(TComputationMutables& mutables, IComputationNode* flow, IComputationExternalNode* input, IComputationWideFlowNode* output)
  90. : TBaseComputation(mutables, flow, EValueRepresentation::Embedded), Flow(flow), Input(input), Output(output)
  91. {}
  92. EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  93. if (state.IsInvalid()) {
  94. if (auto item = Flow->GetValue(ctx); item.IsSpecial()) {
  95. return item.IsFinish() ? EFetchResult::Finish : EFetchResult::Yield;
  96. } else {
  97. state = NUdf::TUnboxedValuePod();
  98. Input->SetValue(ctx, std::move(item));
  99. }
  100. }
  101. while (true) {
  102. if (const auto result = Output->FetchValues(ctx, output); EFetchResult::Finish != result)
  103. return result;
  104. else if (auto item = Flow->GetValue(ctx); item.IsSpecial()) {
  105. state = NUdf::TUnboxedValuePod::Invalid();
  106. return item.IsFinish() ? EFetchResult::Finish : EFetchResult::Yield;
  107. } else {
  108. state = NUdf::TUnboxedValuePod();
  109. Input->SetValue(ctx, std::move(item));
  110. }
  111. }
  112. }
  113. #ifndef MKQL_DISABLE_CODEGEN
  114. TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  115. const auto codegenInput = dynamic_cast<ICodegeneratorExternalNode*>(Input);
  116. MKQL_ENSURE(codegenInput, "Input must be codegenerator node.");
  117. auto& context = ctx.Codegen.GetContext();
  118. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  119. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  120. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  121. const auto step = BasicBlock::Create(context, "step", ctx.Func);
  122. const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
  123. const auto resultType = Type::getInt32Ty(context);
  124. const auto result = PHINode::Create(resultType, 2U, "result", exit);
  125. const auto state = new LoadInst(Type::getInt128Ty(context), statePtr, "state", block);
  126. const auto reset = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, state, GetInvalid(context), "reset", block);
  127. BranchInst::Create(init, work, reset, block);
  128. block = init;
  129. const auto item = GetNodeValue(Flow, ctx, block);
  130. const auto outres = SelectInst::Create(IsFinish(item, block, context), ConstantInt::get(resultType, i32(EFetchResult::Finish)), ConstantInt::get(resultType, i32(EFetchResult::Yield)), "outres", block);
  131. result->addIncoming(outres, block);
  132. BranchInst::Create(exit, next, IsSpecial(item, block, context), block);
  133. block = next;
  134. new StoreInst(GetEmpty(context), statePtr, block);
  135. codegenInput->CreateSetValue(ctx, block, item);
  136. BranchInst::Create(work, block);
  137. block = work;
  138. auto output = GetNodeValues(Output, ctx, block);
  139. const auto finish = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLT, output.first, ConstantInt::get(resultType, 0), "finish", block);
  140. result->addIncoming(output.first, block);
  141. BranchInst::Create(step, exit, finish, block);
  142. block = step;
  143. new StoreInst(GetInvalid(context), statePtr, block);
  144. BranchInst::Create(init, block);
  145. block = exit;
  146. return {result, std::move(output.second)};
  147. }
  148. #endif
  149. private:
  150. void RegisterDependencies() const final {
  151. if (const auto flow = FlowDependsOn(Flow)) {
  152. Own(flow, Input);
  153. DependsOn(flow, Output);
  154. }
  155. Input->AddDependence(Output->GetSource());
  156. }
  157. IComputationNode* const Flow;
  158. IComputationExternalNode* const Input;
  159. IComputationWideFlowNode* const Output;
  160. };
  161. class TListFlatMapFlowWrapper : public TStatefulFlowCodegeneratorNode<TListFlatMapFlowWrapper> {
  162. using TBaseComputation = TStatefulFlowCodegeneratorNode<TListFlatMapFlowWrapper>;
  163. public:
  164. TListFlatMapFlowWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* list, IComputationExternalNode* input, IComputationNode* output)
  165. : TBaseComputation(mutables, output, kind, EValueRepresentation::Boxed), List(list), Input(input), Output(output)
  166. {}
  167. NUdf::TUnboxedValuePod DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  168. if (state.IsInvalid()) {
  169. state = List->GetValue(ctx).GetListIterator();
  170. if (!state.Next(Input->RefValue(ctx))) {
  171. state = NUdf::TUnboxedValuePod::MakeFinish();
  172. }
  173. }
  174. if (state.IsFinish()) {
  175. return NUdf::TUnboxedValuePod::MakeFinish();
  176. }
  177. while (true) {
  178. if (auto output = Output->GetValue(ctx); output.IsFinish()) {
  179. if (state.Next(Input->RefValue(ctx))) {
  180. continue;
  181. }
  182. return state = NUdf::TUnboxedValuePod::MakeFinish();
  183. } else {
  184. return output.Release();
  185. }
  186. }
  187. }
  188. #ifndef MKQL_DISABLE_CODEGEN
  189. Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  190. const auto codegenInput = dynamic_cast<ICodegeneratorExternalNode*>(Input);
  191. MKQL_ENSURE(codegenInput, "Input must be codegenerator node.");
  192. auto& context = ctx.Codegen.GetContext();
  193. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  194. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  195. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  196. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  197. const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
  198. const auto valueType = Type::getInt128Ty(context);
  199. const auto result = PHINode::Create(valueType, 3U, "result", exit);
  200. result->addIncoming(GetFinish(context), block);
  201. const auto state = new LoadInst(valueType, statePtr, "state", block);
  202. const auto choise = SwitchInst::Create(state, work, 2U, block);
  203. choise->addCase(GetInvalid(context), init);
  204. choise->addCase(GetFinish(context), exit);
  205. block = init;
  206. const auto list = GetNodeValue(List, ctx, block);
  207. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetListIterator>(statePtr, list, ctx.Codegen, block);
  208. if (List->IsTemporaryValue()) {
  209. CleanupBoxed(list, ctx, block);
  210. }
  211. BranchInst::Create(next, block);
  212. block = next;
  213. const auto iterator = new LoadInst(valueType, statePtr, "iterator", block);
  214. const auto itemPtr = codegenInput->CreateRefValue(ctx, block);
  215. const auto status = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Next>(Type::getInt1Ty(context), iterator, ctx.Codegen, block, itemPtr);
  216. BranchInst::Create(work, done, status, block);
  217. block = work;
  218. const auto output = GetNodeValue(Output, ctx, block);
  219. result->addIncoming(output, block);
  220. BranchInst::Create(next, exit, IsFinish(output, block, context), block);
  221. block = done;
  222. UnRefBoxed(iterator, ctx, block);
  223. new StoreInst(GetFinish(context), statePtr, block);
  224. result->addIncoming(GetFinish(context), block);
  225. BranchInst::Create(exit, block);
  226. block = exit;
  227. return result;
  228. }
  229. #endif
  230. private:
  231. void RegisterDependencies() const final {
  232. if (const auto flow = FlowDependsOn(List)) {
  233. Own(flow, Input);
  234. DependsOn(flow, Output);
  235. }
  236. Input->AddDependence(Output->GetSource());
  237. }
  238. IComputationNode* const List;
  239. IComputationExternalNode* const Input;
  240. IComputationNode* const Output;
  241. };
  242. class TListFlatMapWideWrapper : public TStatefulWideFlowCodegeneratorNode<TListFlatMapWideWrapper> {
  243. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TListFlatMapWideWrapper>;
  244. public:
  245. TListFlatMapWideWrapper(TComputationMutables& mutables, IComputationNode* list, IComputationExternalNode* input, IComputationWideFlowNode* output)
  246. : TBaseComputation(mutables, output, EValueRepresentation::Boxed), List(list), Input(input), Output(output)
  247. {}
  248. EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  249. if (state.IsInvalid()) {
  250. state = List->GetValue(ctx).GetListIterator();
  251. if (!state.Next(Input->RefValue(ctx))) {
  252. state = NUdf::TUnboxedValuePod::MakeFinish();
  253. }
  254. }
  255. if (state.IsFinish()) {
  256. return EFetchResult::Finish;
  257. }
  258. while (true) {
  259. if (const auto result = Output->FetchValues(ctx, output); EFetchResult::Finish != result)
  260. return result;
  261. else if (state.Next(Input->RefValue(ctx)))
  262. continue;
  263. state = NUdf::TUnboxedValuePod::MakeFinish();
  264. return EFetchResult::Finish;
  265. }
  266. }
  267. #ifndef MKQL_DISABLE_CODEGEN
  268. TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  269. const auto codegenInput = dynamic_cast<ICodegeneratorExternalNode*>(Input);
  270. MKQL_ENSURE(codegenInput, "Input must be codegenerator node.");
  271. auto& context = ctx.Codegen.GetContext();
  272. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  273. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  274. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  275. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  276. const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
  277. const auto resultType = Type::getInt32Ty(context);
  278. const auto result = PHINode::Create(resultType, 3U, "result", exit);
  279. result->addIncoming(ConstantInt::get(resultType, i32(EFetchResult::Finish)), block);
  280. const auto valueType = Type::getInt128Ty(context);
  281. const auto state = new LoadInst(valueType, statePtr, "state", block);
  282. const auto choise = SwitchInst::Create(state, work, 2U, block);
  283. choise->addCase(GetInvalid(context), init);
  284. choise->addCase(GetFinish(context), exit);
  285. block = init;
  286. const auto list = GetNodeValue(List, ctx, block);
  287. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetListIterator>(statePtr, list, ctx.Codegen, block);
  288. if (List->IsTemporaryValue()) {
  289. CleanupBoxed(list, ctx, block);
  290. }
  291. BranchInst::Create(next, block);
  292. block = next;
  293. const auto iterator = new LoadInst(valueType, statePtr, "iterator", block);
  294. const auto itemPtr = codegenInput->CreateRefValue(ctx, block);
  295. const auto status = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Next>(Type::getInt1Ty(context), iterator, ctx.Codegen, block, itemPtr);
  296. BranchInst::Create(work, done, status, block);
  297. block = work;
  298. auto output = GetNodeValues(Output, ctx, block);
  299. const auto finish = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLT, output.first, ConstantInt::get(resultType, 0), "finish", block);
  300. result->addIncoming(output.first, block);
  301. BranchInst::Create(next, exit, finish, block);
  302. block = done;
  303. UnRefBoxed(iterator, ctx, block);
  304. new StoreInst(GetFinish(context), statePtr, block);
  305. result->addIncoming(ConstantInt::get(resultType, i32(EFetchResult::Finish)), block);
  306. BranchInst::Create(exit, block);
  307. block = exit;
  308. return {result, std::move(output.second)};
  309. }
  310. #endif
  311. private:
  312. void RegisterDependencies() const final {
  313. if (const auto flow = FlowDependsOn(List)) {
  314. Own(flow, Input);
  315. DependsOn(flow, Output);
  316. }
  317. Input->AddDependence(Output->GetSource());
  318. }
  319. IComputationNode* const List;
  320. IComputationExternalNode* const Input;
  321. IComputationWideFlowNode* const Output;
  322. };
  323. class TNarrowFlatMapFlowWrapper : public TStatefulFlowCodegeneratorNode<TNarrowFlatMapFlowWrapper> {
  324. using TBaseComputation = TStatefulFlowCodegeneratorNode<TNarrowFlatMapFlowWrapper>;
  325. public:
  326. TNarrowFlatMapFlowWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationWideFlowNode* flow, TComputationExternalNodePtrVector&& items, IComputationNode* output)
  327. : TBaseComputation(mutables, flow, kind, EValueRepresentation::Embedded)
  328. , Flow(flow)
  329. , Items(std::move(items))
  330. , Output(output)
  331. , WideFieldsIndex(mutables.IncrementWideFieldsIndex(Items.size()))
  332. {}
  333. NUdf::TUnboxedValuePod DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  334. auto** fields = ctx.WideFields.data() + WideFieldsIndex;
  335. if (state.IsInvalid()) {
  336. for (auto i = 0U; i < Items.size(); ++i)
  337. if (Items[i]->GetDependencesCount() > 0U)
  338. fields[i] = &Items[i]->RefValue(ctx);
  339. switch (Flow->FetchValues(ctx, fields)) {
  340. case EFetchResult::Finish:
  341. return NUdf::TUnboxedValuePod::MakeFinish();
  342. case EFetchResult::Yield:
  343. return NUdf::TUnboxedValuePod::MakeYield();
  344. default:
  345. state = NUdf::TUnboxedValuePod();
  346. }
  347. }
  348. while (true) {
  349. if (auto output = Output->GetValue(ctx); output.IsFinish()) {
  350. for (auto i = 0U; i < Items.size(); ++i)
  351. if (Items[i]->GetDependencesCount() > 0U)
  352. fields[i] = &Items[i]->RefValue(ctx);
  353. switch (Flow->FetchValues(ctx, fields)) {
  354. case EFetchResult::Finish:
  355. return NUdf::TUnboxedValuePod::MakeFinish();
  356. case EFetchResult::Yield:
  357. return NUdf::TUnboxedValuePod::MakeYield();
  358. default:
  359. state = NUdf::TUnboxedValuePod();
  360. }
  361. } else {
  362. return output.Release();
  363. }
  364. }
  365. }
  366. #ifndef MKQL_DISABLE_CODEGEN
  367. Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  368. auto& context = ctx.Codegen.GetContext();
  369. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  370. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  371. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  372. const auto step = BasicBlock::Create(context, "step", ctx.Func);
  373. const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
  374. const auto valueType = Type::getInt128Ty(context);
  375. const auto result = PHINode::Create(valueType, 2U, "result", exit);
  376. const auto state = new LoadInst(valueType, statePtr, "state", block);
  377. const auto reset = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, state, GetInvalid(context), "reset", block);
  378. BranchInst::Create(init, work, reset, block);
  379. block = init;
  380. const auto getres = GetNodeValues(Flow, ctx, block);
  381. const auto yield = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, getres.first, ConstantInt::get(getres.first->getType(), 0), "yield", block);
  382. const auto good = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SGT, getres.first, ConstantInt::get(getres.first->getType(), 0), "good", block);
  383. const auto outres = SelectInst::Create(yield, GetYield(context), GetFinish(context), "outres", block);
  384. result->addIncoming(outres, block);
  385. BranchInst::Create(next, exit, good, block);
  386. block = next;
  387. new StoreInst(GetEmpty(context), statePtr, block);
  388. for (auto i = 0U; i < Items.size(); ++i)
  389. if (Items[i]->GetDependencesCount() > 0U)
  390. EnsureDynamicCast<ICodegeneratorExternalNode*>(Items[i])->CreateSetValue(ctx, block, getres.second[i](ctx, block));
  391. BranchInst::Create(work, block);
  392. block = work;
  393. const auto output = GetNodeValue(Output, ctx, block);
  394. const auto finish = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, output, GetFinish(context), "finish", block);
  395. result->addIncoming(output, block);
  396. BranchInst::Create(step, exit, finish, block);
  397. block = step;
  398. new StoreInst(GetInvalid(context), statePtr, block);
  399. BranchInst::Create(init, block);
  400. block = exit;
  401. return result;
  402. }
  403. #endif
  404. private:
  405. void RegisterDependencies() const final {
  406. if (const auto flow = FlowDependsOn(Flow)) {
  407. std::for_each(Items.cbegin(), Items.cend(), std::bind(&TNarrowFlatMapFlowWrapper::Own, flow, std::placeholders::_1));
  408. DependsOn(flow, Output);
  409. }
  410. std::for_each(Items.cbegin(), Items.cend(), std::bind(&IComputationNode::AddDependence, std::placeholders::_1, Output->GetSource()));
  411. }
  412. IComputationWideFlowNode* const Flow;
  413. const TComputationExternalNodePtrVector Items;
  414. IComputationNode* const Output;
  415. const ui32 WideFieldsIndex;
  416. };
  417. template <bool IsMultiRowPerItem, bool ResultContainerOpt>
  418. class TFlowFlatMapWrapper : public std::conditional_t<IsMultiRowPerItem,
  419. TStatefulFlowCodegeneratorNode<TFlowFlatMapWrapper<IsMultiRowPerItem, ResultContainerOpt>>,
  420. TStatelessFlowCodegeneratorNode<TFlowFlatMapWrapper<IsMultiRowPerItem, ResultContainerOpt>>> {
  421. using TBaseComputation = std::conditional_t<IsMultiRowPerItem,
  422. TStatefulFlowCodegeneratorNode<TFlowFlatMapWrapper<IsMultiRowPerItem, ResultContainerOpt>>,
  423. TStatelessFlowCodegeneratorNode<TFlowFlatMapWrapper<IsMultiRowPerItem, ResultContainerOpt>>>;
  424. public:
  425. TFlowFlatMapWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* flow, IComputationExternalNode* item, IComputationNode* newItem)
  426. : TBaseComputation(mutables, flow, kind), Flow(flow), Item(item), NewItem(newItem)
  427. {}
  428. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  429. while (true) {
  430. if (auto item = Flow->GetValue(ctx); item.IsSpecial()) {
  431. return item.Release();
  432. } else {
  433. Item->SetValue(ctx, std::move(item));
  434. }
  435. if (auto newItem = NewItem->GetValue(ctx)) {
  436. return newItem.Release().GetOptionalValueIf<!IsMultiRowPerItem && ResultContainerOpt>();
  437. }
  438. }
  439. }
  440. NUdf::TUnboxedValuePod DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  441. while (!state.IsFinish()) {
  442. if (state.HasValue()) {
  443. if constexpr (ResultContainerOpt) {
  444. switch (NUdf::TUnboxedValue result; state.Fetch(result)) {
  445. case NUdf::EFetchStatus::Finish: break;
  446. case NUdf::EFetchStatus::Yield: return NUdf::TUnboxedValuePod::MakeYield();
  447. case NUdf::EFetchStatus::Ok: return result.Release();
  448. }
  449. } else if (NUdf::TUnboxedValue result; state.Next(result)) {
  450. return result.Release();
  451. }
  452. state.Clear();
  453. }
  454. NUdf::TUnboxedValue item = DoCalculate(ctx);
  455. if (item.IsSpecial()) {
  456. return item.Release();
  457. } else {
  458. state = ResultContainerOpt ? std::move(item) : item.GetListIterator();
  459. }
  460. }
  461. return state;
  462. }
  463. #ifndef MKQL_DISABLE_CODEGEN
  464. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  465. auto& context = ctx.Codegen.GetContext();
  466. const auto codegenItem = dynamic_cast<ICodegeneratorExternalNode*>(Item);
  467. MKQL_ENSURE(codegenItem, "Item must be codegenerator node.");
  468. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  469. BranchInst::Create(loop, block);
  470. block = loop;
  471. const auto item = GetNodeValue(Flow, ctx, block);
  472. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  473. const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
  474. const auto result = PHINode::Create(item->getType(), 2, "result", exit);
  475. result->addIncoming(item, block);
  476. BranchInst::Create(exit, work, IsSpecial(item, block, context), block);
  477. block = work;
  478. codegenItem->CreateSetValue(ctx, block, item);
  479. const auto value = GetNodeValue(NewItem, ctx, block);
  480. result->addIncoming(!IsMultiRowPerItem && ResultContainerOpt ? GetOptionalValue(context, value, block) : value, block);
  481. BranchInst::Create(loop, exit, IsEmpty(value, block, context), block);
  482. block = exit;
  483. return result;
  484. }
  485. Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* currentPtr, BasicBlock*& block) const {
  486. auto& context = ctx.Codegen.GetContext();
  487. const auto statusType = Type::getInt32Ty(context);
  488. const auto valueType = Type::getInt128Ty(context);
  489. const auto valuePtr = new AllocaInst(valueType, 0U, "value_ptr", &ctx.Func->getEntryBlock().back());
  490. new StoreInst(ConstantInt::get(valueType, 0), valuePtr, block);
  491. const auto more = BasicBlock::Create(context, "more", ctx.Func);
  492. const auto pull = BasicBlock::Create(context, "pull", ctx.Func);
  493. const auto skip = BasicBlock::Create(context, "skip", ctx.Func);
  494. const auto over = BasicBlock::Create(context, "over", ctx.Func);
  495. const auto result = PHINode::Create(valueType, ResultContainerOpt ? 3U : 2U, "result", over);
  496. BranchInst::Create(more, block);
  497. block = more;
  498. const auto current = new LoadInst(valueType, currentPtr, "current", block);
  499. BranchInst::Create(pull, skip, HasValue(current, block, context), block);
  500. {
  501. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  502. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  503. block = pull;
  504. if constexpr (ResultContainerOpt) {
  505. const auto status = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Fetch>(statusType, current, ctx.Codegen, block, valuePtr);
  506. result->addIncoming(GetYield(context), block);
  507. const auto choise = SwitchInst::Create(status, good, 2U, block);
  508. choise->addCase(ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Yield)), over);
  509. choise->addCase(ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Finish)), next);
  510. } else {
  511. const auto status = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Next>(Type::getInt1Ty(context), current, ctx.Codegen, block, valuePtr);
  512. BranchInst::Create(good, next, status, block);
  513. }
  514. block = good;
  515. const auto value = new LoadInst(valueType, valuePtr, "value", block);
  516. ValueRelease(static_cast<const IComputationNode*>(this)->GetRepresentation(), value, ctx, block);
  517. result->addIncoming(value, block);
  518. BranchInst::Create(over, block);
  519. block = next;
  520. UnRefBoxed(current, ctx, block);
  521. new StoreInst(ConstantInt::get(current->getType(), 0), currentPtr, block);
  522. BranchInst::Create(skip, block);
  523. }
  524. {
  525. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  526. block = skip;
  527. const auto list = DoGenerateGetValue(ctx, block);
  528. result->addIncoming(list, block);
  529. BranchInst::Create(over, good, IsSpecial(list, block, context), block);
  530. block = good;
  531. if constexpr (ResultContainerOpt) {
  532. new StoreInst(list, currentPtr, block);
  533. AddRefBoxed(list, ctx, block);
  534. } else {
  535. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetListIterator>(currentPtr, list, ctx.Codegen, block);
  536. if (NewItem->IsTemporaryValue()) {
  537. CleanupBoxed(list, ctx, block);
  538. }
  539. }
  540. BranchInst::Create(more, block);
  541. }
  542. block = over;
  543. return result;
  544. }
  545. #endif
  546. private:
  547. void RegisterDependencies() const final {
  548. if (const auto flow = this->FlowDependsOn(this->Flow)) {
  549. this->Own(flow, this->Item);
  550. this->DependsOn(flow, this->NewItem);
  551. }
  552. }
  553. IComputationNode* const Flow;
  554. IComputationExternalNode* const Item;
  555. IComputationNode* const NewItem;
  556. };
  557. template <bool IsMultiRowPerItem, bool ResultContainerOpt>
  558. class TNarrowFlatMapWrapper : public std::conditional_t<IsMultiRowPerItem,
  559. TStatefulFlowCodegeneratorNode<TNarrowFlatMapWrapper<IsMultiRowPerItem, ResultContainerOpt>>,
  560. TStatelessFlowCodegeneratorNode<TNarrowFlatMapWrapper<IsMultiRowPerItem, ResultContainerOpt>>> {
  561. using TBaseComputation = std::conditional_t<IsMultiRowPerItem,
  562. TStatefulFlowCodegeneratorNode<TNarrowFlatMapWrapper<IsMultiRowPerItem, ResultContainerOpt>>,
  563. TStatelessFlowCodegeneratorNode<TNarrowFlatMapWrapper<IsMultiRowPerItem, ResultContainerOpt>>>;
  564. public:
  565. TNarrowFlatMapWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationWideFlowNode* flow, const TComputationExternalNodePtrVector&& items, IComputationNode* newItem)
  566. : TBaseComputation(mutables, flow, kind)
  567. , Flow(flow)
  568. , Items(std::move(items))
  569. , NewItem(newItem)
  570. , PasstroughItem(GetPasstroughtMap(TComputationNodePtrVector{NewItem}, Items).front())
  571. , WideFieldsIndex(mutables.IncrementWideFieldsIndex(Items.size()))
  572. {}
  573. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  574. auto** fields = ctx.WideFields.data() + WideFieldsIndex;
  575. while (true) {
  576. for (auto i = 0U; i < Items.size(); ++i)
  577. if (NewItem == Items[i] || Items[i]->GetDependencesCount() > 0U)
  578. fields[i] = &Items[i]->RefValue(ctx);
  579. switch (const auto result = Flow->FetchValues(ctx, fields)) {
  580. case EFetchResult::Finish:
  581. return NUdf::TUnboxedValuePod::MakeFinish();
  582. case EFetchResult::Yield:
  583. return NUdf::TUnboxedValuePod::MakeYield();
  584. case EFetchResult::One:
  585. break;
  586. }
  587. if (auto newItem = NewItem->GetValue(ctx)) {
  588. return newItem.Release().GetOptionalValueIf<!IsMultiRowPerItem && ResultContainerOpt>();
  589. }
  590. }
  591. }
  592. NUdf::TUnboxedValuePod DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  593. while (!state.IsFinish()) {
  594. if (state.HasValue()) {
  595. if constexpr (ResultContainerOpt) {
  596. switch (NUdf::TUnboxedValue result; state.Fetch(result)) {
  597. case NUdf::EFetchStatus::Finish: break;
  598. case NUdf::EFetchStatus::Yield: return NUdf::TUnboxedValuePod::MakeYield();
  599. case NUdf::EFetchStatus::Ok: return result.Release();
  600. }
  601. } else if (NUdf::TUnboxedValue result; state.Next(result)) {
  602. return result.Release();
  603. }
  604. state.Clear();
  605. }
  606. NUdf::TUnboxedValue item = DoCalculate(ctx);
  607. if (item.IsSpecial()) {
  608. return item.Release();
  609. } else {
  610. state = ResultContainerOpt ? std::move(item) : item.GetListIterator();
  611. }
  612. }
  613. return state;
  614. }
  615. #ifndef MKQL_DISABLE_CODEGEN
  616. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  617. auto& context = ctx.Codegen.GetContext();
  618. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  619. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  620. const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
  621. const auto result = PHINode::Create(Type::getInt128Ty(context), 2, "result", exit);
  622. BranchInst::Create(loop, block);
  623. block = loop;
  624. const auto getres = GetNodeValues(Flow, ctx, block);
  625. const auto yield = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, getres.first, ConstantInt::get(getres.first->getType(), 0), "yield", block);
  626. const auto good = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SGT, getres.first, ConstantInt::get(getres.first->getType(), 0), "good", block);
  627. const auto outres = SelectInst::Create(yield, GetYield(context), GetFinish(context), "outres", block);
  628. result->addIncoming(outres, block);
  629. BranchInst::Create(work, exit, good, block);
  630. block = work;
  631. Value* value = nullptr;
  632. if (const auto passtrough = PasstroughItem) {
  633. value = getres.second[*passtrough](ctx, block);
  634. } else {
  635. for (auto i = 0U; i < Items.size(); ++i)
  636. if (Items[i]->GetDependencesCount() > 0U)
  637. EnsureDynamicCast<ICodegeneratorExternalNode*>(Items[i])->CreateSetValue(ctx, block, getres.second[i](ctx, block));
  638. value = GetNodeValue(NewItem, ctx, block);
  639. }
  640. result->addIncoming(!IsMultiRowPerItem && ResultContainerOpt ? GetOptionalValue(context, value, block) : value, block);
  641. BranchInst::Create(loop, exit, IsEmpty(value, block, context), block);
  642. block = exit;
  643. return result;
  644. }
  645. Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* currentPtr, BasicBlock*& block) const {
  646. auto& context = ctx.Codegen.GetContext();
  647. const auto statusType = Type::getInt32Ty(context);
  648. const auto valueType = Type::getInt128Ty(context);
  649. const auto valuePtr = new AllocaInst(valueType, 0U, "value_ptr", &ctx.Func->getEntryBlock().back());
  650. new StoreInst(ConstantInt::get(valueType, 0), valuePtr, block);
  651. const auto more = BasicBlock::Create(context, "more", ctx.Func);
  652. const auto pull = BasicBlock::Create(context, "pull", ctx.Func);
  653. const auto skip = BasicBlock::Create(context, "skip", ctx.Func);
  654. const auto over = BasicBlock::Create(context, "over", ctx.Func);
  655. const auto result = PHINode::Create(valueType, ResultContainerOpt ? 3U : 2U, "result", over);
  656. BranchInst::Create(more, block);
  657. block = more;
  658. const auto current = new LoadInst(valueType, currentPtr, "current", block);
  659. BranchInst::Create(pull, skip, HasValue(current, block, context), block);
  660. {
  661. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  662. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  663. block = pull;
  664. if constexpr (ResultContainerOpt) {
  665. const auto status = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Fetch>(statusType, current, ctx.Codegen, block, valuePtr);
  666. result->addIncoming(GetYield(context), block);
  667. const auto choise = SwitchInst::Create(status, good, 2U, block);
  668. choise->addCase(ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Yield)), over);
  669. choise->addCase(ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Finish)), next);
  670. } else {
  671. const auto status = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Next>(Type::getInt1Ty(context), current, ctx.Codegen, block, valuePtr);
  672. BranchInst::Create(good, next, status, block);
  673. }
  674. block = good;
  675. const auto value = new LoadInst(valueType, valuePtr, "value", block);
  676. ValueRelease(static_cast<const IComputationNode*>(this)->GetRepresentation(), value, ctx, block);
  677. result->addIncoming(value, block);
  678. BranchInst::Create(over, block);
  679. block = next;
  680. UnRefBoxed(current, ctx, block);
  681. new StoreInst(ConstantInt::get(current->getType(), 0), currentPtr, block);
  682. BranchInst::Create(skip, block);
  683. }
  684. {
  685. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  686. block = skip;
  687. const auto list = DoGenerateGetValue(ctx, block);
  688. result->addIncoming(list, block);
  689. BranchInst::Create(over, good, IsSpecial(list, block, context), block);
  690. block = good;
  691. if constexpr (ResultContainerOpt) {
  692. new StoreInst(list, currentPtr, block);
  693. AddRefBoxed(list, ctx, block);
  694. } else {
  695. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetListIterator>(currentPtr, list, ctx.Codegen, block);
  696. if (NewItem->IsTemporaryValue()) {
  697. CleanupBoxed(list, ctx, block);
  698. }
  699. }
  700. BranchInst::Create(more, block);
  701. }
  702. block = over;
  703. return result;
  704. }
  705. #endif
  706. private:
  707. void RegisterDependencies() const final {
  708. if (const auto flow = this->FlowDependsOn(Flow)) {
  709. for (const auto& item : this->Items)
  710. this->Own(flow, item);
  711. this->DependsOn(flow, this->NewItem);
  712. }
  713. }
  714. IComputationWideFlowNode* const Flow;
  715. const TComputationExternalNodePtrVector Items;
  716. IComputationNode* const NewItem;
  717. const std::optional<size_t> PasstroughItem;
  718. const ui32 WideFieldsIndex;
  719. };
  720. template <bool MultiOptional>
  721. class TSimpleListValue : public TCustomListValue {
  722. public:
  723. class TIterator : public TComputationValue<TIterator> {
  724. public:
  725. TIterator(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, NUdf::TUnboxedValue&& iter, IComputationExternalNode* item, IComputationNode* newItem)
  726. : TComputationValue<TIterator>(memInfo)
  727. , CompCtx(compCtx)
  728. , Iter(std::move(iter))
  729. , Item(item)
  730. , NewItem(newItem)
  731. {}
  732. private:
  733. bool Next(NUdf::TUnboxedValue& value) final {
  734. for (;;) {
  735. if (!Iter.Next(Item->RefValue(CompCtx))) {
  736. return false;
  737. }
  738. if (auto newItem = NewItem->GetValue(CompCtx)) {
  739. value = newItem.Release().template GetOptionalValueIf<MultiOptional>();
  740. return true;
  741. }
  742. }
  743. }
  744. TComputationContext& CompCtx;
  745. const NUdf::TUnboxedValue Iter;
  746. IComputationExternalNode* const Item;
  747. IComputationNode* const NewItem;
  748. };
  749. TSimpleListValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, NUdf::TUnboxedValue&& list, IComputationExternalNode* item, IComputationNode* newItem)
  750. : TCustomListValue(memInfo)
  751. , CompCtx(compCtx)
  752. , List(std::move(list))
  753. , Item(item)
  754. , NewItem(newItem)
  755. {
  756. }
  757. private:
  758. NUdf::TUnboxedValue GetListIterator() const final {
  759. return CompCtx.HolderFactory.Create<TIterator>(CompCtx, List.GetListIterator(), Item, NewItem);
  760. }
  761. TComputationContext& CompCtx;
  762. const NUdf::TUnboxedValue List;
  763. IComputationExternalNode* const Item;
  764. IComputationNode* const NewItem;
  765. };
  766. template <bool MultiOptional>
  767. class TSimpleStreamValue : public TComputationValue<TSimpleStreamValue<MultiOptional>> {
  768. public:
  769. using TBase = TComputationValue<TSimpleStreamValue>;
  770. TSimpleStreamValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, NUdf::TUnboxedValue&& stream, IComputationExternalNode* item, IComputationNode* newItem)
  771. : TBase(memInfo)
  772. , CompCtx(compCtx)
  773. , Stream(std::move(stream))
  774. , Item(item)
  775. , NewItem(newItem)
  776. {}
  777. private:
  778. ui32 GetTraverseCount() const override {
  779. return 1;
  780. }
  781. NUdf::TUnboxedValue GetTraverseItem(ui32 index) const override {
  782. Y_UNUSED(index);
  783. return Stream;
  784. }
  785. NUdf::TUnboxedValue Save() const override {
  786. return NUdf::TUnboxedValue::Zero();
  787. }
  788. void Load(const NUdf::TStringRef& state) override {
  789. Y_UNUSED(state);
  790. }
  791. NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) final {
  792. for (;;) {
  793. const auto status = Stream.Fetch(Item->RefValue(CompCtx));
  794. if (NUdf::EFetchStatus::Ok != status) {
  795. return status;
  796. }
  797. if (auto newItem = NewItem->GetValue(CompCtx)) {
  798. result = newItem.Release().template GetOptionalValueIf<MultiOptional>();
  799. return NUdf::EFetchStatus::Ok;
  800. }
  801. }
  802. }
  803. private:
  804. TComputationContext& CompCtx;
  805. const NUdf::TUnboxedValue Stream;
  806. IComputationExternalNode* const Item;
  807. IComputationNode* const NewItem;
  808. };
  809. template <bool IsNewStream>
  810. class TListValue : public TCustomListValue {
  811. public:
  812. class TIterator : public TComputationValue<TIterator> {
  813. public:
  814. TIterator(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, NUdf::TUnboxedValue&& iter, IComputationExternalNode* item, IComputationNode* newItem)
  815. : TComputationValue<TIterator>(memInfo)
  816. , CompCtx(compCtx)
  817. , Iter(std::move(iter))
  818. , Item(item)
  819. , NewItem(newItem)
  820. {}
  821. private:
  822. bool Next(NUdf::TUnboxedValue& value) final {
  823. for (NUdf::TUnboxedValue current = std::move(Current);; current.Clear()) {
  824. if (!current) {
  825. if (Iter.Next(Item->RefValue(CompCtx))) {
  826. current = IsNewStream ? NewItem->GetValue(CompCtx) : NewItem->GetValue(CompCtx).GetListIterator();
  827. } else {
  828. return false;
  829. }
  830. }
  831. if constexpr (IsNewStream) {
  832. const auto status = current.Fetch(value);
  833. MKQL_ENSURE(status != NUdf::EFetchStatus::Yield, "Unexpected stream status");
  834. if (NUdf::EFetchStatus::Finish == status) {
  835. continue;
  836. }
  837. } else {
  838. if (!current.Next(value)) {
  839. continue;
  840. }
  841. }
  842. Current = std::move(current);
  843. return true;
  844. }
  845. }
  846. TComputationContext& CompCtx;
  847. const NUdf::TUnboxedValue Iter;
  848. IComputationExternalNode* const Item;
  849. IComputationNode* const NewItem;
  850. NUdf::TUnboxedValue Current;
  851. };
  852. TListValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, NUdf::TUnboxedValue&& list, IComputationExternalNode* item, IComputationNode* newItem)
  853. : TCustomListValue(memInfo)
  854. , CompCtx(compCtx)
  855. , List(std::move(list))
  856. , Item(item)
  857. , NewItem(newItem)
  858. {}
  859. private:
  860. NUdf::TUnboxedValue GetListIterator() const final {
  861. return CompCtx.HolderFactory.Create<TIterator>(CompCtx, List.GetListIterator(), Item, NewItem);
  862. }
  863. TComputationContext& CompCtx;
  864. const NUdf::TUnboxedValue List;
  865. IComputationExternalNode* const Item;
  866. IComputationNode* const NewItem;
  867. };
  868. template <bool IsNewStream>
  869. class TStreamValue : public TComputationValue<TStreamValue<IsNewStream>> {
  870. public:
  871. using TBase = TComputationValue<TStreamValue<IsNewStream>>;
  872. TStreamValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, NUdf::TUnboxedValue&& stream, IComputationExternalNode* item, IComputationNode* newItem)
  873. : TBase(memInfo)
  874. , CompCtx(compCtx)
  875. , Stream(std::move(stream))
  876. , Item(item)
  877. , NewItem(newItem)
  878. {}
  879. private:
  880. ui32 GetTraverseCount() const override {
  881. return 1;
  882. }
  883. NUdf::TUnboxedValue GetTraverseItem(ui32 index) const override {
  884. Y_UNUSED(index);
  885. return Stream;
  886. }
  887. NUdf::TUnboxedValue Save() const override {
  888. return NUdf::TUnboxedValue::Zero();
  889. }
  890. void Load(const NUdf::TStringRef& state) override {
  891. Y_UNUSED(state);
  892. }
  893. NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) final {
  894. for (NUdf::TUnboxedValue current = std::move(Current);; current.Clear()) {
  895. if (!current) {
  896. const auto status = Stream.Fetch(Item->RefValue(CompCtx));
  897. if (NUdf::EFetchStatus::Ok != status) {
  898. return status;
  899. }
  900. current = IsNewStream ? NewItem->GetValue(CompCtx) : NewItem->GetValue(CompCtx).GetListIterator();
  901. }
  902. auto status = NUdf::EFetchStatus::Ok;
  903. if constexpr (IsNewStream) {
  904. status = current.Fetch(result);
  905. if (NUdf::EFetchStatus::Finish == status) {
  906. continue;
  907. }
  908. } else {
  909. if (!current.Next(result)) {
  910. continue;
  911. }
  912. }
  913. Current = std::move(current);
  914. return status;
  915. }
  916. }
  917. private:
  918. TComputationContext& CompCtx;
  919. const NUdf::TUnboxedValue Stream;
  920. IComputationExternalNode* const Item;
  921. IComputationNode* const NewItem;
  922. NUdf::TUnboxedValue Current;
  923. };
  924. template <bool IsInputStream, bool IsMultiRowPerItem, bool ResultContainerOpt>
  925. class TBaseFlatMapWrapper {
  926. protected:
  927. TBaseFlatMapWrapper(IComputationNode* list, IComputationExternalNode* item, IComputationNode* newItem)
  928. : List(list), Item(item), NewItem(newItem)
  929. {}
  930. #ifndef MKQL_DISABLE_CODEGEN
  931. using TCodegenValue = std::conditional_t<IsInputStream,
  932. typename std::conditional_t<IsMultiRowPerItem, TStreamCodegenStatefulValue, TStreamCodegenValueStateless>,
  933. typename std::conditional_t<IsMultiRowPerItem, TCustomListCodegenStatefulValue, TCustomListCodegenValue>>;
  934. Function* GenerateSimpleMapper(NYql::NCodegen::ICodegen& codegen, const TString& name) const {
  935. auto& module = codegen.GetModule();
  936. auto& context = codegen.GetContext();
  937. const auto codegenItem = dynamic_cast<ICodegeneratorExternalNode*>(Item);
  938. MKQL_ENSURE(codegenItem, "Item must be codegenerator node.");
  939. if (const auto f = module.getFunction(name.c_str()))
  940. return f;
  941. const auto valueType = Type::getInt128Ty(context);
  942. const auto containerType = static_cast<Type*>(valueType);
  943. const auto contextType = GetCompContextType(context);
  944. const auto statusType = IsInputStream ? Type::getInt32Ty(context) : Type::getInt1Ty(context);
  945. const auto funcType = FunctionType::get(statusType, {PointerType::getUnqual(contextType), containerType, PointerType::getUnqual(valueType)}, false);
  946. TCodegenContext ctx(codegen);
  947. ctx.Func = cast<Function>(module.getOrInsertFunction(name.c_str(), funcType).getCallee());
  948. DISubprogramAnnotator annotator(ctx, ctx.Func);
  949. auto args = ctx.Func->arg_begin();
  950. ctx.Ctx = &*args;
  951. const auto containerArg = &*++args;
  952. const auto valuePtr = &*++args;
  953. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  954. auto block = main;
  955. const auto container = static_cast<Value*>(containerArg);
  956. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  957. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  958. const auto pass = BasicBlock::Create(context, "pass", ctx.Func);
  959. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  960. BranchInst::Create(loop, block);
  961. block = loop;
  962. const auto itemPtr = codegenItem->CreateRefValue(ctx, block);
  963. const auto status = IsInputStream ?
  964. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Fetch>(statusType, container, codegen, block, itemPtr):
  965. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Next>(statusType, container, codegen, block, itemPtr);
  966. const auto icmp = IsInputStream ?
  967. CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, status, ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Ok)), "cond", block):
  968. status;
  969. BranchInst::Create(good, done, icmp, block);
  970. block = good;
  971. const auto resItem = GetNodeValue(NewItem, ctx, block);
  972. BranchInst::Create(loop, pass, IsEmpty(resItem, block, context), block);
  973. block = pass;
  974. SafeUnRefUnboxedOne(valuePtr, ctx, block);
  975. const auto getOpt = GetOptionalValue(context, resItem, block);
  976. new StoreInst(getOpt, valuePtr, block);
  977. ValueAddRef(NewItem->GetRepresentation(), valuePtr, ctx, block);
  978. BranchInst::Create(done, block);
  979. block = done;
  980. ReturnInst::Create(context, status, block);
  981. return ctx.Func;
  982. }
  983. Function* GenerateMapper(NYql::NCodegen::ICodegen& codegen, const TString& name) const {
  984. auto& module = codegen.GetModule();
  985. auto& context = codegen.GetContext();
  986. const auto codegenItem = dynamic_cast<ICodegeneratorExternalNode*>(Item);
  987. MKQL_ENSURE(codegenItem, "Item must be codegenerator node.");
  988. if (const auto f = module.getFunction(name.c_str()))
  989. return f;
  990. const auto valueType = Type::getInt128Ty(context);
  991. const auto containerType = static_cast<Type*>(valueType);
  992. const auto contextType = GetCompContextType(context);
  993. const auto statusType = IsInputStream ? Type::getInt32Ty(context) : Type::getInt1Ty(context);
  994. const auto stateType = ResultContainerOpt ? Type::getInt32Ty(context) : Type::getInt1Ty(context);
  995. const auto funcType = FunctionType::get(statusType, {PointerType::getUnqual(contextType), containerType, PointerType::getUnqual(valueType), PointerType::getUnqual(valueType)}, false);
  996. TCodegenContext ctx(codegen);
  997. ctx.Func = cast<Function>(module.getOrInsertFunction(name.c_str(), funcType).getCallee());
  998. DISubprogramAnnotator annotator(ctx, ctx.Func);
  999. auto args = ctx.Func->arg_begin();
  1000. ctx.Ctx = &*args;
  1001. const auto containerArg = &*++args;
  1002. const auto currentArg = &*++args;
  1003. const auto valuePtr = &*++args;
  1004. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  1005. auto block = main;
  1006. const auto container = static_cast<Value*>(containerArg);
  1007. const auto zero = ConstantInt::get(valueType, 0);
  1008. const auto init = new LoadInst(valueType, currentArg, "init", block);
  1009. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  1010. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  1011. const auto pass = BasicBlock::Create(context, "pass", ctx.Func);
  1012. const auto cont = BasicBlock::Create(context, "cont", ctx.Func);
  1013. const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
  1014. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  1015. const auto current = PHINode::Create(valueType, 2, "result", pass);
  1016. current->addIncoming(init, block);
  1017. const auto step = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, init, zero, "step", block);
  1018. BranchInst::Create(next, pass, step, block);
  1019. block = next;
  1020. const auto itemPtr = codegenItem->CreateRefValue(ctx, block);
  1021. const auto status = IsInputStream ?
  1022. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Fetch>(statusType, container, codegen, block, itemPtr):
  1023. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Next>(statusType, container, codegen, block, itemPtr);
  1024. const auto icmp = IsInputStream ?
  1025. CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, status, ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Ok)), "cond", block):
  1026. status;
  1027. BranchInst::Create(good, done, icmp, block);
  1028. block = good;
  1029. if constexpr (ResultContainerOpt) {
  1030. GetNodeValue(currentArg, NewItem, ctx, block);
  1031. } else {
  1032. const auto list = GetNodeValue(NewItem, ctx, block);
  1033. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetListIterator>(currentArg, list, codegen, block);
  1034. if (NewItem->IsTemporaryValue())
  1035. CleanupBoxed(list, ctx, block);
  1036. }
  1037. const auto iter = new LoadInst(valueType, currentArg, "iter", block);
  1038. current->addIncoming(iter, block);
  1039. BranchInst::Create(pass, block);
  1040. block = pass;
  1041. const auto state = ResultContainerOpt ?
  1042. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Fetch>(stateType, current, codegen, block, valuePtr):
  1043. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Next>(stateType, current, codegen, block, valuePtr);
  1044. const auto scmp = ResultContainerOpt ?
  1045. CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, state, ConstantInt::get(stateType, static_cast<ui32>(NUdf::EFetchStatus::Finish)), "scmp", block):
  1046. state;
  1047. BranchInst::Create(exit, cont, scmp, block);
  1048. block = cont;
  1049. UnRefBoxed(current, ctx, block);
  1050. BranchInst::Create(next, block);
  1051. block = exit;
  1052. ReturnInst::Create(context, IsInputStream ? (ResultContainerOpt ? state : ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Ok))) : ConstantInt::getTrue(context), block);
  1053. block = done;
  1054. new StoreInst(zero, currentArg, block);
  1055. ReturnInst::Create(context, status, block);
  1056. return ctx.Func;
  1057. }
  1058. using TFlatMapPtr = std::conditional_t<IsInputStream,
  1059. typename std::conditional_t<IsMultiRowPerItem, TStreamCodegenStatefulValue, TStreamCodegenValueStateless>::TFetchPtr,
  1060. typename std::conditional_t<IsMultiRowPerItem, TCustomListCodegenStatefulValue, TCustomListCodegenValue>::TNextPtr
  1061. >;
  1062. Function* FlatMapFunc = nullptr;
  1063. TFlatMapPtr FlatMap = nullptr;
  1064. #endif
  1065. IComputationNode* const List;
  1066. IComputationExternalNode* const Item;
  1067. IComputationNode* const NewItem;
  1068. };
  1069. template <bool IsMultiRowPerItem, bool ResultContainerOpt>
  1070. class TStreamFlatMapWrapper : public TCustomValueCodegeneratorNode<TStreamFlatMapWrapper<IsMultiRowPerItem, ResultContainerOpt>>,
  1071. private TBaseFlatMapWrapper<true, IsMultiRowPerItem, ResultContainerOpt> {
  1072. typedef TBaseFlatMapWrapper<true, IsMultiRowPerItem, ResultContainerOpt> TBaseWrapper;
  1073. typedef TCustomValueCodegeneratorNode<TStreamFlatMapWrapper<IsMultiRowPerItem, ResultContainerOpt>> TBaseComputation;
  1074. public:
  1075. TStreamFlatMapWrapper(TComputationMutables& mutables, IComputationNode* list, IComputationExternalNode* item, IComputationNode* newItem)
  1076. : TBaseComputation(mutables), TBaseWrapper(list, item, newItem)
  1077. {}
  1078. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  1079. #ifndef MKQL_DISABLE_CODEGEN
  1080. if (ctx.ExecuteLLVM && this->FlatMap)
  1081. return ctx.HolderFactory.Create<typename TBaseWrapper::TCodegenValue>(this->FlatMap, &ctx, this->List->GetValue(ctx));
  1082. #endif
  1083. return ctx.HolderFactory.Create<std::conditional_t<IsMultiRowPerItem, TStreamValue<ResultContainerOpt>, TSimpleStreamValue<ResultContainerOpt>>>(ctx, this->List->GetValue(ctx), this->Item, this->NewItem);
  1084. }
  1085. private:
  1086. void RegisterDependencies() const final {
  1087. this->DependsOn(this->List);
  1088. this->Own(this->Item);
  1089. this->DependsOn(this->NewItem);
  1090. }
  1091. #ifndef MKQL_DISABLE_CODEGEN
  1092. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  1093. this->FlatMapFunc = IsMultiRowPerItem ?
  1094. this->GenerateMapper(codegen, TBaseComputation::MakeName("Fetch")):
  1095. this->GenerateSimpleMapper(codegen, TBaseComputation::MakeName("Fetch"));
  1096. codegen.ExportSymbol(this->FlatMapFunc);
  1097. }
  1098. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  1099. if (this->FlatMapFunc)
  1100. this->FlatMap = reinterpret_cast<typename TBaseWrapper::TFlatMapPtr>(codegen.GetPointerToFunction(this->FlatMapFunc));
  1101. }
  1102. #endif
  1103. };
  1104. #ifndef MKQL_DISABLE_CODEGEN
  1105. NUdf::TUnboxedValuePod* MyArrayAlloc(const ui64 size) {
  1106. return TMKQLAllocator<NUdf::TUnboxedValuePod>::allocate(size);
  1107. }
  1108. void MyArrayFree(const NUdf::TUnboxedValuePod *const ptr, const ui64 size) noexcept {
  1109. TMKQLAllocator<NUdf::TUnboxedValuePod>::deallocate(ptr, size);
  1110. }
  1111. #endif
  1112. template <bool IsMultiRowPerItem, bool ResultContainerOpt>
  1113. class TListFlatMapWrapper : public TBothWaysCodegeneratorNode<TListFlatMapWrapper<IsMultiRowPerItem, ResultContainerOpt>>,
  1114. private TBaseFlatMapWrapper<false, IsMultiRowPerItem, ResultContainerOpt> {
  1115. typedef TBaseFlatMapWrapper<false, IsMultiRowPerItem, ResultContainerOpt> TBaseWrapper;
  1116. typedef TBothWaysCodegeneratorNode<TListFlatMapWrapper<IsMultiRowPerItem, ResultContainerOpt>> TBaseComputation;
  1117. static constexpr size_t UseOnStack = 1ULL << 8ULL;
  1118. public:
  1119. TListFlatMapWrapper(TComputationMutables& mutables, IComputationNode* list, IComputationExternalNode* item, IComputationNode* newItem)
  1120. : TBaseComputation(mutables), TBaseWrapper(list, item, newItem)
  1121. {}
  1122. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  1123. auto list = this->List->GetValue(ctx);
  1124. if (const auto elements = list.GetElements()) {
  1125. const auto size = list.GetListLength();
  1126. TUnboxedValueVector values(size);
  1127. auto it = values.begin();
  1128. std::for_each(elements, elements + size, [&] (NUdf::TUnboxedValue item) {
  1129. this->Item->SetValue(ctx, std::move(item));
  1130. *it = this->NewItem->GetValue(ctx);
  1131. if (IsMultiRowPerItem || *it) {
  1132. auto value = it->GetOptionalValueIf<!IsMultiRowPerItem && ResultContainerOpt>();
  1133. *it++ = value;
  1134. }
  1135. });
  1136. if constexpr (IsMultiRowPerItem) {
  1137. return ctx.HolderFactory.ExtendList<ResultContainerOpt>(values.data(), values.size());
  1138. }
  1139. NUdf::TUnboxedValue* items = nullptr;
  1140. const auto result = ctx.HolderFactory.CreateDirectArrayHolder(std::distance(values.begin(), it), items);
  1141. std::move(values.begin(), it, items);
  1142. return result;
  1143. }
  1144. return ctx.HolderFactory.Create<std::conditional_t<IsMultiRowPerItem, TListValue<ResultContainerOpt>, TSimpleListValue<ResultContainerOpt>>>(ctx, std::move(list), this->Item, this->NewItem);
  1145. }
  1146. #ifndef MKQL_DISABLE_CODEGEN
  1147. NUdf::TUnboxedValuePod MakeLazyList(TComputationContext& ctx, const NUdf::TUnboxedValuePod value) const {
  1148. return ctx.HolderFactory.Create<typename TBaseWrapper::TCodegenValue>(this->FlatMap, &ctx, value);
  1149. }
  1150. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  1151. auto& context = ctx.Codegen.GetContext();
  1152. const auto codegenItem = dynamic_cast<ICodegeneratorExternalNode*>(this->Item);
  1153. MKQL_ENSURE(codegenItem, "Item must be codegenerator node.");
  1154. const auto list = GetNodeValue(this->List, ctx, block);
  1155. const auto lazy = BasicBlock::Create(context, "lazy", ctx.Func);
  1156. const auto hard = BasicBlock::Create(context, "hard", ctx.Func);
  1157. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  1158. const auto map = PHINode::Create(list->getType(), 3U, "map", done);
  1159. const auto elementsType = PointerType::getUnqual(list->getType());
  1160. const auto elements = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElements>(elementsType, list, ctx.Codegen, block);
  1161. const auto fill = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, elements, ConstantPointerNull::get(elementsType), "fill", block);
  1162. BranchInst::Create(hard, lazy, fill, block);
  1163. {
  1164. block = hard;
  1165. const auto smsk = BasicBlock::Create(context, "smsk", ctx.Func);
  1166. const auto hmsk = BasicBlock::Create(context, "hmsk", ctx.Func);
  1167. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  1168. const auto free = BasicBlock::Create(context, "free", ctx.Func);
  1169. const auto vector = PHINode::Create(PointerType::getUnqual(list->getType()), 2U, "vector", main);
  1170. const auto size = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetListLength>(Type::getInt64Ty(context), list, ctx.Codegen, block);
  1171. const auto zeroSize = ConstantInt::get(size->getType(), 0);
  1172. const auto plusSize = ConstantInt::get(size->getType(), 1);
  1173. const auto heap = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, size, ConstantInt::get(size->getType(), UseOnStack), "heap", block);
  1174. BranchInst::Create(hmsk, smsk, heap, block);
  1175. {
  1176. block = smsk;
  1177. const auto arrayType = ArrayType::get(list->getType(), UseOnStack);
  1178. const auto array = *this->Stateless || ctx.AlwaysInline ?
  1179. new AllocaInst(arrayType, 0U, "array", &ctx.Func->getEntryBlock().back()):
  1180. new AllocaInst(arrayType, 0U, "array", block);
  1181. const auto ptr = GetElementPtrInst::CreateInBounds(arrayType, array, {zeroSize, zeroSize}, "ptr", block);
  1182. vector->addIncoming(ptr, block);
  1183. BranchInst::Create(main, block);
  1184. }
  1185. {
  1186. block = hmsk;
  1187. const auto fnType = FunctionType::get(vector->getType(), {size->getType()}, false);
  1188. const auto name = "MyArrayAlloc";
  1189. ctx.Codegen.AddGlobalMapping(name, reinterpret_cast<const void*>(&MyArrayAlloc));
  1190. const auto func = ctx.Codegen.GetModule().getOrInsertFunction(name, fnType);
  1191. const auto ptr = CallInst::Create(func, {size}, "ptr", block);
  1192. vector->addIncoming(ptr, block);
  1193. BranchInst::Create(main, block);
  1194. }
  1195. block = main;
  1196. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  1197. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  1198. const auto stop = BasicBlock::Create(context, "stop", ctx.Func);
  1199. const auto index = PHINode::Create(size->getType(), 2U, "index", loop);
  1200. index->addIncoming(zeroSize, block);
  1201. const auto idx = IsMultiRowPerItem ? index : PHINode::Create(size->getType(), 2U, "idx", loop);
  1202. if constexpr (!IsMultiRowPerItem) {
  1203. idx->addIncoming(zeroSize, block);
  1204. }
  1205. BranchInst::Create(loop, block);
  1206. block = loop;
  1207. const auto more = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, size, index, "more", block);
  1208. BranchInst::Create(next, stop, more, block);
  1209. block = next;
  1210. const auto src = GetElementPtrInst::CreateInBounds(list->getType(), elements, {index}, "src", block);
  1211. const auto item = new LoadInst(list->getType(), src, "item", block);
  1212. codegenItem->CreateSetValue(ctx, block, item);
  1213. const auto dst = GetElementPtrInst::CreateInBounds(list->getType(), vector, {idx}, "dst", block);
  1214. GetNodeValue(dst, this->NewItem, ctx, block);
  1215. const auto inc = BinaryOperator::CreateAdd(index, plusSize, "inc", block);
  1216. index->addIncoming(inc, block);
  1217. if constexpr (!IsMultiRowPerItem) {
  1218. const auto plus = BinaryOperator::CreateAdd(idx, plusSize, "plus", block);
  1219. const auto load = new LoadInst(list->getType(), dst, "load", block);
  1220. new StoreInst(GetOptionalValue(context, load, block), dst, block);
  1221. const auto move = SelectInst::Create(IsExists(load, block, context), plus, idx, "move", block);
  1222. idx->addIncoming(move, block);
  1223. }
  1224. BranchInst::Create(loop, block);
  1225. block = stop;
  1226. if (this->List->IsTemporaryValue()) {
  1227. CleanupBoxed(list, ctx, block);
  1228. }
  1229. Value* res;
  1230. if constexpr (!IsMultiRowPerItem) {
  1231. const auto newType = PointerType::getUnqual(list->getType());
  1232. const auto newPtr = *this->Stateless || ctx.AlwaysInline ?
  1233. new AllocaInst(newType, 0U, "new_ptr", &ctx.Func->getEntryBlock().back()):
  1234. new AllocaInst(newType, 0U, "new_ptr", block);
  1235. res = GenNewArray(ctx, idx, newPtr, block);
  1236. const auto target = new LoadInst(newType, newPtr, "target", block);
  1237. const auto pType = PointerType::getUnqual(Type::getInt8Ty(context));
  1238. const auto pdst = CastInst::Create(Instruction::BitCast, target, pType, "pdst", block);
  1239. const auto psrc = CastInst::Create(Instruction::BitCast, vector, pType, "psrc", block);
  1240. const auto bytes = BinaryOperator::CreateShl(idx, ConstantInt::get(idx->getType(), 4), "bytes", block);
  1241. const auto fnType = FunctionType::get(Type::getVoidTy(context), {pType, pType, bytes->getType(), Type::getInt1Ty(context)}, false);
  1242. const auto memcpyName = (LLVM_VERSION_MAJOR < 16) ? "llvm.memcpy.p0i8.p0i8.i64" : "llvm.memcpy.p0.p0.i64";
  1243. const auto func = ctx.Codegen.GetModule().getOrInsertFunction(memcpyName, fnType);
  1244. CallInst::Create(func, {pdst, psrc, bytes, ConstantInt::getFalse(context)}, "", block);
  1245. } else {
  1246. const auto factory = ctx.GetFactory();
  1247. const auto func = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&THolderFactory::ExtendList<ResultContainerOpt>));
  1248. const auto funType = FunctionType::get(list->getType(), {factory->getType(), vector->getType(), index->getType()}, false);
  1249. const auto funcPtr = CastInst::Create(Instruction::IntToPtr, func, PointerType::getUnqual(funType), "function", block);
  1250. res = CallInst::Create(funType, funcPtr, {factory, vector, index}, "res", block);
  1251. }
  1252. map->addIncoming(res, block);
  1253. BranchInst::Create(free, done, heap, block);
  1254. {
  1255. block = free;
  1256. const auto fnType = FunctionType::get(Type::getVoidTy(context), {vector->getType(), size->getType()}, false);
  1257. const auto name = "MyArrayFree";
  1258. ctx.Codegen.AddGlobalMapping(name, reinterpret_cast<const void*>(&MyArrayFree));
  1259. const auto func = ctx.Codegen.GetModule().getOrInsertFunction(name, fnType);
  1260. CallInst::Create(func, {vector, size}, "", block);
  1261. map->addIncoming(res, block);
  1262. BranchInst::Create(done, block);
  1263. }
  1264. }
  1265. {
  1266. block = lazy;
  1267. const auto doFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TListFlatMapWrapper::MakeLazyList));
  1268. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  1269. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  1270. const auto funType = FunctionType::get(list->getType() , {self->getType(), ctx.Ctx->getType(), list->getType()}, false);
  1271. const auto doFuncPtr = CastInst::Create(Instruction::IntToPtr, doFunc, PointerType::getUnqual(funType), "function", block);
  1272. const auto value = CallInst::Create(funType, doFuncPtr, {self, ctx.Ctx, list}, "value", block);
  1273. map->addIncoming(value, block);
  1274. BranchInst::Create(done, block);
  1275. }
  1276. block = done;
  1277. return map;
  1278. }
  1279. #endif
  1280. private:
  1281. void RegisterDependencies() const final {
  1282. this->DependsOn(this->List);
  1283. this->Own(this->Item);
  1284. this->DependsOn(this->NewItem);
  1285. }
  1286. #ifndef MKQL_DISABLE_CODEGEN
  1287. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  1288. TMutableCodegeneratorRootNode<TListFlatMapWrapper>::GenerateFunctions(codegen);
  1289. this->FlatMapFunc = IsMultiRowPerItem ?
  1290. this->GenerateMapper(codegen, TBaseComputation::MakeName("Next")):
  1291. this->GenerateSimpleMapper(codegen, TBaseComputation::MakeName("Next"));
  1292. codegen.ExportSymbol(this->FlatMapFunc);
  1293. }
  1294. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  1295. TMutableCodegeneratorRootNode<TListFlatMapWrapper>::FinalizeFunctions(codegen);
  1296. if (this->FlatMapFunc)
  1297. this->FlatMap = reinterpret_cast<typename TBaseWrapper::TFlatMapPtr>(codegen.GetPointerToFunction(this->FlatMapFunc));
  1298. }
  1299. #endif
  1300. };
  1301. }
  1302. IComputationNode* WrapFlatMap(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1303. MKQL_ENSURE(callable.GetInputsCount() == 3, "Expected 3 args");
  1304. const auto listType = callable.GetInput(0).GetStaticType();;
  1305. const auto newListType = callable.GetInput(2).GetStaticType();
  1306. const auto type = callable.GetType()->GetReturnType();
  1307. const auto list = LocateNode(ctx.NodeLocator, callable, 0);
  1308. const auto newItem = LocateNode(ctx.NodeLocator, callable, 2);
  1309. const auto itemArg = LocateExternalNode(ctx.NodeLocator, callable, 1);
  1310. const auto kind = GetValueRepresentation(type);
  1311. if (listType->IsFlow()) {
  1312. if (newListType->IsFlow()) {
  1313. if (const auto wideOut = dynamic_cast<IComputationWideFlowNode*>(newItem))
  1314. return new TFlowFlatMapWideWrapper(ctx.Mutables, list, itemArg, wideOut);
  1315. else
  1316. return new TFlowFlatMapFlowWrapper(ctx.Mutables, kind, list, itemArg, newItem);
  1317. } else if (newListType->IsList()) {
  1318. return new TFlowFlatMapWrapper<true, false>(ctx.Mutables, kind, list, itemArg, newItem);
  1319. } else if (newListType->IsStream()) {
  1320. return new TFlowFlatMapWrapper<true, true>(ctx.Mutables, kind, list, itemArg, newItem);
  1321. } else if (newListType->IsOptional()) {
  1322. if (AS_TYPE(TOptionalType, newListType)->GetItemType()->IsOptional()) {
  1323. return new TFlowFlatMapWrapper<false, true>(ctx.Mutables, kind, list, itemArg, newItem);
  1324. } else {
  1325. return new TFlowFlatMapWrapper<false, false>(ctx.Mutables, kind, list, itemArg, newItem);
  1326. }
  1327. }
  1328. } else if (listType->IsStream()) {
  1329. if (newListType->IsList()) {
  1330. return new TStreamFlatMapWrapper<true, false>(ctx.Mutables, list, itemArg, newItem);
  1331. } else if (newListType->IsStream()) {
  1332. return new TStreamFlatMapWrapper<true, true>(ctx.Mutables, list, itemArg, newItem);
  1333. } else if (newListType->IsOptional()) {
  1334. if (AS_TYPE(TOptionalType, newListType)->GetItemType()->IsOptional()) {
  1335. return new TStreamFlatMapWrapper<false, true>(ctx.Mutables, list, itemArg, newItem);
  1336. } else {
  1337. return new TStreamFlatMapWrapper<false, false>(ctx.Mutables, list, itemArg, newItem);
  1338. }
  1339. }
  1340. } else if (listType->IsList()) {
  1341. if (newListType->IsFlow()) {
  1342. if (const auto wideOut = dynamic_cast<IComputationWideFlowNode*>(newItem))
  1343. return new TListFlatMapWideWrapper(ctx.Mutables, list, itemArg, wideOut);
  1344. else
  1345. return new TListFlatMapFlowWrapper(ctx.Mutables, kind, list, itemArg, newItem);
  1346. } else if (newListType->IsList()) {
  1347. return new TListFlatMapWrapper<true, false>(ctx.Mutables, list, itemArg, newItem);
  1348. } else if (newListType->IsStream()) {
  1349. return new TListFlatMapWrapper<true, true>(ctx.Mutables, list, itemArg, newItem);
  1350. } else if (newListType->IsOptional()) {
  1351. if (AS_TYPE(TOptionalType, newListType)->GetItemType()->IsOptional()) {
  1352. return new TListFlatMapWrapper<false, true>(ctx.Mutables, list, itemArg, newItem);
  1353. } else {
  1354. return new TListFlatMapWrapper<false, false>(ctx.Mutables, list, itemArg, newItem);
  1355. }
  1356. }
  1357. }
  1358. THROW yexception() << "Expected flow, list or stream of lists, streams or optionals.";
  1359. }
  1360. IComputationNode* WrapNarrowFlatMap(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1361. MKQL_ENSURE(callable.GetInputsCount() > 1U, "Expected at least two args.");
  1362. const auto width = GetWideComponentsCount(AS_TYPE(TFlowType, callable.GetInput(0U).GetStaticType()));
  1363. MKQL_ENSURE(callable.GetInputsCount() == width + 2U, "Wrong signature.");
  1364. const auto last = callable.GetInputsCount() - 1U;
  1365. const auto flow = LocateNode(ctx.NodeLocator, callable, 0U);
  1366. const auto newItem = LocateNode(ctx.NodeLocator, callable, last);
  1367. TComputationExternalNodePtrVector args(width, nullptr);
  1368. ui32 index = 0U;
  1369. std::generate(args.begin(), args.end(), [&](){ return LocateExternalNode(ctx.NodeLocator, callable, ++index); });
  1370. const auto newListType = callable.GetInput(last).GetStaticType();
  1371. const auto kind = GetValueRepresentation(callable.GetType()->GetReturnType());
  1372. if (const auto wide = dynamic_cast<IComputationWideFlowNode*>(flow)) {
  1373. if (newListType->IsFlow()) {
  1374. return new TNarrowFlatMapFlowWrapper(ctx.Mutables, kind, wide, std::move(args), newItem);
  1375. } else if (newListType->IsList()) {
  1376. return new TNarrowFlatMapWrapper<true, false>(ctx.Mutables, kind, wide, std::move(args), newItem);
  1377. } else if (newListType->IsStream()) {
  1378. return new TNarrowFlatMapWrapper<true, true>(ctx.Mutables, kind, wide, std::move(args), newItem);
  1379. } else if (newListType->IsOptional()) {
  1380. if (AS_TYPE(TOptionalType, newListType)->GetItemType()->IsOptional()) {
  1381. return new TNarrowFlatMapWrapper<false, true>(ctx.Mutables, kind, wide, std::move(args), newItem);
  1382. } else {
  1383. return new TNarrowFlatMapWrapper<false, false>(ctx.Mutables, kind, wide, std::move(args), newItem);
  1384. }
  1385. }
  1386. }
  1387. THROW yexception() << "Expected wide flow.";
  1388. }
  1389. }
  1390. }