mkql_udf.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. #include "mkql_udf.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_holders_codegen.h>
  4. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  5. #include <yql/essentials/minikql/mkql_node_cast.h>
  6. #include <yql/essentials/minikql/computation/mkql_validate.h>
  7. #include <yql/essentials/minikql/mkql_function_registry.h>
  8. #include <yql/essentials/minikql/mkql_node_printer.h>
  9. #include <yql/essentials/minikql/mkql_type_builder.h>
  10. #include <yql/essentials/minikql/mkql_utils.h>
  11. #include <yql/essentials/utils/yql_panic.h>
  12. namespace NKikimr {
  13. namespace NMiniKQL {
  14. namespace {
  15. constexpr size_t TypeDiffLimit = 1000;
  16. TString TruncateTypeDiff(const TString& s) {
  17. if (s.size() < TypeDiffLimit) {
  18. return s;
  19. }
  20. return s.substr(0,TypeDiffLimit) + "...";
  21. }
  22. template<class TValidatePolicy, class TValidateMode>
  23. class TSimpleUdfWrapper: public TMutableComputationNode<TSimpleUdfWrapper<TValidatePolicy,TValidateMode>> {
  24. using TBaseComputation = TMutableComputationNode<TSimpleUdfWrapper<TValidatePolicy,TValidateMode>>;
  25. public:
  26. TSimpleUdfWrapper(
  27. TComputationMutables& mutables,
  28. TString&& functionName,
  29. TString&& typeConfig,
  30. NUdf::TSourcePosition pos,
  31. const TCallableType* callableType,
  32. TType* userType)
  33. : TBaseComputation(mutables, EValueRepresentation::Boxed)
  34. , FunctionName(std::move(functionName))
  35. , TypeConfig(std::move(typeConfig))
  36. , Pos(pos)
  37. , CallableType(callableType)
  38. , UserType(userType)
  39. {
  40. this->Stateless = false;
  41. }
  42. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  43. ui32 flags = 0;
  44. TFunctionTypeInfo funcInfo;
  45. const auto status = ctx.HolderFactory.GetFunctionRegistry()->FindFunctionTypeInfo(
  46. ctx.TypeEnv, ctx.TypeInfoHelper, ctx.CountersProvider, FunctionName, UserType->IsVoid() ? nullptr : UserType,
  47. TypeConfig, flags, Pos, ctx.SecureParamsProvider, &funcInfo);
  48. if (!status.IsOk()) {
  49. UdfTerminate((TStringBuilder() << Pos << " Failed to find UDF function " << FunctionName << ", reason: "
  50. << status.GetError()).c_str());
  51. }
  52. if (!funcInfo.Implementation) {
  53. UdfTerminate((TStringBuilder() << Pos << " UDF implementation is not set for function " << FunctionName).c_str());
  54. }
  55. NUdf::TUnboxedValue udf(NUdf::TUnboxedValuePod(funcInfo.Implementation.Release()));
  56. TValidate<TValidatePolicy,TValidateMode>::WrapCallable(CallableType, udf, TStringBuilder() << "FunctionWithConfig<" << FunctionName << ">");
  57. return udf.Release();
  58. }
  59. private:
  60. void RegisterDependencies() const final {}
  61. const TString FunctionName;
  62. const TString TypeConfig;
  63. const NUdf::TSourcePosition Pos;
  64. const TCallableType *const CallableType;
  65. TType *const UserType;
  66. };
  67. class TUdfRunCodegeneratorNode: public TSimpleUdfWrapper<TValidateErrorPolicyNone, TValidateModeLazy<TValidateErrorPolicyNone>>
  68. #ifndef MKQL_DISABLE_CODEGEN
  69. , public ICodegeneratorRunNode
  70. #endif
  71. {
  72. public:
  73. TUdfRunCodegeneratorNode(
  74. TComputationMutables& mutables,
  75. TString&& functionName,
  76. TString&& typeConfig,
  77. NUdf::TSourcePosition pos,
  78. const TCallableType* callableType,
  79. TType* userType,
  80. TString&& moduleIRUniqID,
  81. TString&& moduleIR,
  82. TString&& fuctioNameIR,
  83. NUdf::TUniquePtr<NUdf::IBoxedValue>&& impl)
  84. : TSimpleUdfWrapper(mutables, std::move(functionName), std::move(typeConfig), pos, callableType, userType)
  85. , ModuleIRUniqID(std::move(moduleIRUniqID))
  86. , ModuleIR(std::move(moduleIR))
  87. , IRFunctionName(std::move(fuctioNameIR))
  88. , Impl(std::move(impl))
  89. {}
  90. #ifndef MKQL_DISABLE_CODEGEN
  91. void CreateRun(const TCodegenContext& ctx, BasicBlock*& block, Value* result, Value* args) const final {
  92. ctx.Codegen.LoadBitCode(ModuleIR, ModuleIRUniqID);
  93. auto& context = ctx.Codegen.GetContext();
  94. const auto type = Type::getInt128Ty(context);
  95. YQL_ENSURE(result->getType() == PointerType::getUnqual(type));
  96. const auto data = ConstantInt::get(Type::getInt64Ty(context), reinterpret_cast<ui64>(Impl.Get()));
  97. const auto ptrStructType = PointerType::getUnqual(StructType::get(context));
  98. const auto boxed = CastInst::Create(Instruction::IntToPtr, data, ptrStructType, "boxed", block);
  99. const auto builder = ctx.GetBuilder();
  100. const auto funType = FunctionType::get(Type::getVoidTy(context), {boxed->getType(), result->getType(), builder->getType(), args->getType()}, false);
  101. const auto runFunc = ctx.Codegen.GetModule().getOrInsertFunction(llvm::StringRef(IRFunctionName.data(), IRFunctionName.size()), funType);
  102. CallInst::Create(runFunc, {boxed, result, builder, args}, "", block);
  103. }
  104. #endif
  105. private:
  106. const TString ModuleIRUniqID;
  107. const TString ModuleIR;
  108. const TString IRFunctionName;
  109. const NUdf::TUniquePtr<NUdf::IBoxedValue> Impl;
  110. };
  111. template<class TValidatePolicy, class TValidateMode>
  112. class TUdfWrapper: public TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolicy,TValidateMode>> {
  113. using TBaseComputation = TMutableCodegeneratorPtrNode<TUdfWrapper<TValidatePolicy,TValidateMode>>;
  114. public:
  115. TUdfWrapper(
  116. TComputationMutables& mutables,
  117. TString&& functionName,
  118. TString&& typeConfig,
  119. NUdf::TSourcePosition pos,
  120. IComputationNode* runConfigNode,
  121. const TCallableType* callableType,
  122. TType* userType)
  123. : TBaseComputation(mutables, EValueRepresentation::Boxed)
  124. , FunctionName(std::move(functionName))
  125. , TypeConfig(std::move(typeConfig))
  126. , Pos(pos)
  127. , RunConfigNode(runConfigNode)
  128. , CallableType(callableType)
  129. , UserType(userType)
  130. , UdfIndex(mutables.CurValueIndex++)
  131. {
  132. this->Stateless = false;
  133. }
  134. NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const {
  135. auto& udf = ctx.MutableValues[UdfIndex];
  136. if (!udf.HasValue()) {
  137. MakeUdf(ctx, udf);
  138. }
  139. const auto runConfig = RunConfigNode->GetValue(ctx);
  140. auto callable = udf.Run(ctx.Builder, &runConfig);
  141. Wrap(callable);
  142. return callable;
  143. }
  144. #ifndef MKQL_DISABLE_CODEGEN
  145. void DoGenerateGetValue(const TCodegenContext& ctx, Value* pointer, BasicBlock*& block) const {
  146. auto& context = ctx.Codegen.GetContext();
  147. const auto valueType = Type::getInt128Ty(context);
  148. const auto udfPtr = GetElementPtrInst::CreateInBounds(valueType, ctx.GetMutables(), {ConstantInt::get(Type::getInt32Ty(context), UdfIndex)}, "udf_ptr", block);
  149. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  150. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  151. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  152. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  153. BranchInst::Create(main, make, HasValue(udfPtr, block, context), block);
  154. block = make;
  155. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TUdfWrapper::MakeUdf));
  156. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), udfPtr->getType()}, false);
  157. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  158. CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, udfPtr}, "", block);
  159. BranchInst::Create(main, block);
  160. block = main;
  161. GetNodeValue(pointer, RunConfigNode, ctx, block);
  162. const auto conf = new LoadInst(valueType, pointer, "conf", block);
  163. const auto udf = new LoadInst(valueType, udfPtr, "udf", block);
  164. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Run>(pointer, udf, ctx.Codegen, block, ctx.GetBuilder(), pointer);
  165. ValueUnRef(RunConfigNode->GetRepresentation(), conf, ctx, block);
  166. const auto wrap = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TUdfWrapper::Wrap));
  167. const auto funType = FunctionType::get(Type::getVoidTy(context), {self->getType(), pointer->getType()}, false);
  168. const auto doFuncPtr = CastInst::Create(Instruction::IntToPtr, wrap, PointerType::getUnqual(funType), "function", block);
  169. CallInst::Create(funType, doFuncPtr, {self, pointer}, "", block);
  170. }
  171. #endif
  172. private:
  173. void MakeUdf(TComputationContext& ctx, NUdf::TUnboxedValue& udf) const {
  174. ui32 flags = 0;
  175. TFunctionTypeInfo funcInfo;
  176. const auto status = ctx.HolderFactory.GetFunctionRegistry()->FindFunctionTypeInfo(
  177. ctx.TypeEnv, ctx.TypeInfoHelper, ctx.CountersProvider, FunctionName, UserType->IsVoid() ? nullptr : UserType,
  178. TypeConfig, flags, Pos, ctx.SecureParamsProvider, &funcInfo);
  179. if (!status.IsOk()) {
  180. UdfTerminate((TStringBuilder() << Pos << " Failed to find UDF function " << FunctionName << ", reason: "
  181. << status.GetError()).c_str());
  182. }
  183. if (!funcInfo.Implementation) {
  184. UdfTerminate((TStringBuilder() << Pos << " UDF implementation is not set for function " << FunctionName).c_str());
  185. }
  186. udf = NUdf::TUnboxedValuePod(funcInfo.Implementation.Release());
  187. }
  188. void Wrap(NUdf::TUnboxedValue& callable) const {
  189. TValidate<TValidatePolicy,TValidateMode>::WrapCallable(CallableType, callable, TStringBuilder() << "FunctionWithConfig<" << FunctionName << ">");
  190. }
  191. void RegisterDependencies() const final {
  192. this->DependsOn(RunConfigNode);
  193. }
  194. const TString FunctionName;
  195. const TString TypeConfig;
  196. const NUdf::TSourcePosition Pos;
  197. IComputationNode* const RunConfigNode;
  198. const TCallableType* CallableType;
  199. TType* const UserType;
  200. const ui32 UdfIndex;
  201. };
  202. template<bool Simple, class TValidatePolicy, class TValidateMode>
  203. using TWrapper = std::conditional_t<Simple, TSimpleUdfWrapper<TValidatePolicy, TValidateMode>, TUdfWrapper<TValidatePolicy, TValidateMode>>;
  204. template<bool Simple, typename...TArgs>
  205. inline IComputationNode* CreateUdfWrapper(const TComputationNodeFactoryContext& ctx, TArgs&&...args)
  206. {
  207. switch (ctx.ValidateMode) {
  208. case NUdf::EValidateMode::None:
  209. return new TWrapper<Simple, TValidateErrorPolicyNone,TValidateModeLazy<TValidateErrorPolicyNone>>(ctx.Mutables, std::forward<TArgs>(args)...);
  210. case NUdf::EValidateMode::Lazy:
  211. if (ctx.ValidatePolicy == NUdf::EValidatePolicy::Fail) {
  212. return new TWrapper<Simple, TValidateErrorPolicyFail,TValidateModeLazy<TValidateErrorPolicyFail>>(ctx.Mutables, std::forward<TArgs>(args)...);
  213. } else {
  214. return new TWrapper<Simple, TValidateErrorPolicyThrow,TValidateModeLazy<TValidateErrorPolicyThrow>>(ctx.Mutables, std::forward<TArgs>(args)...);
  215. }
  216. case NUdf::EValidateMode::Greedy:
  217. if (ctx.ValidatePolicy == NUdf::EValidatePolicy::Fail) {
  218. return new TWrapper<Simple, TValidateErrorPolicyFail,TValidateModeGreedy<TValidateErrorPolicyFail>>(ctx.Mutables, std::forward<TArgs>(args)...);
  219. } else {
  220. return new TWrapper<Simple, TValidateErrorPolicyThrow,TValidateModeGreedy<TValidateErrorPolicyThrow>>(ctx.Mutables, std::forward<TArgs>(args)...);
  221. }
  222. default:
  223. Y_ABORT("Unexpected validate mode: %u", static_cast<unsigned>(ctx.ValidateMode));
  224. };
  225. }
  226. }
  227. IComputationNode* WrapUdf(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  228. MKQL_ENSURE(callable.GetInputsCount() == 4 || callable.GetInputsCount() == 7, "Expected 4 or 7 arguments");
  229. const auto funcNameNode = callable.GetInput(0);
  230. const auto userTypeNode = callable.GetInput(1);
  231. const auto typeCfgNode = callable.GetInput(2);
  232. const auto runCfgNode = callable.GetInput(3);
  233. MKQL_ENSURE(userTypeNode.IsImmediate(), "Expected immediate node");
  234. MKQL_ENSURE(userTypeNode.GetStaticType()->IsType(), "Expected type");
  235. TString funcName(AS_VALUE(TDataLiteral, funcNameNode)->AsValue().AsStringRef());
  236. TString typeConfig(AS_VALUE(TDataLiteral, typeCfgNode)->AsValue().AsStringRef());
  237. NUdf::TSourcePosition pos;
  238. if (callable.GetInputsCount() == 7) {
  239. pos.File_ = AS_VALUE(TDataLiteral, callable.GetInput(4))->AsValue().AsStringRef();
  240. pos.Row_ = AS_VALUE(TDataLiteral, callable.GetInput(5))->AsValue().Get<ui32>();
  241. pos.Column_ = AS_VALUE(TDataLiteral, callable.GetInput(6))->AsValue().Get<ui32>();
  242. }
  243. ui32 flags = 0;
  244. TFunctionTypeInfo funcInfo;
  245. const auto userType = static_cast<TType*>(userTypeNode.GetNode());
  246. const auto status = ctx.FunctionRegistry.FindFunctionTypeInfo(
  247. ctx.Env, ctx.TypeInfoHelper, ctx.CountersProvider, funcName, userType->IsVoid() ? nullptr : userType,
  248. typeConfig, flags, pos, ctx.SecureParamsProvider, &funcInfo);
  249. if (!status.IsOk()) {
  250. UdfTerminate((TStringBuilder() << pos << " Failed to find UDF function " << funcName << ", reason: "
  251. << status.GetError()).c_str());
  252. }
  253. if (!funcInfo.FunctionType->IsConvertableTo(*callable.GetType()->GetReturnType(), true)) {
  254. TString diff = TStringBuilder() << "type mismatch, expected return type: " << PrintNode(callable.GetType()->GetReturnType(), true) <<
  255. ", actual:" << PrintNode(funcInfo.FunctionType, true);
  256. UdfTerminate((TStringBuilder() << pos << " UDF Function '" << funcName << "' " << TruncateTypeDiff(diff)).c_str());
  257. }
  258. if (!funcInfo.Implementation) {
  259. UdfTerminate((TStringBuilder() << pos << " UDF implementation is not set for function " << funcName).c_str());
  260. }
  261. const auto runConfigType = funcInfo.RunConfigType;
  262. if (!runConfigType->IsSameType(*runCfgNode.GetStaticType())) {
  263. TString diff = TStringBuilder() << "run config type mismatch, expected: " << PrintNode(runCfgNode.GetStaticType(), true) <<
  264. ", actual:" << PrintNode(runConfigType, true);
  265. UdfTerminate((TStringBuilder() << pos << " UDF Function '" << funcName << "' " << TruncateTypeDiff(diff)).c_str());
  266. }
  267. if (runConfigType->IsVoid()) {
  268. if (ctx.ValidateMode == NUdf::EValidateMode::None && funcInfo.ModuleIR && funcInfo.IRFunctionName) {
  269. return new TUdfRunCodegeneratorNode(
  270. ctx.Mutables, std::move(funcName), std::move(typeConfig), pos, funcInfo.FunctionType, userType,
  271. std::move(funcInfo.ModuleIRUniqID), std::move(funcInfo.ModuleIR), std::move(funcInfo.IRFunctionName), std::move(funcInfo.Implementation)
  272. );
  273. }
  274. return CreateUdfWrapper<true>(ctx, std::move(funcName), std::move(typeConfig), pos, funcInfo.FunctionType, userType);
  275. }
  276. const auto runCfgCompNode = LocateNode(ctx.NodeLocator, *runCfgNode.GetNode());
  277. return CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, runCfgCompNode, funcInfo.FunctionType, userType);
  278. }
  279. IComputationNode* WrapScriptUdf(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  280. MKQL_ENSURE(callable.GetInputsCount() == 4 || callable.GetInputsCount() == 7, "Expected 4 or 7 arguments");
  281. const auto funcNameNode = callable.GetInput(0);
  282. const auto userTypeNode = callable.GetInput(1);
  283. const auto typeConfigNode = callable.GetInput(2);
  284. const auto programNode = callable.GetInput(3);
  285. MKQL_ENSURE(userTypeNode.IsImmediate() && userTypeNode.GetStaticType()->IsType(), "Expected immediate type");
  286. TString funcName(AS_VALUE(TDataLiteral, funcNameNode)->AsValue().AsStringRef());
  287. TString typeConfig(AS_VALUE(TDataLiteral, typeConfigNode)->AsValue().AsStringRef());
  288. NUdf::TSourcePosition pos;
  289. if (callable.GetInputsCount() == 7) {
  290. pos.File_ = AS_VALUE(TDataLiteral, callable.GetInput(4))->AsValue().AsStringRef();
  291. pos.Row_ = AS_VALUE(TDataLiteral, callable.GetInput(5))->AsValue().Get<ui32>();
  292. pos.Column_ = AS_VALUE(TDataLiteral, callable.GetInput(6))->AsValue().Get<ui32>();
  293. }
  294. const auto userType = static_cast<TType*>(userTypeNode.GetNode());
  295. ui32 flags = 0;
  296. TFunctionTypeInfo funcInfo;
  297. const auto status = ctx.FunctionRegistry.FindFunctionTypeInfo(
  298. ctx.Env, ctx.TypeInfoHelper, ctx.CountersProvider, funcName, userType,
  299. typeConfig, flags, pos, ctx.SecureParamsProvider, &funcInfo);
  300. if (!status.IsOk()) {
  301. UdfTerminate((TStringBuilder() << pos << " Failed to find UDF function " << funcName << ", reason: "
  302. << status.GetError()).c_str());
  303. }
  304. if (!funcInfo.Implementation) {
  305. UdfTerminate((TStringBuilder() << pos << " UDF implementation is not set for function " << funcName).c_str());
  306. }
  307. if (funcInfo.FunctionType) {
  308. UdfTerminate((TStringBuilder() << pos << " UDF function type exists for function " << funcName).c_str());
  309. }
  310. const auto callableType = callable.GetType();
  311. MKQL_ENSURE(callableType->GetKind() == TType::EKind::Callable, "Expected callable type in callable type info");
  312. const auto callableResultType = callableType->GetReturnType();
  313. MKQL_ENSURE(callableResultType->GetKind() == TType::EKind::Callable, "Expected callable type in result of script wrapper");
  314. const auto funcTypeInfo = static_cast<TCallableType*>(callableResultType);
  315. const auto programCompNode = LocateNode(ctx.NodeLocator, *programNode.GetNode());
  316. return CreateUdfWrapper<false>(ctx, std::move(funcName), std::move(typeConfig), pos, programCompNode, funcTypeInfo, userType);
  317. }
  318. }
  319. }