mkql_block_skiptake.cpp 20 KB


  1. #include "mkql_block_skiptake.h"
  2. #include <yql/essentials/minikql/computation/mkql_block_impl.h>
  3. #include <yql/essentials/minikql/arrow/arrow_defs.h>
  4. #include <yql/essentials/minikql/arrow/arrow_util.h>
  5. #include <yql/essentials/minikql/mkql_type_builder.h>
  6. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  7. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  8. #include <yql/essentials/minikql/mkql_node_builder.h>
  9. #include <yql/essentials/minikql/mkql_node_cast.h>
  10. namespace NKikimr {
  11. namespace NMiniKQL {
  12. namespace {
  13. class TWideSkipBlocksWrapper : public TStatefulWideFlowCodegeneratorNode<TWideSkipBlocksWrapper> {
  14. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideSkipBlocksWrapper>;
  15. public:
  16. TWideSkipBlocksWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, IComputationNode* count, ui32 size)
  17. : TBaseComputation(mutables, flow, EValueRepresentation::Embedded), Flow(flow), Count(count), Width(size - 1U)
  18. {}
  19. EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  20. if (state.IsInvalid()) {
  21. state = Count->GetValue(ctx);
  22. }
  23. if (auto count = state.Get<ui64>()) while (true) {
  24. if (const auto result = Flow->FetchValues(ctx, output); EFetchResult::One != result) {
  25. state = NUdf::TUnboxedValuePod(count);
  26. return result;
  27. }
  28. if (const auto blockSize = GetBlockCount(*output[Width]); count < blockSize) {
  29. state = NUdf::TUnboxedValuePod::Zero();
  30. *output[Width] = MakeBlockCount(ctx.HolderFactory, blockSize - count);
  31. for (auto i = 0U; i < Width; ++i)
  32. if (const auto out = output[i])
  33. *out = SliceBlock(ctx.HolderFactory, *out, count);
  34. return EFetchResult::One;
  35. } else
  36. count -= blockSize;
  37. }
  38. return Flow->FetchValues(ctx, output);
  39. }
  40. #ifndef MKQL_DISABLE_CODEGEN
  41. TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  42. auto& context = ctx.Codegen.GetContext();
  43. const auto indexType = Type::getInt64Ty(context);
  44. const auto valueType = Type::getInt128Ty(context);
  45. const auto atTop = &ctx.Func->getEntryBlock().back();
  46. const auto offsetPtr = new AllocaInst(indexType, 0U, "offset_ptr", atTop);
  47. const auto sizePtr = new AllocaInst(indexType, 0U, "size_ptr", atTop);
  48. const auto sliceFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TWideSkipBlocksWrapper::SliceBlock));
  49. const auto sliceType = FunctionType::get(valueType, {ctx.GetFactory()->getType(), valueType, indexType}, false);
  50. const auto slicePtr = CastInst::Create(Instruction::IntToPtr, sliceFunc, PointerType::getUnqual(sliceType), "slice", atTop);
  51. const auto name = "GetBlockCount";
  52. ctx.Codegen.AddGlobalMapping(name, reinterpret_cast<const void*>(&GetBlockCount));
  53. const auto getCountType = FunctionType::get(indexType, { valueType }, false);
  54. const auto getCount = ctx.Codegen.GetModule().getOrInsertFunction(name, getCountType);
  55. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  56. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  57. const auto load = new LoadInst(valueType, statePtr, "load", block);
  58. const auto state = PHINode::Create(valueType, 2U, "state", main);
  59. state->addIncoming(load, block);
  60. BranchInst::Create(init, main, IsInvalid(load, block, context), block);
  61. block = init;
  62. GetNodeValue(statePtr, Count, ctx, block);
  63. const auto save = new LoadInst(valueType, statePtr, "save", block);
  64. state->addIncoming(save, block);
  65. BranchInst::Create(main, block);
  66. block = main;
  67. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  68. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  69. const auto test = BasicBlock::Create(context, "test", ctx.Func);
  70. const auto over = BasicBlock::Create(context, "over", ctx.Func);
  71. const auto pass = BasicBlock::Create(context, "pass", ctx.Func);
  72. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  73. const auto resultType = Type::getInt32Ty(context);
  74. const auto result = PHINode::Create(resultType, 2U, "result", done);
  75. const auto trunc = GetterFor<ui64>(state, context, block);
  76. const auto count = PHINode::Create(trunc->getType(), 2U, "count", work);
  77. count->addIncoming(trunc, block);
  78. BranchInst::Create(work, block);
  79. block = work;
  80. const auto getres = GetNodeValues(Flow, ctx, block);
  81. const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, getres.first, ConstantInt::get(getres.first->getType(), static_cast<i32>(EFetchResult::Yield)), "special", block);
  82. BranchInst::Create(pass, good, special, block);
  83. block = good;
  84. const auto more = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, count, ConstantInt::get(indexType, 0), "more", block);
  85. BranchInst::Create(test, pass, more, block);
  86. block = test;
  87. const auto countValue = getres.second.back()(ctx, block);
  88. const auto height = CallInst::Create(getCount, { countValue }, "height", block);
  89. ValueCleanup(EValueRepresentation::Any, countValue, ctx, block);
  90. const auto part = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_ULT, count, height, "part", block);
  91. const auto decr = BinaryOperator::CreateSub(count, height, "decr", block);
  92. count->addIncoming(decr, block);
  93. BranchInst::Create(over, work, part, block);
  94. block = over;
  95. const auto tail = BinaryOperator::CreateSub(height, count, "tail", block);
  96. new StoreInst(count, offsetPtr, block);
  97. new StoreInst(tail, sizePtr, block);
  98. new StoreInst(GetFalse(context), statePtr, block);
  99. result->addIncoming(getres.first, block);
  100. BranchInst::Create(done, block);
  101. block = pass;
  102. new StoreInst(ConstantInt::get(indexType, 0), offsetPtr, block);
  103. new StoreInst(ConstantInt::get(indexType, 0), sizePtr, block);
  104. new StoreInst(SetterFor<ui64>(count, context, block), statePtr, block);
  105. result->addIncoming(getres.first, block);
  106. BranchInst::Create(done, block);
  107. block = done;
  108. ICodegeneratorInlineWideNode::TGettersList getters(getres.second.size());
  109. getters.back() = [sizePtr, indexType, valueType, getSize = getres.second.back()](const TCodegenContext& ctx, BasicBlock*& block) {
  110. auto& context = ctx.Codegen.GetContext();
  111. const auto pass = BasicBlock::Create(context, "pass", ctx.Func);
  112. const auto calc = BasicBlock::Create(context, "calc", ctx.Func);
  113. const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
  114. const auto height = PHINode::Create(valueType, 2U, "state", exit);
  115. const auto count = new LoadInst(indexType, sizePtr, "count", block);
  116. const auto work = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, count, ConstantInt::get(indexType, 0), "work", block);
  117. BranchInst::Create(calc, pass, work, block);
  118. block = calc;
  119. const auto makeCountFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&MakeBlockCount));
  120. const auto makeCountType = FunctionType::get(valueType, {ctx.GetFactory()->getType(), indexType}, false);
  121. const auto makeCountPtr = CastInst::Create(Instruction::IntToPtr, makeCountFunc, PointerType::getUnqual(makeCountType), "make_count_func", block);
  122. const auto slice = CallInst::Create(makeCountType, makeCountPtr, {ctx.GetFactory(), count}, "slice", block);
  123. height->addIncoming(slice, block);
  124. BranchInst::Create(exit, block);
  125. block = pass;
  126. const auto size = getSize(ctx, block);
  127. height->addIncoming(size, block);
  128. BranchInst::Create(exit, block);
  129. block = exit;
  130. return height;
  131. };
  132. for (auto idx = 0U; idx < Width; ++idx) {
  133. getters[idx] = [offsetPtr, indexType, valueType, sliceType, slicePtr, getBlock = getres.second[idx]](const TCodegenContext& ctx, BasicBlock*& block) {
  134. auto& context = ctx.Codegen.GetContext();
  135. const auto calc = BasicBlock::Create(context, "calc", ctx.Func);
  136. const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
  137. const auto output = PHINode::Create(valueType, 2U, "output", exit);
  138. const auto offset = new LoadInst(indexType, offsetPtr, "offset", block);
  139. const auto work = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, offset, ConstantInt::get(indexType, 0), "work", block);
  140. const auto value = getBlock(ctx, block);
  141. output->addIncoming(value, block);
  142. BranchInst::Create(calc, exit, work, block);
  143. block = calc;
  144. const auto slice = CallInst::Create(sliceType, slicePtr, {ctx.GetFactory(), value, offset}, "slice", block);
  145. ValueCleanup(EValueRepresentation::Any, value, ctx, block);
  146. output->addIncoming(slice, block);
  147. BranchInst::Create(exit, block);
  148. block = exit;
  149. return output;
  150. };
  151. }
  152. return {result, std::move(getters)};
  153. }
  154. #endif
  155. private:
  156. static NUdf::TUnboxedValuePod SliceBlock(const THolderFactory& holderFactory, NUdf::TUnboxedValuePod block, const uint64_t offset) {
  157. const auto& datum = TArrowBlock::From(block).GetDatum();
  158. return datum.is_scalar() ? block : holderFactory.CreateArrowBlock(DeepSlice(datum.array(), offset, datum.array()->length - offset));
  159. }
  160. void RegisterDependencies() const final {
  161. if (const auto flow = FlowDependsOn(Flow)) {
  162. DependsOn(flow, Count);
  163. }
  164. }
  165. IComputationWideFlowNode* const Flow;
  166. IComputationNode* const Count;
  167. const ui32 Width;
  168. };
  169. class TWideTakeBlocksWrapper : public TStatefulWideFlowCodegeneratorNode<TWideTakeBlocksWrapper> {
  170. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideTakeBlocksWrapper>;
  171. public:
  172. TWideTakeBlocksWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, IComputationNode* count, ui32 size)
  173. : TBaseComputation(mutables, flow, EValueRepresentation::Embedded), Flow(flow), Count(count), Width(size - 1U)
  174. {}
  175. EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  176. if (state.IsInvalid()) {
  177. state = Count->GetValue(ctx);
  178. }
  179. if (const auto count = state.Get<ui64>()) {
  180. if (const auto result = Flow->FetchValues(ctx, output); EFetchResult::One == result) {
  181. if (const auto blockSize = GetBlockCount(*output[Width]); count < blockSize) {
  182. state = NUdf::TUnboxedValuePod::Zero();
  183. *output[Width] = MakeBlockCount(ctx.HolderFactory, count);
  184. for (auto i = 0U; i < Width; ++i)
  185. if (const auto out = output[i])
  186. *out = SliceBlock(ctx.HolderFactory, *out, count);
  187. } else
  188. state = NUdf::TUnboxedValuePod(ui64(count - blockSize));
  189. return EFetchResult::One;
  190. } else {
  191. return result;
  192. }
  193. }
  194. return EFetchResult::Finish;
  195. }
  196. #ifndef MKQL_DISABLE_CODEGEN
  197. TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  198. auto& context = ctx.Codegen.GetContext();
  199. const auto indexType = Type::getInt64Ty(context);
  200. const auto valueType = Type::getInt128Ty(context);
  201. const auto atTop = &ctx.Func->getEntryBlock().back();
  202. const auto sizePtr = new AllocaInst(indexType, 0U, "size_ptr", atTop);
  203. new StoreInst(ConstantInt::get(indexType, 0), sizePtr, atTop);
  204. const auto sliceFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TWideTakeBlocksWrapper::SliceBlock));
  205. const auto sliceType = FunctionType::get(valueType, {ctx.GetFactory()->getType(), valueType, indexType}, false);
  206. const auto slicePtr = CastInst::Create(Instruction::IntToPtr, sliceFunc, PointerType::getUnqual(sliceType), "slice", atTop);
  207. const auto name = "GetBlockCount";
  208. ctx.Codegen.AddGlobalMapping(name, reinterpret_cast<const void*>(&GetBlockCount));
  209. const auto getCountType = FunctionType::get(indexType, { valueType }, false);
  210. const auto getCount = ctx.Codegen.GetModule().getOrInsertFunction(name, getCountType);
  211. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  212. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  213. const auto load = new LoadInst(valueType, statePtr, "load", block);
  214. const auto state = PHINode::Create(valueType, 2U, "state", main);
  215. state->addIncoming(load, block);
  216. BranchInst::Create(init, main, IsInvalid(load, block, context), block);
  217. block = init;
  218. GetNodeValue(statePtr, Count, ctx, block);
  219. const auto save = new LoadInst(valueType, statePtr, "save", block);
  220. state->addIncoming(save, block);
  221. BranchInst::Create(main, block);
  222. block = main;
  223. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  224. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  225. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  226. const auto resultType = Type::getInt32Ty(context);
  227. const auto result = PHINode::Create(resultType, 3U, "result", done);
  228. result->addIncoming(ConstantInt::get(resultType, static_cast<i32>(EFetchResult::Finish)), block);
  229. const auto count = GetterFor<ui64>(state, context, block);
  230. const auto plus = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, count, ConstantInt::get(count->getType(), 0ULL), "plus", block);
  231. BranchInst::Create(work, done, plus, block);
  232. block = work;
  233. const auto getres = GetNodeValues(Flow, ctx, block);
  234. const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, getres.first, ConstantInt::get(getres.first->getType(), static_cast<i32>(EFetchResult::Yield)), "special", block);
  235. result->addIncoming(getres.first, block);
  236. BranchInst::Create(done, good, special, block);
  237. block = good;
  238. const auto countValue = getres.second.back()(ctx, block);
  239. const auto height = CallInst::Create(getCount, { countValue }, "height", block);
  240. ValueCleanup(EValueRepresentation::Any, countValue, ctx, block);
  241. const auto part = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_ULT, count, height, "part", block);
  242. const auto decr = BinaryOperator::CreateSub(count, height, "decr", block);
  243. const auto next = SelectInst::Create(part, ConstantInt::get(indexType, 0), decr, "next", block);
  244. const auto size = SelectInst::Create(part, count, ConstantInt::get(indexType, 0), "size", block);
  245. new StoreInst(SetterFor<ui64>(next, context, block), statePtr, block);
  246. new StoreInst(size, sizePtr, block);
  247. result->addIncoming(getres.first, block);
  248. BranchInst::Create(done, block);
  249. block = done;
  250. ICodegeneratorInlineWideNode::TGettersList getters(getres.second.size());
  251. getters.back() = [sizePtr, indexType, valueType, getSize = getres.second.back()](const TCodegenContext& ctx, BasicBlock*& block) {
  252. auto& context = ctx.Codegen.GetContext();
  253. const auto pass = BasicBlock::Create(context, "pass", ctx.Func);
  254. const auto calc = BasicBlock::Create(context, "calc", ctx.Func);
  255. const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
  256. const auto height = PHINode::Create(valueType, 2U, "state", exit);
  257. const auto count = new LoadInst(indexType, sizePtr, "count", block);
  258. const auto work = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, count, ConstantInt::get(indexType, 0), "work", block);
  259. BranchInst::Create(calc, pass, work, block);
  260. block = calc;
  261. const auto makeCountFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&MakeBlockCount));
  262. const auto makeCountType = FunctionType::get(valueType, {ctx.GetFactory()->getType(), indexType}, false);
  263. const auto makeCountPtr = CastInst::Create(Instruction::IntToPtr, makeCountFunc, PointerType::getUnqual(makeCountType), "make_count_func", block);
  264. const auto slice = CallInst::Create(makeCountType, makeCountPtr, {ctx.GetFactory(), count}, "slice", block);
  265. height->addIncoming(slice, block);
  266. BranchInst::Create(exit, block);
  267. block = pass;
  268. const auto size = getSize(ctx, block);
  269. height->addIncoming(size, block);
  270. BranchInst::Create(exit, block);
  271. block = exit;
  272. return height;
  273. };
  274. for (auto idx = 0U; idx < Width; ++idx) {
  275. getters[idx] = [sizePtr, indexType, valueType, sliceType, slicePtr, getBlock = getres.second[idx]](const TCodegenContext& ctx, BasicBlock*& block) {
  276. auto& context = ctx.Codegen.GetContext();
  277. const auto calc = BasicBlock::Create(context, "calc", ctx.Func);
  278. const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
  279. const auto output = PHINode::Create(valueType, 2U, "output", exit);
  280. const auto size = new LoadInst(indexType, sizePtr, "size", block);
  281. const auto work = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, size, ConstantInt::get(indexType, 0), "work", block);
  282. const auto value = getBlock(ctx, block);
  283. output->addIncoming(value, block);
  284. BranchInst::Create(calc, exit, work, block);
  285. block = calc;
  286. const auto slice = CallInst::Create(sliceType, slicePtr, {ctx.GetFactory(), value, size}, "slice", block);
  287. ValueCleanup(EValueRepresentation::Any, value, ctx, block);
  288. output->addIncoming(slice, block);
  289. BranchInst::Create(exit, block);
  290. block = exit;
  291. return output;
  292. };
  293. }
  294. return {result, std::move(getters)};
  295. }
  296. #endif
  297. private:
  298. static NUdf::TUnboxedValuePod SliceBlock(const THolderFactory& holderFactory, NUdf::TUnboxedValuePod block, const uint64_t offset) {
  299. const auto& datum = TArrowBlock::From(block).GetDatum();
  300. return datum.is_scalar() ? block : holderFactory.CreateArrowBlock(DeepSlice(datum.array(), 0ULL, offset));
  301. }
  302. void RegisterDependencies() const final {
  303. if (const auto flow = FlowDependsOn(Flow)) {
  304. DependsOn(flow, Count);
  305. }
  306. }
  307. IComputationWideFlowNode* const Flow;
  308. IComputationNode* const Count;
  309. const ui32 Width;
  310. };
  311. IComputationNode* WrapSkipTake(bool skip, TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  312. MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args");
  313. const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType());
  314. const auto flowWidth = GetWideComponentsCount(flowType);
  315. MKQL_ENSURE(flowWidth > 0, "Expected at least one column");
  316. auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0));
  317. MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node");
  318. const auto count = LocateNode(ctx.NodeLocator, callable, 1);
  319. const auto countType = AS_TYPE(TDataType, callable.GetInput(1).GetStaticType());
  320. MKQL_ENSURE(countType->GetSchemeType() == NUdf::TDataType<ui64>::Id, "Expected ui64");
  321. if (skip) {
  322. return new TWideSkipBlocksWrapper(ctx.Mutables, wideFlow, count, flowWidth);
  323. }
  324. return new TWideTakeBlocksWrapper(ctx.Mutables, wideFlow, count, flowWidth);
  325. }
  326. } //namespace
  327. IComputationNode* WrapWideSkipBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  328. bool skip = true;
  329. return WrapSkipTake(skip, callable, ctx);
  330. }
  331. IComputationNode* WrapWideTakeBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  332. bool skip = false;
  333. return WrapSkipTake(skip, callable, ctx);
  334. }
  335. }
  336. }