mkql_take.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. #include "mkql_take.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  4. #include <yql/essentials/minikql/mkql_node_cast.h>
  5. namespace NKikimr {
  6. namespace NMiniKQL {
  7. namespace {
  8. class TTakeFlowWrapper : public TStatefulFlowCodegeneratorNode<TTakeFlowWrapper> {
  9. using TBaseComputation = TStatefulFlowCodegeneratorNode<TTakeFlowWrapper>;
  10. public:
  11. TTakeFlowWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* flow, IComputationNode* count)
  12. : TBaseComputation(mutables, flow, kind, EValueRepresentation::Embedded), Flow(flow), Count(count)
  13. {}
  14. NUdf::TUnboxedValue DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  15. if (state.IsInvalid()) {
  16. state = Count->GetValue(ctx);
  17. }
  18. if (auto count = state.Get<ui64>()) {
  19. const auto item = Flow->GetValue(ctx);
  20. if (!(item.IsSpecial())) {
  21. state = NUdf::TUnboxedValuePod(--count);
  22. }
  23. return item;
  24. }
  25. return NUdf::TUnboxedValuePod::MakeFinish();
  26. }
  27. #ifndef MKQL_DISABLE_CODEGEN
  28. Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  29. auto& context = ctx.Codegen.GetContext();
  30. const auto valueType = Type::getInt128Ty(context);
  31. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  32. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  33. const auto load = new LoadInst(valueType, statePtr, "load", block);
  34. const auto state = PHINode::Create(load->getType(), 2U, "state", main);
  35. state->addIncoming(load, block);
  36. BranchInst::Create(init, main, IsInvalid(load, block, context), block);
  37. block = init;
  38. GetNodeValue(statePtr, Count, ctx, block);
  39. const auto save = new LoadInst(valueType, statePtr, "save", block);
  40. state->addIncoming(save, block);
  41. BranchInst::Create(main, block);
  42. block = main;
  43. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  44. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  45. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  46. const auto result = PHINode::Create(valueType, 3U, "result", done);
  47. result->addIncoming(GetFinish(context), block);
  48. const auto trunc = GetterFor<ui64>(state, context, block);
  49. const auto plus = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, trunc, ConstantInt::get(trunc->getType(), 0ULL), "plus", block);
  50. BranchInst::Create(work, done, plus, block);
  51. block = work;
  52. const auto item = GetNodeValue(Flow, ctx, block);
  53. result->addIncoming(item, block);
  54. BranchInst::Create(done, good, IsSpecial(item, block, context), block);
  55. block = good;
  56. result->addIncoming(item, block);
  57. const auto decr = BinaryOperator::CreateSub(trunc, ConstantInt::get(trunc->getType(), 1ULL), "decr", block);
  58. new StoreInst(SetterFor<ui64>(decr, context, block), statePtr, block);
  59. BranchInst::Create(done, block);
  60. block = done;
  61. return result;
  62. }
  63. #endif
  64. private:
  65. void RegisterDependencies() const final {
  66. if (const auto flow = FlowDependsOn(Flow))
  67. DependsOn(flow, Count);
  68. }
  69. IComputationNode* const Flow;
  70. IComputationNode* const Count;
  71. };
  72. class TWideTakeWrapper : public TStatefulWideFlowCodegeneratorNode<TWideTakeWrapper> {
  73. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideTakeWrapper>;
  74. public:
  75. TWideTakeWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, IComputationNode* count)
  76. : TBaseComputation(mutables, flow, EValueRepresentation::Embedded), Flow(flow), Count(count)
  77. {}
  78. EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  79. if (state.IsInvalid()) {
  80. state = Count->GetValue(ctx);
  81. }
  82. if (auto count = state.Get<ui64>()) {
  83. if (const auto result = Flow->FetchValues(ctx, output); EFetchResult::One == result) {
  84. state = NUdf::TUnboxedValuePod(--count);
  85. return EFetchResult::One;
  86. } else {
  87. return result;
  88. }
  89. }
  90. return EFetchResult::Finish;
  91. }
  92. #ifndef MKQL_DISABLE_CODEGEN
  93. TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  94. auto& context = ctx.Codegen.GetContext();
  95. const auto valueType = Type::getInt128Ty(context);
  96. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  97. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  98. const auto load = new LoadInst(valueType, statePtr, "load", block);
  99. const auto state = PHINode::Create(load->getType(), 2U, "state", main);
  100. state->addIncoming(load, block);
  101. BranchInst::Create(init, main, IsInvalid(load, block, context), block);
  102. block = init;
  103. GetNodeValue(statePtr, Count, ctx, block);
  104. const auto save = new LoadInst(valueType, statePtr, "save", block);
  105. state->addIncoming(save, block);
  106. BranchInst::Create(main, block);
  107. block = main;
  108. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  109. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  110. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  111. const auto resultType = Type::getInt32Ty(context);
  112. const auto result = PHINode::Create(resultType, 3U, "result", done);
  113. result->addIncoming(ConstantInt::get(resultType, static_cast<i32>(EFetchResult::Finish)), block);
  114. const auto trunc = GetterFor<ui64>(state, context, block);
  115. const auto plus = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, trunc, ConstantInt::get(trunc->getType(), 0ULL), "plus", block);
  116. BranchInst::Create(work, done, plus, block);
  117. block = work;
  118. const auto getres = GetNodeValues(Flow, ctx, block);
  119. const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, getres.first, ConstantInt::get(getres.first->getType(), 0), "special", block);
  120. result->addIncoming(getres.first, block);
  121. BranchInst::Create(done, good, special, block);
  122. block = good;
  123. const auto decr = BinaryOperator::CreateSub(trunc, ConstantInt::get(trunc->getType(), 1ULL), "decr", block);
  124. new StoreInst(SetterFor<ui64>(decr, context, block), statePtr, block);
  125. result->addIncoming(getres.first, block);
  126. BranchInst::Create(done, block);
  127. block = done;
  128. return {result, std::move(getres.second)};
  129. }
  130. #endif
  131. private:
  132. void RegisterDependencies() const final {
  133. if (const auto flow = FlowDependsOn(Flow))
  134. DependsOn(flow, Count);
  135. }
  136. IComputationWideFlowNode* const Flow;
  137. IComputationNode* const Count;
  138. };
  139. class TTakeStreamWrapper : public TMutableComputationNode<TTakeStreamWrapper> {
  140. typedef TMutableComputationNode<TTakeStreamWrapper> TBaseComputation;
  141. public:
  142. class TStreamValue : public TComputationValue<TStreamValue> {
  143. public:
  144. using TBase = TComputationValue<TStreamValue>;
  145. TStreamValue(TMemoryUsageInfo* memInfo, NUdf::TUnboxedValue&& input, ui64 count)
  146. : TBase(memInfo)
  147. , Input_(std::move(input))
  148. , Count_(count)
  149. , Index_(0)
  150. {}
  151. private:
  152. NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) override {
  153. if (Index_ >= Count_) {
  154. return NUdf::EFetchStatus::Finish;
  155. }
  156. const auto status = Input_.Fetch(result);
  157. if (status != NUdf::EFetchStatus::Ok) {
  158. return status;
  159. }
  160. ++Index_;
  161. return status;
  162. }
  163. const NUdf::TUnboxedValue Input_;
  164. const ui64 Count_;
  165. ui64 Index_;
  166. };
  167. TTakeStreamWrapper(TComputationMutables& mutables, IComputationNode* list, IComputationNode* count)
  168. : TBaseComputation(mutables, EValueRepresentation::Boxed)
  169. , List(list)
  170. , Count(count)
  171. {
  172. }
  173. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  174. return ctx.HolderFactory.Create<TStreamValue>(List->GetValue(ctx), Count->GetValue(ctx).Get<ui64>());
  175. }
  176. private:
  177. void RegisterDependencies() const final {
  178. DependsOn(List);
  179. DependsOn(Count);
  180. }
  181. IComputationNode* const List;
  182. IComputationNode* const Count;
  183. };
  184. class TTakeWrapper : public TMutableCodegeneratorNode<TTakeWrapper> {
  185. typedef TMutableCodegeneratorNode<TTakeWrapper> TBaseComputation;
  186. public:
  187. TTakeWrapper(TComputationMutables& mutables, IComputationNode* list, IComputationNode* count)
  188. : TBaseComputation(mutables, list->GetRepresentation())
  189. , List(list)
  190. , Count(count)
  191. {
  192. }
  193. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  194. return ctx.HolderFactory.TakeList(ctx.Builder, List->GetValue(ctx).Release(), Count->GetValue(ctx).Get<ui64>());
  195. }
  196. #ifndef MKQL_DISABLE_CODEGEN
  197. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  198. auto& context = ctx.Codegen.GetContext();
  199. const auto factory = ctx.GetFactory();
  200. const auto builder = ctx.GetBuilder();
  201. const auto func = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&THolderFactory::TakeList));
  202. const auto list = GetNodeValue(List, ctx, block);
  203. const auto cnt = GetNodeValue(Count, ctx, block);
  204. const auto count = GetterFor<ui64>(cnt, context, block);
  205. const auto funType = FunctionType::get(list->getType(), {factory->getType(), builder->getType(), list->getType(), count->getType()}, false);
  206. const auto funcPtr = CastInst::Create(Instruction::IntToPtr, func, PointerType::getUnqual(funType), "function", block);
  207. const auto result = CallInst::Create(funType, funcPtr, {factory, builder, list, count}, "result", block);
  208. return result;
  209. }
  210. #endif
  211. private:
  212. void RegisterDependencies() const final {
  213. DependsOn(List);
  214. DependsOn(Count);
  215. }
  216. IComputationNode* const List;
  217. IComputationNode* const Count;
  218. };
  219. }
  220. IComputationNode* WrapTake(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  221. MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args");
  222. const auto type = callable.GetInput(0).GetStaticType();
  223. const auto flow = LocateNode(ctx.NodeLocator, callable, 0);
  224. const auto count = LocateNode(ctx.NodeLocator, callable, 1);
  225. if (type->IsFlow()) {
  226. if (const auto wide = dynamic_cast<IComputationWideFlowNode*>(flow))
  227. return new TWideTakeWrapper(ctx.Mutables, wide, count);
  228. else
  229. return new TTakeFlowWrapper(ctx.Mutables, GetValueRepresentation(type), flow, count);
  230. } else if (type->IsStream()) {
  231. return new TTakeStreamWrapper(ctx.Mutables, flow, count);
  232. } else if (type->IsList()) {
  233. return new TTakeWrapper(ctx.Mutables, flow, count);
  234. }
  235. THROW yexception() << "Expected flow, list or stream.";
  236. }
  237. }
  238. }