mkql_udf.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. #include "mkql_udf.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_holders_codegen.h>
  4. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  5. #include <yql/essentials/minikql/mkql_node_cast.h>
  6. #include <yql/essentials/minikql/computation/mkql_validate.h>
  7. #include <yql/essentials/minikql/mkql_function_registry.h>
  8. #include <yql/essentials/minikql/mkql_node_printer.h>
  9. #include <yql/essentials/minikql/mkql_type_builder.h>
  10. #include <yql/essentials/minikql/mkql_utils.h>
  11. #include <yql/essentials/utils/yql_panic.h>
  12. namespace NKikimr {
  13. namespace NMiniKQL {
  14. namespace {
  15. template<class TValidatePolicy, class TValidateMode>
  16. class TSimpleUdfWrapper: public TMutableComputationNode<TSimpleUdfWrapper<TValidatePolicy,TValidateMode>> {
  17. using TBaseComputation = TMutableComputationNode<TSimpleUdfWrapper<TValidatePolicy,TValidateMode>>;
  18. public:
  19. TSimpleUdfWrapper(
  20. TComputationMutables& mutables,
  21. TString&& functionName,
  22. TString&& typeConfig,
  23. NUdf::TSourcePosition pos,
  24. const TCallableType* callableType,
  25. TType* userType)
  26. : TBaseComputation(mutables, EValueRepresentation::Boxed)
  27. , FunctionName(std::move(functionName))
  28. , TypeConfig(std::move(typeConfig))
  29. , Pos(pos)
  30. , CallableType(callableType)
  31. , UserType(userType)
  32. {
  33. this->Stateless = false;
  34. }
  35. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  36. ui32 flags = 0;
  37. TFunctionTypeInfo funcInfo;
  38. const auto status = ctx.HolderFactory.GetFunctionRegistry()->FindFunctionTypeInfo(
  39. ctx.TypeEnv, ctx.TypeInfoHelper, ctx.CountersProvider, FunctionName, UserType->IsVoid() ? nullptr : UserType,
  40. TypeConfig, flags, Pos, ctx.SecureParamsProvider, &funcInfo);
  41. MKQL_ENSURE(status.IsOk(), status.GetError());
  42. MKQL_ENSURE(funcInfo.Implementation, "UDF implementation is not set for function " << FunctionName);
  43. NUdf::TUnboxedValue udf(NUdf::TUnboxedValuePod(funcInfo.Implementation.Release()));
  44. TValidate<TValidatePolicy,TValidateMode>::WrapCallable(CallableType, udf, TStringBuilder() << "FunctionWithConfig<" << FunctionName << ">");
  45. return udf.Release();
  46. }
  47. private:
  48. void RegisterDependencies() const final {}
  49. const TString FunctionName;
  50. const TString TypeConfig;
  51. const NUdf::TSourcePosition Pos;
  52. const TCallableType *const CallableType;
  53. TType *const UserType;
  54. };
  55. class TUdfRunCodegeneratorNode: public TSimpleUdfWrapper<TValidateErrorPolicyNone, TValidateModeLazy<TValidateErrorPolicyNone>>
  56. #ifndef MKQL_DISABLE_CODEGEN
  57. , public ICodegeneratorRunNode
  58. #endif
  59. {
  60. public:
  61. TUdfRunCodegeneratorNode(
  62. TComputationMutables& mutables,
  63. TString&& functionName,
  64. TString&& typeConfig,
  65. NUdf::TSourcePosition pos,
  66. const TCallableType* callableType,
  67. TType* userType,
  68. TString&& moduleIRUniqID,
  69. TString&& moduleIR,
  70. TString&& fuctioNameIR,
  71. NUdf::TUniquePtr<NUdf::IBoxedValue>&& impl)
  72. : TSimpleUdfWrapper(mutables, std::move(functionName), std::move(typeConfig), pos, callableType, userType)
  73. , ModuleIRUniqID(std::move(moduleIRUniqID))
  74. , ModuleIR(std::move(moduleIR))
  75. , IRFunctionName(std::move(fuctioNameIR))
  76. , Impl(std::move(impl))
  77. {}
  78. #ifndef MKQL_DISABLE_CODEGEN
  79. void CreateRun(const TCodegenContext& ctx, BasicBlock*& block, Value* result, Value* args) const final {
  80. ctx.Codegen.LoadBitCode(ModuleIR, ModuleIRUniqID);
  81. auto& context = ctx.Codegen.GetContext();
  82. const auto type = Type::getInt128Ty(context);
  83. YQL_ENSURE(result->getType() == PointerType::getUnqual(type));
  84. const auto data = ConstantInt::get(Type::getInt64Ty(context), reinterpret_cast<ui64>(Impl.Get()));
  85. const auto ptrStructType = PointerType::getUnqual(StructType::get(context));
  86. const auto boxed = CastInst::Create(Instruction::IntToPtr, data, ptrStructType, "boxed", block);
  87. const auto builder = ctx.GetBuilder();
  88. const auto funType = FunctionType::get(Type::getVoidTy(context), {boxed->getType(), result->getType(), builder->getType(), args->getType()}, false);
  89. const auto runFunc = ctx.Codegen.GetModule().getOrInsertFunction(llvm::StringRef(IRFunctionName.data(), IRFunctionName.size()), funType);
  90. CallInst::Create(runFunc, {boxed, result, builder, args}, "", block);
  91. }
  92. #endif
  93. private:
  94. const TString ModuleIRUniqID;
  95. const TString ModuleIR;
  96. const TString IRFunctionName;
  97. const NUdf::TUniquePtr<NUdf::IBoxedValue> Impl;
  98. };
  99. template<class TValidatePolicy, class TValidateMode>
  100. class TUdfWrapper: public TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolicy,TValidateMode>> {
  101. using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolicy,TValidateMode>>;
  102. public:
  103. TUdfWrapper(
  104. TComputationMutables& mutables,
  105. TString&& functionName,
  106. TString&& typeConfig,
  107. NUdf::TSourcePosition pos,
  108. IComputationNode* runConfigNode,
  109. const TCallableType* callableType,
  110. TType* userType)
  111. : TBaseComputation(mutables, EValueRepresentation::Boxed)
  112. , FunctionName(std::move(functionName))
  113. , TypeConfig(std::move(typeConfig))
  114. , Pos(pos)
  115. , RunConfigNode(runConfigNode)
  116. , CallableType(callableType)
  117. , UserType(userType)
  118. , UdfIndex(mutables.CurValueIndex++)
  119. {
  120. this->Stateless = false;
  121. }
  122. NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const {
  123. auto& udf = ctx.MutableValues[UdfIndex];
  124. if (!udf.HasValue()) {
  125. MakeUdf(ctx, udf);
  126. }
  127. const auto runConfig = RunConfigNode->GetValue(ctx);
  128. auto callable = udf.Run(ctx.Builder, &runConfig);
  129. Wrap(callable);
  130. return callable;
  131. }
  132. #ifndef MKQL_DISABLE_CODEGEN
  133. void DoGenerateGetValue(const TCodegenContext& ctx, Value* pointer, BasicBlock*& block) const {
  134. auto& context = ctx.Codegen.GetContext();
  135. const auto valueType = Type::getInt128Ty(context);
  136. const auto udfPtr = GetElementPtrInst::CreateInBounds(valueType, ctx.GetMutables(), {ConstantInt::get(Type::getInt32Ty(context), UdfIndex)}, "udf_ptr", block);
  137. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  138. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  139. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  140. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  141. BranchInst::Create(main, make, HasValue(udfPtr, block, context), block);
  142. block = make;
  143. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TUdfWrapper::MakeUdf));
  144. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), udfPtr->getType()}, false);
  145. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  146. CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, udfPtr}, "", block);
  147. BranchInst::Create(main, block);
  148. block = main;
  149. GetNodeValue(pointer, RunConfigNode, ctx, block);
  150. const auto conf = new LoadInst(valueType, pointer, "conf", block);
  151. const auto udf = new LoadInst(valueType, udfPtr, "udf", block);
  152. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Run>(pointer, udf, ctx.Codegen, block, ctx.GetBuilder(), pointer);
  153. ValueUnRef(RunConfigNode->GetRepresentation(), conf, ctx, block);
  154. const auto wrap = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TUdfWrapper::Wrap));
  155. const auto funType = FunctionType::get(Type::getVoidTy(context), {self->getType(), pointer->getType()}, false);
  156. const auto doFuncPtr = CastInst::Create(Instruction::IntToPtr, wrap, PointerType::getUnqual(funType), "function", block);
  157. CallInst::Create(funType, doFuncPtr, {self, pointer}, "", block);
  158. }
  159. #endif
  160. private:
  161. void MakeUdf(TComputationContext& ctx, NUdf::TUnboxedValue& udf) const {
  162. ui32 flags = 0;
  163. TFunctionTypeInfo funcInfo;
  164. const auto status = ctx.HolderFactory.GetFunctionRegistry()->FindFunctionTypeInfo(
  165. ctx.TypeEnv, ctx.TypeInfoHelper, ctx.CountersProvider, FunctionName, UserType->IsVoid() ? nullptr : UserType,
  166. TypeConfig, flags, Pos, ctx.SecureParamsProvider, &funcInfo);
  167. MKQL_ENSURE(status.IsOk(), status.GetError());
  168. MKQL_ENSURE(funcInfo.Implementation, "UDF implementation is not set for function " << FunctionName);
  169. udf = NUdf::TUnboxedValuePod(funcInfo.Implementation.Release());
  170. }
  171. void Wrap(NUdf::TUnboxedValue& callable) const {
  172. MKQL_ENSURE(bool(callable), "Returned empty value in function: " << FunctionName);
  173. TValidate<TValidatePolicy,TValidateMode>::WrapCallable(CallableType, callable, TStringBuilder() << "FunctionWithConfig<" << FunctionName << ">");
  174. }
  175. void RegisterDependencies() const final {
  176. this->DependsOn(RunConfigNode);
  177. }
  178. const TString FunctionName;
  179. const TString TypeConfig;
  180. const NUdf::TSourcePosition Pos;
  181. IComputationNode* const RunConfigNode;
  182. const TCallableType* CallableType;
  183. TType* const UserType;
  184. const ui32 UdfIndex;
  185. };
  186. template<bool Simple, class TValidatePolicy, class TValidateMode>
  187. using TWrapper = std::conditional_t<Simple, TSimpleUdfWrapper<TValidatePolicy, TValidateMode>, TUdfWrapper<TValidatePolicy, TValidateMode>>;
  188. template<bool Simple, typename...TArgs>
  189. inline IComputationNode* CreateUdfWrapper(const TComputationNodeFactoryContext& ctx, TArgs&&...args)
  190. {
  191. switch (ctx.ValidateMode) {
  192. case NUdf::EValidateMode::None:
  193. return new TWrapper<Simple, TValidateErrorPolicyNone,TValidateModeLazy<TValidateErrorPolicyNone>>(ctx.Mutables, std::forward<TArgs>(args)...);
  194. case NUdf::EValidateMode::Lazy:
  195. if (ctx.ValidatePolicy == NUdf::EValidatePolicy::Fail) {
  196. return new TWrapper<Simple, TValidateErrorPolicyFail,TValidateModeLazy<TValidateErrorPolicyFail>>(ctx.Mutables, std::forward<TArgs>(args)...);
  197. } else {
  198. return new TWrapper<Simple, TValidateErrorPolicyThrow,TValidateModeLazy<TValidateErrorPolicyThrow>>(ctx.Mutables, std::forward<TArgs>(args)...);
  199. }
  200. case NUdf::EValidateMode::Greedy:
  201. if (ctx.ValidatePolicy == NUdf::EValidatePolicy::Fail) {
  202. return new TWrapper<Simple, TValidateErrorPolicyFail,TValidateModeGreedy<TValidateErrorPolicyFail>>(ctx.Mutables, std::forward<TArgs>(args)...);
  203. } else {
  204. return new TWrapper<Simple, TValidateErrorPolicyThrow,TValidateModeGreedy<TValidateErrorPolicyThrow>>(ctx.Mutables, std::forward<TArgs>(args)...);
  205. }
  206. default:
  207. Y_ABORT("Unexpected validate mode: %u", static_cast<unsigned>(ctx.ValidateMode));
  208. };
  209. }
  210. }
  211. IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  212. MKQL_ENSURE(callable.GetInputsCount() == 4 || callable.GetInputsCount() == 7, "Expected 4 or 7 arguments");
  213. const auto funcNameNode = callable.GetInput(0);
  214. const auto userTypeNode = callable.GetInput(1);
  215. const auto typeCfgNode = callable.GetInput(2);
  216. const auto runCfgNode = callable.GetInput(3);
  217. MKQL_ENSURE(userTypeNode.IsImmediate(), "Expected immediate node");
  218. MKQL_ENSURE(userTypeNode.GetStaticType()->IsType(), "Expected type");
  219. TString funcName(AS_VALUE(TDataLiteral, funcNameNode)->AsValue().AsStringRef());
  220. TString typeConfig(AS_VALUE(TDataLiteral, typeCfgNode)->AsValue().AsStringRef());
  221. NUdf::TSourcePosition pos;
  222. if (callable.GetInputsCount() == 7) {
  223. pos.File_ = AS_VALUE(TDataLiteral, callable.GetInput(4))->AsValue().AsStringRef();
  224. pos.Row_ = AS_VALUE(TDataLiteral, callable.GetInput(5))->AsValue().Get<ui32>();
  225. pos.Column_ = AS_VALUE(TDataLiteral, callable.GetInput(6))->AsValue().Get<ui32>();
  226. }
  227. ui32 flags = 0;
  228. TFunctionTypeInfo funcInfo;
  229. const auto userType = static_cast<TType*>(userTypeNode.GetNode());
  230. const auto status = ctx.FunctionRegistry.FindFunctionTypeInfo(
  231. ctx.Env, ctx.TypeInfoHelper, ctx.CountersProvider, funcName, userType->IsVoid() ? nullptr : userType,
  232. typeConfig, flags, pos, ctx.SecureParamsProvider, &funcInfo);
  233. MKQL_ENSURE(status.IsOk(), status.GetError());
  234. MKQL_ENSURE(funcInfo.FunctionType->IsConvertableTo(*callable.GetType()->GetReturnType(), true),
  235. "Function '" << funcName << "' type mismatch, expected return type: " << PrintNode(callable.GetType()->GetReturnType(), true) <<
  236. ", actual:" << PrintNode(funcInfo.FunctionType, true));
  237. MKQL_ENSURE(funcInfo.Implementation, "UDF implementation is not set for function " << funcName);
  238. const auto runConfigType = funcInfo.RunConfigType;
  239. const bool typesMatch = runConfigType->IsSameType(*runCfgNode.GetStaticType());
  240. MKQL_ENSURE(typesMatch, "RunConfig '" << funcName << "' type mismatch, expected: " << PrintNode(runCfgNode.GetStaticType(), true) <<
  241. ", actual: " << PrintNode(runConfigType, true));
  242. if (runConfigType->IsVoid()) {
  243. if (ctx.ValidateMode == NUdf::EValidateMode::None && funcInfo.ModuleIR && funcInfo.IRFunctionName) {
  244. return new TUdfRunCodegeneratorNode(
  245. ctx.Mutables, std::move(funcName), std::move(typeConfig), pos, funcInfo.FunctionType, userType,
  246. std::move(funcInfo.ModuleIRUniqID), std::move(funcInfo.ModuleIR), std::move(funcInfo.IRFunctionName), std::move(funcInfo.Implementation)
  247. );
  248. }
  249. return CreateUdfWrapper<true>(ctx, std::move(funcName), std::move(typeConfig), pos, funcInfo.FunctionType, userType);
  250. }
  251. const auto runCfgCompNode = LocateNode(ctx.NodeLocator, *runCfgNode.GetNode());
  252. return CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, runCfgCompNode, funcInfo.FunctionType, userType);
  253. }
  254. IComputationNode* WrapScriptUdf(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  255. MKQL_ENSURE(callable.GetInputsCount() == 4 || callable.GetInputsCount() == 7, "Expected 4 or 7 arguments");
  256. const auto funcNameNode = callable.GetInput(0);
  257. const auto userTypeNode = callable.GetInput(1);
  258. const auto typeConfigNode = callable.GetInput(2);
  259. const auto programNode = callable.GetInput(3);
  260. MKQL_ENSURE(userTypeNode.IsImmediate() && userTypeNode.GetStaticType()->IsType(), "Expected immediate type");
  261. TString funcName(AS_VALUE(TDataLiteral, funcNameNode)->AsValue().AsStringRef());
  262. TString typeConfig(AS_VALUE(TDataLiteral, typeConfigNode)->AsValue().AsStringRef());
  263. NUdf::TSourcePosition pos;
  264. if (callable.GetInputsCount() == 7) {
  265. pos.File_ = AS_VALUE(TDataLiteral, callable.GetInput(4))->AsValue().AsStringRef();
  266. pos.Row_ = AS_VALUE(TDataLiteral, callable.GetInput(5))->AsValue().Get<ui32>();
  267. pos.Column_ = AS_VALUE(TDataLiteral, callable.GetInput(6))->AsValue().Get<ui32>();
  268. }
  269. const auto userType = static_cast<TType*>(userTypeNode.GetNode());
  270. ui32 flags = 0;
  271. TFunctionTypeInfo funcInfo;
  272. const auto status = ctx.FunctionRegistry.FindFunctionTypeInfo(
  273. ctx.Env, ctx.TypeInfoHelper, ctx.CountersProvider, funcName, userType,
  274. typeConfig, flags, pos, ctx.SecureParamsProvider, &funcInfo);
  275. MKQL_ENSURE(status.IsOk(), status.GetError());
  276. MKQL_ENSURE(funcInfo.Implementation, "UDF implementation is not set");
  277. MKQL_ENSURE(!funcInfo.FunctionType, "Function type info is exist for same kind script, it's better use it");
  278. const auto callableType = callable.GetType();
  279. MKQL_ENSURE(callableType->GetKind() == TType::EKind::Callable, "Expected callable type in callable type info");
  280. const auto callableResultType = callableType->GetReturnType();
  281. MKQL_ENSURE(callableResultType->GetKind() == TType::EKind::Callable, "Expected callable type in result of script wrapper");
  282. const auto funcTypeInfo = static_cast<TCallableType*>(callableResultType);
  283. const auto programCompNode = LocateNode(ctx.NodeLocator, *programNode.GetNode());
  284. return CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, programCompNode, funcTypeInfo, userType);
  285. }
  286. }
  287. }