mkql_callable.cpp 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. #include "mkql_callable.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  4. #include <yql/essentials/minikql/mkql_node_cast.h>
  5. namespace NKikimr {
  6. namespace NMiniKQL {
  7. namespace {
  8. class TCallableWrapper : public TCustomValueCodegeneratorNode<TCallableWrapper> {
  9. typedef TCustomValueCodegeneratorNode<TCallableWrapper> TBaseComputation;
  10. private:
  11. class TValue : public TComputationValue<TValue> {
  12. public:
  13. TValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, IComputationNode* resultNode,
  14. const TComputationExternalNodePtrVector& argNodes)
  15. : TComputationValue(memInfo)
  16. , CompCtx(compCtx)
  17. , ResultNode(resultNode)
  18. , ArgNodes(argNodes)
  19. {}
  20. private:
  21. NUdf::TUnboxedValue Run(const NUdf::IValueBuilder*, const NUdf::TUnboxedValuePod* args) const override
  22. {
  23. for (const auto node : ArgNodes) {
  24. node->SetValue(CompCtx, NUdf::TUnboxedValuePod(*args++));
  25. }
  26. return ResultNode->GetValue(CompCtx);
  27. }
  28. TComputationContext& CompCtx;
  29. IComputationNode *const ResultNode;
  30. const TComputationExternalNodePtrVector ArgNodes;
  31. };
  32. class TCodegenValue : public TComputationValue<TCodegenValue> {
  33. public:
  34. using TBase = TComputationValue<TCodegenValue>;
  35. using TRunPtr = NUdf::TUnboxedValuePod (*)(TComputationContext*, const NUdf::TUnboxedValuePod*);
  36. TCodegenValue(TMemoryUsageInfo* memInfo, TRunPtr run, TComputationContext* ctx)
  37. : TBase(memInfo)
  38. , RunFunc(run)
  39. , Ctx(ctx)
  40. {}
  41. private:
  42. NUdf::TUnboxedValue Run(const NUdf::IValueBuilder*, const NUdf::TUnboxedValuePod* args) const override {
  43. return RunFunc(Ctx, args);
  44. }
  45. const TRunPtr RunFunc;
  46. TComputationContext* const Ctx;
  47. };
  48. public:
  49. TCallableWrapper(TComputationMutables& mutables, IComputationNode* resultNode, TComputationExternalNodePtrVector&& argNodes)
  50. : TBaseComputation(mutables)
  51. , ResultNode(resultNode)
  52. , ArgNodes(std::move(argNodes))
  53. {
  54. }
  55. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  56. #ifndef MKQL_DISABLE_CODEGEN
  57. if (ctx.ExecuteLLVM && Run)
  58. return ctx.HolderFactory.Create<TCodegenValue>(Run, &ctx);
  59. #endif
  60. return ctx.HolderFactory.Create<TValue>(ctx, ResultNode, ArgNodes);
  61. }
  62. private:
  63. void RegisterDependencies() const final {
  64. for (const auto& arg : ArgNodes) {
  65. Own(arg);
  66. }
  67. DependsOn(ResultNode);
  68. }
  69. #ifndef MKQL_DISABLE_CODEGEN
  70. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  71. RunFunc = GenerateRun(codegen);
  72. codegen.ExportSymbol(RunFunc);
  73. }
  74. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  75. if (RunFunc)
  76. Run = reinterpret_cast<TRunPtr>(codegen.GetPointerToFunction(RunFunc));
  77. }
  78. Function* GenerateRun(NYql::NCodegen::ICodegen& codegen) const {
  79. auto& module = codegen.GetModule();
  80. auto& context = codegen.GetContext();
  81. const auto& name = TBaseComputation::MakeName("Run");
  82. if (const auto f = module.getFunction(name.c_str()))
  83. return f;
  84. const auto valueType = Type::getInt128Ty(context);
  85. const auto argsType = ArrayType::get(valueType, ArgNodes.size());
  86. const auto contextType = GetCompContextType(context);
  87. const auto funcType =
  88. FunctionType::get(valueType, {PointerType::getUnqual(contextType), PointerType::getUnqual(argsType)}, false);
  89. TCodegenContext ctx(codegen);
  90. ctx.Func = cast<Function>(module.getOrInsertFunction(name.c_str(), funcType).getCallee());
  91. DISubprogramAnnotator annotator(ctx, ctx.Func);
  92. auto args = ctx.Func->arg_begin();
  93. ctx.Ctx = &*args;
  94. const auto argsPtr = &*++args;
  95. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  96. auto block = main;
  97. const auto arguments = new LoadInst(argsType, argsPtr, "arguments", block);
  98. unsigned i = 0U;
  99. for (const auto node : ArgNodes) {
  100. const auto arg = ExtractValueInst::Create(arguments, {i++}, "arg", block);
  101. const auto codegenArgNode = dynamic_cast<ICodegeneratorExternalNode*>(node);
  102. MKQL_ENSURE(codegenArgNode, "Argument must be codegenerator node.");
  103. codegenArgNode->CreateSetValue(ctx, block, arg);
  104. }
  105. const auto result = GetNodeValue(ResultNode, ctx, block);
  106. ReturnInst::Create(context, result, block);
  107. return ctx.Func;
  108. }
  109. using TRunPtr = TCodegenValue::TRunPtr;
  110. Function* RunFunc = nullptr;
  111. TRunPtr Run = nullptr;
  112. #endif
  113. IComputationNode *const ResultNode;
  114. const TComputationExternalNodePtrVector ArgNodes;
  115. };
  116. }
  117. IComputationNode* WrapCallable(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  118. MKQL_ENSURE(callable.GetInputsCount() > 0U, "Expected at least one argument");
  119. const auto argsCount = callable.GetInputsCount() - 1U;
  120. const auto resultNode = LocateNode(ctx.NodeLocator, callable, argsCount);
  121. TComputationExternalNodePtrVector argNodes(argsCount);
  122. for (ui32 i = 0U; i < argsCount; ++i) {
  123. const auto listItem = AS_CALLABLE("Arg", callable.GetInput(i));
  124. MKQL_ENSURE(listItem->GetType()->GetName() == "Arg", "Wrong Callable arguments");
  125. MKQL_ENSURE(listItem->GetInputsCount() == 0, "Wrong Callable arguments");
  126. MKQL_ENSURE(listItem->GetType()->IsMergeDisabled(), "Merge mode is not disabled");
  127. argNodes[i] = LocateExternalNode(ctx.NodeLocator, callable, i);
  128. }
  129. return new TCallableWrapper(ctx.Mutables, resultNode, std::move(argNodes));
  130. }
  131. }
  132. }