#include "mkql_invoke.h" #include #include // Y_IGNORE #include #include #include namespace NKikimr { namespace NMiniKQL { namespace { template class TUnaryArgInvokeBase { protected: TUnaryArgInvokeBase(TStringBuf name, const TFunctionDescriptor& descr) : Name(name), Descriptor(descr) {} NUdf::TUnboxedValuePod DoCalc(const NUdf::TUnboxedValuePod& arg) const { if (IsOptional && !arg) { return {}; } return Descriptor.Function(&arg); } #ifndef MKQL_DISABLE_CODEGEN Value* DoGenGetValue(const TCodegenContext& ctx, Value* arg, BasicBlock*& block) const { if (IsOptional) { auto& context = ctx.Codegen.GetContext(); const auto good = BasicBlock::Create(context, "good", ctx.Func); const auto done = BasicBlock::Create(context, "done", ctx.Func); const auto result = PHINode::Create(arg->getType(), 2U, "result", done); result->addIncoming(arg, block); BranchInst::Create(good, done, IsExists(arg, block, context), block); block = good; const auto out = reinterpret_cast(Descriptor.Generator)(&arg, ctx, block); result->addIncoming(out, block); BranchInst::Create(done, block); block = done; return result; } else { return reinterpret_cast(Descriptor.Generator)(&arg, ctx, block); } } #endif const TStringBuf Name; const TFunctionDescriptor Descriptor; }; template class TSimpleUnaryArgInvokeWrapper : public TDecoratorCodegeneratorNode>, private TUnaryArgInvokeBase { typedef TDecoratorCodegeneratorNode> TBaseComputation; public: TSimpleUnaryArgInvokeWrapper(TStringBuf name, const TFunctionDescriptor& descr, IComputationNode* arg) : TBaseComputation(arg), TUnaryArgInvokeBase(name, descr) {} NUdf::TUnboxedValuePod DoCalculate(TComputationContext&, const NUdf::TUnboxedValuePod& arg) const { return this->DoCalc(arg); } #ifndef MKQL_DISABLE_CODEGEN Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* arg, BasicBlock*& block) const { return this->DoGenGetValue(ctx, arg, block); } #endif private: TString DebugString() const final { return TBaseComputation::DebugString() + "(" + this->Name + ")" ; } }; template class TDefaultUnaryArgInvokeWrapper : public TMutableCodegeneratorNode>, private TUnaryArgInvokeBase { typedef TMutableCodegeneratorNode> TBaseComputation; public: TDefaultUnaryArgInvokeWrapper(TComputationMutables& mutables, EValueRepresentation kind, TStringBuf name, const TFunctionDescriptor& descr, IComputationNode* arg) : TBaseComputation(mutables, kind), TUnaryArgInvokeBase(name, descr), Arg(arg) {} NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const { return this->DoCalc(Arg->GetValue(ctx)); } #ifndef MKQL_DISABLE_CODEGEN Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const { const auto arg = GetNodeValue(Arg, ctx, block); return this->DoGenGetValue(ctx, arg, block); } #endif private: void RegisterDependencies() const final { this->DependsOn(Arg); } TString DebugString() const final { return TBaseComputation::DebugString() + "(" + this->Name + ")" ; } IComputationNode *const Arg; }; class TBinaryInvokeWrapper : public TBinaryCodegeneratorNode { typedef TBinaryCodegeneratorNode TBaseComputation; public: TBinaryInvokeWrapper(TStringBuf name, const TFunctionDescriptor& descr, IComputationNode* left, IComputationNode* right, EValueRepresentation kind = EValueRepresentation::Embedded) : TBaseComputation(left, right, kind), Name(name), Descriptor(descr) { } NUdf::TUnboxedValuePod DoCalculate(TComputationContext& compCtx) const { const std::array args {{Left->GetValue(compCtx), Right->GetValue(compCtx)}}; return Descriptor.Function(args.data()); } #ifndef MKQL_DISABLE_CODEGEN Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const { const std::array args {{GetNodeValue(Left, ctx, block), GetNodeValue(Right, ctx, block)}}; return reinterpret_cast(Descriptor.Generator)(args.data(), ctx, block); } #endif private: TString DebugString() const final { return TBaseComputation::DebugString() + "(" + Name + ")" ; } const TStringBuf Name; const TFunctionDescriptor Descriptor; }; template class TInvokeWrapper : public TMutableCodegeneratorNode> { typedef TMutableCodegeneratorNode> TBaseComputation; public: TInvokeWrapper(TComputationMutables& mutables, EValueRepresentation kind, TStringBuf name, const TFunctionDescriptor& descr, TComputationNodePtrVector&& argNodes) : TBaseComputation(mutables, kind) , Name(name), Descriptor(descr) , ArgNodes(std::move(argNodes)) { } NUdf::TUnboxedValue DoCalculate(TComputationContext& ctx) const { std::array values; std::transform(ArgNodes.cbegin(), ArgNodes.cend(), values.begin(), std::bind(&IComputationNode::GetValue, std::placeholders::_1, std::ref(ctx)) ); return Descriptor.Function(values.data()); } #ifndef MKQL_DISABLE_CODEGEN Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const { std::array values; std::transform(ArgNodes.cbegin(), ArgNodes.cend(), values.begin(), [&](IComputationNode* node) { return GetNodeValue(node, ctx, block); } ); return reinterpret_cast(Descriptor.Generator)(values.data(), ctx, block); } #endif private: void RegisterDependencies() const final { std::for_each(ArgNodes.cbegin(), ArgNodes.cend(), std::bind(&TInvokeWrapper::DependsOn, this, std::placeholders::_1)); } TString DebugString() const final { return TBaseComputation::DebugString() + "(" + Name + ")" ; } const TStringBuf Name; const TFunctionDescriptor Descriptor; const TComputationNodePtrVector ArgNodes; }; } IComputationNode* WrapInvoke(TCallable& callable, const TComputationNodeFactoryContext& ctx) { MKQL_ENSURE(callable.GetInputsCount() >= 2U && callable.GetInputsCount() <= 4U, "Expected from one to three arguments."); const auto returnType = callable.GetType()->GetReturnType(); const auto inputsCount = callable.GetInputsCount(); std::array argsTypes; TComputationNodePtrVector argNodes; argNodes.reserve(inputsCount - 1U); argsTypes.front().first = UnpackOptionalData(returnType, argsTypes.front().second)->GetSchemeType(); for (ui32 i = 1U; i < inputsCount; ++i) { argsTypes[i].first = UnpackOptionalData(callable.GetInput(i), argsTypes[i].second)->GetSchemeType(); argNodes.emplace_back(LocateNode(ctx.NodeLocator, callable, i)); } const auto funcName = AS_VALUE(TDataLiteral, callable.GetInput(0))->AsValue().AsStringRef(); const auto funcDesc = ctx.FunctionRegistry.GetBuiltins()->GetBuiltin(funcName, argsTypes.data(), inputsCount); const auto returnKind = GetValueRepresentation(returnType); switch (argNodes.size()) { case 1U: if (EValueRepresentation::Embedded == returnKind) { return new TSimpleUnaryArgInvokeWrapper(funcName, funcDesc, argNodes.front()); } else { return new TDefaultUnaryArgInvokeWrapper(ctx.Mutables, returnKind, funcName, funcDesc, argNodes.front()); } case 2U: if (EValueRepresentation::Embedded == returnKind) { return new TBinaryInvokeWrapper(funcName, funcDesc, argNodes.front(), argNodes.back()); } return new TInvokeWrapper<2U>(ctx.Mutables, returnKind, funcName, funcDesc, std::move(argNodes)); case 3U: return new TInvokeWrapper<3U>(ctx.Mutables, returnKind, funcName, funcDesc, std::move(argNodes)); default: Y_ABORT("Too wide invoke."); } } } }