mkql_block_container.cpp 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. #include "mkql_block_container.h"
  2. #include <yql/essentials/minikql/computation/mkql_block_impl.h>
  3. #include <yql/essentials/minikql/arrow/arrow_defs.h>
  4. #include <yql/essentials/minikql/arrow/arrow_util.h>
  5. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  6. #include <yql/essentials/minikql/mkql_node_cast.h>
  7. #include <yql/essentials/minikql/mkql_node_builder.h>
  8. #include <arrow/util/bitmap_ops.h>
  9. namespace NKikimr {
  10. namespace NMiniKQL {
  11. namespace {
  12. class TBlockAsContainerExec {
  13. public:
  14. TBlockAsContainerExec(const TVector<TType*>& argTypes, const std::shared_ptr<arrow::DataType>& returnArrowType)
  15. : ArgTypes(argTypes)
  16. , ReturnArrowType(returnArrowType)
  17. {}
  18. arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const {
  19. bool allScalars = true;
  20. size_t length = 0;
  21. for (const auto& x : batch.values) {
  22. if (!x.is_scalar()) {
  23. allScalars = false;
  24. length = x.array()->length;
  25. break;
  26. }
  27. }
  28. if (allScalars) {
  29. // return scalar too
  30. std::vector<std::shared_ptr<arrow::Scalar>> arrowValue;
  31. for (const auto& x : batch.values) {
  32. arrowValue.emplace_back(x.scalar());
  33. }
  34. *res = arrow::Datum(std::make_shared<arrow::StructScalar>(arrowValue, ReturnArrowType));
  35. return arrow::Status::OK();
  36. }
  37. auto newArrayData = arrow::ArrayData::Make(ReturnArrowType, length, { nullptr }, 0, 0);
  38. MKQL_ENSURE(ArgTypes.size() == batch.values.size(), "Mismatch batch columns");
  39. for (ui32 i = 0; i < batch.values.size(); ++i) {
  40. const auto& datum = batch.values[i];
  41. if (datum.is_scalar()) {
  42. // expand scalar to array
  43. auto expandedArray = MakeArrayFromScalar(*datum.scalar(), length, AS_TYPE(TBlockType, ArgTypes[i])->GetItemType(), *ctx->memory_pool());
  44. newArrayData->child_data.push_back(expandedArray.array());
  45. } else {
  46. newArrayData->child_data.push_back(datum.array());
  47. }
  48. }
  49. *res = arrow::Datum(newArrayData);
  50. return arrow::Status::OK();
  51. }
  52. private:
  53. const TVector<TType*> ArgTypes;
  54. const std::shared_ptr<arrow::DataType> ReturnArrowType;
  55. };
  56. std::shared_ptr<arrow::compute::ScalarKernel> MakeBlockAsContainerKernel(const TVector<TType*>& argTypes, TType* resultType) {
  57. std::shared_ptr<arrow::DataType> returnArrowType;
  58. MKQL_ENSURE(ConvertArrowType(AS_TYPE(TBlockType, resultType)->GetItemType(), returnArrowType), "Unsupported arrow type");
  59. auto exec = std::make_shared<TBlockAsContainerExec>(argTypes, returnArrowType);
  60. auto kernel = std::make_shared<arrow::compute::ScalarKernel>(ConvertToInputTypes(argTypes), ConvertToOutputType(resultType),
  61. [exec](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
  62. return exec->Exec(ctx, batch, res);
  63. });
  64. kernel->null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE;
  65. return kernel;
  66. }
  67. } // namespace
  68. IComputationNode* WrapBlockAsContainer(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  69. TComputationNodePtrVector argsNodes;
  70. TVector<TType*> argsTypes;
  71. for (ui32 i = 0; i < callable.GetInputsCount(); ++i) {
  72. argsNodes.push_back(LocateNode(ctx.NodeLocator, callable, i));
  73. argsTypes.push_back(callable.GetInput(i).GetStaticType());
  74. }
  75. auto kernel = MakeBlockAsContainerKernel(argsTypes, callable.GetType()->GetReturnType());
  76. return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), std::move(argsNodes), argsTypes, *kernel, kernel);
  77. }
  78. } // namespace NMiniKQL
  79. } // namespace NKikimr