mkql_block_func.cpp 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. #include "mkql_block_func.h"
  2. #include <yql/essentials/minikql/computation/mkql_block_impl.h>
  3. #include <yql/essentials/minikql/arrow/arrow_defs.h>
  4. #include <yql/essentials/minikql/mkql_node_builder.h>
  5. #include <yql/essentials/minikql/mkql_node_cast.h>
  6. #include <yql/essentials/minikql/mkql_type_builder.h>
  7. #include <arrow/compute/cast.h>
  8. namespace NKikimr {
  9. namespace NMiniKQL {
  10. namespace {
  11. const TKernel& ResolveKernel(const IBuiltinFunctionRegistry& builtins, const TString& funcName, const TVector<TType*>& inputTypes, TType* returnType) {
  12. std::vector<NUdf::TDataTypeId> argTypes;
  13. for (const auto& t : inputTypes) {
  14. auto asBlockType = AS_TYPE(TBlockType, t);
  15. bool isOptional;
  16. auto dataType = UnpackOptionalData(asBlockType->GetItemType(), isOptional);
  17. argTypes.push_back(dataType->GetSchemeType());
  18. }
  19. NUdf::TDataTypeId returnTypeId;
  20. {
  21. auto asBlockType = AS_TYPE(TBlockType, returnType);
  22. bool isOptional;
  23. auto dataType = UnpackOptionalData(asBlockType->GetItemType(), isOptional);
  24. returnTypeId = dataType->GetSchemeType();
  25. }
  26. auto kernel = builtins.FindKernel(funcName, argTypes.data(), argTypes.size(), returnTypeId);
  27. MKQL_ENSURE(kernel, "Can't find kernel for " << funcName);
  28. return *kernel;
  29. }
  30. } // namespace
  31. IComputationNode* WrapBlockFunc(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  32. MKQL_ENSURE(callable.GetInputsCount() >= 1, "Expected at least 1 arg");
  33. const auto funcNameData = AS_VALUE(TDataLiteral, callable.GetInput(0));
  34. const auto funcName = TString(funcNameData->AsValue().AsStringRef());
  35. TComputationNodePtrVector argsNodes;
  36. TVector<TType*> argsTypes;
  37. const auto callableType = callable.GetType();
  38. for (ui32 i = 1; i < callable.GetInputsCount(); ++i) {
  39. argsNodes.push_back(LocateNode(ctx.NodeLocator, callable, i));
  40. argsTypes.push_back(callableType->GetArgumentType(i));
  41. }
  42. const TKernel& kernel = ResolveKernel(*ctx.FunctionRegistry.GetBuiltins(), funcName, argsTypes, callableType->GetReturnType());
  43. if (kernel.IsPolymorphic()) {
  44. auto arrowKernel = kernel.MakeArrowKernel(argsTypes, callableType->GetReturnType());
  45. return new TBlockFuncNode(ctx.Mutables, funcName, std::move(argsNodes), argsTypes, *arrowKernel, arrowKernel, kernel.Family.FunctionOptions);
  46. } else {
  47. return new TBlockFuncNode(ctx.Mutables, funcName, std::move(argsNodes), argsTypes, kernel.GetArrowKernel(), {}, kernel.Family.FunctionOptions);
  48. }
  49. }
  50. }
  51. }