mkql_block_getelem.cpp 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. #include "mkql_block_getelem.h"
  2. #include <yql/essentials/minikql/computation/mkql_block_impl.h>
  3. #include <yql/essentials/minikql/mkql_node_cast.h>
  4. #include <yql/essentials/minikql/mkql_node_builder.h>
  5. namespace NKikimr {
  6. namespace NMiniKQL {
  7. namespace {
  8. class TBlockGetElementExec {
  9. public:
  10. TBlockGetElementExec(const std::shared_ptr<arrow::DataType>& returnArrowType, ui32 index, bool isOptional, bool needExternalOptional)
  11. : ReturnArrowType(returnArrowType)
  12. , Index(index)
  13. , IsOptional(isOptional)
  14. , NeedExternalOptional(needExternalOptional)
  15. {}
  16. arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const {
  17. arrow::Datum inputDatum = batch.values[0];
  18. if (inputDatum.is_scalar()) {
  19. if (inputDatum.scalar()->is_valid) {
  20. const auto& structScalar = arrow::internal::checked_cast<const arrow::StructScalar&>(*inputDatum.scalar());
  21. *res = arrow::Datum(structScalar.value[Index]);
  22. } else {
  23. *res = arrow::Datum(arrow::MakeNullScalar(ReturnArrowType));
  24. }
  25. } else {
  26. const auto& array = inputDatum.array();
  27. auto child = array->child_data[Index];
  28. if (NeedExternalOptional) {
  29. auto newArrayData = arrow::ArrayData::Make(ReturnArrowType, array->length, { array->buffers[0] });
  30. newArrayData->child_data.push_back(child);
  31. *res = arrow::Datum(newArrayData);
  32. } else if (!IsOptional || !array->buffers[0]) {
  33. *res = arrow::Datum(child);
  34. } else {
  35. auto newArrayData = child->Copy();
  36. if (!newArrayData->buffers[0]) {
  37. newArrayData->buffers[0] = array->buffers[0];
  38. } else {
  39. auto buffer = AllocateBitmapWithReserve(array->length + array->offset, ctx->memory_pool());
  40. arrow::internal::BitmapAnd(child->GetValues<uint8_t>(0, 0), child->offset, array->GetValues<uint8_t>(0, 0), array->offset, array->length, array->offset, buffer->mutable_data());
  41. newArrayData->buffers[0] = buffer;
  42. }
  43. newArrayData->SetNullCount(arrow::kUnknownNullCount);
  44. *res = arrow::Datum(newArrayData);
  45. }
  46. }
  47. return arrow::Status::OK();
  48. }
  49. private:
  50. const std::shared_ptr<arrow::DataType> ReturnArrowType;
  51. const ui32 Index;
  52. const bool IsOptional;
  53. const bool NeedExternalOptional;
  54. };
  55. std::shared_ptr<arrow::compute::ScalarKernel> MakeBlockGetElementKernel(const TVector<TType*>& argTypes, TType* resultType,
  56. ui32 index, bool isOptional, bool needExternalOptional) {
  57. std::shared_ptr<arrow::DataType> returnArrowType;
  58. MKQL_ENSURE(ConvertArrowType(AS_TYPE(TBlockType, resultType)->GetItemType(), returnArrowType), "Unsupported arrow type");
  59. auto exec = std::make_shared<TBlockGetElementExec>(returnArrowType, index, isOptional, needExternalOptional);
  60. auto kernel = std::make_shared<arrow::compute::ScalarKernel>(ConvertToInputTypes(argTypes), ConvertToOutputType(resultType),
  61. [exec](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
  62. return exec->Exec(ctx, batch, res);
  63. });
  64. kernel->null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE;
  65. return kernel;
  66. }
  67. TType* GetElementType(const TStructType* structType, ui32 index) {
  68. MKQL_ENSURE(index < structType->GetMembersCount(), "Bad member index");
  69. return structType->GetMemberType(index);
  70. }
  71. TType* GetElementType(const TTupleType* tupleType, ui32 index) {
  72. MKQL_ENSURE(index < tupleType->GetElementsCount(), "Bad tuple index");
  73. return tupleType->GetElementType(index);
  74. }
  75. template<typename ObjectType>
  76. IComputationNode* WrapBlockGetElement(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  77. MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected two args.");
  78. auto inputObject = callable.GetInput(0);
  79. auto blockType = AS_TYPE(TBlockType, inputObject.GetStaticType());
  80. bool isOptional;
  81. auto objectType = AS_TYPE(ObjectType, UnpackOptional(blockType->GetItemType(), isOptional));
  82. auto indexData = AS_VALUE(TDataLiteral, callable.GetInput(1));
  83. auto index = indexData->AsValue().Get<ui32>();
  84. auto childType = GetElementType(objectType, index);
  85. bool needExternalOptional = isOptional && childType->IsVariant();
  86. auto objectNode = LocateNode(ctx.NodeLocator, callable, 0);
  87. TComputationNodePtrVector argsNodes = { objectNode };
  88. TVector<TType*> argsTypes = { blockType };
  89. auto kernel = MakeBlockGetElementKernel(argsTypes, callable.GetType()->GetReturnType(), index, isOptional, needExternalOptional);
  90. return new TBlockFuncNode(ctx.Mutables, callable.GetType()->GetName(), std::move(argsNodes), argsTypes, *kernel, kernel);
  91. }
  92. } // namespace
  93. IComputationNode* WrapBlockMember(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  94. return WrapBlockGetElement<TStructType>(callable, ctx);
  95. }
  96. IComputationNode* WrapBlockNth(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  97. return WrapBlockGetElement<TTupleType>(callable, ctx);
  98. }
  99. } // namespace NMiniKQL
  100. } // namespace NKikimr