Browse Source

YQL-15891 LLVM for BlockCombineAll.

a-romanov 1 year ago
parent
commit
bb268f4050
1 changed files with 230 additions and 71 deletions
  1. 230 71
      ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp

+ 230 - 71
ydb/library/yql/minikql/comp_nodes/mkql_block_agg.cpp

@@ -456,18 +456,20 @@ size_t CalcMaxBlockLenForOutput(TType* out) {
 }
 
 
-class TBlockCombineAllWrapper : public TStatefulWideFlowComputationNode<TBlockCombineAllWrapper> {
+class TBlockCombineAllWrapper : public TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapper> {
+using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapper>;
 public:
     TBlockCombineAllWrapper(TComputationMutables& mutables,
         IComputationWideFlowNode* flow,
         std::optional<ui32> filterColumn,
         size_t width,
         std::vector<TAggParams<IBlockAggregatorCombineAll>>&& aggsParams)
-        : TStatefulWideFlowComputationNode(mutables, flow, EValueRepresentation::Any)
+        : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
         , Flow_(flow)
         , FilterColumn_(filterColumn)
         , Width_(width)
         , AggsParams_(std::move(aggsParams))
+        , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(width))
     {
         MKQL_ENSURE(Width_ > 0, "Missing block length column");
     }
@@ -477,88 +479,160 @@ public:
         NUdf::TUnboxedValue*const* output) const
     {
         auto& s = GetState(state, ctx);
-        if (s.IsFinished_) {
+        if (s.IsFinished_)
             return EFetchResult::Finish;
-        }
 
-        for (;;) {
-            auto result = Flow_->FetchValues(ctx, s.ValuePointers_.data());
-            if (result == EFetchResult::Yield) {
-                return result;
-            } else if (result == EFetchResult::One) {
-                ui64 batchLength = GetBatchLength(s.Values_.data());
-                if (!batchLength) {
+        for (const auto fields = ctx.WideFields.data() + WideFieldsIndex_;;) {
+            switch (Flow_->FetchValues(ctx, fields)) {
+                case EFetchResult::Yield:
+                    return EFetchResult::Yield;
+                case EFetchResult::One:
+                    s.ProcessInput();
                     continue;
+                case EFetchResult::Finish:
+                    break;
+            }
+            if (s.MakeOutput()) {
+                for (size_t i = 0; i < AggsParams_.size(); ++i) {
+                    if (const auto out = output[i]) {
+                        *out = s.Pull(i);
+                    }
                 }
+                return EFetchResult::One;
+            }
+            return EFetchResult::Finish;
+        }
+    }
+#ifndef MKQL_DISABLE_CODEGEN
+    ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
+        auto& context = ctx.Codegen.GetContext();
 
-                std::optional<ui64> filtered;
-                if (FilterColumn_) {
-                    auto filterDatum = TArrowBlock::From(s.Values_[*FilterColumn_]).GetDatum();
-                    if (filterDatum.is_scalar()) {
-                        if (!filterDatum.scalar_as<arrow::UInt8Scalar>().value) {
-                            continue;
-                        }
-                    } else {
-                        ui64 popCount = GetBitmapPopCount(filterDatum.array());
-                        if (popCount == 0) {
-                            continue;
-                        }
+        const auto valueType = Type::getInt128Ty(context);
+        const auto ptrValueType = PointerType::getUnqual(valueType);
+        const auto statusType = Type::getInt32Ty(context);
+        const auto indexType = Type::getInt64Ty(context);
+        const auto flagType = Type::getInt1Ty(context);
+        const auto arrayType = ArrayType::get(valueType, Width_);
+        const auto ptrValuesType = PointerType::getUnqual(arrayType);
 
-                        if (popCount < batchLength) {
-                            filtered = popCount;
-                        }
-                    }
-                }
+        TLLVMFieldsStructureState stateFields(context, Width_);
+        const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
+        const auto statePtrType = PointerType::getUnqual(stateType);
 
-                s.HasValues_ = true;
-                char* ptr = s.AggStates_.data();
-                for (size_t i = 0; i < s.Aggs_.size(); ++i) {
-                    if (output[i]) {
-                        s.Aggs_[i]->AddMany(ptr, s.Values_.data(), batchLength, filtered);
-                    }
+        const auto atTop = &ctx.Func->getEntryBlock().back();
 
-                    ptr += s.Aggs_[i]->StateSize;
-                }
-            } else {
-                s.IsFinished_ = true;
-                if (!s.HasValues_) {
-                    return EFetchResult::Finish;
-                }
+        const auto pullFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Pull));
+        const auto pullType = FunctionType::get(valueType, {statePtrType, indexType}, false);
+        const auto pullPtr = CastInst::Create(Instruction::IntToPtr, pullFunc, PointerType::getUnqual(pullType), "pull", atTop);
 
-                char* ptr = s.AggStates_.data();
-                for (size_t i = 0; i < s.Aggs_.size(); ++i) {
-                    if (auto* out = output[i]; out != nullptr) {
-                        *out = s.Aggs_[i]->FinishOne(ptr);
-                        s.Aggs_[i]->DestroyState(ptr);
-                    }
+        const auto stateOnStack = new AllocaInst(statePtrType, 0U, "state_on_stack", atTop);
+        new StoreInst(ConstantPointerNull::get(statePtrType), stateOnStack, atTop);
 
-                    ptr += s.Aggs_[i]->StateSize;
-                }
+        const auto make = BasicBlock::Create(context, "make", ctx.Func);
+        const auto main = BasicBlock::Create(context, "main", ctx.Func);
+        const auto read = BasicBlock::Create(context, "read", ctx.Func);
+        const auto good = BasicBlock::Create(context, "good", ctx.Func);
+        const auto work = BasicBlock::Create(context, "work", ctx.Func);
+        const auto over = BasicBlock::Create(context, "over", ctx.Func);
 
-                return EFetchResult::One;
-            }
+        BranchInst::Create(main, make, HasValue(statePtr, block), block);
+        block = make;
+
+        const auto ptrType = PointerType::getUnqual(StructType::get(context));
+        const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
+        const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TBlockCombineAllWrapper::MakeState));
+        const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), statePtr->getType(), ctx.Ctx->getType()}, false);
+        const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
+        CallInst::Create(makeType, makeFuncPtr, {self, statePtr, ctx.Ctx}, "", block);
+
+        BranchInst::Create(main, block);
+
+        block = main;
+
+        const auto state = new LoadInst(valueType, statePtr, "state", block);
+        const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
+        const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
+
+        const auto finishedPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetIsFinished() }, "is_finished_ptr", block);
+        const auto finished = new LoadInst(flagType, finishedPtr, "finished", block);
+
+        const auto result = PHINode::Create(statusType, 3U, "result", over);
+        result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block);
+
+        BranchInst::Create(over, read, finished, block);
+
+        block = read;
+
+        const auto valuesPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetPointer() }, "values_ptr", block);
+        const auto values = new LoadInst(ptrValuesType, valuesPtr, "values", block);
+        SafeUnRefUnboxed(values, ctx, block);
+
+        const auto getres = GetNodeValues(Flow_, ctx, block);
+        result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block);
+
+        const auto way = SwitchInst::Create(getres.first, good, 2U, block);
+        way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Finish)), work);
+        way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Yield)), over);
+
+        block = good;
+
+        Value* array = UndefValue::get(arrayType);
+        for (auto idx = 0U; idx < getres.second.size(); ++idx) {
+            const auto value = getres.second[idx](ctx, block);
+            AddRefBoxed(value, ctx, block);
+            array = InsertValueInst::Create(array, value, {idx}, (TString("value_") += ToString(idx)).c_str(), block);
         }
