mkql_iterator.cpp 10 KB


  1. #include "mkql_iterator.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  4. #include <yql/essentials/minikql/computation/mkql_computation_node_holders_codegen.h>
  5. #include <yql/essentials/minikql/mkql_node_cast.h>
  6. namespace NKikimr {
  7. namespace NMiniKQL {
  8. namespace {
  9. class TIteratorWrapper : public TMutableCodegeneratorNode<TIteratorWrapper> {
  10. typedef TMutableCodegeneratorNode<TIteratorWrapper> TBaseComputation;
  11. public:
  12. TIteratorWrapper(TComputationMutables& mutables, IComputationNode* list, TComputationNodePtrVector&& dependentNodes)
  13. : TBaseComputation(mutables, EValueRepresentation::Boxed), List(list), DependentNodes(std::move(dependentNodes))
  14. {}
  15. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  16. return ctx.HolderFactory.CreateIteratorOverList(List->GetValue(ctx).Release());
  17. }
  18. #ifndef MKQL_DISABLE_CODEGEN
  19. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  20. auto& context = ctx.Codegen.GetContext();
  21. const auto value = GetNodeValue(List, ctx, block);
  22. const auto factory = ctx.GetFactory();
  23. const auto func = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&THolderFactory::CreateIteratorOverList));
  24. const auto signature = FunctionType::get(value->getType(), {factory->getType(), value->getType()}, false);
  25. const auto creator = CastInst::Create(Instruction::IntToPtr, func, PointerType::getUnqual(signature), "creator", block);
  26. const auto output = CallInst::Create(signature, creator, {factory, value}, "output", block);
  27. return output;
  28. }
  29. #endif
  30. private:
  31. void RegisterDependencies() const final {
  32. DependsOn(List);
  33. std::for_each(DependentNodes.cbegin(), DependentNodes.cend(),std::bind(&TIteratorWrapper::DependsOn, this, std::placeholders::_1));
  34. }
  35. IComputationNode *const List;
  36. const TComputationNodePtrVector DependentNodes;
  37. };
  38. class TForwardListWrapper : public TMutableCodegeneratorNode<TForwardListWrapper> {
  39. typedef TMutableCodegeneratorNode<TForwardListWrapper> TBaseComputation;
  40. public:
  41. TForwardListWrapper(TComputationMutables& mutables, IComputationNode* stream)
  42. : TBaseComputation(mutables, EValueRepresentation::Boxed), Stream(stream)
  43. {}
  44. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  45. return ctx.HolderFactory.CreateForwardList(Stream->GetValue(ctx).Release());
  46. }
  47. #ifndef MKQL_DISABLE_CODEGEN
  48. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  49. auto& context = ctx.Codegen.GetContext();
  50. const auto value = GetNodeValue(Stream, ctx, block);
  51. const auto factory = ctx.GetFactory();
  52. const auto func = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&THolderFactory::CreateForwardList));
  53. const auto signature = FunctionType::get(value->getType(), {factory->getType(), value->getType()}, false);
  54. const auto creator = CastInst::Create(Instruction::IntToPtr, func, PointerType::getUnqual(signature), "creator", block);
  55. const auto output = CallInst::Create(signature, creator, {factory, value}, "output", block);
  56. return output;
  57. }
  58. #endif
  59. private:
  60. void RegisterDependencies() const final {
  61. DependsOn(Stream);
  62. }
  63. IComputationNode *const Stream;
  64. };
  65. class TFlowForwardListWrapper : public TCustomValueCodegeneratorNode<TFlowForwardListWrapper> {
  66. typedef TCustomValueCodegeneratorNode<TFlowForwardListWrapper> TBaseComputation;
  67. public:
  68. class TIterator : public TComputationValue<TIterator> {
  69. public:
  70. using TPtr = IComputationNode*;
  71. TIterator(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, TPtr flow)
  72. : TComputationValue<TIterator>(memInfo), CompCtx(compCtx), Flow(flow)
  73. {}
  74. private:
  75. bool Next(NUdf::TUnboxedValue& value) final {
  76. value = Flow->GetValue(CompCtx);
  77. if (value.IsYield()) {
  78. Throw();
  79. }
  80. return !value.IsFinish();
  81. }
  82. TComputationContext& CompCtx;
  83. const TPtr Flow;
  84. };
  85. class TCodegenIterator : public TComputationValue<TCodegenIterator> {
  86. public:
  87. using TPtr = bool (*)(TComputationContext*, NUdf::TUnboxedValuePod&);
  88. TCodegenIterator(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, TPtr func)
  89. : TComputationValue<TCodegenIterator>(memInfo), CompCtx(compCtx), Func(func)
  90. {}
  91. private:
  92. bool Next(NUdf::TUnboxedValue& value) final {
  93. return Func(&CompCtx, value);
  94. }
  95. TComputationContext& CompCtx;
  96. const TPtr Func;
  97. };
  98. template <class TIterator>
  99. class TForwardListValue : public TCustomListValue {
  100. public:
  101. TForwardListValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, typename TIterator::TPtr ptr)
  102. : TCustomListValue(memInfo), CompCtx(compCtx), Ptr(ptr)
  103. {}
  104. private:
  105. NUdf::TUnboxedValue GetListIterator() const final {
  106. if (const auto ptr = Ptr) {
  107. Ptr = nullptr;
  108. return CompCtx.HolderFactory.Create<TIterator>(CompCtx, ptr);
  109. }
  110. THROW yexception() << "Second pass on forward list.";
  111. }
  112. TComputationContext& CompCtx;
  113. mutable typename TIterator::TPtr Ptr;
  114. };
  115. TFlowForwardListWrapper(TComputationMutables& mutables, IComputationNode* flow)
  116. : TBaseComputation(mutables), Flow(flow)
  117. {}
  118. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  119. #ifndef MKQL_DISABLE_CODEGEN
  120. if (ctx.ExecuteLLVM && Next)
  121. return ctx.HolderFactory.Create<TForwardListValue<TCodegenIterator>>(ctx, Next);
  122. #endif
  123. return ctx.HolderFactory.Create<TForwardListValue<TIterator>>(ctx, Flow);
  124. }
  125. private:
  126. void RegisterDependencies() const final {
  127. this->DependsOn(Flow);
  128. }
  129. [[noreturn]] static void Throw() {
  130. UdfTerminate("Unexpected flow status.");
  131. }
  132. #ifndef MKQL_DISABLE_CODEGEN
  133. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  134. NextFunc = GenerateNext(codegen);
  135. codegen.ExportSymbol(NextFunc);
  136. }
  137. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  138. if (NextFunc)
  139. Next = reinterpret_cast<TCodegenIterator::TPtr>(codegen.GetPointerToFunction(NextFunc));
  140. }
  141. Function* GenerateNext(NYql::NCodegen::ICodegen& codegen) const {
  142. auto& module = codegen.GetModule();
  143. auto& context = codegen.GetContext();
  144. const auto& name = TBaseComputation::MakeName("Next");
  145. if (const auto f = module.getFunction(name.c_str()))
  146. return f;
  147. const auto valueType = Type::getInt128Ty(context);
  148. const auto contextType = GetCompContextType(context);
  149. const auto funcType = FunctionType::get(Type::getInt1Ty(context), {PointerType::getUnqual(contextType), PointerType::getUnqual(valueType)}, false);
  150. TCodegenContext ctx(codegen);
  151. ctx.Func = cast<Function>(module.getOrInsertFunction(name.c_str(), funcType).getCallee());
  152. DISubprogramAnnotator annotator(ctx, ctx.Func);
  153. auto args = ctx.Func->arg_begin();
  154. ctx.Ctx = &*args;
  155. const auto valuePtr = &*++args;
  156. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  157. auto block = main;
  158. SafeUnRefUnboxedOne(valuePtr, ctx, block);
  159. GetNodeValue(valuePtr, Flow, ctx, block);
  160. const auto value = new LoadInst(valueType, valuePtr, "value", block);
  161. const auto kill = BasicBlock::Create(context, "kill", ctx.Func);
  162. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  163. BranchInst::Create(kill, good, IsYield(value, block, context), block);
  164. block = kill;
  165. const auto doThrow = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TFlowForwardListWrapper::Throw));
  166. const auto doThrowType = FunctionType::get(Type::getVoidTy(context), {}, false);
  167. const auto doThrowPtr = CastInst::Create(Instruction::IntToPtr, doThrow, PointerType::getUnqual(doThrowType), "thrower", block);
  168. CallInst::Create(doThrowType, doThrowPtr, {}, "", block)->setTailCall();
  169. new UnreachableInst(context, block);
  170. block = good;
  171. const auto result = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, value, GetFinish(context), "result", block);
  172. ReturnInst::Create(context, result, block);
  173. return ctx.Func;
  174. }
  175. Function* NextFunc = nullptr;
  176. TCodegenIterator::TPtr Next = nullptr;
  177. #endif
  178. IComputationNode* const Flow;
  179. };
  180. }
  181. IComputationNode* WrapEmptyIterator(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  182. MKQL_ENSURE(callable.GetInputsCount() == 0, "Expected 0 arg");
  183. const auto type = callable.GetType()->GetReturnType();
  184. if (type->IsFlow()) {
  185. return ctx.NodeFactory.CreateImmutableNode(NUdf::TUnboxedValuePod::MakeFinish());
  186. } else if (type->IsStream()) {
  187. return ctx.NodeFactory.CreateEmptyNode();
  188. }
  189. THROW yexception() << "Expected flow or stream.";
  190. }
  191. IComputationNode* WrapIterator(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  192. MKQL_ENSURE(callable.GetInputsCount() >= 1, "Expected at least 1 arg");
  193. const auto type = callable.GetInput(0).GetStaticType();
  194. MKQL_ENSURE(type->IsList(), "Requires list");
  195. TComputationNodePtrVector dependentNodes(callable.GetInputsCount() - 1);
  196. for (ui32 i = 1; i < callable.GetInputsCount(); ++i) {
  197. dependentNodes[i - 1] = LocateNode(ctx.NodeLocator, callable, i);
  198. }
  199. return new TIteratorWrapper(ctx.Mutables, LocateNode(ctx.NodeLocator, callable, 0), std::move(dependentNodes));
  200. }
  201. IComputationNode* WrapForwardList(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  202. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 arg");
  203. const auto type = callable.GetInput(0).GetStaticType();
  204. if (type->IsFlow()) {
  205. return new TFlowForwardListWrapper(ctx.Mutables, LocateNode(ctx.NodeLocator, callable, 0));
  206. } else if (type->IsStream()) {
  207. return new TForwardListWrapper(ctx.Mutables, LocateNode(ctx.NodeLocator, callable, 0));
  208. }
  209. THROW yexception() << "Expected flow or stream.";
  210. }
  211. }
  212. }