mkql_block_if.cpp 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. #include "mkql_block_if.h"
  2. #include <yql/essentials/minikql/computation/mkql_block_reader.h>
  3. #include <yql/essentials/minikql/computation/mkql_block_builder.h>
  4. #include <yql/essentials/minikql/computation/mkql_block_impl.h>
  5. #include <yql/essentials/minikql/arrow/arrow_defs.h>
  6. #include <yql/essentials/minikql/arrow/arrow_util.h>
  7. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  8. #include <yql/essentials/minikql/mkql_node_cast.h>
  9. namespace NKikimr {
  10. namespace NMiniKQL {
  11. namespace {
  12. class TBlockIfScalarWrapper : public TMutableComputationNode<TBlockIfScalarWrapper> {
  13. public:
  14. class TArrowNode : public IArrowKernelComputationNode {
  15. public:
  16. TArrowNode(const TBlockIfScalarWrapper* parent)
  17. : Parent_(parent)
  18. , ArgsValuesDescr_(ToValueDescr(parent->ArgsTypes))
  19. , Kernel_(ConvertToInputTypes(parent->ArgsTypes), ConvertToOutputType(parent->ResultType), [parent](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
  20. *res = parent->CalculateImpl(MakeDatumProvider(batch.values[0]), MakeDatumProvider(batch.values[1]), MakeDatumProvider(batch.values[2]), *ctx->memory_pool());
  21. return arrow::Status::OK();
  22. })
  23. {
  24. Kernel_.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE;
  25. Kernel_.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE;
  26. }
  27. TStringBuf GetKernelName() const final {
  28. return "If";
  29. }
  30. const arrow::compute::ScalarKernel& GetArrowKernel() const {
  31. return Kernel_;
  32. }
  33. const std::vector<arrow::ValueDescr>& GetArgsDesc() const {
  34. return ArgsValuesDescr_;
  35. }
  36. const IComputationNode* GetArgument(ui32 index) const {
  37. switch (index) {
  38. case 0:
  39. return Parent_->Pred;
  40. case 1:
  41. return Parent_->Then;
  42. case 2:
  43. return Parent_->Else;
  44. default:
  45. throw yexception() << "Bad argument index";
  46. }
  47. }
  48. private:
  49. const TBlockIfScalarWrapper* Parent_;
  50. const std::vector<arrow::ValueDescr> ArgsValuesDescr_;
  51. arrow::compute::ScalarKernel Kernel_;
  52. };
  53. friend class TArrowNode;
  54. TBlockIfScalarWrapper(TComputationMutables& mutables, IComputationNode* pred, IComputationNode* thenNode, IComputationNode* elseNode, TType* resultType,
  55. bool thenIsScalar, bool elseIsScalar, const TVector<TType*>& argsTypes)
  56. : TMutableComputationNode(mutables)
  57. , Pred(pred)
  58. , Then(thenNode)
  59. , Else(elseNode)
  60. , ResultType(resultType)
  61. , ThenIsScalar(thenIsScalar)
  62. , ElseIsScalar(elseIsScalar)
  63. , ArgsTypes(argsTypes)
  64. {
  65. }
  66. std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const final {
  67. Y_UNUSED(ctx);
  68. return std::make_unique<TArrowNode>(this);
  69. }
  70. arrow::Datum CalculateImpl(const TDatumProvider& predProv, const TDatumProvider& thenProv, const TDatumProvider& elseProv,
  71. arrow::MemoryPool& memoryPool) const {
  72. auto predValue = predProv();
  73. const bool predScalarValue = GetPrimitiveScalarValue<bool>(*predValue.scalar());
  74. auto result = predScalarValue ? thenProv() : elseProv();
  75. if (ThenIsScalar == ElseIsScalar || (predScalarValue ? !ThenIsScalar : !ElseIsScalar)) {
  76. // can return result as-is
  77. return result;
  78. }
  79. auto otherDatum = predScalarValue ? elseProv() : thenProv();
  80. MKQL_ENSURE(otherDatum.is_arraylike(), "Expecting array");
  81. std::shared_ptr<arrow::Scalar> resultScalar = result.scalar();
  82. TVector<std::shared_ptr<arrow::ArrayData>> resultArrays;
  83. auto itemType = AS_TYPE(TBlockType, ResultType)->GetItemType();
  84. ForEachArrayData(otherDatum, [&](const std::shared_ptr<arrow::ArrayData>& otherData) {
  85. auto chunk = MakeArrayFromScalar(*resultScalar, otherData->length, itemType, memoryPool);
  86. ForEachArrayData(chunk, [&](const auto& array) {
  87. resultArrays.push_back(array);
  88. });
  89. });
  90. return MakeArray(resultArrays);
  91. }
  92. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  93. return ctx.HolderFactory.CreateArrowBlock(CalculateImpl(MakeDatumProvider(Pred, ctx), MakeDatumProvider(Then, ctx), MakeDatumProvider(Else, ctx), ctx.ArrowMemoryPool));
  94. }
  95. private:
  96. void RegisterDependencies() const final {
  97. DependsOn(Pred);
  98. DependsOn(Then);
  99. DependsOn(Else);
  100. }
  101. IComputationNode* const Pred;
  102. IComputationNode* const Then;
  103. IComputationNode* const Else;
  104. TType* const ResultType;
  105. const bool ThenIsScalar;
  106. const bool ElseIsScalar;
  107. const TVector<TType*> ArgsTypes;
  108. };
  109. template<bool ThenIsScalar, bool ElseIsScalar>
  110. class TIfBlockExec {
  111. public:
  112. explicit TIfBlockExec(TType* type)
  113. : ThenReader(MakeBlockReader(TTypeInfoHelper(), type)), ElseReader(MakeBlockReader(TTypeInfoHelper(), type)), Type(type)
  114. {
  115. }
  116. arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const {
  117. arrow::Datum predDatum = batch.values[0];
  118. arrow::Datum thenDatum = batch.values[1];
  119. arrow::Datum elseDatum = batch.values[2];
  120. TBlockItem thenItem;
  121. const arrow::ArrayData* thenArray = nullptr;
  122. if constexpr(ThenIsScalar) {
  123. thenItem = ThenReader->GetScalarItem(*thenDatum.scalar());
  124. } else {
  125. MKQL_ENSURE(thenDatum.is_array(), "Expecting array");
  126. thenArray = thenDatum.array().get();
  127. }
  128. TBlockItem elseItem;
  129. const arrow::ArrayData* elseArray = nullptr;
  130. if constexpr(ElseIsScalar) {
  131. elseItem = ElseReader->GetScalarItem(*elseDatum.scalar());
  132. } else {
  133. MKQL_ENSURE(elseDatum.is_array(), "Expecting array");
  134. elseArray = elseDatum.array().get();
  135. }
  136. MKQL_ENSURE(predDatum.is_array(), "Expecting array");
  137. const std::shared_ptr<arrow::ArrayData>& pred = predDatum.array();
  138. const size_t len = pred->length;
  139. auto builder = MakeArrayBuilder(TTypeInfoHelper(), Type, *ctx->memory_pool(), len, nullptr);
  140. const ui8* predValues = pred->GetValues<uint8_t>(1);
  141. for (size_t i = 0; i < len; ++i) {
  142. if constexpr (!ThenIsScalar) {
  143. thenItem = ThenReader->GetItem(*thenArray, i);
  144. }
  145. if constexpr (!ElseIsScalar) {
  146. elseItem = ElseReader->GetItem(*elseArray, i);
  147. }
  148. ui64 mask = -ui64(predValues[i]);
  149. ui64 low = (thenItem.Low() & mask) | (elseItem.Low() & ~mask);
  150. ui64 high = (thenItem.High() & mask) | (elseItem.High() & ~mask);
  151. builder->Add(TBlockItem{low, high});
  152. }
  153. *res = builder->Build(true);
  154. return arrow::Status::OK();
  155. }
  156. private:
  157. const std::unique_ptr<IBlockReader> ThenReader;
  158. const std::unique_ptr<IBlockReader> ElseReader;
  159. TType* const Type;
  160. };
  161. template<bool ThenIsScalar, bool ElseIsScalar>
  162. std::shared_ptr<arrow::compute::ScalarKernel> MakeBlockIfKernel(const TVector<TType*>& argTypes, TType* resultType) {
  163. using TExec = TIfBlockExec<ThenIsScalar, ElseIsScalar>;
  164. auto exec = std::make_shared<TExec>(AS_TYPE(TBlockType, resultType)->GetItemType());
  165. auto kernel = std::make_shared<arrow::compute::ScalarKernel>(ConvertToInputTypes(argTypes), ConvertToOutputType(resultType),
  166. [exec](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
  167. return exec->Exec(ctx, batch, res);
  168. });
  169. kernel->null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE;
  170. return kernel;
  171. }
  172. } // namespace
  173. IComputationNode* WrapBlockIf(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  174. MKQL_ENSURE(callable.GetInputsCount() == 3, "Expected 3 args");
  175. auto pred = callable.GetInput(0);
  176. auto thenNode = callable.GetInput(1);
  177. auto elseNode = callable.GetInput(2);
  178. auto predType = AS_TYPE(TBlockType, pred.GetStaticType());
  179. MKQL_ENSURE(AS_TYPE(TDataType, predType->GetItemType())->GetSchemeType() == NUdf::TDataType<bool>::Id,
  180. "Expected bool as first argument");
  181. auto thenType = AS_TYPE(TBlockType, thenNode.GetStaticType());
  182. auto elseType = AS_TYPE(TBlockType, elseNode.GetStaticType());
  183. MKQL_ENSURE(thenType->GetItemType()->IsSameType(*elseType->GetItemType()), "Different return types in branches.");
  184. auto predCompute = LocateNode(ctx.NodeLocator, callable, 0);
  185. auto thenCompute = LocateNode(ctx.NodeLocator, callable, 1);
  186. auto elseCompute = LocateNode(ctx.NodeLocator, callable, 2);
  187. bool predIsScalar = predType->GetShape() == TBlockType::EShape::Scalar;
  188. bool thenIsScalar = thenType->GetShape() == TBlockType::EShape::Scalar;
  189. bool elseIsScalar = elseType->GetShape() == TBlockType::EShape::Scalar;
  190. TVector<TType*> argsTypes = { predType, thenType, elseType };
  191. if (predIsScalar) {
  192. return new TBlockIfScalarWrapper(ctx.Mutables, predCompute, thenCompute, elseCompute, thenType,
  193. thenIsScalar, elseIsScalar, argsTypes);
  194. }
  195. TComputationNodePtrVector argsNodes = { predCompute, thenCompute, elseCompute };
  196. std::shared_ptr<arrow::compute::ScalarKernel> kernel;
  197. if (thenIsScalar && elseIsScalar) {
  198. kernel = MakeBlockIfKernel<true, true>(argsTypes, thenType);
  199. } else if (thenIsScalar && !elseIsScalar) {
  200. kernel = MakeBlockIfKernel<true, false>(argsTypes, thenType);
  201. } else if (!thenIsScalar && elseIsScalar) {
  202. kernel = MakeBlockIfKernel<false, true>(argsTypes, thenType);
  203. } else {
  204. kernel = MakeBlockIfKernel<false, false>(argsTypes, thenType);
  205. }
  206. return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), std::move(argsNodes), argsTypes, *kernel, kernel);
  207. }
  208. }
  209. }