mkql_fold.cpp 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. #include "mkql_fold.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  3. #include <yql/essentials/minikql/mkql_node_cast.h>
  4. namespace NKikimr {
  5. namespace NMiniKQL {
  6. namespace {
  7. class TFoldWrapper : public TMutableCodegeneratorRootNode<TFoldWrapper> {
  8. typedef TMutableCodegeneratorRootNode<TFoldWrapper> TBaseComputation;
  9. public:
  10. TFoldWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* list, IComputationExternalNode* item, IComputationExternalNode* state,
  11. IComputationNode* newState, IComputationNode* initialState)
  12. : TBaseComputation(mutables, kind)
  13. , List(list)
  14. , Item(item)
  15. , State(state)
  16. , NewState(newState)
  17. , InitialState(initialState)
  18. {
  19. }
  20. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const {
  21. State->SetValue(compCtx, InitialState->GetValue(compCtx));
  22. TThresher<false>::DoForEachItem(List->GetValue(compCtx),
  23. [this, &compCtx] (NUdf::TUnboxedValue&& item) {
  24. Item->SetValue(compCtx, std::move(item));
  25. State->SetValue(compCtx, NewState->GetValue(compCtx));
  26. }
  27. );
  28. return State->GetValue(compCtx).Release();
  29. }
  30. #ifndef MKQL_DISABLE_CODEGEN
  31. llvm::Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  32. auto &context = ctx.Codegen.GetContext();
  33. const auto codegenState = dynamic_cast<ICodegeneratorExternalNode*>(State);
  34. const auto codegenItem = dynamic_cast<ICodegeneratorExternalNode*>(Item);
  35. MKQL_ENSURE(codegenState, "State must be codegenerator node.");
  36. MKQL_ENSURE(codegenItem, "Item must be codegenerator node.");
  37. const auto valueType = Type::getInt128Ty(context);
  38. const auto ptrType = PointerType::getUnqual(valueType);
  39. const auto init = GetNodeValue(InitialState, ctx, block);
  40. codegenState->CreateSetValue(ctx, block, init);
  41. const auto list = GetNodeValue(List, ctx, block);
  42. const auto elements = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElements>(ptrType, list, ctx.Codegen, block);
  43. const auto null = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, elements, ConstantPointerNull::get(ptrType), "null", block);
  44. const auto fast = BasicBlock::Create(context, "fast", ctx.Func);
  45. const auto slow = BasicBlock::Create(context, "slow", ctx.Func);
  46. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  47. BranchInst::Create(slow, fast, null, block);
  48. {
  49. block = fast;
  50. const auto sizeType = Type::getInt64Ty(context);
  51. const auto size = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetListLength>(sizeType, list, ctx.Codegen, block);
  52. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  53. const auto index = PHINode::Create(sizeType, 2U, "index", loop);
  54. index->addIncoming(ConstantInt::get(sizeType, 0), block);
  55. BranchInst::Create(loop, block);
  56. block = loop;
  57. const auto more = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_ULT, index, size, "more", block);
  58. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  59. BranchInst::Create(good, done, more, block);
  60. block = good;
  61. const auto itemPtr = GetElementPtrInst::CreateInBounds(valueType, elements, {index}, "item_ptr", block);
  62. const auto item = new LoadInst(valueType, itemPtr, "item", block);
  63. codegenItem->CreateSetValue(ctx, block, item);
  64. const auto newState = GetNodeValue(NewState, ctx, block);
  65. codegenState->CreateSetValue(ctx, block, newState);
  66. const auto next = BinaryOperator::CreateAdd(index, ConstantInt::get(sizeType, 1), "next", block);
  67. index->addIncoming(next, block);
  68. BranchInst::Create(loop, block);
  69. }
  70. {
  71. block = slow;
  72. const auto iterPtr = *Stateless || ctx.AlwaysInline ?
  73. new AllocaInst(valueType, 0U, "iter_ptr", &ctx.Func->getEntryBlock().back()):
  74. new AllocaInst(valueType, 0U, "iter_ptr", block);
  75. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetListIterator>(iterPtr, list, ctx.Codegen, block);
  76. const auto iter = new LoadInst(valueType, iterPtr, "iter", block);
  77. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  78. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  79. const auto stop = BasicBlock::Create(context, "stop", ctx.Func);
  80. BranchInst::Create(loop, block);
  81. block = loop;
  82. const auto itemPtr = codegenItem->CreateRefValue(ctx, block);
  83. const auto status = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Next>(Type::getInt1Ty(context), iter, ctx.Codegen, block, itemPtr);
  84. BranchInst::Create(good, stop, status, block);
  85. block = good;
  86. const auto newState = GetNodeValue(NewState, ctx, block);
  87. codegenState->CreateSetValue(ctx, block, newState);
  88. BranchInst::Create(loop, block);
  89. block = stop;
  90. UnRefBoxed(iter, ctx, block);
  91. BranchInst::Create(done, block);
  92. }
  93. block = done;
  94. if (List->IsTemporaryValue())
  95. CleanupBoxed(list, ctx, block);
  96. return codegenState->CreateGetValue(ctx, block);
  97. }
  98. #endif
  99. private:
  100. void RegisterDependencies() const final {
  101. this->DependsOn(List);
  102. this->DependsOn(InitialState);
  103. this->Own(Item);
  104. this->Own(State);
  105. this->DependsOn(NewState);
  106. }
  107. IComputationNode* const List;
  108. IComputationExternalNode* const Item;
  109. IComputationExternalNode* const State;
  110. IComputationNode* const NewState;
  111. IComputationNode* const InitialState;
  112. };
  113. }
  114. IComputationNode* WrapFold(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  115. MKQL_ENSURE(callable.GetInputsCount() == 5, "Expected 5 args");
  116. MKQL_ENSURE(callable.GetInput(0).GetStaticType()->IsList(), "Expected List");
  117. const auto list = LocateNode(ctx.NodeLocator, callable, 0);
  118. const auto initialState = LocateNode(ctx.NodeLocator, callable, 1);
  119. const auto newState = LocateNode(ctx.NodeLocator, callable, 4);
  120. const auto item = LocateExternalNode(ctx.NodeLocator, callable, 2);
  121. const auto state = LocateExternalNode(ctx.NodeLocator, callable, 3);
  122. const auto kind = GetValueRepresentation(callable.GetType()->GetReturnType());
  123. return new TFoldWrapper(ctx.Mutables, kind, list, item, state, newState, initialState);
  124. }
  125. }
  126. }