mkql_block_just.cpp 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. #include "mkql_block_just.h"
  2. #include <yql/essentials/minikql/computation/mkql_block_impl.h>
  3. #include <yql/essentials/minikql/arrow/arrow_defs.h>
  4. #include <yql/essentials/minikql/arrow/arrow_util.h>
  5. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  6. #include <yql/essentials/minikql/mkql_node_cast.h>
  7. namespace NKikimr {
  8. namespace NMiniKQL {
  9. namespace {
  10. template<bool Trivial>
  11. class TJustBlockExec {
  12. public:
  13. TJustBlockExec(const std::shared_ptr<arrow::DataType>& returnArrowType)
  14. : ReturnArrowType(returnArrowType)
  15. {}
  16. arrow::Status Exec(arrow::compute::KernelContext*, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const {
  17. arrow::Datum inputDatum = batch.values[0];
  18. if (Trivial) {
  19. *res = inputDatum;
  20. return arrow::Status::OK();
  21. }
  22. if (inputDatum.is_scalar()) {
  23. std::vector<std::shared_ptr<arrow::Scalar>> arrowValue;
  24. arrowValue.emplace_back(inputDatum.scalar());
  25. *res = arrow::Datum(std::make_shared<arrow::StructScalar>(arrowValue, ReturnArrowType));
  26. } else {
  27. auto array = inputDatum.array();
  28. auto newArrayData = arrow::ArrayData::Make(ReturnArrowType, array->length, { nullptr }, 0, 0);
  29. newArrayData->child_data.push_back(array);
  30. *res = arrow::Datum(newArrayData);
  31. }
  32. return arrow::Status::OK();
  33. }
  34. private:
  35. const std::shared_ptr<arrow::DataType> ReturnArrowType;
  36. };
  37. template<bool Trivial>
  38. std::shared_ptr<arrow::compute::ScalarKernel> MakeBlockJustKernel(const TVector<TType*>& argTypes, TType* resultType) {
  39. using TExec = TJustBlockExec<Trivial>;
  40. std::shared_ptr<arrow::DataType> returnArrowType;
  41. MKQL_ENSURE(ConvertArrowType(AS_TYPE(TBlockType, resultType)->GetItemType(), returnArrowType), "Unsupported arrow type");
  42. auto exec = std::make_shared<TExec>(returnArrowType);
  43. auto kernel = std::make_shared<arrow::compute::ScalarKernel>(ConvertToInputTypes(argTypes), ConvertToOutputType(resultType),
  44. [exec](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
  45. return exec->Exec(ctx, batch, res);
  46. });
  47. kernel->null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE;
  48. return kernel;
  49. }
  50. } // namespace
  51. IComputationNode* WrapBlockJust(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  52. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args");
  53. auto data = callable.GetInput(0);
  54. auto dataType = AS_TYPE(TBlockType, data.GetStaticType());
  55. auto itemType = dataType->GetItemType();
  56. auto dataCompute = LocateNode(ctx.NodeLocator, callable, 0);
  57. TComputationNodePtrVector argsNodes = { dataCompute };
  58. TVector<TType*> argsTypes = { dataType };
  59. std::shared_ptr<arrow::compute::ScalarKernel> kernel;
  60. if (itemType->IsOptional() || itemType->IsVariant()) {
  61. kernel = MakeBlockJustKernel<false>(argsTypes, callable.GetType()->GetReturnType());
  62. } else {
  63. kernel = MakeBlockJustKernel<true>(argsTypes, callable.GetType()->GetReturnType());
  64. }
  65. return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), std::move(argsNodes), argsTypes, *kernel, kernel);
  66. }
  67. }
  68. }