mkql_reduce.cpp 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. #include "mkql_reduce.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  4. #include <yql/essentials/minikql/mkql_node_cast.h>
  5. namespace NKikimr {
  6. namespace NMiniKQL {
  7. namespace {
  8. template<bool IsStream>
  9. class TReduceWrapper : public TMutableCodegeneratorRootNode<TReduceWrapper<IsStream>> {
  10. typedef TMutableCodegeneratorRootNode<TReduceWrapper<IsStream>> TBaseComputation;
  11. public:
  12. TReduceWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* list, IComputationExternalNode* item, IComputationExternalNode* state1,
  13. IComputationNode* newState1, IComputationNode* newState2,
  14. IComputationNode* initialState1, IComputationExternalNode* itemState2, IComputationExternalNode* state3,
  15. IComputationNode* newState3, IComputationNode* initialState3)
  16. : TBaseComputation(mutables, kind)
  17. , List(list)
  18. , Item(item)
  19. , State1(state1)
  20. , NewState1(newState1)
  21. , NewState2(newState2)
  22. , InitialState1(initialState1)
  23. , ItemState2(itemState2)
  24. , State3(state3)
  25. , NewState3(newState3)
  26. , InitialState3(initialState3)
  27. {
  28. }
  29. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const {
  30. State1->SetValue(compCtx, InitialState1->GetValue(compCtx));
  31. State3->SetValue(compCtx, InitialState3->GetValue(compCtx));
  32. TThresher<IsStream>::DoForEachItem(List->GetValue(compCtx),
  33. [this, &compCtx] (NUdf::TUnboxedValue&& item) {
  34. Item->SetValue(compCtx, std::move(item));
  35. State1->SetValue(compCtx, NewState1->GetValue(compCtx));
  36. }
  37. );
  38. ItemState2->SetValue(compCtx, NewState2->GetValue(compCtx));
  39. return NewState3->GetValue(compCtx).Release();
  40. }
  41. #ifndef MKQL_DISABLE_CODEGEN
  42. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  43. auto &context = ctx.Codegen.GetContext();
  44. const auto codegenItem = dynamic_cast<ICodegeneratorExternalNode*>(Item);
  45. const auto codegenState1 = dynamic_cast<ICodegeneratorExternalNode*>(State1);
  46. const auto codegenItemState2 = dynamic_cast<ICodegeneratorExternalNode*>(ItemState2);
  47. const auto codegenState3 = dynamic_cast<ICodegeneratorExternalNode*>(State3);
  48. MKQL_ENSURE(codegenState1, "State1 must be codegenerator node.");
  49. MKQL_ENSURE(codegenState3, "State3 must be codegenerator node.");
  50. MKQL_ENSURE(codegenItem, "Item must be codegenerator node.");
  51. MKQL_ENSURE(codegenItemState2, "ItemState2 must be codegenerator node.");
  52. const auto valueType = Type::getInt128Ty(context);
  53. const auto init1 = GetNodeValue(InitialState1, ctx, block);
  54. codegenState1->CreateSetValue(ctx, block, init1);
  55. const auto init3 = GetNodeValue(InitialState3, ctx, block);
  56. codegenState3->CreateSetValue(ctx, block, init3);
  57. const auto list = GetNodeValue(List, ctx, block);
  58. const auto itemPtr = *this->Stateless || ctx.AlwaysInline ?
  59. new AllocaInst(valueType, 0U, "item_ptr", &ctx.Func->getEntryBlock().back()):
  60. new AllocaInst(valueType, 0U, "item_ptr", block);
  61. new StoreInst(ConstantInt::get(valueType, 0), itemPtr, block);
  62. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  63. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  64. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  65. if constexpr (IsStream) {
  66. BranchInst::Create(loop, block);
  67. block = loop;
  68. const auto status = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Fetch>(Type::getInt32Ty(context), list, ctx.Codegen, block, itemPtr);
  69. const auto icmp = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, status, ConstantInt::get(status->getType(), static_cast<ui32>(NUdf::EFetchStatus::Ok)), "cond", block);
  70. BranchInst::Create(done, good, icmp, block);
  71. block = good;
  72. codegenItem->CreateSetValue(ctx, block, itemPtr);
  73. const auto newState1 = GetNodeValue(NewState1, ctx, block);
  74. codegenState1->CreateSetValue(ctx, block, newState1);
  75. BranchInst::Create(loop, block);
  76. block = done;
  77. } else {
  78. const auto iterPtr = *this->Stateless || ctx.AlwaysInline ?
  79. new AllocaInst(valueType, 0U, "iter_ptr", &ctx.Func->getEntryBlock().back()):
  80. new AllocaInst(valueType, 0U, "iter_ptr", block);
  81. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetListIterator>(iterPtr, list, ctx.Codegen, block);
  82. const auto iter = new LoadInst(valueType, iterPtr, "iter", block);
  83. BranchInst::Create(loop, block);
  84. block = loop;
  85. const auto status = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Next>(Type::getInt1Ty(context), iter, ctx.Codegen, block, itemPtr);
  86. const auto icmp = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, status, ConstantInt::getFalse(context), "cond", block);
  87. BranchInst::Create(done, good, icmp, block);
  88. block = good;
  89. codegenItem->CreateSetValue(ctx, block, itemPtr);
  90. const auto newState1 = GetNodeValue(NewState1, ctx, block);
  91. codegenState1->CreateSetValue(ctx, block, newState1);
  92. BranchInst::Create(loop, block);
  93. block = done;
  94. UnRefBoxed(iter, ctx, block);
  95. }
  96. const auto newState2 = GetNodeValue(NewState2, ctx, block);
  97. codegenItemState2->CreateSetValue(ctx, block, newState2);
  98. const auto newState3 = GetNodeValue(NewState3, ctx, block);
  99. return newState3;
  100. }
  101. #endif
  102. private:
  103. void RegisterDependencies() const final {
  104. this->DependsOn(List);
  105. this->DependsOn(InitialState1);
  106. this->DependsOn(InitialState3);
  107. this->DependsOn(NewState1);
  108. this->DependsOn(NewState2);
  109. this->DependsOn(NewState3);
  110. this->Own(Item);
  111. this->Own(State1);
  112. this->Own(ItemState2);
  113. this->Own(State3);
  114. }
  115. IComputationNode* const List;
  116. IComputationExternalNode* const Item;
  117. IComputationExternalNode* const State1;
  118. IComputationNode* const NewState1;
  119. IComputationNode* const NewState2;
  120. IComputationNode* const InitialState1;
  121. IComputationExternalNode* const ItemState2;
  122. IComputationExternalNode* const State3;
  123. IComputationNode* const NewState3;
  124. IComputationNode* const InitialState3;
  125. };
  126. }
  127. IComputationNode* WrapReduce(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  128. MKQL_ENSURE(callable.GetInputsCount() == 10, "Expected 10 args");
  129. auto list = LocateNode(ctx.NodeLocator, callable, 0);
  130. auto initialState1 = LocateNode(ctx.NodeLocator, callable, 1);
  131. auto initialState3 = LocateNode(ctx.NodeLocator, callable, 2);
  132. auto newState1 = LocateNode(ctx.NodeLocator, callable, 5);
  133. auto newState2 = LocateNode(ctx.NodeLocator, callable, 6);
  134. auto newState3 = LocateNode(ctx.NodeLocator, callable, 9);
  135. auto itemArg = LocateExternalNode(ctx.NodeLocator, callable, 3);
  136. auto state1NodeArg = LocateExternalNode(ctx.NodeLocator, callable, 4);
  137. auto itemState2Arg = LocateExternalNode(ctx.NodeLocator, callable, 7);
  138. auto state3NodeArg = LocateExternalNode(ctx.NodeLocator, callable, 8);
  139. const auto kind = GetValueRepresentation(callable.GetType()->GetReturnType());
  140. if (callable.GetInput(0).GetStaticType()->IsStream()) {
  141. return new TReduceWrapper<true>(ctx.Mutables, kind, list, itemArg, state1NodeArg, newState1, newState2,
  142. initialState1, itemState2Arg, state3NodeArg, newState3, initialState3);
  143. } else {
  144. return new TReduceWrapper<false>(ctx.Mutables, kind, list, itemArg, state1NodeArg, newState1, newState2,
  145. initialState1, itemState2Arg, state3NodeArg, newState3, initialState3);
  146. }
  147. }
  148. }
  149. }