123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308 |
- #include "mkql_wide_chain_map.h"
- #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
- #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
- #include <yql/essentials/minikql/computation/mkql_custom_list.h>
- #include <yql/essentials/minikql/mkql_node_cast.h>
- #include <yql/essentials/utils/cast.h>
- namespace NKikimr {
- namespace NMiniKQL {
- using NYql::EnsureDynamicCast;
- namespace {
- class TWideChain1MapWrapper : public TStatefulWideFlowCodegeneratorNode<TWideChain1MapWrapper> {
- using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideChain1MapWrapper>;
- public:
- TWideChain1MapWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow,
- TComputationExternalNodePtrVector&& inputs,
- TComputationNodePtrVector&& initItems,
- TComputationExternalNodePtrVector&& outputs,
- TComputationNodePtrVector&& updateItems)
- : TBaseComputation(mutables, flow, EValueRepresentation::Embedded)
- , Flow(flow)
- , Inputs(std::move(inputs))
- , InitItems(std::move(initItems))
- , Outputs(std::move(outputs))
- , UpdateItems(std::move(updateItems))
- , InputsOnInit(GetPasstroughtMapOneToOne(Inputs, InitItems))
- , InputsOnUpdate(GetPasstroughtMapOneToOne(Inputs, UpdateItems))
- , InitOnInputs(GetPasstroughtMapOneToOne(InitItems, Inputs))
- , UpdateOnInputs(GetPasstroughtMapOneToOne(UpdateItems, Inputs))
- , OutputsOnUpdate(GetPasstroughtMapOneToOne(Outputs, UpdateItems))
- , UpdateOnOutputs(GetPasstroughtMapOneToOne(UpdateItems, Outputs))
- , WideFieldsIndex(mutables.IncrementWideFieldsIndex(Inputs.size()))
- , TempStateIndex(std::exchange(mutables.CurValueIndex, mutables.CurValueIndex + Outputs.size()))
- {}
- EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
- if (state.IsInvalid()) {
- state = NUdf::TUnboxedValuePod();
- return CalculateFirst(ctx, output);
- }
- return CalculateOther(ctx, output);
- }
- #ifndef MKQL_DISABLE_CODEGEN
- ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
- auto& context = ctx.Codegen.GetContext();
- const auto flagType = Type::getInt1Ty(context);
- const auto flagPtr = new AllocaInst(flagType, 0U, "flag_ptr", &ctx.Func->getEntryBlock().back());
- const auto good = BasicBlock::Create(context, "good", ctx.Func);
- const auto done = BasicBlock::Create(context, "done", ctx.Func);
- const auto getres = GetNodeValues(Flow, ctx, block);
- const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, getres.first, ConstantInt::get(getres.first->getType(), 0), "special", block);
- BranchInst::Create(done, good, special, block);
- block = good;
- for (auto i = 0U; i < Inputs.size(); ++i)
- if (Inputs[i]->GetDependencesCount() > 0U || !InputsOnInit[i] || !InputsOnUpdate[i])
- EnsureDynamicCast<ICodegeneratorExternalNode*>(Inputs[i])->CreateSetValue(ctx, block, getres.second[i](ctx, block));
- const auto init = BasicBlock::Create(context, "init", ctx.Func);
- const auto next = BasicBlock::Create(context, "next", ctx.Func);
- const auto flag = IsInvalid(statePtr, block, context);
- new StoreInst(flag, flagPtr, block);
- BranchInst::Create(init, next, flag, block);
- block = init;
- for (auto i = 0U; i < Outputs.size(); ++i) {
- if (Outputs[i]->GetDependencesCount() > 0U || OutputsOnUpdate[i]) {
- const auto& map = InitOnInputs[i];
- const auto value = map ? getres.second[*map](ctx, block) : GetNodeValue(InitItems[i], ctx, block);
- EnsureDynamicCast<ICodegeneratorExternalNode*>(Outputs[i])->CreateSetValue(ctx, block, value);
- }
- }
- new StoreInst(GetEmpty(context), statePtr, block);
- BranchInst::Create(done, block);
- block = next;
- std::vector<Value*> outputs(Outputs.size(), nullptr);
- for (auto i = 0U; i < outputs.size(); ++i) {
- if (const auto& dep = OutputsOnUpdate[i]; Outputs[i]->GetDependencesCount() > 0U || (dep && *dep != i)) {
- const auto& map = UpdateOnInputs[i];
- outputs[i] = map ? getres.second[*map](ctx, block) : GetNodeValue(UpdateItems[i], ctx, block);
- }
- }
- for (auto i = 0U; i < outputs.size(); ++i)
- if (const auto out = outputs[i])
- EnsureDynamicCast<ICodegeneratorExternalNode*>(Outputs[i])->CreateSetValue(ctx, block, out);
- BranchInst::Create(done, block);
- block = done;
- ICodegeneratorInlineWideNode::TGettersList result;
- result.reserve(Outputs.size());
- for (auto i = 0U; i < Outputs.size(); ++i) {
- if (const auto& one = InitOnInputs[i], two = UpdateOnInputs[i]; one && two && *one == *two)
- result.emplace_back(getres.second[*two]);
- else if (Outputs[i]->GetDependencesCount() > 0 || OutputsOnUpdate[i])
- result.emplace_back([output = Outputs[i]] (const TCodegenContext& ctx, BasicBlock*& block) { return GetNodeValue(output, ctx, block); });
- else
- result.emplace_back([this, i, source = getres.second, flagPtr, flagType] (const TCodegenContext& ctx, BasicBlock*& block) {
- auto& context = ctx.Codegen.GetContext();
- const auto init = BasicBlock::Create(context, "init", ctx.Func);
- const auto next = BasicBlock::Create(context, "next", ctx.Func);
- const auto done = BasicBlock::Create(context, "done", ctx.Func);
- const auto result = PHINode::Create(Type::getInt128Ty(context), 2U, "result", done);
- const auto flag = new LoadInst(flagType, flagPtr, "flag", block);
- BranchInst::Create(init, next, flag, block);
- block = init;
- if (const auto& map = InitOnInputs[i])
- result->addIncoming(source[*map](ctx, block), block);
- else
- result->addIncoming(GetNodeValue(InitItems[i], ctx, block), block);
- BranchInst::Create(done, block);
- block = next;
- if (const auto& map = UpdateOnInputs[i])
- result->addIncoming(source[*map](ctx, block), block);
- else
- result->addIncoming(GetNodeValue(UpdateItems[i], ctx, block), block);
- BranchInst::Create(done, block);
- block = done;
- return result;
- });
- };
- return {getres.first, std::move(result)};
- }
- #endif
- private:
- EFetchResult CalculateFirst(TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
- auto** fields = ctx.WideFields.data() + WideFieldsIndex;
- for (auto i = 0U; i < Inputs.size(); ++i) {
- if (const auto& map = InputsOnInit[i]; map && !Inputs[i]->GetDependencesCount()) {
- if (const auto& to = UpdateOnOutputs[*map]) {
- fields[i] = &Outputs[*to]->RefValue(ctx);
- continue;
- } else if (const auto out = output[*map]) {
- fields[i] = out;
- continue;
- }
- } else {
- fields[i] = &Inputs[i]->RefValue(ctx);
- continue;
- }
- fields[i] = nullptr;
- }
- if (const auto result = Flow->FetchValues(ctx, fields); EFetchResult::One != result)
- return result;
- for (auto i = 0U; i < Outputs.size(); ++i) {
- if (Outputs[i]->GetDependencesCount() > 0U || OutputsOnUpdate[i]) {
- if (const auto& map = InitOnInputs[i]; !map || Inputs[*map]->GetDependencesCount() > 0U) {
- Outputs[i]->SetValue(ctx, InitItems[i]->GetValue(ctx));
- }
- }
- }
- for (auto i = 0U; i < Outputs.size(); ++i) {
- if (const auto out = output[i]) {
- if (Outputs[i]->GetDependencesCount() > 0U || OutputsOnUpdate[i])
- *out = Outputs[i]->GetValue(ctx);
- else {
- if (const auto& map = InitOnInputs[i]) {
- if (const auto from = *map; !Inputs[from]->GetDependencesCount()) {
- if (const auto first = *InputsOnInit[from]; first != i)
- *out = *output[first];
- continue;
- }
- }
- *out = InitItems[i]->GetValue(ctx);
- }
- }
- }
- return EFetchResult::One;
- }
- EFetchResult CalculateOther(TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
- auto** fields = ctx.WideFields.data() + WideFieldsIndex;
- for (auto i = 0U; i < Inputs.size(); ++i) {
- if (const auto& map = InputsOnUpdate[i]; map && !Inputs[i]->GetDependencesCount()) {
- if (const auto out = output[*map]) {
- fields[i] = out;
- continue;
- }
- } else {
- fields[i] = &Inputs[i]->RefValue(ctx);
- continue;
- }
- fields[i] = nullptr;
- }
- if (const auto result = Flow->FetchValues(ctx, fields); EFetchResult::One != result)
- return result;
- for (auto i = 0U; i < Outputs.size(); ++i) {
- if (Outputs[i]->GetDependencesCount() > 0U || OutputsOnUpdate[i]) {
- if (const auto& map = UpdateOnInputs[i]; !map || Inputs[*map]->GetDependencesCount() > 0U) {
- ctx.MutableValues[TempStateIndex + i] = UpdateItems[i]->GetValue(ctx);
- }
- }
- }
- for (auto i = 0U; i < Outputs.size(); ++i) {
- if (Outputs[i]->GetDependencesCount() > 0U || OutputsOnUpdate[i]) {
- if (const auto& map = UpdateOnInputs[i]; !map || Inputs[*map]->GetDependencesCount() > 0U) {
- Outputs[i]->SetValue(ctx, std::move(ctx.MutableValues[TempStateIndex + i]));
- }
- }
- }
- for (auto i = 0U; i < Outputs.size(); ++i) {
- if (const auto out = output[i]) {
- if (Outputs[i]->GetDependencesCount() > 0U || OutputsOnUpdate[i])
- *out = Outputs[i]->GetValue(ctx);
- else {
- if (const auto& map = UpdateOnInputs[i]) {
- if (const auto from = *map; !Inputs[from]->GetDependencesCount()) {
- if (const auto first = *InputsOnUpdate[from]; first != i)
- *out = *output[first];
- continue;
- }
- }
- *out = UpdateItems[i]->GetValue(ctx);
- }
- }
- }
- return EFetchResult::One;
- }
- void RegisterDependencies() const final {
- if (const auto flow = FlowDependsOn(Flow)) {
- std::for_each(Inputs.cbegin(), Inputs.cend(), std::bind(&TWideChain1MapWrapper::Own, flow, std::placeholders::_1));
- std::for_each(Outputs.cbegin(), Outputs.cend(), std::bind(&TWideChain1MapWrapper::Own, flow, std::placeholders::_1));
- std::for_each(InitItems.cbegin(), InitItems.cend(), std::bind(&TWideChain1MapWrapper::DependsOn, flow, std::placeholders::_1));
- std::for_each(UpdateItems.cbegin(), UpdateItems.cend(), std::bind(&TWideChain1MapWrapper::DependsOn, flow, std::placeholders::_1));
- }
- }
- IComputationWideFlowNode* const Flow;
- const TComputationExternalNodePtrVector Inputs;
- const TComputationNodePtrVector InitItems;
- const TComputationExternalNodePtrVector Outputs;
- const TComputationNodePtrVector UpdateItems;
- const TPasstroughtMap InputsOnInit, InputsOnUpdate, InitOnInputs, UpdateOnInputs, OutputsOnUpdate, UpdateOnOutputs;
- const ui32 WideFieldsIndex;
- const ui32 TempStateIndex;
- };
- }
- IComputationNode* WrapWideChain1Map(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
- MKQL_ENSURE(callable.GetInputsCount() > 0U, "Expected argument.");
- const auto inputWidth = GetWideComponentsCount(AS_TYPE(TFlowType, callable.GetInput(0U).GetStaticType()));
- const auto outputWidth = GetWideComponentsCount(AS_TYPE(TFlowType, callable.GetType()->GetReturnType()));
- MKQL_ENSURE(callable.GetInputsCount() == inputWidth + outputWidth * 3U + 1U, "Wrong signature.");
- const auto flow = LocateNode(ctx.NodeLocator, callable, 0U);
- if (const auto wide = dynamic_cast<IComputationWideFlowNode*>(flow)) {
- TComputationNodePtrVector initOutput(outputWidth, nullptr), updateOutput(outputWidth, nullptr);
- auto index = inputWidth;
- std::generate(initOutput.begin(), initOutput.end(), [&](){ return LocateNode(ctx.NodeLocator, callable, ++index); });
- index += outputWidth;
- std::generate(updateOutput.begin(), updateOutput.end(), [&](){ return LocateNode(ctx.NodeLocator, callable, ++index); });
- TComputationExternalNodePtrVector inputs(inputWidth, nullptr), outputs(outputWidth, nullptr);
- index = 0U;
- std::generate(inputs.begin(), inputs.end(), [&](){ return LocateExternalNode(ctx.NodeLocator, callable, ++index); });
- index += outputWidth;
- std::generate(outputs.begin(), outputs.end(), [&](){ return LocateExternalNode(ctx.NodeLocator, callable, ++index); });
- return new TWideChain1MapWrapper(ctx.Mutables, wide, std::move(inputs), std::move(initOutput), std::move(outputs), std::move(updateOutput));
- }
- THROW yexception() << "Expected wide flow.";
- }
- }
- }
|