mkql_invoke.cpp 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. #include "mkql_invoke.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  4. #include <yql/essentials/minikql/mkql_node_cast.h>
  5. #include <yql/essentials/minikql/mkql_node_builder.h>
  6. #include <yql/essentials/minikql/invoke_builtins/mkql_builtins.h>
  7. namespace NKikimr {
  8. namespace NMiniKQL {
  9. namespace {
  10. template<bool IsOptional>
  11. class TUnaryArgInvokeBase {
  12. protected:
  13. TUnaryArgInvokeBase(TStringBuf name, const TFunctionDescriptor& descr)
  14. : Name(name), Descriptor(descr)
  15. {}
  16. NUdf::TUnboxedValuePod DoCalc(const NUdf::TUnboxedValuePod& arg) const {
  17. if (IsOptional && !arg) {
  18. return {};
  19. }
  20. return Descriptor.Function(&arg);
  21. }
  22. #ifndef MKQL_DISABLE_CODEGEN
  23. Value* DoGenGetValue(const TCodegenContext& ctx, Value* arg, BasicBlock*& block) const {
  24. if (IsOptional) {
  25. auto& context = ctx.Codegen.GetContext();
  26. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  27. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  28. const auto result = PHINode::Create(arg->getType(), 2U, "result", done);
  29. result->addIncoming(arg, block);
  30. BranchInst::Create(good, done, IsExists(arg, block, context), block);
  31. block = good;
  32. const auto out = reinterpret_cast<TGeneratorPtr>(Descriptor.Generator)(&arg, ctx, block);
  33. result->addIncoming(out, block);
  34. BranchInst::Create(done, block);
  35. block = done;
  36. return result;
  37. } else {
  38. return reinterpret_cast<TGeneratorPtr>(Descriptor.Generator)(&arg, ctx, block);
  39. }
  40. }
  41. #endif
  42. const TStringBuf Name;
  43. const TFunctionDescriptor Descriptor;
  44. };
  45. template<bool IsOptional>
  46. class TSimpleUnaryArgInvokeWrapper : public TDecoratorCodegeneratorNode<TSimpleUnaryArgInvokeWrapper<IsOptional>>, private TUnaryArgInvokeBase<IsOptional> {
  47. typedef TDecoratorCodegeneratorNode<TSimpleUnaryArgInvokeWrapper<IsOptional>> TBaseComputation;
  48. public:
  49. TSimpleUnaryArgInvokeWrapper(TStringBuf name, const TFunctionDescriptor& descr, IComputationNode* arg)
  50. : TBaseComputation(arg), TUnaryArgInvokeBase<IsOptional>(name, descr)
  51. {}
  52. NUdf::TUnboxedValuePod DoCalculate(TComputationContext&, const NUdf::TUnboxedValuePod& arg) const {
  53. return this->DoCalc(arg);
  54. }
  55. #ifndef MKQL_DISABLE_CODEGEN
  56. Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* arg, BasicBlock*& block) const {
  57. return this->DoGenGetValue(ctx, arg, block);
  58. }
  59. #endif
  60. private:
  61. TString DebugString() const final {
  62. return TBaseComputation::DebugString() + "(" + this->Name + ")" ;
  63. }
  64. };
  65. template<bool IsOptional>
  66. class TDefaultUnaryArgInvokeWrapper : public TMutableCodegeneratorNode<TDefaultUnaryArgInvokeWrapper<IsOptional>>, private TUnaryArgInvokeBase<IsOptional> {
  67. typedef TMutableCodegeneratorNode<TDefaultUnaryArgInvokeWrapper<IsOptional>> TBaseComputation;
  68. public:
  69. TDefaultUnaryArgInvokeWrapper(TComputationMutables& mutables, EValueRepresentation kind, TStringBuf name, const TFunctionDescriptor& descr, IComputationNode* arg)
  70. : TBaseComputation(mutables, kind), TUnaryArgInvokeBase<IsOptional>(name, descr), Arg(arg)
  71. {}
  72. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  73. return this->DoCalc(Arg->GetValue(ctx));
  74. }
  75. #ifndef MKQL_DISABLE_CODEGEN
  76. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  77. const auto arg = GetNodeValue(Arg, ctx, block);
  78. return this->DoGenGetValue(ctx, arg, block);
  79. }
  80. #endif
  81. private:
  82. void RegisterDependencies() const final {
  83. this->DependsOn(Arg);
  84. }
  85. TString DebugString() const final {
  86. return TBaseComputation::DebugString() + "(" + this->Name + ")" ;
  87. }
  88. IComputationNode *const Arg;
  89. };
  90. class TBinaryInvokeWrapper : public TBinaryCodegeneratorNode<TBinaryInvokeWrapper> {
  91. typedef TBinaryCodegeneratorNode<TBinaryInvokeWrapper> TBaseComputation;
  92. public:
  93. TBinaryInvokeWrapper(TStringBuf name, const TFunctionDescriptor& descr, IComputationNode* left, IComputationNode* right, EValueRepresentation kind = EValueRepresentation::Embedded)
  94. : TBaseComputation(left, right, kind), Name(name), Descriptor(descr)
  95. {
  96. }
  97. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const {
  98. const std::array<NUdf::TUnboxedValue, 2U> args {{Left->GetValue(compCtx), Right->GetValue(compCtx)}};
  99. return Descriptor.Function(args.data());
  100. }
  101. #ifndef MKQL_DISABLE_CODEGEN
  102. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  103. const std::array<Value*, 2U> args {{GetNodeValue(Left, ctx, block), GetNodeValue(Right, ctx, block)}};
  104. return reinterpret_cast<TGeneratorPtr>(Descriptor.Generator)(args.data(), ctx, block);
  105. }
  106. #endif
  107. private:
  108. TString DebugString() const final {
  109. return TBaseComputation::DebugString() + "(" + Name + ")" ;
  110. }
  111. const TStringBuf Name;
  112. const TFunctionDescriptor Descriptor;
  113. };
  114. template<size_t Size>
  115. class TInvokeWrapper : public TMutableCodegeneratorNode<TInvokeWrapper<Size>> {
  116. typedef TMutableCodegeneratorNode<TInvokeWrapper<Size>> TBaseComputation;
  117. public:
  118. TInvokeWrapper(TComputationMutables& mutables, EValueRepresentation kind, TStringBuf name, const TFunctionDescriptor& descr, TComputationNodePtrVector&& argNodes)
  119. : TBaseComputation(mutables, kind)
  120. , Name(name), Descriptor(descr)
  121. , ArgNodes(std::move(argNodes))
  122. {
  123. }
  124. NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const {
  125. std::array<NUdf::TUnboxedValue, Size> values;
  126. std::transform(ArgNodes.cbegin(), ArgNodes.cend(), values.begin(),
  127. std::bind(&IComputationNode::GetValue, std::placeholders::_1, std::ref(ctx))
  128. );
  129. return Descriptor.Function(values.data());
  130. }
  131. #ifndef MKQL_DISABLE_CODEGEN
  132. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  133. std::array<Value*, Size> values;
  134. std::transform(ArgNodes.cbegin(), ArgNodes.cend(), values.begin(),
  135. [&](IComputationNode* node) { return GetNodeValue(node, ctx, block); }
  136. );
  137. return reinterpret_cast<TGeneratorPtr>(Descriptor.Generator)(values.data(), ctx, block);
  138. }
  139. #endif
  140. private:
  141. void RegisterDependencies() const final {
  142. std::for_each(ArgNodes.cbegin(), ArgNodes.cend(), std::bind(&TInvokeWrapper::DependsOn, this, std::placeholders::_1));
  143. }
  144. TString DebugString() const final {
  145. return TBaseComputation::DebugString() + "(" + Name + ")" ;
  146. }
  147. const TStringBuf Name;
  148. const TFunctionDescriptor Descriptor;
  149. const TComputationNodePtrVector ArgNodes;
  150. };
  151. }
  152. IComputationNode* WrapInvoke(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  153. MKQL_ENSURE(callable.GetInputsCount() >= 2U && callable.GetInputsCount() <= 4U, "Expected from one to three arguments.");
  154. const auto returnType = callable.GetType()->GetReturnType();
  155. const auto inputsCount = callable.GetInputsCount();
  156. std::array<TArgType, 4U> argsTypes;
  157. TComputationNodePtrVector argNodes;
  158. argNodes.reserve(inputsCount - 1U);
  159. argsTypes.front().first = UnpackOptionalData(returnType, argsTypes.front().second)->GetSchemeType();
  160. for (ui32 i = 1U; i < inputsCount; ++i) {
  161. argsTypes[i].first = UnpackOptionalData(callable.GetInput(i), argsTypes[i].second)->GetSchemeType();
  162. argNodes.emplace_back(LocateNode(ctx.NodeLocator, callable, i));
  163. }
  164. const auto funcName = AS_VALUE(TDataLiteral, callable.GetInput(0))->AsValue().AsStringRef();
  165. const auto funcDesc = ctx.FunctionRegistry.GetBuiltins()->GetBuiltin(funcName, argsTypes.data(), inputsCount);
  166. const auto returnKind = GetValueRepresentation(returnType);
  167. switch (argNodes.size()) {
  168. case 1U:
  169. if (EValueRepresentation::Embedded == returnKind) {
  170. return new TSimpleUnaryArgInvokeWrapper<false>(funcName, funcDesc, argNodes.front());
  171. } else {
  172. return new TDefaultUnaryArgInvokeWrapper<false>(ctx.Mutables, returnKind, funcName, funcDesc, argNodes.front());
  173. }
  174. case 2U:
  175. if (EValueRepresentation::Embedded == returnKind) {
  176. return new TBinaryInvokeWrapper(funcName, funcDesc, argNodes.front(), argNodes.back());
  177. }
  178. return new TInvokeWrapper<2U>(ctx.Mutables, returnKind, funcName, funcDesc, std::move(argNodes));
  179. case 3U:
  180. return new TInvokeWrapper<3U>(ctx.Mutables, returnKind, funcName, funcDesc, std::move(argNodes));
  181. default:
  182. Y_ABORT("Too wide invoke.");
  183. }
  184. }
  185. }
  186. }