|
@@ -456,18 +456,20 @@ size_t CalcMaxBlockLenForOutput(TType* out) {
|
|
|
}
|
|
|
|
|
|
|
|
|
-class TBlockCombineAllWrapper : public TStatefulWideFlowComputationNode<TBlockCombineAllWrapper> {
|
|
|
+class TBlockCombineAllWrapper : public TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapper> {
|
|
|
+using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TBlockCombineAllWrapper>;
|
|
|
public:
|
|
|
TBlockCombineAllWrapper(TComputationMutables& mutables,
|
|
|
IComputationWideFlowNode* flow,
|
|
|
std::optional<ui32> filterColumn,
|
|
|
size_t width,
|
|
|
std::vector<TAggParams<IBlockAggregatorCombineAll>>&& aggsParams)
|
|
|
- : TStatefulWideFlowComputationNode(mutables, flow, EValueRepresentation::Any)
|
|
|
+ : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
|
|
|
, Flow_(flow)
|
|
|
, FilterColumn_(filterColumn)
|
|
|
, Width_(width)
|
|
|
, AggsParams_(std::move(aggsParams))
|
|
|
+ , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(width))
|
|
|
{
|
|
|
MKQL_ENSURE(Width_ > 0, "Missing block length column");
|
|
|
}
|
|
@@ -477,88 +479,160 @@ public:
|
|
|
NUdf::TUnboxedValue*const* output) const
|
|
|
{
|
|
|
auto& s = GetState(state, ctx);
|
|
|
- if (s.IsFinished_) {
|
|
|
+ if (s.IsFinished_)
|
|
|
return EFetchResult::Finish;
|
|
|
- }
|
|
|
|
|
|
- for (;;) {
|
|
|
- auto result = Flow_->FetchValues(ctx, s.ValuePointers_.data());
|
|
|
- if (result == EFetchResult::Yield) {
|
|
|
- return result;
|
|
|
- } else if (result == EFetchResult::One) {
|
|
|
- ui64 batchLength = GetBatchLength(s.Values_.data());
|
|
|
- if (!batchLength) {
|
|
|
+ for (const auto fields = ctx.WideFields.data() + WideFieldsIndex_;;) {
|
|
|
+ switch (Flow_->FetchValues(ctx, fields)) {
|
|
|
+ case EFetchResult::Yield:
|
|
|
+ return EFetchResult::Yield;
|
|
|
+ case EFetchResult::One:
|
|
|
+ s.ProcessInput();
|
|
|
continue;
|
|
|
+ case EFetchResult::Finish:
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ if (s.MakeOutput()) {
|
|
|
+ for (size_t i = 0; i < AggsParams_.size(); ++i) {
|
|
|
+ if (const auto out = output[i]) {
|
|
|
+ *out = s.Pull(i);
|
|
|
+ }
|
|
|
}
|
|
|
+ return EFetchResult::One;
|
|
|
+ }
|
|
|
+ return EFetchResult::Finish;
|
|
|
+ }
|
|
|
+ }
|
|
|
+#ifndef MKQL_DISABLE_CODEGEN
|
|
|
+ ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
|
|
|
+ auto& context = ctx.Codegen.GetContext();
|
|
|
|
|
|
- std::optional<ui64> filtered;
|
|
|
- if (FilterColumn_) {
|
|
|
- auto filterDatum = TArrowBlock::From(s.Values_[*FilterColumn_]).GetDatum();
|
|
|
- if (filterDatum.is_scalar()) {
|
|
|
- if (!filterDatum.scalar_as<arrow::UInt8Scalar>().value) {
|
|
|
- continue;
|
|
|
- }
|
|
|
- } else {
|
|
|
- ui64 popCount = GetBitmapPopCount(filterDatum.array());
|
|
|
- if (popCount == 0) {
|
|
|
- continue;
|
|
|
- }
|
|
|
+ const auto valueType = Type::getInt128Ty(context);
|
|
|
+ const auto ptrValueType = PointerType::getUnqual(valueType);
|
|
|
+ const auto statusType = Type::getInt32Ty(context);
|
|
|
+ const auto indexType = Type::getInt64Ty(context);
|
|
|
+ const auto flagType = Type::getInt1Ty(context);
|
|
|
+ const auto arrayType = ArrayType::get(valueType, Width_);
|
|
|
+ const auto ptrValuesType = PointerType::getUnqual(arrayType);
|
|
|
|
|
|
- if (popCount < batchLength) {
|
|
|
- filtered = popCount;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
+ TLLVMFieldsStructureState stateFields(context, Width_);
|
|
|
+ const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
|
|
|
+ const auto statePtrType = PointerType::getUnqual(stateType);
|
|
|
|
|
|
- s.HasValues_ = true;
|
|
|
- char* ptr = s.AggStates_.data();
|
|
|
- for (size_t i = 0; i < s.Aggs_.size(); ++i) {
|
|
|
- if (output[i]) {
|
|
|
- s.Aggs_[i]->AddMany(ptr, s.Values_.data(), batchLength, filtered);
|
|
|
- }
|
|
|
+ const auto atTop = &ctx.Func->getEntryBlock().back();
|
|
|
|
|
|
- ptr += s.Aggs_[i]->StateSize;
|
|
|
- }
|
|
|
- } else {
|
|
|
- s.IsFinished_ = true;
|
|
|
- if (!s.HasValues_) {
|
|
|
- return EFetchResult::Finish;
|
|
|
- }
|
|
|
+ const auto pullFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Pull));
|
|
|
+ const auto pullType = FunctionType::get(valueType, {statePtrType, indexType}, false);
|
|
|
+ const auto pullPtr = CastInst::Create(Instruction::IntToPtr, pullFunc, PointerType::getUnqual(pullType), "pull", atTop);
|
|
|
|
|
|
- char* ptr = s.AggStates_.data();
|
|
|
- for (size_t i = 0; i < s.Aggs_.size(); ++i) {
|
|
|
- if (auto* out = output[i]; out != nullptr) {
|
|
|
- *out = s.Aggs_[i]->FinishOne(ptr);
|
|
|
- s.Aggs_[i]->DestroyState(ptr);
|
|
|
- }
|
|
|
+ const auto stateOnStack = new AllocaInst(statePtrType, 0U, "state_on_stack", atTop);
|
|
|
+ new StoreInst(ConstantPointerNull::get(statePtrType), stateOnStack, atTop);
|
|
|
|
|
|
- ptr += s.Aggs_[i]->StateSize;
|
|
|
- }
|
|
|
+ const auto make = BasicBlock::Create(context, "make", ctx.Func);
|
|
|
+ const auto main = BasicBlock::Create(context, "main", ctx.Func);
|
|
|
+ const auto read = BasicBlock::Create(context, "read", ctx.Func);
|
|
|
+ const auto good = BasicBlock::Create(context, "good", ctx.Func);
|
|
|
+ const auto work = BasicBlock::Create(context, "work", ctx.Func);
|
|
|
+ const auto over = BasicBlock::Create(context, "over", ctx.Func);
|
|
|
|
|
|
- return EFetchResult::One;
|
|
|
- }
|
|
|
+ BranchInst::Create(main, make, HasValue(statePtr, block), block);
|
|
|
+ block = make;
|
|
|
+
|
|
|
+ const auto ptrType = PointerType::getUnqual(StructType::get(context));
|
|
|
+ const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
|
|
|
+ const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TBlockCombineAllWrapper::MakeState));
|
|
|
+ const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), statePtr->getType(), ctx.Ctx->getType()}, false);
|
|
|
+ const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
|
|
|
+ CallInst::Create(makeType, makeFuncPtr, {self, statePtr, ctx.Ctx}, "", block);
|
|
|
+
|
|
|
+ BranchInst::Create(main, block);
|
|
|
+
|
|
|
+ block = main;
|
|
|
+
|
|
|
+ const auto state = new LoadInst(valueType, statePtr, "state", block);
|
|
|
+ const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
|
|
|
+ const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
|
|
|
+
|
|
|
+ const auto finishedPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetIsFinished() }, "is_finished_ptr", block);
|
|
|
+ const auto finished = new LoadInst(flagType, finishedPtr, "finished", block);
|
|
|
+
|
|
|
+ const auto result = PHINode::Create(statusType, 3U, "result", over);
|
|
|
+ result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block);
|
|
|
+
|
|
|
+ BranchInst::Create(over, read, finished, block);
|
|
|
+
|
|
|
+ block = read;
|
|
|
+
|
|
|
+ const auto valuesPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetPointer() }, "values_ptr", block);
|
|
|
+ const auto values = new LoadInst(ptrValuesType, valuesPtr, "values", block);
|
|
|
+ SafeUnRefUnboxed(values, ctx, block);
|
|
|
+
|
|
|
+ const auto getres = GetNodeValues(Flow_, ctx, block);
|
|
|
+ result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block);
|
|
|
+
|
|
|
+ const auto way = SwitchInst::Create(getres.first, good, 2U, block);
|
|
|
+ way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Finish)), work);
|
|
|
+ way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Yield)), over);
|
|
|
+
|
|
|
+ block = good;
|
|
|
+
|
|
|
+ Value* array = UndefValue::get(arrayType);
|
|
|
+ for (auto idx = 0U; idx < getres.second.size(); ++idx) {
|
|
|
+ const auto value = getres.second[idx](ctx, block);
|
|
|
+ AddRefBoxed(value, ctx, block);
|
|
|
+ array = InsertValueInst::Create(array, value, {idx}, (TString("value_") += ToString(idx)).c_str(), block);
|
|
|
}
|
|
|
+ new StoreInst(array, values, block);
|
|
|
|
|
|
- return EFetchResult::Finish;
|
|
|
- }
|
|
|
+ const auto processBlockFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::ProcessInput));
|
|
|
+ const auto processBlockType = FunctionType::get(Type::getVoidTy(context), {statePtrType, ctx.GetFactory()->getType()}, false);
|
|
|
+ const auto processBlockPtr = CastInst::Create(Instruction::IntToPtr, processBlockFunc, PointerType::getUnqual(processBlockType), "process_inputs_func", block);
|
|
|
+ CallInst::Create(processBlockType, processBlockPtr, {stateArg, ctx.GetFactory()}, "", block);
|
|
|
+
|
|
|
+ BranchInst::Create(read, block);
|
|
|
+
|
|
|
+ block = work;
|
|
|
+
|
|
|
+ const auto makeOutputFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::MakeOutput));
|
|
|
+ const auto makeOutputType = FunctionType::get(flagType, {statePtrType, ctx.GetFactory()->getType()}, false);
|
|
|
+ const auto makeOutputPtr = CastInst::Create(Instruction::IntToPtr, makeOutputFunc, PointerType::getUnqual(makeOutputType), "make_output_func", block);
|
|
|
+ const auto hasData = CallInst::Create(makeOutputType, makeOutputPtr, {stateArg, ctx.GetFactory()}, "make_output", block);
|
|
|
+ const auto output = SelectInst::Create(hasData, ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), "output", block);
|
|
|
+ new StoreInst(stateArg, stateOnStack, block);
|
|
|
+
|
|
|
+ result->addIncoming(output, block);
|
|
|
+ BranchInst::Create(over, block);
|
|
|
+
|
|
|
+ block = over;
|
|
|
|
|
|
+ ICodegeneratorInlineWideNode::TGettersList getters(AggsParams_.size());
|
|
|
+ for (size_t idx = 0U; idx < getters.size(); ++idx) {
|
|
|
+ getters[idx] = [idx, pullType, pullPtr, indexType, statePtrType, stateOnStack](const TCodegenContext& ctx, BasicBlock*& block) {
|
|
|
+ const auto stateArg = new LoadInst(statePtrType, stateOnStack, "state", block);
|
|
|
+ return CallInst::Create(pullType, pullPtr, {stateArg, ConstantInt::get(indexType, idx)}, "pull", block);
|
|
|
+ };
|
|
|
+ }
|
|
|
+ return {result, std::move(getters)};
|
|
|
+ }
|
|
|
+#endif
|
|
|
private:
|
|
|
struct TState : public TComputationValue<TState> {
|
|
|
- std::vector<NUdf::TUnboxedValue> Values_;
|
|
|
- std::vector<NUdf::TUnboxedValue*> ValuePointers_;
|
|
|
- std::vector<std::unique_ptr<IBlockAggregatorCombineAll>> Aggs_;
|
|
|
+ NUdf::TUnboxedValue* Pointer_ = nullptr;
|
|
|
bool IsFinished_ = false;
|
|
|
bool HasValues_ = false;
|
|
|
+ TUnboxedValueVector Values_;
|
|
|
+ std::vector<std::unique_ptr<IBlockAggregatorCombineAll>> Aggs_;
|
|
|
std::vector<char> AggStates_;
|
|
|
+ const std::optional<ui32> FilterColumn_;
|
|
|
+ const size_t Width_;
|
|
|
|
|
|
- TState(TMemoryUsageInfo* memInfo, size_t width, std::optional<ui32>, const std::vector<TAggParams<IBlockAggregatorCombineAll>>& params, TComputationContext& ctx)
|
|
|
+ TState(TMemoryUsageInfo* memInfo, size_t width, std::optional<ui32> filterColumn, const std::vector<TAggParams<IBlockAggregatorCombineAll>>& params, TComputationContext& ctx)
|
|
|
: TComputationValue(memInfo)
|
|
|
- , Values_(width)
|
|
|
- , ValuePointers_(width)
|
|
|
+ , Values_(std::max(width, params.size()))
|
|
|
+ , FilterColumn_(filterColumn)
|
|
|
+ , Width_(width)
|
|
|
{
|
|
|
- for (size_t i = 0; i < width; ++i) {
|
|
|
- ValuePointers_[i] = &Values_[i];
|
|
|
- }
|
|
|
+ Pointer_ = Values_.data();
|
|
|
|
|
|
ui32 totalStateSize = 0;
|
|
|
for (const auto& p : params) {
|
|
@@ -574,29 +648,114 @@ private:
|
|
|
ptr += agg->StateSize;
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ void ProcessInput() {
|
|
|
+ const ui64 batchLength = TArrowBlock::From(Values_[Width_ - 1U]).GetDatum().scalar_as<arrow::UInt64Scalar>().value;
|
|
|
+ if (!batchLength) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ std::optional<ui64> filtered;
|
|
|
+ if (FilterColumn_) {
|
|
|
+ const auto filterDatum = TArrowBlock::From(Values_[*FilterColumn_]).GetDatum();
|
|
|
+ if (filterDatum.is_scalar()) {
|
|
|
+ if (!filterDatum.scalar_as<arrow::UInt8Scalar>().value) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ const ui64 popCount = GetBitmapPopCount(filterDatum.array());
|
|
|
+ if (popCount == 0) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (popCount < batchLength) {
|
|
|
+ filtered = popCount;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ HasValues_ = true;
|
|
|
+ char* ptr = AggStates_.data();
|
|
|
+ for (size_t i = 0; i < Aggs_.size(); ++i) {
|
|
|
+ Aggs_[i]->AddMany(ptr, Values_.data(), batchLength, filtered);
|
|
|
+ ptr += Aggs_[i]->StateSize;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ bool MakeOutput() {
|
|
|
+ IsFinished_ = true;
|
|
|
+ if (!HasValues_)
|
|
|
+ return false;
|
|
|
+
|
|
|
+ char* ptr = AggStates_.data();
|
|
|
+ for (size_t i = 0; i < Aggs_.size(); ++i) {
|
|
|
+ Values_[i] = Aggs_[i]->FinishOne(ptr);
|
|
|
+ Aggs_[i]->DestroyState(ptr);
|
|
|
+ ptr += Aggs_[i]->StateSize;
|
|
|
+ }
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ NUdf::TUnboxedValuePod Pull(size_t index) {
|
|
|
+ return Values_[index].Release();
|
|
|
+ }
|
|
|
};
|
|
|
+#ifndef MKQL_DISABLE_CODEGEN
|
|
|
+ class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TBlockState>> {
|
|
|
+ private:
|
|
|
+ using TBase = TLLVMFieldsStructure<TComputationValue<TBlockState>>;
|
|
|
+ llvm::PointerType*const PointerType;
|
|
|
+ llvm::IntegerType*const IsFinishedType;
|
|
|
+ public:
|
|
|
+ std::vector<llvm::Type*> GetFieldsArray() {
|
|
|
+ std::vector<llvm::Type*> result = TBase::GetFields();
|
|
|
+ result.emplace_back(PointerType);
|
|
|
+ result.emplace_back(IsFinishedType);
|
|
|
+ return result;
|
|
|
+ }
|
|
|
|
|
|
-private:
|
|
|
+ llvm::Constant* GetPointer() {
|
|
|
+ return llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), TBase::GetFieldsCount() + 0);
|
|
|
+ }
|
|
|
+
|
|
|
+ llvm::Constant* GetIsFinished() {
|
|
|
+ return llvm::ConstantInt::get(llvm::Type::getInt32Ty(Context), TBase::GetFieldsCount() + 1);
|
|
|
+ }
|
|
|
+
|
|
|
+ TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width)
|
|
|
+ : TBase(context)
|
|
|
+ , PointerType(llvm::PointerType::getUnqual(llvm::ArrayType::get(llvm::Type::getInt128Ty(Context), width)))
|
|
|
+ , IsFinishedType(llvm::Type::getInt1Ty(Context))
|
|
|
+ {}
|
|
|
+ };
|
|
|
+#endif
|
|
|
void RegisterDependencies() const final {
|
|
|
FlowDependsOn(Flow_);
|
|
|
}
|
|
|
|
|
|
+ void MakeState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
|
|
|
+ state = ctx.HolderFactory.Create<TState>(Width_, FilterColumn_, AggsParams_, ctx);
|
|
|
+ }
|
|
|
+
|
|
|
TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
|
|
|
if (!state.HasValue()) {
|
|
|
- state = ctx.HolderFactory.Create<TState>(Width_, FilterColumn_, AggsParams_, ctx);
|
|
|
+ MakeState(state, ctx);
|
|
|
+
|
|
|
+ auto& s = *static_cast<TState*>(state.AsBoxed().Get());
|
|
|
+ const auto fields = ctx.WideFields.data() + WideFieldsIndex_;
|
|
|
+ for (size_t i = 0; i < s.Values_.size(); ++i) {
|
|
|
+ fields[i] = &s.Values_[i];
|
|
|
+ }
|
|
|
+ return s;
|
|
|
}
|
|
|
return *static_cast<TState*>(state.AsBoxed().Get());
|
|
|
}
|
|
|
-
|
|
|
- ui64 GetBatchLength(const NUdf::TUnboxedValue* columns) const {
|
|
|
- return TArrowBlock::From(columns[Width_ - 1]).GetDatum().scalar_as<arrow::UInt64Scalar>().value;
|
|
|
- }
|
|
|
-
|
|
|
private:
|
|
|
- IComputationWideFlowNode* Flow_;
|
|
|
- std::optional<ui32> FilterColumn_;
|
|
|
+ IComputationWideFlowNode *const Flow_;
|
|
|
+ const std::optional<ui32> FilterColumn_;
|
|
|
const size_t Width_;
|
|
|
const std::vector<TAggParams<IBlockAggregatorCombineAll>> AggsParams_;
|
|
|
+ const size_t WideFieldsIndex_;
|
|
|
};
|
|
|
|
|
|
template <typename T>
|