123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314 |
- #include "mkql_if.h"
- #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
- #include <yql/essentials/minikql/mkql_node_cast.h>
- #include <yql/essentials/minikql/mkql_node_builder.h>
- namespace NKikimr {
- namespace NMiniKQL {
- namespace {
- template<bool IsOptional>
- class TIfWrapper : public TMutableCodegeneratorNode<TIfWrapper<IsOptional>> {
- using TBaseComputation = TMutableCodegeneratorNode<TIfWrapper<IsOptional>>;
- public:
- TIfWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* predicate, IComputationNode* thenBranch, IComputationNode* elseBranch)
- : TBaseComputation(mutables, kind)
- , Predicate(predicate)
- , ThenBranch(thenBranch)
- , ElseBranch(elseBranch)
- {}
- NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
- const auto& predicate = Predicate->GetValue(ctx);
- if (IsOptional && !predicate) {
- return NUdf::TUnboxedValuePod();
- }
- return (predicate.Get<bool>() ? ThenBranch : ElseBranch)->GetValue(ctx).Release();
- }
- #ifndef MKQL_DISABLE_CODEGEN
- Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
- auto& context = ctx.Codegen.GetContext();
- const auto then = BasicBlock::Create(context, "then", ctx.Func);
- const auto elsb = BasicBlock::Create(context, "else", ctx.Func);
- const auto done = BasicBlock::Create(context, "done", ctx.Func);
- const auto value = GetNodeValue(Predicate, ctx, block);
- const auto result = PHINode::Create(value->getType(), IsOptional ? 3U : 2U, "result", done);
- if (IsOptional) {
- result->addIncoming(value, block);
- const auto good = BasicBlock::Create(context, "good", ctx.Func);
- BranchInst::Create(done, good, IsEmpty(value, block, context), block);
- block = good;
- }
- const auto cast = CastInst::Create(Instruction::Trunc, value, Type::getInt1Ty(context), "bool", block);
- BranchInst::Create(then, elsb, cast, block);
- {
- block = then;
- const auto left = GetNodeValue(ThenBranch, ctx, block);
- result->addIncoming(left, block);
- BranchInst::Create(done, block);
- }
- {
- block = elsb;
- const auto right = GetNodeValue(ElseBranch, ctx, block);
- result->addIncoming(right, block);
- BranchInst::Create(done, block);
- }
- block = done;
- return result;
- }
- #endif
- private:
- void RegisterDependencies() const final {
- this->DependsOn(Predicate);
- this->DependsOn(ThenBranch);
- this->DependsOn(ElseBranch);
- }
- IComputationNode* const Predicate;
- IComputationNode* const ThenBranch;
- IComputationNode* const ElseBranch;
- };
- template<bool IsOptional>
- class TFlowIfWrapper : public TStatefulFlowCodegeneratorNode<TFlowIfWrapper<IsOptional>> {
- using TBaseComputation = TStatefulFlowCodegeneratorNode<TFlowIfWrapper<IsOptional>>;
- public:
- TFlowIfWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* predicate, IComputationNode* thenBranch, IComputationNode* elseBranch)
- : TBaseComputation(mutables, nullptr, kind)
- , Predicate(predicate)
- , ThenBranch(thenBranch)
- , ElseBranch(elseBranch)
- {}
- NUdf::TUnboxedValuePod DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
- if (state.IsInvalid()) {
- state = Predicate->GetValue(ctx);
- }
- if (IsOptional && !state) {
- return NUdf::TUnboxedValuePod::MakeFinish();
- }
- return (state.Get<bool>() ? ThenBranch : ElseBranch)->GetValue(ctx).Release();
- }
- #ifndef MKQL_DISABLE_CODEGEN
- Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
- auto& context = ctx.Codegen.GetContext();
- const auto init = BasicBlock::Create(context, "init", ctx.Func);
- const auto test = BasicBlock::Create(context, "test", ctx.Func);
- const auto then = BasicBlock::Create(context, "then", ctx.Func);
- const auto elsb = BasicBlock::Create(context, "else", ctx.Func);
- const auto done = BasicBlock::Create(context, "done", ctx.Func);
- BranchInst::Create(init, test, IsInvalid(statePtr, block, context), block);
- block = init;
- GetNodeValue(statePtr, Predicate, ctx, block);
- BranchInst::Create(test, block);
- block = test;
- const auto valueType = Type::getInt128Ty(context);
- const auto state = new LoadInst(valueType, statePtr, "state", block);
- const auto result = PHINode::Create(valueType, IsOptional ? 3U : 2U, "result", done);
- if (IsOptional) {
- result->addIncoming(state, block);
- const auto good = BasicBlock::Create(context, "good", ctx.Func);
- BranchInst::Create(done, good, IsEmpty(state, block, context), block);
- block = good;
- }
- const auto cast = CastInst::Create(Instruction::Trunc, state, Type::getInt1Ty(context), "bool", block);
- BranchInst::Create(then, elsb, cast, block);
- {
- block = then;
- const auto left = GetNodeValue(ThenBranch, ctx, block);
- result->addIncoming(left, block);
- BranchInst::Create(done, block);
- }
- {
- block = elsb;
- const auto right = GetNodeValue(ElseBranch, ctx, block);
- result->addIncoming(right, block);
- BranchInst::Create(done, block);
- }
- block = done;
- return result;
- }
- #endif
- private:
- void RegisterDependencies() const final {
- if (const auto flow = this->FlowDependsOnBoth(ThenBranch, ElseBranch))
- this->DependsOn(flow, Predicate);
- }
- IComputationNode* const Predicate;
- IComputationNode* const ThenBranch;
- IComputationNode* const ElseBranch;
- };
- class TWideIfWrapper : public TStatefulWideFlowCodegeneratorNode<TWideIfWrapper> {
- using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideIfWrapper>;
- public:
- TWideIfWrapper(TComputationMutables& mutables, IComputationNode* predicate, IComputationWideFlowNode* thenBranch, IComputationWideFlowNode* elseBranch)
- : TBaseComputation(mutables, nullptr, EValueRepresentation::Embedded)
- , Predicate(predicate)
- , ThenBranch(thenBranch)
- , ElseBranch(elseBranch)
- {}
- EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
- if (state.IsInvalid()) {
- state = Predicate->GetValue(ctx);
- }
- return (state.Get<bool>() ? ThenBranch : ElseBranch)->FetchValues(ctx, output);
- }
- #ifndef MKQL_DISABLE_CODEGEN
- TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
- auto& context = ctx.Codegen.GetContext();
- const auto init = BasicBlock::Create(context, "init", ctx.Func);
- const auto test = BasicBlock::Create(context, "test", ctx.Func);
- const auto then = BasicBlock::Create(context, "then", ctx.Func);
- const auto elsb = BasicBlock::Create(context, "else", ctx.Func);
- const auto done = BasicBlock::Create(context, "done", ctx.Func);
- BranchInst::Create(init, test, IsInvalid(statePtr, block, context), block);
- block = init;
- GetNodeValue(statePtr, Predicate, ctx, block);
- BranchInst::Create(test, block);
- block = test;
- const auto valueType = Type::getInt128Ty(context);
- const auto state = new LoadInst(valueType, statePtr, "state", block);
- const auto result = PHINode::Create(Type::getInt32Ty(context), 2, "result", done);
- const auto cast = CastInst::Create(Instruction::Trunc, state, Type::getInt1Ty(context), "bool", block);
- BranchInst::Create(then, elsb, cast, block);
- block = then;
- const auto left = GetNodeValues(ThenBranch, ctx, block);
- result->addIncoming(left.first, block);
- BranchInst::Create(done, block);
- block = elsb;
- const auto right = GetNodeValues(ElseBranch, ctx, block);
- result->addIncoming(right.first, block);
- BranchInst::Create(done, block);
- block = done;
- MKQL_ENSURE(left.second.size() == right.second.size(), "Expected same width of flows.");
- TGettersList getters;
- getters.reserve(left.second.size());
- const auto index = static_cast<const IComputationNode*>(this)->GetIndex();
- size_t idx = 0U;
- std::generate_n(std::back_inserter(getters), right.second.size(), [&]() {
- const auto i = idx++;
- return [index, lget = left.second[i], rget = right.second[i]](const TCodegenContext& ctx, BasicBlock*& block) {
- auto& context = ctx.Codegen.GetContext();
- const auto then = BasicBlock::Create(context, "then", ctx.Func);
- const auto elsb = BasicBlock::Create(context, "elsb", ctx.Func);
- const auto done = BasicBlock::Create(context, "done", ctx.Func);
- const auto valueType = Type::getInt128Ty(context);
- const auto result = PHINode::Create(valueType, 2, "result", done);
- const auto statePtr = GetElementPtrInst::CreateInBounds(valueType, ctx.GetMutables(), {ConstantInt::get(Type::getInt32Ty(context), index)}, "state_ptr", block);
- const auto state = new LoadInst(valueType, statePtr, "state", block);
- const auto trunc = CastInst::Create(Instruction::Trunc, state, Type::getInt1Ty(context), "trunc", block);
- BranchInst::Create(then, elsb, trunc, block);
- block = then;
- result->addIncoming(lget(ctx, block), block);
- BranchInst::Create(done, block);
- block = elsb;
- result->addIncoming(rget(ctx, block), block);
- BranchInst::Create(done, block);
- block = done;
- return result;
- };
- });
- return {result, std::move(getters)};
- }
- #endif
- private:
- void RegisterDependencies() const final {
- if (const auto flow = FlowDependsOnBoth(ThenBranch, ElseBranch))
- DependsOn(flow, Predicate);
- }
- IComputationNode* const Predicate;
- IComputationWideFlowNode* const ThenBranch;
- IComputationWideFlowNode* const ElseBranch;
- };
- }
- IComputationNode* WrapIf(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
- MKQL_ENSURE(callable.GetInputsCount() == 3, "Expected 3 args");
- bool isOptional;
- const auto predicateType = UnpackOptionalData(callable.GetInput(0), isOptional);
- MKQL_ENSURE(predicateType->GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected bool or optional of bool.");
- const auto predicate = LocateNode(ctx.NodeLocator, callable, 0);
- const auto thenBranch = LocateNode(ctx.NodeLocator, callable, 1);
- const auto elseBranch = LocateNode(ctx.NodeLocator, callable, 2);
- const auto type = callable.GetType()->GetReturnType();
- if (type->IsFlow()) {
- const auto thenWide = dynamic_cast<IComputationWideFlowNode*>(thenBranch);
- const auto elseWide = dynamic_cast<IComputationWideFlowNode*>(elseBranch);
- if (thenWide && elseWide && !isOptional)
- return new TWideIfWrapper(ctx.Mutables, predicate, thenWide, elseWide);
- else if (!thenWide && !elseWide) {
- if (isOptional)
- return new TFlowIfWrapper<true>(ctx.Mutables, GetValueRepresentation(type), predicate, thenBranch, elseBranch);
- else
- return new TFlowIfWrapper<false>(ctx.Mutables, GetValueRepresentation(type), predicate, thenBranch, elseBranch);
- }
- } else {
- if (isOptional) {
- return new TIfWrapper<true>(ctx.Mutables, GetValueRepresentation(type), predicate, thenBranch, elseBranch);
- } else {
- return new TIfWrapper<false>(ctx.Mutables, GetValueRepresentation(type), predicate, thenBranch, elseBranch);
- }
- }
- THROW yexception() << "Wrong signature.";
- }
- }
- }
|