mkql_collect.cpp 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. #include "mkql_collect.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  4. namespace NKikimr {
  5. namespace NMiniKQL {
  6. namespace {
  7. class TCollectFlowWrapper : public TMutableCodegeneratorRootNode<TCollectFlowWrapper> {
  8. using TBaseComputation = TMutableCodegeneratorRootNode<TCollectFlowWrapper>;
  9. public:
  10. TCollectFlowWrapper(TComputationMutables& mutables, IComputationNode* flow)
  11. : TBaseComputation(mutables, EValueRepresentation::Boxed), Flow(flow)
  12. {}
  13. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  14. for (NUdf::TUnboxedValue list = ctx.HolderFactory.GetEmptyContainerLazy();;) {
  15. auto item = Flow->GetValue(ctx);
  16. if (item.IsFinish()) {
  17. return list.Release();
  18. }
  19. MKQL_ENSURE(!item.IsYield(), "Unexpected flow status!");
  20. list = ctx.HolderFactory.Append(list.Release(), item.Release());
  21. }
  22. }
  23. #ifndef MKQL_DISABLE_CODEGEN
  24. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  25. auto& context = ctx.Codegen.GetContext();
  26. const auto factory = ctx.GetFactory();
  27. const auto valueType = Type::getInt128Ty(context);
  28. const auto empty = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&THolderFactory::GetEmptyContainerLazy));
  29. const auto append = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&THolderFactory::Append));
  30. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  31. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  32. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  33. const auto burn = BasicBlock::Create(context, "burn", ctx.Func);
  34. const auto list = PHINode::Create(valueType, 2U, "list", work);
  35. const auto funType = FunctionType::get(valueType, {factory->getType()}, false);
  36. const auto funcPtr = CastInst::Create(Instruction::IntToPtr, empty, PointerType::getUnqual(funType), "empty", block);
  37. const auto first = CallInst::Create(funType, funcPtr, {factory}, "init", block);
  38. list->addIncoming(first, block);
  39. BranchInst::Create(work, block);
  40. block = work;
  41. const auto item = GetNodeValue(Flow, ctx, block);
  42. const auto select = SwitchInst::Create(item, good, 2U, block);
  43. select->addCase(GetFinish(context), done);
  44. select->addCase(GetYield(context), burn);
  45. {
  46. block = good;
  47. const auto funType = FunctionType::get(valueType, {factory->getType(), list->getType(), item->getType()}, false);
  48. const auto funcPtr = CastInst::Create(Instruction::IntToPtr, append, PointerType::getUnqual(funType), "append", block);
  49. const auto next = CallInst::Create(funType, funcPtr, {factory, list, item}, "next", block);
  50. list->addIncoming(next, block);
  51. BranchInst::Create(work, block);
  52. }
  53. {
  54. block = burn;
  55. const auto thrower = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TCollectFlowWrapper::Throw));
  56. const auto throwerType = FunctionType::get(Type::getVoidTy(context), {}, false);
  57. const auto throwerPtr = CastInst::Create(Instruction::IntToPtr, thrower, PointerType::getUnqual(throwerType), "thrower", block);
  58. CallInst::Create(throwerType, throwerPtr, {}, "", block);
  59. new UnreachableInst(context, block);
  60. }
  61. block = done;
  62. return list;
  63. }
  64. #endif
  65. private:
  66. [[noreturn]] static void Throw() {
  67. UdfTerminate("Unexpected flow status!");
  68. }
  69. void RegisterDependencies() const final {
  70. this->DependsOn(Flow);
  71. }
  72. IComputationNode* const Flow;
  73. };
  74. template <bool IsList>
  75. class TCollectWrapper : public TMutableCodegeneratorNode<TCollectWrapper<IsList>> {
  76. typedef TMutableCodegeneratorNode<TCollectWrapper<IsList>> TBaseComputation;
  77. public:
  78. TCollectWrapper(TComputationMutables& mutables, IComputationNode* seq)
  79. : TBaseComputation(mutables, EValueRepresentation::Boxed), Seq(seq)
  80. {}
  81. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  82. auto seq = Seq->GetValue(ctx);
  83. if (IsList && seq.GetElements()) {
  84. return seq.Release();
  85. }
  86. return ctx.HolderFactory.Collect<!IsList>(seq.Release());
  87. }
  88. #ifndef MKQL_DISABLE_CODEGEN
  89. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  90. auto& context = ctx.Codegen.GetContext();
  91. const auto factory = ctx.GetFactory();
  92. const auto func = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&THolderFactory::Collect<!IsList>));
  93. const auto seq = GetNodeValue(Seq, ctx, block);
  94. if constexpr (IsList) {
  95. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  96. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  97. const auto valueType = Type::getInt128Ty(context);
  98. const auto ptrType = PointerType::getUnqual(valueType);
  99. const auto result = PHINode::Create(valueType, 2U, "result", done);
  100. const auto elements = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElements>(ptrType, seq, ctx.Codegen, block);
  101. const auto null = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, elements, ConstantPointerNull::get(ptrType), "null", block);
  102. result->addIncoming(seq, block);
  103. BranchInst::Create(work, done, null, block);
  104. block = work;
  105. const auto funType = FunctionType::get(seq->getType(), {factory->getType(), seq->getType()}, false);
  106. const auto funcPtr = CastInst::Create(Instruction::IntToPtr, func, PointerType::getUnqual(funType), "function", block);
  107. const auto res = CallInst::Create(funType, funcPtr, {factory, seq}, "res", block);
  108. result->addIncoming(res, block);
  109. BranchInst::Create(done, block);
  110. block = done;
  111. return result;
  112. } else {
  113. const auto funType = FunctionType::get(seq->getType(), {factory->getType(), seq->getType()}, false);
  114. const auto funcPtr = CastInst::Create(Instruction::IntToPtr, func, PointerType::getUnqual(funType), "function", block);
  115. const auto res = CallInst::Create(funType, funcPtr, {factory, seq}, "res", block);
  116. return res;
  117. }
  118. }
  119. #endif
  120. private:
  121. void RegisterDependencies() const final {
  122. this->DependsOn(Seq);
  123. }
  124. IComputationNode* const Seq;
  125. };
  126. }
  127. IComputationNode* WrapCollect(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  128. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 arg");
  129. const auto type = callable.GetInput(0).GetStaticType();
  130. const auto list = LocateNode(ctx.NodeLocator, callable, 0);
  131. if (type->IsFlow()) {
  132. return new TCollectFlowWrapper(ctx.Mutables, list);
  133. } else if (type->IsList()) {
  134. return new TCollectWrapper<true>(ctx.Mutables, list);
  135. } else if (type->IsStream()) {
  136. return new TCollectWrapper<false>(ctx.Mutables, list);
  137. }
  138. THROW yexception() << "Expected flow, list or stream.";
  139. }
  140. }
  141. }