+        new StoreInst(array, values, block);
 
-        return EFetchResult::Finish;
-    }
+        const auto processBlockFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::ProcessInput));
+        const auto processBlockType = FunctionType::get(Type::getVoidTy(context), {statePtrType, ctx.GetFactory()->getType()}, false);
+        const auto processBlockPtr = CastInst::Create(Instruction::IntToPtr, processBlockFunc, PointerType::getUnqual(processBlockType), "process_inputs_func", block);
+        CallInst::Create(processBlockType, processBlockPtr, {stateArg, ctx.GetFactory()}, "", block);
+
+        BranchInst::Create(read, block);
+
+        block = work;
+
+        const auto makeOutputFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::MakeOutput));
+        const auto makeOutputType = FunctionType::get(flagType, {statePtrType, ctx.GetFactory()->getType()}, false);
+        const auto makeOutputPtr = CastInst::Create(Instruction::IntToPtr, makeOutputFunc, PointerType::getUnqual(makeOutputType), "make_output_func", block);
+        const auto hasData = CallInst::Create(makeOutputType, makeOutputPtr, {stateArg, ctx.GetFactory()}, "make_output", block);
+        const auto output = SelectInst::Create(hasData, ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), "output", block);
+        new StoreInst(stateArg, stateOnStack, block);
+
+        result->addIncoming(output, block);
+        BranchInst::Create(over, block);
+
+        block = over;
 
