mkql_collect.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. #include "mkql_collect.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  4. namespace NKikimr {
  5. namespace NMiniKQL {
  6. namespace {
  7. class TCollectFlowWrapper : public TMutableCodegeneratorRootNode<TCollectFlowWrapper> {
  8. using TBaseComputation = TMutableCodegeneratorRootNode<TCollectFlowWrapper>;
  9. public:
  10. TCollectFlowWrapper(TComputationMutables& mutables, IComputationNode* flow)
  11. : TBaseComputation(mutables, EValueRepresentation::Boxed), Flow(flow)
  12. {}
  13. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  14. for (NUdf::TUnboxedValue list = ctx.HolderFactory.GetEmptyContainerLazy();;) {
  15. auto item = Flow->GetValue(ctx);
  16. if (item.IsFinish()) {
  17. return list.Release();
  18. }
  19. MKQL_ENSURE(!item.IsYield(), "Unexpected flow status!");
  20. list = ctx.HolderFactory.Append(list.Release(), item.Release());
  21. }
  22. }
  23. #ifndef MKQL_DISABLE_CODEGEN
  24. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  25. auto& context = ctx.Codegen.GetContext();
  26. const auto factory = ctx.GetFactory();
  27. const auto valueType = Type::getInt128Ty(context);
  28. const auto empty = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&THolderFactory::GetEmptyContainerLazy));
  29. const auto append = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&THolderFactory::Append));
  30. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  31. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  32. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  33. const auto burn = BasicBlock::Create(context, "burn", ctx.Func);
  34. const auto list = PHINode::Create(valueType, 2U, "list", work);
  35. if (NYql::NCodegen::ETarget::Windows != ctx.Codegen.GetEffectiveTarget()) {
  36. const auto funType = FunctionType::get(valueType, {factory->getType()}, false);
  37. const auto funcPtr = CastInst::Create(Instruction::IntToPtr, empty, PointerType::getUnqual(funType), "empty", block);
  38. const auto first = CallInst::Create(funType, funcPtr, {factory}, "init", block);
  39. list->addIncoming(first, block);
  40. } else {
  41. const auto ptr = new AllocaInst(valueType, 0U, "ptr", block);
  42. const auto funType = FunctionType::get(Type::getVoidTy(context), {factory->getType(), ptr->getType()}, false);
  43. const auto funcPtr = CastInst::Create(Instruction::IntToPtr, empty, PointerType::getUnqual(funType), "empty", block);
  44. CallInst::Create(funType, funcPtr, {factory, ptr}, "", block);
  45. const auto first = new LoadInst(valueType, ptr, "init", block);
  46. list->addIncoming(first, block);
  47. }
  48. BranchInst::Create(work, block);
  49. block = work;
  50. const auto item = GetNodeValue(Flow, ctx, block);
  51. const auto select = SwitchInst::Create(item, good, 2U, block);
  52. select->addCase(GetFinish(context), done);
  53. select->addCase(GetYield(context), burn);
  54. {
  55. block = good;
  56. if (NYql::NCodegen::ETarget::Windows != ctx.Codegen.GetEffectiveTarget()) {
  57. const auto funType = FunctionType::get(valueType, {factory->getType(), list->getType(), item->getType()}, false);
  58. const auto funcPtr = CastInst::Create(Instruction::IntToPtr, append, PointerType::getUnqual(funType), "append", block);
  59. const auto next = CallInst::Create(funType, funcPtr, {factory, list, item}, "next", block);
  60. list->addIncoming(next, block);
  61. } else {
  62. const auto retPtr = new AllocaInst(list->getType(), 0U, "ret_ptr", block);
  63. const auto itemPtr = new AllocaInst(item->getType(), 0U, "item_ptr", block);
  64. new StoreInst(list, retPtr, block);
  65. new StoreInst(item, itemPtr, block);
  66. const auto funType = FunctionType::get(Type::getVoidTy(context), {factory->getType(), retPtr->getType(), retPtr->getType(), itemPtr->getType()}, false);
  67. const auto funcPtr = CastInst::Create(Instruction::IntToPtr, append, PointerType::getUnqual(funType), "append", block);
  68. CallInst::Create(funType, funcPtr, {factory, retPtr, retPtr, itemPtr}, "", block);
  69. const auto next = new LoadInst(list->getType(), retPtr, "next", block);
  70. list->addIncoming(next, block);
  71. }
  72. BranchInst::Create(work, block);
  73. }
  74. {
  75. block = burn;
  76. const auto thrower = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TCollectFlowWrapper::Throw));
  77. const auto throwerType = FunctionType::get(Type::getVoidTy(context), {}, false);
  78. const auto throwerPtr = CastInst::Create(Instruction::IntToPtr, thrower, PointerType::getUnqual(throwerType), "thrower", block);
  79. CallInst::Create(throwerType, throwerPtr, {}, "", block);
  80. new UnreachableInst(context, block);
  81. }
  82. block = done;
  83. return list;
  84. }
  85. #endif
  86. private:
  87. [[noreturn]] static void Throw() {
  88. UdfTerminate("Unexpected flow status!");
  89. }
  90. void RegisterDependencies() const final {
  91. this->DependsOn(Flow);
  92. }
  93. IComputationNode* const Flow;
  94. };
  95. template <bool IsList>
  96. class TCollectWrapper : public TMutableCodegeneratorNode<TCollectWrapper<IsList>> {
  97. typedef TMutableCodegeneratorNode<TCollectWrapper<IsList>> TBaseComputation;
  98. public:
  99. TCollectWrapper(TComputationMutables& mutables, IComputationNode* seq)
  100. : TBaseComputation(mutables, EValueRepresentation::Boxed), Seq(seq)
  101. {}
  102. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  103. auto seq = Seq->GetValue(ctx);
  104. if (IsList && seq.GetElements()) {
  105. return seq.Release();
  106. }
  107. return ctx.HolderFactory.Collect<!IsList>(seq.Release());
  108. }
  109. #ifndef MKQL_DISABLE_CODEGEN
  110. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  111. auto& context = ctx.Codegen.GetContext();
  112. const auto factory = ctx.GetFactory();
  113. const auto func = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&THolderFactory::Collect<!IsList>));
  114. const auto seq = GetNodeValue(Seq, ctx, block);
  115. if constexpr (IsList) {
  116. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  117. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  118. const auto valueType = Type::getInt128Ty(context);
  119. const auto ptrType = PointerType::getUnqual(valueType);
  120. const auto result = PHINode::Create(valueType, 2U, "result", done);
  121. const auto elements = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElements>(ptrType, seq, ctx.Codegen, block);
  122. const auto null = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, elements, ConstantPointerNull::get(ptrType), "null", block);
  123. result->addIncoming(seq, block);
  124. BranchInst::Create(work, done, null, block);
  125. block = work;
  126. if (NYql::NCodegen::ETarget::Windows != ctx.Codegen.GetEffectiveTarget()) {
  127. const auto funType = FunctionType::get(seq->getType(), {factory->getType(), seq->getType()}, false);
  128. const auto funcPtr = CastInst::Create(Instruction::IntToPtr, func, PointerType::getUnqual(funType), "function", block);
  129. const auto res = CallInst::Create(funType, funcPtr, {factory, seq}, "res", block);
  130. result->addIncoming(res, block);
  131. } else {
  132. const auto ptr = new AllocaInst(seq->getType(), 0U, "ptr", block);
  133. new StoreInst(seq, ptr, block);
  134. const auto funType = FunctionType::get(Type::getVoidTy(context), {factory->getType(), ptr->getType(), ptr->getType()}, false);
  135. const auto funcPtr = CastInst::Create(Instruction::IntToPtr, func, PointerType::getUnqual(funType), "function", block);
  136. CallInst::Create(funType, funcPtr, {factory, ptr, ptr}, "", block);
  137. const auto res = new LoadInst(seq->getType(), ptr, "res", block);
  138. result->addIncoming(res, block);
  139. }
  140. BranchInst::Create(done, block);
  141. block = done;
  142. return result;
  143. } else {
  144. if (NYql::NCodegen::ETarget::Windows != ctx.Codegen.GetEffectiveTarget()) {
  145. const auto funType = FunctionType::get(seq->getType(), {factory->getType(), seq->getType()}, false);
  146. const auto funcPtr = CastInst::Create(Instruction::IntToPtr, func, PointerType::getUnqual(funType), "function", block);
  147. const auto res = CallInst::Create(funType, funcPtr, {factory, seq}, "res", block);
  148. return res;
  149. } else {
  150. const auto ptr = new AllocaInst(seq->getType(), 0U, "ptr", block);
  151. new StoreInst(seq, ptr, block);
  152. const auto funType = FunctionType::get(Type::getVoidTy(context), {factory->getType(), ptr->getType(), ptr->getType()}, false);
  153. const auto funcPtr = CastInst::Create(Instruction::IntToPtr, func, PointerType::getUnqual(funType), "function", block);
  154. CallInst::Create(funType, funcPtr, {factory, ptr, ptr}, "", block);
  155. const auto res = new LoadInst(seq->getType(), ptr, "res", block);
  156. return res;
  157. }
  158. }
  159. }
  160. #endif
  161. private:
  162. void RegisterDependencies() const final {
  163. this->DependsOn(Seq);
  164. }
  165. IComputationNode* const Seq;
  166. };
  167. }
  168. IComputationNode* WrapCollect(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  169. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 arg");
  170. const auto type = callable.GetInput(0).GetStaticType();
  171. const auto list = LocateNode(ctx.NodeLocator, callable, 0);
  172. if (type->IsFlow()) {
  173. return new TCollectFlowWrapper(ctx.Mutables, list);
  174. } else if (type->IsList()) {
  175. return new TCollectWrapper<true>(ctx.Mutables, list);
  176. } else if (type->IsStream()) {
  177. return new TCollectWrapper<false>(ctx.Mutables, list);
  178. }
  179. THROW yexception() << "Expected flow, list or stream.";
  180. }
  181. }
  182. }