mkql_skip.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. #include "mkql_skip.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  4. #include <yql/essentials/minikql/computation/mkql_simple_codegen.h>
  5. #include <yql/essentials/minikql/mkql_node_cast.h>
  6. namespace NKikimr {
  7. namespace NMiniKQL {
  8. namespace {
  9. class TSkipFlowWrapper : public TStatefulFlowCodegeneratorNode<TSkipFlowWrapper> {
  10. using TBaseComputation = TStatefulFlowCodegeneratorNode<TSkipFlowWrapper>;
  11. public:
  12. TSkipFlowWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* flow, IComputationNode* count)
  13. : TBaseComputation(mutables, flow, kind, EValueRepresentation::Embedded), Flow(flow), Count(count)
  14. {}
  15. NUdf::TUnboxedValue DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  16. if (state.IsInvalid()) {
  17. state = Count->GetValue(ctx);
  18. }
  19. if (auto count = state.Get<ui64>()) {
  20. do {
  21. const auto item = Flow->GetValue(ctx);
  22. if (item.IsSpecial()) {
  23. state = NUdf::TUnboxedValuePod(count);
  24. return item;
  25. }
  26. } while (--count);
  27. state = NUdf::TUnboxedValuePod::Zero();
  28. }
  29. return Flow->GetValue(ctx);
  30. }
  31. #ifndef MKQL_DISABLE_CODEGEN
  32. Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  33. auto& context = ctx.Codegen.GetContext();
  34. const auto valueType = Type::getInt128Ty(context);
  35. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  36. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  37. const auto load = new LoadInst(valueType, statePtr, "load", block);
  38. const auto state = PHINode::Create(valueType, 2U, "state", main);
  39. state->addIncoming(load, block);
  40. BranchInst::Create(init, main, IsInvalid(load, block, context), block);
  41. block = init;
  42. GetNodeValue(statePtr, Count, ctx, block);
  43. const auto save = new LoadInst(valueType, statePtr, "save", block);
  44. state->addIncoming(save, block);
  45. BranchInst::Create(main, block);
  46. block = main;
  47. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  48. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  49. const auto pass = BasicBlock::Create(context, "pass", ctx.Func);
  50. const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
  51. const auto skip = BasicBlock::Create(context, "skip", ctx.Func);
  52. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  53. const auto result = PHINode::Create(valueType, 2U, "result", done);
  54. const auto trunc = GetterFor<ui64>(state, context, block);
  55. const auto count = PHINode::Create(trunc->getType(), 2U, "count", work);
  56. count->addIncoming(trunc, block);
  57. const auto plus = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, trunc, ConstantInt::get(trunc->getType(), 0ULL), "plus", block);
  58. BranchInst::Create(work, skip, plus, block);
  59. block = work;
  60. const auto item = GetNodeValue(Flow, ctx, block);
  61. BranchInst::Create(pass, good, IsSpecial(item, block, context), block);
  62. block = pass;
  63. result->addIncoming(item, block);
  64. new StoreInst(SetterFor<ui64>(count, context, block), statePtr, block);
  65. BranchInst::Create(done, block);
  66. block = good;
  67. ValueCleanup(Flow->GetRepresentation(), item, ctx, block);
  68. const auto decr = BinaryOperator::CreateSub(count, ConstantInt::get(count->getType(), 1ULL), "decr", block);
  69. count->addIncoming(decr, block);
  70. const auto next = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, decr, ConstantInt::get(decr->getType(), 0ULL), "next", block);
  71. BranchInst::Create(work, exit, next, block);
  72. block = exit;
  73. new StoreInst(SetterFor<ui64>(decr, context, block), statePtr, block);
  74. BranchInst::Create(skip, block);
  75. block = skip;
  76. const auto res = GetNodeValue(Flow, ctx, block);
  77. result->addIncoming(res, block);
  78. BranchInst::Create(done, block);
  79. block = done;
  80. return result;
  81. }
  82. #endif
  83. private:
  84. void RegisterDependencies() const final {
  85. if (const auto flow = FlowDependsOn(Flow))
  86. DependsOn(flow, Count);
  87. }
  88. IComputationNode* const Flow;
  89. IComputationNode* const Count;
  90. };
  91. class TWideSkipWrapper : public TSimpleStatefulWideFlowCodegeneratorNode<TWideSkipWrapper, ui64> {
  92. using TBaseComputation = TSimpleStatefulWideFlowCodegeneratorNode<TWideSkipWrapper, ui64>;
  93. public:
  94. TWideSkipWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, IComputationNode* count, ui32 size)
  95. : TBaseComputation(mutables, flow, size, size)
  96. , Flow(flow)
  97. , Count(count)
  98. , StubsIndex(mutables.IncrementWideFieldsIndex(size))
  99. {}
  100. void InitState(NUdf::TUnboxedValue& cntToSkip, TComputationContext& ctx) const {
  101. cntToSkip = Count->GetValue(ctx);
  102. }
  103. NUdf::TUnboxedValue*const* PrepareInput(NUdf::TUnboxedValue& cntToSkip, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  104. return cntToSkip.Get<ui64>() ? ctx.WideFields.data() + StubsIndex : output;
  105. }
  106. TMaybeFetchResult DoProcess(NUdf::TUnboxedValue& cntToSkip, TComputationContext&, TMaybeFetchResult fetchRes, NUdf::TUnboxedValue*const*) const {
  107. if (fetchRes.Get() == EFetchResult::One && cntToSkip.Get<ui64>()) {
  108. cntToSkip = NUdf::TUnboxedValuePod(cntToSkip.Get<ui64>() - 1);
  109. return TMaybeFetchResult::None();
  110. }
  111. return fetchRes;
  112. }
  113. #ifndef MKQL_DISABLE_CODEGEN
  114. TGenerateResult GenFetchProcess(Value* statePtrVal, const TCodegenContext& ctx, const TResultCodegenerator& fetchGenerator, BasicBlock*& block) const override {
  115. auto& context = ctx.Codegen.GetContext();
  116. const auto decr = BasicBlock::Create(context, "decr", ctx.Func);
  117. const auto end = BasicBlock::Create(context, "end", ctx.Func);
  118. const auto fetched = fetchGenerator(ctx, block);
  119. const auto cntToSkipVal = GetterFor<ui64>(new LoadInst(IntegerType::getInt128Ty(context), statePtrVal, "unboxed_state", block), context, block);
  120. const auto needSkipCond = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, cntToSkipVal, ConstantInt::get(cntToSkipVal->getType(), 0), "need_skip", block);
  121. const auto gotOneCond = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, fetched.first, ConstantInt::get(fetched.first->getType(), 1), "got_one", block);
  122. const auto willSkipCond = BinaryOperator::Create(Instruction::And, needSkipCond, gotOneCond, "will_skip", block);
  123. BranchInst::Create(decr, end, willSkipCond, block);
  124. block = decr;
  125. const auto cntToSkipNewVal = BinaryOperator::CreateSub(cntToSkipVal, ConstantInt::get(cntToSkipVal->getType(), 1), "decr", block);
  126. new StoreInst(SetterFor<ui64>(cntToSkipNewVal, context, block), statePtrVal, block);
  127. BranchInst::Create(end, block);
  128. block = end;
  129. const auto result = SelectInst::Create(willSkipCond, TMaybeFetchResult::None().LLVMConst(context), TMaybeFetchResult::LLVMFromFetchResult(fetched.first, "fetch_res_ext", block), "result", block);
  130. return {result, fetched.second};
  131. }
  132. #endif
  133. private:
  134. void RegisterDependencies() const final {
  135. if (const auto flow = FlowDependsOn(Flow))
  136. DependsOn(flow, Count);
  137. }
  138. IComputationWideFlowNode* const Flow;
  139. IComputationNode* const Count;
  140. const ui32 StubsIndex;
  141. };
  142. class TSkipStreamWrapper : public TMutableComputationNode<TSkipStreamWrapper> {
  143. typedef TMutableComputationNode<TSkipStreamWrapper> TBaseComputation;
  144. public:
  145. class TStreamValue : public TComputationValue<TStreamValue> {
  146. public:
  147. using TBase = TComputationValue<TStreamValue>;
  148. TStreamValue(TMemoryUsageInfo* memInfo, NUdf::TUnboxedValue&& input, ui64 count)
  149. : TBase(memInfo)
  150. , Input_(std::move(input))
  151. , Count_(count)
  152. , Index_(0)
  153. {}
  154. private:
  155. NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) override {
  156. for (;;) {
  157. if (Index_ >= Count_) {
  158. return Input_.Fetch(result);
  159. }
  160. auto status = Input_.Fetch(result);
  161. if (status != NUdf::EFetchStatus::Ok) {
  162. return status;
  163. }
  164. ++Index_;
  165. }
  166. }
  167. const NUdf::TUnboxedValue Input_;
  168. const ui64 Count_;
  169. ui64 Index_;
  170. };
  171. TSkipStreamWrapper(TComputationMutables& mutables, IComputationNode* list, IComputationNode* count)
  172. : TBaseComputation(mutables, list->GetRepresentation())
  173. , List(list)
  174. , Count(count)
  175. {
  176. }
  177. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  178. return ctx.HolderFactory.Create<TStreamValue>(List->GetValue(ctx), Count->GetValue(ctx).Get<ui64>());
  179. }
  180. void RegisterDependencies() const final {
  181. DependsOn(List);
  182. DependsOn(Count);
  183. }
  184. private:
  185. IComputationNode* const List;
  186. IComputationNode* const Count;
  187. };
  188. class TSkipWrapper : public TMutableCodegeneratorNode<TSkipWrapper> {
  189. typedef TMutableCodegeneratorNode<TSkipWrapper> TBaseComputation;
  190. public:
  191. TSkipWrapper(TComputationMutables& mutables, IComputationNode* list, IComputationNode* count)
  192. : TBaseComputation(mutables, list->GetRepresentation())
  193. , List(list)
  194. , Count(count)
  195. {
  196. }
  197. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  198. return ctx.HolderFactory.SkipList(ctx.Builder, List->GetValue(ctx).Release(), Count->GetValue(ctx).Get<ui64>());
  199. }
  200. #ifndef MKQL_DISABLE_CODEGEN
  201. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  202. auto& context = ctx.Codegen.GetContext();
  203. const auto factory = ctx.GetFactory();
  204. const auto builder = ctx.GetBuilder();
  205. const auto func = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&THolderFactory::SkipList));
  206. const auto list = GetNodeValue(List, ctx, block);
  207. const auto cnt = GetNodeValue(Count, ctx, block);
  208. const auto count = GetterFor<ui64>(cnt, context, block);
  209. const auto funType = FunctionType::get(list->getType(), {factory->getType(), builder->getType(), list->getType(), count->getType()}, false);
  210. const auto funcPtr = CastInst::Create(Instruction::IntToPtr, func, PointerType::getUnqual(funType), "function", block);
  211. const auto result = CallInst::Create(funType, funcPtr, {factory, builder, list, count}, "result", block);
  212. return result;
  213. }
  214. #endif
  215. private:
  216. void RegisterDependencies() const final {
  217. DependsOn(List);
  218. DependsOn(Count);
  219. }
  220. IComputationNode* const List;
  221. IComputationNode* const Count;
  222. };
  223. }
  224. IComputationNode* WrapSkip(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  225. MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args");
  226. const auto type = callable.GetInput(0).GetStaticType();
  227. const auto flow = LocateNode(ctx.NodeLocator, callable, 0);
  228. const auto count = LocateNode(ctx.NodeLocator, callable, 1);
  229. if (type->IsFlow()) {
  230. if (const auto wide = dynamic_cast<IComputationWideFlowNode*>(flow))
  231. return new TWideSkipWrapper(ctx.Mutables, wide, count, GetWideComponentsCount(AS_TYPE(TFlowType, type)));
  232. else
  233. return new TSkipFlowWrapper(ctx.Mutables, GetValueRepresentation(type), flow, count);
  234. } else if (type->IsStream()) {
  235. return new TSkipStreamWrapper(ctx.Mutables, flow, count);
  236. } else if (type->IsList()) {
  237. return new TSkipWrapper(ctx.Mutables, flow, count);
  238. }
  239. THROW yexception() << "Expected flow, list or stream.";
  240. }
  241. }
  242. }