+        ICodegeneratorInlineWideNode::TGettersList getters(AggsParams_.size());
+        for (size_t idx = 0U; idx < getters.size(); ++idx) {
+            getters[idx] = [idx, pullType, pullPtr, indexType, statePtrType, stateOnStack](const TCodegenContext& ctx, BasicBlock*& block) {
+                const auto stateArg = new LoadInst(statePtrType, stateOnStack, "state", block);
+                return CallInst::Create(pullType, pullPtr, {stateArg, ConstantInt::get(indexType, idx)}, "pull", block);
+            };
+        }
+        return {result, std::move(getters)};
+    }
+#endif
 private:
     struct TState : public TComputationValue<TState> {
-        std::vector<NUdf::TUnboxedValue> Values_;
-        std::vector<NUdf::TUnboxedValue*> ValuePointers_;
-        std::vector<std::unique_ptr<IBlockAggregatorCombineAll>> Aggs_;
+        NUdf::TUnboxedValue* Pointer_ = nullptr;
         bool IsFinished_ = false;
         bool HasValues_ = false;
+        TUnboxedValueVector Values_;
+        std::vector<std::unique_ptr<IBlockAggregatorCombineAll>> Aggs_;
         std::vector<char> AggStates_;
+        const std::optional<ui32> FilterColumn_;
+        const size_t Width_;
 
-        TState(TMemoryUsageInfo* memInfo, size_t width, std::optional<ui32>, const std::vector<TAggParams<IBlockAggregatorCombineAll>>& params, TComputationContext& ctx)
+        TState(TMemoryUsageInfo* memInfo, size_t width, std::optional<ui32> filterColumn, const std::vector<TAggParams<IBlockAggregatorCombineAll>>& params, TComputationContext& ctx)
             : TComputationValue(memInfo)
-            , Values_(width)
-            , ValuePointers_(width)
+            , Values_(std::max(width, params.size()))
+            , FilterColumn_(filterColumn)
+            , Width_(width)
         {
-            for (size_t i = 0; i < width; ++i) {
-                ValuePointers_[i] = &Values_[i];
-            }
+            Pointer_ = Values_.data();
 
             ui32 totalStateSize = 0;
             for (const auto& p : params) {
@@ -574,29 +648,114 @@ private:
                 ptr += agg->StateSize;
             }
         }
+
+        void ProcessInput() {
+            const ui64 batchLength = TArrowBlock::From(Values_[Width_ - 1U]).GetDatum().scalar_as<arrow::UInt64Scalar>().value;
+            if (!batchLength) {
+                return;
+            }
+
+            std::optional<ui64> filtered;
+            if (FilterColumn_) {
+                const auto filterDatum = TArrowBlock::From(Values_[*FilterColumn_]).GetDatum();
+                if (filterDatum.is_scalar()) {
+                    if (!filterDatum.scalar_as<arrow::UInt8Scalar>().value) {
+                        return;
+                    }
+                } else {
+                    const ui64 popCount = GetBitmapPopCount(filterDatum.array());
+                    if (popCount == 0) {
+                        return;
+                    }
+
+                    if (popCount < batchLength) {
+                        filtered = popCount;
+                    }
+                }
+            }
+
+            HasValues_ = true;
+            char* ptr = AggStates_.data();
+            for (size_t i = 0; i < Aggs_.size(); ++i) {
+                Aggs_[i]->AddMany(ptr, Values_.data(), batchLength, filtered);
+                ptr += Aggs_[i]->StateSize;
+            }
+        }
+
+        bool MakeOutput() {
+            IsFinished_ = true;
+            if (!HasValues_)
+                return false;
+
+            char* ptr = AggStates_.data();
+            for (size_t i = 0; i < Aggs_.size(); ++i) {
+                Values_[i] = Aggs_[i]->FinishOne(ptr);
+                Aggs_[i]->DestroyState(ptr);
+                ptr += Aggs_[i]->StateSize;
+            }
+            return true;
+        }
+
+        NUdf::TUnboxedValuePod Pull(size_t index) {
+            return Values_[index].Release();
+        }
     };
