mkql_function_metadata.h 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. #pragma once
  2. #include <yql/essentials/minikql/defs.h>
  3. #include <yql/essentials/minikql/mkql_node.h>
  4. #include <yql/essentials/public/udf/udf_value.h>
  5. #include <util/digest/numeric.h>
  6. #include <util/generic/vector.h>
  7. #include <arrow/compute/kernel.h>
  8. namespace NKikimr {
  9. namespace NMiniKQL {
  10. using TFunctionPtr = NUdf::TUnboxedValuePod (*)(const NUdf::TUnboxedValuePod* args);
  11. struct TFunctionParamMetadata {
  12. enum EFlags : ui16 {
  13. FlagIsNullable = 0x01,
  14. };
  15. TFunctionParamMetadata() = default;
  16. TFunctionParamMetadata(NUdf::TDataTypeId schemeType, ui32 flags)
  17. : SchemeType(schemeType)
  18. , Flags(flags)
  19. {}
  20. bool IsNullable() const {
  21. return Flags & FlagIsNullable;
  22. }
  23. NUdf::TDataTypeId SchemeType = 0;
  24. ui16 Flags = 0;
  25. };
  26. struct TFunctionDescriptor {
  27. TFunctionDescriptor() = default;
  28. TFunctionDescriptor(const TFunctionParamMetadata* resultAndArgs, TFunctionPtr function, void* generator = nullptr)
  29. : ResultAndArgs(resultAndArgs)
  30. , Function(function)
  31. , Generator(generator)
  32. {}
  33. const TFunctionParamMetadata* ResultAndArgs = nullptr; // ends with SchemeType zero
  34. TFunctionPtr Function = nullptr;
  35. void *Generator = nullptr;
  36. };
  37. using TFunctionParamMetadataList = std::vector<TFunctionParamMetadata>;
  38. using TArgType = std::pair<NUdf::TDataTypeId, bool>; // type with optional flag
  39. using TDescriptionList = std::vector<TFunctionDescriptor>;
  40. using TFunctionsMap = std::unordered_map<TString, TDescriptionList>;
  41. class TKernel;
  42. class TKernelFamily {
  43. public:
  44. const arrow::compute::FunctionOptions* FunctionOptions;
  45. TKernelFamily(const arrow::compute::FunctionOptions* functionOptions = nullptr)
  46. : FunctionOptions(functionOptions)
  47. {}
  48. virtual ~TKernelFamily() = default;
  49. virtual const TKernel* FindKernel(const NUdf::TDataTypeId* argTypes, size_t argTypesCount, NUdf::TDataTypeId returnType) const = 0;
  50. virtual TVector<const TKernel*> GetAllKernels() const = 0;
  51. };
  52. class TKernel {
  53. public:
  54. enum class ENullMode {
  55. Default,
  56. AlwaysNull,
  57. AlwaysNotNull
  58. };
  59. const TKernelFamily& Family;
  60. const std::vector<NUdf::TDataTypeId> ArgTypes;
  61. const NUdf::TDataTypeId ReturnType;
  62. const ENullMode NullMode;
  63. TKernel(const TKernelFamily& family, const std::vector<NUdf::TDataTypeId>& argTypes, NUdf::TDataTypeId returnType, ENullMode nullMode)
  64. : Family(family)
  65. , ArgTypes(argTypes)
  66. , ReturnType(returnType)
  67. , NullMode(nullMode)
  68. {
  69. }
  70. virtual const arrow::compute::ScalarKernel& GetArrowKernel() const = 0;
  71. virtual std::shared_ptr<arrow::compute::ScalarKernel> MakeArrowKernel(const TVector<TType*>& argTypes, TType* resultType) const = 0;
  72. virtual bool IsPolymorphic() const = 0;
  73. virtual ~TKernel() = default;
  74. };
  75. using TKernelMapKey = std::pair<std::vector<NUdf::TDataTypeId>, NUdf::TDataTypeId>;
  76. struct TTypeHasher {
  77. std::size_t operator()(const TKernelMapKey& s) const noexcept {
  78. size_t r = 0;
  79. for (const auto& x : s.first) {
  80. r = CombineHashes<size_t>(r, x);
  81. }
  82. r = CombineHashes<size_t>(r, s.second);
  83. return r;
  84. }
  85. };
  86. using TKernelMap = std::unordered_map<TKernelMapKey, std::unique_ptr<TKernel>, TTypeHasher>;
  87. using TKernelFamilyMap = std::unordered_map<TString, std::unique_ptr<TKernelFamily>>;
  88. class TKernelFamilyBase : public TKernelFamily
  89. {
  90. public:
  91. TKernelFamilyBase(const arrow::compute::FunctionOptions* functionOptions = nullptr);
  92. const TKernel* FindKernel(const NUdf::TDataTypeId* argTypes, size_t argTypesCount, NUdf::TDataTypeId returnType) const final;
  93. TVector<const TKernel*> GetAllKernels() const final;
  94. void Adopt(const std::vector<NUdf::TDataTypeId>& argTypes, NUdf::TDataTypeId returnType, std::unique_ptr<TKernel>&& kernel);
  95. private:
  96. TKernelMap KernelMap;
  97. };
  98. class IBuiltinFunctionRegistry: public TThrRefBase, private TNonCopyable
  99. {
  100. public:
  101. typedef TIntrusivePtr<IBuiltinFunctionRegistry> TPtr;
  102. virtual ui64 GetMetadataEtag() const = 0;
  103. virtual void PrintInfoTo(IOutputStream& out) const = 0;
  104. virtual void Register(const std::string_view& name, const TFunctionDescriptor& description) = 0;
  105. virtual bool HasBuiltin(const std::string_view& name) const = 0;
  106. virtual void RegisterAll(TFunctionsMap&& functions, TFunctionParamMetadataList&& arguments) = 0;
  107. virtual const TFunctionsMap& GetFunctions() const = 0;
  108. virtual TFunctionDescriptor GetBuiltin(const std::string_view& name, const std::pair<NUdf::TDataTypeId, bool>* argTypes, size_t argTypesCount) const = 0;
  109. virtual const TKernel* FindKernel(const std::string_view& name, const NUdf::TDataTypeId* argTypes, size_t argTypesCount, NUdf::TDataTypeId returnType) const = 0;
  110. virtual void RegisterKernelFamily(const std::string_view& name, std::unique_ptr<TKernelFamily>&& family) = 0;
  111. virtual TVector<std::pair<TString, const TKernelFamily*>> GetAllKernelFamilies() const = 0;
  112. };
  113. }
  114. }