mkql_callable.cpp 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. #include "mkql_callable.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  4. #include <yql/essentials/minikql/mkql_node_cast.h>
  5. namespace NKikimr {
  6. namespace NMiniKQL {
  7. namespace {
  8. class TCallableWrapper : public TCustomValueCodegeneratorNode<TCallableWrapper> {
  9. typedef TCustomValueCodegeneratorNode<TCallableWrapper> TBaseComputation;
  10. private:
  11. class TValue : public TComputationValue<TValue> {
  12. public:
  13. TValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, IComputationNode* resultNode,
  14. const TComputationExternalNodePtrVector& argNodes)
  15. : TComputationValue(memInfo)
  16. , CompCtx(compCtx)
  17. , ResultNode(resultNode)
  18. , ArgNodes(argNodes)
  19. {}
  20. private:
  21. NUdf::TUnboxedValue Run(const NUdf::IValueBuilder*, const NUdf::TUnboxedValuePod* args) const override
  22. {
  23. for (const auto node : ArgNodes) {
  24. node->SetValue(CompCtx, NUdf::TUnboxedValuePod(*args++));
  25. }
  26. return ResultNode->GetValue(CompCtx);
  27. }
  28. TComputationContext& CompCtx;
  29. IComputationNode *const ResultNode;
  30. const TComputationExternalNodePtrVector ArgNodes;
  31. };
  32. class TCodegenValue : public TComputationValue<TCodegenValue> {
  33. public:
  34. using TBase = TComputationValue<TCodegenValue>;
  35. using TRunPtr = NUdf::TUnboxedValuePod (*)(TComputationContext*, const NUdf::TUnboxedValuePod*);
  36. TCodegenValue(TMemoryUsageInfo* memInfo, TRunPtr run, TComputationContext* ctx)
  37. : TBase(memInfo)
  38. , RunFunc(run)
  39. , Ctx(ctx)
  40. {}
  41. private:
  42. NUdf::TUnboxedValue Run(const NUdf::IValueBuilder*, const NUdf::TUnboxedValuePod* args) const override {
  43. return RunFunc(Ctx, args);
  44. }
  45. const TRunPtr RunFunc;
  46. TComputationContext* const Ctx;
  47. };
  48. public:
  49. TCallableWrapper(TComputationMutables& mutables, IComputationNode* resultNode, TComputationExternalNodePtrVector&& argNodes)
  50. : TBaseComputation(mutables)
  51. , ResultNode(resultNode)
  52. , ArgNodes(std::move(argNodes))
  53. {
  54. }
  55. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  56. #ifndef MKQL_DISABLE_CODEGEN
  57. if (ctx.ExecuteLLVM && Run)
  58. return ctx.HolderFactory.Create<TCodegenValue>(Run, &ctx);
  59. #endif
  60. return ctx.HolderFactory.Create<TValue>(ctx, ResultNode, ArgNodes);
  61. }
  62. private:
  63. void RegisterDependencies() const final {
  64. for (const auto& arg : ArgNodes) {
  65. Own(arg);
  66. }
  67. DependsOn(ResultNode);
  68. }
  69. #ifndef MKQL_DISABLE_CODEGEN
  70. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  71. RunFunc = GenerateRun(codegen);
  72. codegen.ExportSymbol(RunFunc);
  73. }
  74. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  75. if (RunFunc)
  76. Run = reinterpret_cast<TRunPtr>(codegen.GetPointerToFunction(RunFunc));
  77. }
  78. Function* GenerateRun(NYql::NCodegen::ICodegen& codegen) const {
  79. auto& module = codegen.GetModule();
  80. auto& context = codegen.GetContext();
  81. const auto& name = TBaseComputation::MakeName("Run");
  82. if (const auto f = module.getFunction(name.c_str()))
  83. return f;
  84. const auto valueType = Type::getInt128Ty(context);
  85. const auto argsType = ArrayType::get(valueType, ArgNodes.size());
  86. const auto contextType = GetCompContextType(context);
  87. const auto funcType = codegen.GetEffectiveTarget() != NYql::NCodegen::ETarget::Windows ?
  88. FunctionType::get(valueType, {PointerType::getUnqual(contextType), PointerType::getUnqual(argsType)}, false):
  89. FunctionType::get(Type::getVoidTy(context), {PointerType::getUnqual(valueType), PointerType::getUnqual(contextType), PointerType::getUnqual(argsType)}, false);
  90. TCodegenContext ctx(codegen);
  91. ctx.Func = cast<Function>(module.getOrInsertFunction(name.c_str(), funcType).getCallee());
  92. DISubprogramAnnotator annotator(ctx, ctx.Func);
  93. auto args = ctx.Func->arg_begin();
  94. const auto resultArg = codegen.GetEffectiveTarget() == NYql::NCodegen::ETarget::Windows ? &*args++ : nullptr;
  95. if (resultArg) {
  96. resultArg->addAttr(Attribute::StructRet);
  97. resultArg->addAttr(Attribute::NoAlias);
  98. }
  99. ctx.Ctx = &*args;
  100. const auto argsPtr = &*++args;
  101. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  102. auto block = main;
  103. const auto arguments = new LoadInst(argsType, argsPtr, "arguments", block);
  104. unsigned i = 0U;
  105. for (const auto node : ArgNodes) {
  106. const auto arg = ExtractValueInst::Create(arguments, {i++}, "arg", block);
  107. const auto codegenArgNode = dynamic_cast<ICodegeneratorExternalNode*>(node);
  108. MKQL_ENSURE(codegenArgNode, "Argument must be codegenerator node.");
  109. codegenArgNode->CreateSetValue(ctx, block, arg);
  110. }
  111. const auto result = GetNodeValue(ResultNode, ctx, block);
  112. if (resultArg) {
  113. new StoreInst(result, resultArg, block);
  114. ReturnInst::Create(context, block);
  115. } else {
  116. ReturnInst::Create(context, result, block);
  117. }
  118. return ctx.Func;
  119. }
  120. using TRunPtr = TCodegenValue::TRunPtr;
  121. Function* RunFunc = nullptr;
  122. TRunPtr Run = nullptr;
  123. #endif
  124. IComputationNode *const ResultNode;
  125. const TComputationExternalNodePtrVector ArgNodes;
  126. };
  127. }
  128. IComputationNode* WrapCallable(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  129. MKQL_ENSURE(callable.GetInputsCount() > 0U, "Expected at least one argument");
  130. const auto argsCount = callable.GetInputsCount() - 1U;
  131. const auto resultNode = LocateNode(ctx.NodeLocator, callable, argsCount);
  132. TComputationExternalNodePtrVector argNodes(argsCount);
  133. for (ui32 i = 0U; i < argsCount; ++i) {
  134. const auto listItem = AS_CALLABLE("Arg", callable.GetInput(i));
  135. MKQL_ENSURE(listItem->GetType()->GetName() == "Arg", "Wrong Callable arguments");
  136. MKQL_ENSURE(listItem->GetInputsCount() == 0, "Wrong Callable arguments");
  137. MKQL_ENSURE(listItem->GetType()->IsMergeDisabled(), "Merge mode is not disabled");
  138. argNodes[i] = LocateExternalNode(ctx.NodeLocator, callable, i);
  139. }
  140. return new TCallableWrapper(ctx.Mutables, resultNode, std::move(argNodes));
  141. }
  142. }
  143. }