+#ifndef MKQL_DISABLE_CODEGEN
+    class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TBlockState>> {
+    private:
+        using TBase = TLLVMFieldsStructure<TComputationValue<TBlockState>>;
+        llvm::PointerType*const PointerType;
+        llvm::IntegerType*const IsFinishedType;
+    public:
+        std::vector<llvm::Type*> GetFieldsArray() {
+            std::vector<llvm::Type*> result = TBase::GetFields();
+            result.emplace_back(PointerType);
+            result.emplace_back(IsFinishedType);
+            return result;
+        }
 
-private:
+        llvm::Constant* GetPointer() {
+            return llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), TBase::GetFieldsCount() + 0);
+        }
+
+        llvm::Constant* GetIsFinished() {
+            return llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), TBase::GetFieldsCount() + 1);
+        }
+
+        TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width)
+            : TBase(context)
+            , PointerType(llvm::PointerType::getUnqual(llvm::ArrayType::get(llvm::Type::getInt128Ty(Context), width)))
+            , IsFinishedType(llvm::Type::getInt1Ty(Context))
+        {}
+    };
+#endif
     void RegisterDependencies() const final {
         FlowDependsOn(Flow_);
     }
 
+    void MakeState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
+        state = ctx.HolderFactory.Create<TState>(Width_, FilterColumn_, AggsParams_, ctx);
+    }
+
     TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
         if (!state.HasValue()) {
-            state = ctx.HolderFactory.Create<TState>(Width_, FilterColumn_, AggsParams_, ctx);
+            MakeState(state, ctx);
+
+            auto& s = *static_cast<TState*>(state.AsBoxed().Get());
+            const auto fields = ctx.WideFields.data() + WideFieldsIndex_;
+            for (size_t i = 0; i < s.Values_.size(); ++i) {
+                fields[i] = &s.Values_[i];
+            }
+            return s;
         }
         return *static_cast<TState*>(state.AsBoxed().Get());
     }
-
-    ui64 GetBatchLength(const NUdf::TUnboxedValue* columns) const {
-        return TArrowBlock::From(columns[Width_ - 1]).GetDatum().scalar_as<arrow::UInt64Scalar>().value;
-    }
-
 private:
-    IComputationWideFlowNode* Flow_;
-    std::optional<ui32> FilterColumn_;
+    IComputationWideFlowNode *const Flow_;
+    const std::optional<ui32> FilterColumn_;
     const size_t Width_;
     const std::vector<TAggParams<IBlockAggregatorCombineAll>> AggsParams_;
+    const size_t WideFieldsIndex_;
 };
 
 template <typename T>