mkql_block_exists.cpp 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. #include "mkql_exists.h"
  2. #include <yql/essentials/minikql/arrow/arrow_util.h>
  3. #include <yql/essentials/minikql/computation/mkql_block_impl.h>
  4. #include <yql/essentials/minikql/mkql_node_cast.h>
  5. namespace NKikimr {
  6. namespace NMiniKQL {
  7. namespace {
  8. class TBlockExistsExec {
  9. public:
  10. arrow::Status Exec(arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) const {
  11. const auto& input = batch.values[0];
  12. MKQL_ENSURE(input.is_array(), "Expected array");
  13. const auto& arr = *input.array();
  14. auto nullCount = arr.GetNullCount();
  15. if (nullCount == arr.length) {
  16. *res = MakeFalseArray(ctx->memory_pool(), arr.length);
  17. } else if (nullCount == 0) {
  18. *res = MakeTrueArray(ctx->memory_pool(), arr.length);
  19. } else {
  20. *res = MakeBitmapArray(ctx->memory_pool(), arr.length, arr.offset,
  21. arr.buffers[0]->data());
  22. }
  23. return arrow::Status::OK();
  24. }
  25. };
  26. std::shared_ptr<arrow::compute::ScalarKernel> MakeBlockExistsKernel(const TVector<TType*>& argTypes, TType* resultType) {
  27. std::shared_ptr<arrow::DataType> returnArrowType;
  28. MKQL_ENSURE(ConvertArrowType(AS_TYPE(TBlockType, resultType)->GetItemType(), returnArrowType), "Unsupported arrow type");
  29. // Ensure the result Arrow type (i.e. boolean) is Arrow UInt8Type.
  30. Y_DEBUG_ABORT_UNLESS(returnArrowType == arrow::uint8());
  31. auto exec = std::make_shared<TBlockExistsExec>();
  32. auto kernel = std::make_shared<arrow::compute::ScalarKernel>(ConvertToInputTypes(argTypes), ConvertToOutputType(resultType),
  33. [exec](arrow::compute::KernelContext* ctx, const arrow::compute::ExecBatch& batch, arrow::Datum* res) {
  34. return exec->Exec(ctx, batch, res);
  35. });
  36. kernel->null_handling = arrow::compute::NullHandling::OUTPUT_NOT_NULL;
  37. return kernel;
  38. }
  39. } // namespace
  40. IComputationNode* WrapBlockExists(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  41. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 arg");
  42. auto compute = LocateNode(ctx.NodeLocator, callable, 0);
  43. TComputationNodePtrVector argsNodes = { compute };
  44. TVector<TType*> argsTypes = { callable.GetInput(0).GetStaticType() };
  45. auto kernel = MakeBlockExistsKernel(argsTypes, callable.GetType()->GetReturnType());
  46. return new TBlockFuncNode(ctx.Mutables, "Exists", std::move(argsNodes), argsTypes, *kernel, kernel);
  47. }
  48. } // namespace NMiniKQL
  49. } // namespace NKikimr