mkql_block_skiptake_ut.cpp 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. #include "mkql_computation_node_ut.h"
  2. #include <yql/essentials/minikql/arrow/arrow_defs.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  4. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  5. #include <arrow/array/builder_primitive.h>
  6. namespace NKikimr {
  7. namespace NMiniKQL {
  8. namespace {
  9. class TTestBlockFlowWrapper: public TStatefulWideFlowCodegeneratorNode<TTestBlockFlowWrapper> {
  10. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TTestBlockFlowWrapper>;
  11. public:
  12. TTestBlockFlowWrapper(TComputationMutables& mutables, size_t blockSize, size_t blockCount)
  13. : TBaseComputation(mutables, nullptr, EValueRepresentation::Embedded)
  14. , BlockSize(blockSize)
  15. , BlockCount(blockCount)
  16. {
  17. mutables.CurValueIndex += 3U;
  18. }
  19. EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  20. return DoCalculateImpl(state, ctx, *output[0], *output[1], *output[2]);
  21. }
  22. #ifndef MKQL_DISABLE_CODEGEN
  23. ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  24. auto& context = ctx.Codegen.GetContext();
  25. const auto valueType = Type::getInt128Ty(context);
  26. const auto ptrValueType = PointerType::getUnqual(valueType);
  27. const auto statusType = Type::getInt32Ty(context);
  28. const auto atTop = &ctx.Func->getEntryBlock().back();
  29. const auto values0Ptr = GetElementPtrInst::CreateInBounds(valueType, ctx.GetMutables(), {ConstantInt::get(Type::getInt32Ty(context), static_cast<const IComputationNode*>(this)->GetIndex() + 1U)}, "values_0_ptr", atTop);
  30. const auto values1Ptr = GetElementPtrInst::CreateInBounds(valueType, ctx.GetMutables(), {ConstantInt::get(Type::getInt32Ty(context), static_cast<const IComputationNode*>(this)->GetIndex() + 2U)}, "values_1_ptr", atTop);
  31. const auto values2Ptr = GetElementPtrInst::CreateInBounds(valueType, ctx.GetMutables(), {ConstantInt::get(Type::getInt32Ty(context), static_cast<const IComputationNode*>(this)->GetIndex() + 3U)}, "values_2_ptr", atTop);
  32. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  33. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", atTop);
  34. const auto doFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TTestBlockFlowWrapper::DoCalculateImpl));
  35. const auto doType = FunctionType::get(statusType, {self->getType(), ptrValueType, ctx.Ctx->getType(), ptrValueType, ptrValueType, ptrValueType}, false);
  36. const auto doFuncPtr = CastInst::Create(Instruction::IntToPtr, doFunc, PointerType::getUnqual(doType), "function", atTop);
  37. const auto result = CallInst::Create(doType, doFuncPtr, {self, statePtr, ctx.Ctx, values0Ptr, values1Ptr, values2Ptr}, "result", block);
  38. ICodegeneratorInlineWideNode::TGettersList getters{
  39. [values0Ptr, valueType](const TCodegenContext&, BasicBlock*& block) { return new LoadInst(valueType, values0Ptr, "value", block); },
  40. [values1Ptr, valueType](const TCodegenContext&, BasicBlock*& block) { return new LoadInst(valueType, values1Ptr, "value", block); },
  41. [values2Ptr, valueType](const TCodegenContext&, BasicBlock*& block) { return new LoadInst(valueType, values2Ptr, "value", block); }
  42. };
  43. return {result, std::move(getters)};
  44. }
  45. #endif
  46. private:
  47. EFetchResult DoCalculateImpl(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue& val1, NUdf::TUnboxedValue& val2, NUdf::TUnboxedValue& val3) const {
  48. if (state.IsInvalid()) {
  49. state = NUdf::TUnboxedValue::Zero();
  50. }
  51. auto index = state.Get<ui64>();
  52. if (index >= BlockCount) {
  53. return EFetchResult::Finish;
  54. }
  55. arrow::UInt64Builder builder(&ctx.ArrowMemoryPool);
  56. ARROW_OK(builder.Reserve(BlockSize));
  57. for (size_t i = 0; i < BlockSize; ++i) {
  58. builder.UnsafeAppend(index * BlockSize + i);
  59. }
  60. std::shared_ptr<arrow::ArrayData> block;
  61. ARROW_OK(builder.FinishInternal(&block));
  62. val1 = ctx.HolderFactory.CreateArrowBlock(std::move(block));
  63. val2 = ctx.HolderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(index)));
  64. val3 = ctx.HolderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(BlockSize)));
  65. state = NUdf::TUnboxedValuePod(++index);
  66. return EFetchResult::One;
  67. }
  68. void RegisterDependencies() const final {
  69. }
  70. const size_t BlockSize;
  71. const size_t BlockCount;
  72. };
  73. IComputationNode* WrapTestBlockFlow(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  74. MKQL_ENSURE(callable.GetInputsCount() == 0, "Expected no args");
  75. return new TTestBlockFlowWrapper(ctx.Mutables, 5, 2);
  76. }
  77. TComputationNodeFactory GetNodeFactory() {
  78. return [](TCallable& callable, const TComputationNodeFactoryContext& ctx) -> IComputationNode* {
  79. if (callable.GetType()->GetName() == "TestBlockFlow") {
  80. return WrapTestBlockFlow(callable, ctx);
  81. }
  82. return GetBuiltinFactory()(callable, ctx);
  83. };
  84. } //namespace
  85. template<bool LLVM>
  86. TRuntimeNode MakeFlow(TSetup<LLVM>& setup) {
  87. TProgramBuilder& pb = *setup.PgmBuilder;
  88. TCallableBuilder callableBuilder(*setup.Env, "TestBlockFlow",
  89. pb.NewFlowType(
  90. pb.NewMultiType({
  91. pb.NewBlockType(pb.NewDataType(NUdf::EDataSlot::Uint64), TBlockType::EShape::Many),
  92. pb.NewBlockType(pb.NewDataType(NUdf::EDataSlot::Uint64), TBlockType::EShape::Scalar),
  93. pb.NewBlockType(pb.NewDataType(NUdf::EDataSlot::Uint64), TBlockType::EShape::Scalar),
  94. })));
  95. return TRuntimeNode(callableBuilder.Build(), false);
  96. }
  97. } // namespace
  98. Y_UNIT_TEST_SUITE(TMiniKQLWideTakeSkipBlocks) {
  99. Y_UNIT_TEST_LLVM(TestWideSkipBlocks) {
  100. TSetup<LLVM> setup(GetNodeFactory());
  101. TProgramBuilder& pb = *setup.PgmBuilder;
  102. const auto flow = MakeFlow(setup);
  103. const auto part = pb.WideSkipBlocks(flow, pb.NewDataLiteral<ui64>(7));
  104. const auto plain = pb.ToFlow(pb.WideFromBlocks(pb.FromFlow(part)));
  105. const auto singleValueFlow = pb.NarrowMap(plain, [&](TRuntimeNode::TList items) -> TRuntimeNode {
  106. return pb.Add(items[0], items[1]);
  107. });
  108. const auto pgmReturn = pb.ForwardList(singleValueFlow);
  109. const auto graph = setup.BuildGraph(pgmReturn);
  110. const auto iterator = graph->GetValue().GetListIterator();
  111. NUdf::TUnboxedValue item;
  112. UNIT_ASSERT(iterator.Next(item));
  113. UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 8);
  114. UNIT_ASSERT(iterator.Next(item));
  115. UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 9);
  116. UNIT_ASSERT(iterator.Next(item));
  117. UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 10);
  118. UNIT_ASSERT(!iterator.Next(item));
  119. UNIT_ASSERT(!iterator.Next(item));
  120. }
  121. Y_UNIT_TEST_LLVM(TestWideTakeBlocks) {
  122. TSetup<LLVM> setup(GetNodeFactory());
  123. TProgramBuilder& pb = *setup.PgmBuilder;
  124. const auto flow = MakeFlow(setup);
  125. const auto part = pb.WideTakeBlocks(flow, pb.NewDataLiteral<ui64>(4));
  126. const auto plain = pb.ToFlow(pb.WideFromBlocks(pb.FromFlow(part)));
  127. const auto singleValueFlow = pb.NarrowMap(plain, [&](TRuntimeNode::TList items) -> TRuntimeNode {
  128. return pb.Add(items[0], items[1]);
  129. });
  130. const auto pgmReturn = pb.ForwardList(singleValueFlow);
  131. const auto graph = setup.BuildGraph(pgmReturn);
  132. const auto iterator = graph->GetValue().GetListIterator();
  133. NUdf::TUnboxedValue item;
  134. UNIT_ASSERT(iterator.Next(item));
  135. UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 0);
  136. UNIT_ASSERT(iterator.Next(item));
  137. UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 1);
  138. UNIT_ASSERT(iterator.Next(item));
  139. UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 2);
  140. UNIT_ASSERT(iterator.Next(item));
  141. UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 3);
  142. UNIT_ASSERT(!iterator.Next(item));
  143. UNIT_ASSERT(!iterator.Next(item));
  144. }
  145. Y_UNIT_TEST_LLVM(TestWideTakeSkipBlocks) {
  146. TSetup<LLVM> setup(GetNodeFactory());
  147. TProgramBuilder& pb = *setup.PgmBuilder;
  148. const auto flow = MakeFlow(setup);
  149. const auto part = pb.WideTakeBlocks(pb.WideSkipBlocks(flow, pb.NewDataLiteral<ui64>(3)), pb.NewDataLiteral<ui64>(5));
  150. const auto plain = pb.ToFlow(pb.WideFromBlocks(pb.FromFlow(part)));
  151. const auto singleValueFlow = pb.NarrowMap(plain, [&](TRuntimeNode::TList items) -> TRuntimeNode {
  152. // 0, 0;
  153. // 1, 0;
  154. // 2, 0;
  155. // 3, 0; -> 3
  156. // 4, 0; -> 4
  157. // 5, 1; -> 6
  158. // 6, 1; -> 7
  159. // 7, 1; -> 8
  160. // 8, 1;
  161. // 9, 1;
  162. // 10, 1;
  163. return pb.Add(items[0], items[1]);
  164. });
  165. const auto pgmReturn = pb.ForwardList(singleValueFlow);
  166. const auto graph = setup.BuildGraph(pgmReturn);
  167. const auto iterator = graph->GetValue().GetListIterator();
  168. NUdf::TUnboxedValue item;
  169. UNIT_ASSERT(iterator.Next(item));
  170. UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 3);
  171. UNIT_ASSERT(iterator.Next(item));
  172. UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 4);
  173. UNIT_ASSERT(iterator.Next(item));
  174. UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 6);
  175. UNIT_ASSERT(iterator.Next(item));
  176. UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 7);
  177. UNIT_ASSERT(iterator.Next(item));
  178. UNIT_ASSERT_VALUES_EQUAL(item.Get<ui64>(), 8);
  179. UNIT_ASSERT(!iterator.Next(item));
  180. UNIT_ASSERT(!iterator.Next(item));
  181. }
  182. }
  183. } // namespace NMiniKQL
  184. } // namespace NKikimr