|
@@ -38,11 +38,11 @@ public:
|
|
|
|
|
|
NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
|
|
|
const auto maxLen = CalcBlockLen(CalcMaxBlockItemSize(ItemType_));
|
|
|
- auto builder = MakeArrayBuilder(TTypeInfoHelper(), ItemType_, ctx.ArrowMemoryPool, maxLen, &ctx.Builder->GetPgBuilder());
|
|
|
+ const auto builder = MakeArrayBuilder(TTypeInfoHelper(), ItemType_, ctx.ArrowMemoryPool, maxLen, &ctx.Builder->GetPgBuilder());
|
|
|
|
|
|
for (size_t i = 0; i < builder->MaxLength(); ++i) {
|
|
|
auto result = Flow_->GetValue(ctx);
|
|
|
- if (result.IsFinish() || result.IsYield()) {
|
|
|
+ if (result.IsSpecial()) {
|
|
|
if (i == 0) {
|
|
|
return result.Release();
|
|
|
}
|
|
@@ -53,7 +53,6 @@ public:
|
|
|
|
|
|
return ctx.HolderFactory.CreateArrowBlock(builder->Build(true));
|
|
|
}
|
|
|
-
|
|
|
private:
|
|
|
void RegisterDependencies() const final {
|
|
|
FlowDependsOn(Flow_);
|
|
@@ -63,135 +62,389 @@ private:
|
|
|
TType* ItemType_;
|
|
|
};
|
|
|
|
|
|
-class TWideToBlocksWrapper : public TStatefulWideFlowBlockComputationNode<TWideToBlocksWrapper> {
|
|
|
+class TWideToBlocksWrapper : public TStatefulWideFlowCodegeneratorNode<TWideToBlocksWrapper> {
|
|
|
+using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideToBlocksWrapper>;
|
|
|
public:
|
|
|
TWideToBlocksWrapper(TComputationMutables& mutables,
|
|
|
IComputationWideFlowNode* flow,
|
|
|
TVector<TType*>&& types)
|
|
|
- : TStatefulWideFlowBlockComputationNode(mutables, flow, types.size() + 1)
|
|
|
+ : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
|
|
|
, Flow_(flow)
|
|
|
, Types_(std::move(types))
|
|
|
+ , MaxLength_(CalcBlockLen(std::accumulate(Types_.cbegin(), Types_.cend(), 0ULL, [](size_t max, const TType* type){ return std::max(max, CalcMaxBlockItemSize(type)); })))
|
|
|
, Width_(Types_.size())
|
|
|
+ , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(Width_))
|
|
|
{
|
|
|
}
|
|
|
|
|
|
EFetchResult DoCalculate(NUdf::TUnboxedValue& state,
|
|
|
TComputationContext& ctx,
|
|
|
- NUdf::TUnboxedValue*const* output) const
|
|
|
- {
|
|
|
+ NUdf::TUnboxedValue*const* output) const {
|
|
|
auto& s = GetState(state, ctx);
|
|
|
- if (s.IsFinished_) {
|
|
|
- return EFetchResult::Finish;
|
|
|
- }
|
|
|
-
|
|
|
- for (; s.Rows_ < s.MaxLength_; ++s.Rows_) {
|
|
|
- if (const auto result = Flow_->FetchValues(ctx, s.ValuePointers_.data()); EFetchResult::One != result) {
|
|
|
- if (EFetchResult::Finish == result) {
|
|
|
- s.IsFinished_ = true;
|
|
|
- }
|
|
|
+ const auto fields = ctx.WideFields.data() + WideFieldsIndex_;
|
|
|
|
|
|
- if (EFetchResult::Yield == result || s.Rows_ == 0) {
|
|
|
- return result;
|
|
|
+ while (!s.Count) {
|
|
|
+ do {
|
|
|
+ if (s.IsFinished_) {
|
|
|
+ if (s.Rows_)
|
|
|
+ break;
|
|
|
+ else
|
|
|
+ return EFetchResult::Finish;
|
|
|
+ } else switch (Flow_->FetchValues(ctx, fields)) {
|
|
|
+ case EFetchResult::One:
|
|
|
+ for (size_t i = 0; i < Types_.size(); ++i)
|
|
|
+ s.Add(s.Values[i], i);
|
|
|
+ ++s.Rows_;
|
|
|
+ continue;
|
|
|
+ case EFetchResult::Yield:
|
|
|
+ return EFetchResult::Yield;
|
|
|
+ case EFetchResult::Finish:
|
|
|
+ s.IsFinished_ = true;
|
|
|
+ break;
|
|
|
}
|
|
|
+ } while (s.Rows_ < MaxLength_);
|
|
|
|
|
|
- break;
|
|
|
- }
|
|
|
- for (size_t j = 0; j < Width_; ++j) {
|
|
|
- if (output[j] != nullptr) {
|
|
|
- s.Builders_[j]->Add(s.Values_[j]);
|
|
|
- }
|
|
|
- }
|
|
|
+ s.MakeBlocks(ctx.HolderFactory);
|
|
|
}
|
|
|
|
|
|
- for (size_t i = 0; i < Width_; ++i) {
|
|
|
- if (auto* out = output[i]; out != nullptr) {
|
|
|
- *out = ctx.HolderFactory.CreateArrowBlock(s.Builders_[i]->Build(s.IsFinished_));
|
|
|
+ const auto sliceSize = s.Slice();
|
|
|
+ for (size_t i = 0; i <= Types_.size(); ++i) {
|
|
|
+ if (const auto out = output[i]) {
|
|
|
+ *out = s.Get(sliceSize, ctx.HolderFactory, i);
|
|
|
}
|
|
|
}
|
|
|
+ return EFetchResult::One;
|
|
|
+ }
|
|
|
+#ifndef MKQL_DISABLE_CODEGEN
|
|
|
+ ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
|
|
|
+ auto& context = ctx.Codegen.GetContext();
|
|
|
+
|
|
|
+ const auto valueType = Type::getInt128Ty(context);
|
|
|
+ const auto ptrValueType = PointerType::getUnqual(valueType);
|
|
|
+ const auto statusType = Type::getInt32Ty(context);
|
|
|
+ const auto indexType = Type::getInt64Ty(context);
|
|
|
+
|
|
|
+ TLLVMFieldsStructureState stateFields(context);
|
|
|
+ const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
|
|
|
+ const auto statePtrType = PointerType::getUnqual(stateType);
|
|
|
+
|
|
|
+ const auto atTop = &ctx.Func->getEntryBlock().back();
|
|
|
+
|
|
|
+ const auto addFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Add));
|
|
|
+ const auto addType = FunctionType::get(Type::getVoidTy(context), {statePtrType, valueType, indexType}, false);
|
|
|
+ const auto addPtr = CastInst::Create(Instruction::IntToPtr, addFunc, PointerType::getUnqual(addType), "add", atTop);
|
|
|
+
|
|
|
+ const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Get));
|
|
|
+ const auto getType = FunctionType::get(valueType, {statePtrType, indexType, ctx.GetFactory()->getType(), indexType}, false);
|
|
|
+ const auto getPtr = CastInst::Create(Instruction::IntToPtr, getFunc, PointerType::getUnqual(getType), "get", atTop);
|
|
|
+
|
|
|
+ const auto heightPtr = new AllocaInst(indexType, 0U, "height_ptr", atTop);
|
|
|
+ const auto stateOnStack = new AllocaInst(statePtrType, 0U, "state_on_stack", atTop);
|
|
|
+
|
|
|
+ new StoreInst(ConstantInt::get(indexType, 0), heightPtr, atTop);
|
|
|
+ new StoreInst(ConstantPointerNull::get(statePtrType), stateOnStack, atTop);
|
|
|
+
|
|
|
+ const auto make = BasicBlock::Create(context, "make", ctx.Func);
|
|
|
+ const auto main = BasicBlock::Create(context, "main", ctx.Func);
|
|
|
+ const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
|
|
|
+ const auto more = BasicBlock::Create(context, "more", ctx.Func);
|
|
|
+ const auto skip = BasicBlock::Create(context, "skip", ctx.Func);
|
|
|
+ const auto read = BasicBlock::Create(context, "read", ctx.Func);
|
|
|
+ const auto good = BasicBlock::Create(context, "good", ctx.Func);
|
|
|
+ const auto stop = BasicBlock::Create(context, "stop", ctx.Func);
|
|
|
+ const auto work = BasicBlock::Create(context, "work", ctx.Func);
|
|
|
+ const auto fill = BasicBlock::Create(context, "fill", ctx.Func);
|
|
|
+ const auto over = BasicBlock::Create(context, "over", ctx.Func);
|
|
|
+
|
|
|
+ BranchInst::Create(main, make, HasValue(statePtr, block), block);
|
|
|
+ block = make;
|
|
|
+
|
|
|
+ const auto ptrType = PointerType::getUnqual(StructType::get(context));
|
|
|
+ const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
|
|
|
+ const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TWideToBlocksWrapper::MakeState));
|
|
|
+ const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false);
|
|
|
+ const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
|
|
|
+ CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
|
|
|
+ BranchInst::Create(main, block);
|
|
|
+
|
|
|
+ block = main;
|
|
|
+
|
|
|
+ const auto state = new LoadInst(valueType, statePtr, "state", block);
|
|
|
+ const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
|
|
|
+ const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
|
|
|
+ const auto countPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetCount() }, "count_ptr", block);
|
|
|
|
|
|
- if (auto* out = output[Width_]; out != nullptr) {
|
|
|
- *out = ctx.HolderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(s.Rows_)));
|
|
|
+
|
|
|
+ BranchInst::Create(loop, block);
|
|
|
+
|
|
|
+ block = loop;
|
|
|
+
|
|
|
+ const auto count = new LoadInst(indexType, countPtr, "count", block);
|
|
|
+ const auto none = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, count, ConstantInt::get(indexType, 0), "none", block);
|
|
|
+
|
|
|
+ BranchInst::Create(more, fill, none, block);
|
|
|
+
|
|
|
+ block = more;
|
|
|
+
|
|
|
+ const auto rowsPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetRows() }, "rows_ptr", block);
|
|
|
+ const auto finishedPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetIsFinished() }, "is_finished_ptr", block);
|
|
|
+
|
|
|
+ const auto rows = new LoadInst(indexType, rowsPtr, "rows", block);
|
|
|
+ const auto finished = new LoadInst(Type::getInt1Ty(context), finishedPtr, "finished", block);
|
|
|
+
|
|
|
+ BranchInst::Create(skip, read, finished, block);
|
|
|
+
|
|
|
+ block = skip;
|
|
|
+
|
|
|
+ const auto empty = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, rows, ConstantInt::get(indexType, 0), "empty", block);
|
|
|
+
|
|
|
+ const auto result = PHINode::Create(statusType, 3U, "result", over);
|
|
|
+ result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block);
|
|
|
+
|
|
|
+ BranchInst::Create(over, work, empty, block);
|
|
|
+
|
|
|
+ block = read;
|
|
|
+
|
|
|
+ const auto getres = GetNodeValues(Flow_, ctx, block);
|
|
|
+
|
|
|
+ const auto way = SwitchInst::Create(getres.first, good, 2U, block);
|
|
|
+ way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Finish)), stop);
|
|
|
+ way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Yield)), over);
|
|
|
+
|
|
|
+ result->addIncoming(getres.first, block);
|
|
|
+
|
|
|
+ block = stop;
|
|
|
+
|
|
|
+ new StoreInst(ConstantInt::getTrue(context), finishedPtr, block);
|
|
|
+ BranchInst::Create(skip, block);
|
|
|
+
|
|
|
+ block = good;
|
|
|
+
|
|
|
+ const auto increment = BinaryOperator::CreateAdd(rows, ConstantInt::get(indexType, 1), "increment", block);
|
|
|
+ new StoreInst(increment, rowsPtr, block);
|
|
|
+
|
|
|
+ for (size_t idx = 0U; idx < Types_.size(); ++idx) {
|
|
|
+ const auto value = getres.second[idx](ctx, block);
|
|
|
+ CallInst::Create(addType, addPtr, {stateArg, value, ConstantInt::get(indexType, idx)}, "", block);
|
|
|
+ ValueCleanup(GetValueRepresentation(Types_[idx]), value, ctx, block);
|
|
|
}
|
|
|
|
|
|
- s.Rows_ = 0;
|
|
|
- return EFetchResult::One;
|
|
|
+ const auto next = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_ULT, increment, ConstantInt::get(indexType, MaxLength_), "next", block);
|
|
|
+ BranchInst::Create(more, fill, next, block);
|
|
|
+
|
|
|
+ block = work;
|
|
|
+
|
|
|
+ const auto makeBlockFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::MakeBlocks));
|
|
|
+ const auto makeBlockType = FunctionType::get(indexType, {statePtrType, ctx.GetFactory()->getType()}, false);
|
|
|
+ const auto makeBlockPtr = CastInst::Create(Instruction::IntToPtr, makeBlockFunc, PointerType::getUnqual(makeBlockType), "make_blocks_func", block);
|
|
|
+ CallInst::Create(makeBlockType, makeBlockPtr, {stateArg, ctx.GetFactory()}, "", block);
|
|
|
+
|
|
|
+ BranchInst::Create(loop, block);
|
|
|
+
|
|
|
+ block = fill;
|
|
|
+
|
|
|
+ const auto sliceFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Slice));
|
|
|
+ const auto sliceType = FunctionType::get(indexType, {statePtrType}, false);
|
|
|
+ const auto slicePtr = CastInst::Create(Instruction::IntToPtr, sliceFunc, PointerType::getUnqual(sliceType), "slice_func", block);
|
|
|
+ const auto slice = CallInst::Create(sliceType, slicePtr, {stateArg}, "slice", block);
|
|
|
+ new StoreInst(slice, heightPtr, block);
|
|
|
+ new StoreInst(stateArg, stateOnStack, block);
|
|
|
+
|
|
|
+ result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), block);
|
|
|
+
|
|
|
+ BranchInst::Create(over, block);
|
|
|
+
|
|
|
+ block = over;
|
|
|
+
|
|
|
+ ICodegeneratorInlineWideNode::TGettersList getters(Types_.size() + 1U);
|
|
|
+ for (size_t idx = 0U; idx < getters.size(); ++idx) {
|
|
|
+ getters[idx] = [idx, getType, getPtr, heightPtr, indexType, statePtrType, stateOnStack](const TCodegenContext& ctx, BasicBlock*& block) {
|
|
|
+ const auto stateArg = new LoadInst(statePtrType, stateOnStack, "state", block);
|
|
|
+ const auto heightArg = new LoadInst(indexType, heightPtr, "height", block);
|
|
|
+ return CallInst::Create(getType, getPtr, {stateArg, heightArg, ctx.GetFactory(), ConstantInt::get(indexType, idx)}, "get", block);
|
|
|
+ };
|
|
|
+ }
|
|
|
+ return {result, std::move(getters)};
|
|
|
}
|
|
|
+#endif
|
|
|
private:
|
|
|
- struct TState : public TComputationValue<TState> {
|
|
|
- std::vector<NUdf::TUnboxedValue> Values_;
|
|
|
- std::vector<NUdf::TUnboxedValue*> ValuePointers_;
|
|
|
- std::vector<std::unique_ptr<IArrayBuilder>> Builders_;
|
|
|
- size_t MaxLength_;
|
|
|
+ struct TState : public TBlockState {
|
|
|
size_t Rows_ = 0;
|
|
|
bool IsFinished_ = false;
|
|
|
-
|
|
|
- TState(TMemoryUsageInfo* memInfo, TComputationContext& ctx, const TVector<TType*>& types)
|
|
|
- : TComputationValue(memInfo)
|
|
|
- , Values_(types.size())
|
|
|
- , ValuePointers_(types.size())
|
|
|
+ NUdf::TUnboxedValue**const Fields_;
|
|
|
+ std::vector<std::unique_ptr<IArrayBuilder>> Builders_;
|
|
|
+ TState(TMemoryUsageInfo* memInfo, TComputationContext& ctx, const TVector<TType*>& types, size_t maxLength, NUdf::TUnboxedValue**const fields)
|
|
|
+ : TBlockState(memInfo, types.size() + 1U)
|
|
|
+ , Fields_(fields)
|
|
|
+ , Builders_(types.size())
|
|
|
{
|
|
|
- size_t maxBlockItemSize = 0;
|
|
|
for (size_t i = 0; i < types.size(); ++i) {
|
|
|
- maxBlockItemSize = std::max(CalcMaxBlockItemSize(types[i]), maxBlockItemSize);
|
|
|
+ Fields_[i] = &Values[i];
|
|
|
+ Builders_[i] = MakeArrayBuilder(TTypeInfoHelper(), types[i], ctx.ArrowMemoryPool, maxLength, &ctx.Builder->GetPgBuilder());
|
|
|
}
|
|
|
- MaxLength_ = CalcBlockLen(maxBlockItemSize);
|
|
|
+ }
|
|
|
|
|
|
- for (size_t i = 0; i < types.size(); ++i) {
|
|
|
- ValuePointers_[i] = &Values_[i];
|
|
|
- Builders_.push_back(MakeArrayBuilder(TTypeInfoHelper(), types[i], ctx.ArrowMemoryPool, MaxLength_, &ctx.Builder->GetPgBuilder()));
|
|
|
+ void Add(const NUdf::TUnboxedValuePod value, size_t idx) {
|
|
|
+ Builders_[idx]->Add(value);
|
|
|
+ }
|
|
|
+
|
|
|
+ void MakeBlocks(const THolderFactory& holderFactory) {
|
|
|
+ Values.back() = holderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(Rows_)));
|
|
|
+ Rows_ = 0;
|
|
|
+
|
|
|
+ for (size_t i = 0; i < Builders_.size(); ++i) {
|
|
|
+ if (const auto builder = Builders_[i].get()) {
|
|
|
+ Values[i] = holderFactory.CreateArrowBlock(builder->Build(IsFinished_));
|
|
|
+ }
|
|
|
}
|
|
|
+
|
|
|
+ FillArrays();
|
|
|
}
|
|
|
};
|
|
|
+#ifndef MKQL_DISABLE_CODEGEN
|
|
|
+ class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TState>> {
|
|
|
+ private:
|
|
|
+ using TBase = TLLVMFieldsStructure<TComputationValue<TState>>;
|
|
|
+ llvm::IntegerType*const CountType;
|
|
|
+ llvm::ArrayType*const SkipSpaceType;
|
|
|
+ llvm::IntegerType*const RowsType;
|
|
|
+ llvm::IntegerType*const IsFinishedType;
|
|
|
+ protected:
|
|
|
+ using TBase::Context;
|
|
|
+ public:
|
|
|
+ std::vector<llvm::Type*> GetFieldsArray() {
|
|
|
+ std::vector<llvm::Type*> result = TBase::GetFields();
|
|
|
+ result.emplace_back(CountType);
|
|
|
+ result.emplace_back(SkipSpaceType);
|
|
|
+ result.emplace_back(RowsType);
|
|
|
+ result.emplace_back(IsFinishedType);
|
|
|
+ return result;
|
|
|
+ }
|
|
|
|
|
|
-private:
|
|
|
+ llvm::Constant* GetCount() {
|
|
|
+ return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 0);
|
|
|
+ }
|
|
|
+
|
|
|
+ llvm::Constant* GetRows() {
|
|
|
+ return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 2);
|
|
|
+ }
|
|
|
+
|
|
|
+ llvm::Constant* GetIsFinished() {
|
|
|
+ return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 3);
|
|
|
+ }
|
|
|
+
|
|
|
+ TLLVMFieldsStructureState(llvm::LLVMContext& context)
|
|
|
+ : TBase(context)
|
|
|
+ , CountType(Type::getInt64Ty(Context))
|
|
|
+ , SkipSpaceType(ArrayType::get(Type::getInt128Ty(Context), 3U)) // Skip std::vectors Values & Arrays
|
|
|
+ , RowsType(Type::getInt64Ty(Context))
|
|
|
+ , IsFinishedType(Type::getInt1Ty(Context))
|
|
|
+ {}
|
|
|
+ };
|
|
|
+#endif
|
|
|
void RegisterDependencies() const final {
|
|
|
FlowDependsOn(Flow_);
|
|
|
}
|
|
|
|
|
|
+ void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
|
|
|
+ state = ctx.HolderFactory.Create<TState>(ctx, Types_, MaxLength_, ctx.WideFields.data() + WideFieldsIndex_);
|
|
|
+ }
|
|
|
+
|
|
|
TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
|
|
|
- if (!state.HasValue()) {
|
|
|
- state = ctx.HolderFactory.Create<TState>(ctx, Types_);
|
|
|
- }
|
|
|
+ if (!state.HasValue())
|
|
|
+ MakeState(ctx, state);
|
|
|
return *static_cast<TState*>(state.AsBoxed().Get());
|
|
|
}
|
|
|
|
|
|
private:
|
|
|
- IComputationWideFlowNode* Flow_;
|
|
|
+ IComputationWideFlowNode* const Flow_;
|
|
|
const TVector<TType*> Types_;
|
|
|
+ const size_t MaxLength_;
|
|
|
const size_t Width_;
|
|
|
+ const size_t WideFieldsIndex_;
|
|
|
};
|
|
|
|
|
|
-class TFromBlocksWrapper : public TMutableComputationNode<TFromBlocksWrapper> {
|
|
|
+class TFromBlocksWrapper : public TStatefulFlowCodegeneratorNode<TFromBlocksWrapper> {
|
|
|
+using TBaseComputation = TStatefulFlowCodegeneratorNode<TFromBlocksWrapper>;
|
|
|
public:
|
|
|
TFromBlocksWrapper(TComputationMutables& mutables, IComputationNode* flow, TType* itemType)
|
|
|
- : TMutableComputationNode(mutables)
|
|
|
+ : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
|
|
|
, Flow_(flow)
|
|
|
, ItemType_(itemType)
|
|
|
- , StateIndex_(mutables.CurValueIndex++)
|
|
|
{
|
|
|
}
|
|
|
|
|
|
- NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
|
|
|
- auto& state = GetState(ctx);
|
|
|
-
|
|
|
- for (;;) {
|
|
|
- auto item = state.GetValue(ctx);
|
|
|
- if (item) {
|
|
|
- return *item;
|
|
|
- }
|
|
|
-
|
|
|
- auto input = Flow_->GetValue(ctx);
|
|
|
- if (input.IsFinish()) {
|
|
|
- return NUdf::TUnboxedValue::MakeFinish();
|
|
|
- }
|
|
|
- if (input.IsYield()) {
|
|
|
- return NUdf::TUnboxedValue::MakeYield();
|
|
|
- }
|
|
|
+ NUdf::TUnboxedValuePod DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
|
|
|
+ for (auto& s = GetState(state, ctx);;) {
|
|
|
+ if (auto item = s.GetValue(ctx.HolderFactory); !item.IsInvalid())
|
|
|
+ return item;
|
|
|
|
|
|
- state.Reset(TArrowBlock::From(input).GetDatum());
|
|
|
+ if (const auto input = Flow_->GetValue(ctx).Release(); input.IsSpecial())
|
|
|
+ return input;
|
|
|
+ else
|
|
|
+ s.Reset(input);
|
|
|
}
|
|
|
}
|
|
|
+#ifndef MKQL_DISABLE_CODEGEN
|
|
|
+ Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
|
|
|
+ auto& context = ctx.Codegen.GetContext();
|
|
|
|
|
|
+ const auto valueType = Type::getInt128Ty(context);
|
|
|
+ const auto statePtrType = PointerType::getUnqual(StructType::get(context));
|
|
|
+
|
|
|
+ const auto make = BasicBlock::Create(context, "make", ctx.Func);
|
|
|
+ const auto work = BasicBlock::Create(context, "work", ctx.Func);
|
|
|
+ const auto read = BasicBlock::Create(context, "read", ctx.Func);
|
|
|
+ const auto init = BasicBlock::Create(context, "init", ctx.Func);
|
|
|
+ const auto done = BasicBlock::Create(context, "done", ctx.Func);
|
|
|
+
|
|
|
+ BranchInst::Create(work, make, HasValue(statePtr, block), block);
|
|
|
+ block = make;
|
|
|
+
|
|
|
+ const auto ptrType = PointerType::getUnqual(StructType::get(context));
|
|
|
+ const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
|
|
|
+ const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TFromBlocksWrapper::MakeState));
|
|
|
+ const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false);
|
|
|
+ const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
|
|
|
+ CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
|
|
|
+ BranchInst::Create(work, block);
|
|
|
+
|
|
|
+ block = work;
|
|
|
+
|
|
|
+ const auto state = new LoadInst(valueType, statePtr, "state", block);
|
|
|
+ const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
|
|
|
+ const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
|
|
|
+
|
|
|
+ const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::GetValue));
|
|
|
+ const auto getType = FunctionType::get(valueType, {statePtrType, ctx.GetFactory()->getType()}, false);
|
|
|
+ const auto getPtr = CastInst::Create(Instruction::IntToPtr, getFunc, PointerType::getUnqual(getType), "get", block);
|
|
|
+ const auto value = CallInst::Create(getType, getPtr, {stateArg, ctx.GetFactory() }, "value", block);
|
|
|
+
|
|
|
+ const auto result = PHINode::Create(valueType, 2U, "result", done);
|
|
|
+ result->addIncoming(value, block);
|
|
|
+
|
|
|
+ BranchInst::Create(read, done, IsInvalid(value, block), block);
|
|
|
+
|
|
|
+ block = read;
|
|
|
+
|
|
|
+ const auto input = GetNodeValue(Flow_, ctx, block);
|
|
|
+ result->addIncoming(input, block);
|
|
|
+
|
|
|
+ BranchInst::Create(done, init, IsSpecial(input, block), block);
|
|
|
+
|
|
|
+ block = init;
|
|
|
+
|
|
|
+ const auto setFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Reset));
|
|
|
+ const auto setType = FunctionType::get(valueType, {statePtrType, valueType}, false);
|
|
|
+ const auto setPtr = CastInst::Create(Instruction::IntToPtr, setFunc, PointerType::getUnqual(setType), "set", block);
|
|
|
+ CallInst::Create(setType, setPtr, {stateArg, input }, "", block);
|
|
|
+
|
|
|
+ BranchInst::Create(work, block);
|
|
|
+
|
|
|
+ block = done;
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+#endif
|
|
|
private:
|
|
|
struct TState : public TComputationValue<TState> {
|
|
|
using TComputationValue::TComputationValue;
|
|
@@ -203,10 +456,10 @@ private:
|
|
|
{
|
|
|
}
|
|
|
|
|
|
- TMaybe<NUdf::TUnboxedValuePod> GetValue(TComputationContext& ctx) {
|
|
|
+ NUdf::TUnboxedValuePod GetValue(const THolderFactory& holderFactory) {
|
|
|
for (;;) {
|
|
|
if (Arrays_.empty()) {
|
|
|
- return {};
|
|
|
+ return NUdf::TUnboxedValuePod::Invalid();
|
|
|
}
|
|
|
if (Index_ < ui64(Arrays_.front()->length)) {
|
|
|
break;
|
|
@@ -214,16 +467,18 @@ private:
|
|
|
Index_ = 0;
|
|
|
Arrays_.pop_front();
|
|
|
}
|
|
|
- return Converter_->MakeValue(Reader_->GetItem(*Arrays_.front(), Index_++), ctx.HolderFactory);
|
|
|
+ return Converter_->MakeValue(Reader_->GetItem(*Arrays_.front(), Index_++), holderFactory);
|
|
|
}
|
|
|
|
|
|
- void Reset(const arrow::Datum& datum) {
|
|
|
+ void Reset(const NUdf::TUnboxedValuePod block) {
|
|
|
+ const NUdf::TUnboxedValue v(block);
|
|
|
+ const auto& datum = TArrowBlock::From(v).GetDatum();
|
|
|
MKQL_ENSURE(datum.is_arraylike(), "Expecting array as FromBlocks argument");
|
|
|
MKQL_ENSURE(Arrays_.empty(), "Not all input is processed");
|
|
|
if (datum.is_array()) {
|
|
|
Arrays_.push_back(datum.array());
|
|
|
} else {
|
|
|
- for (auto& chunk : datum.chunks()) {
|
|
|
+ for (const auto& chunk : datum.chunks()) {
|
|
|
Arrays_.push_back(chunk->data());
|
|
|
}
|
|
|
}
|
|
@@ -239,21 +494,22 @@ private:
|
|
|
|
|
|
private:
|
|
|
void RegisterDependencies() const final {
|
|
|
- this->DependsOn(Flow_);
|
|
|
+ FlowDependsOn(Flow_);
|
|
|
}
|
|
|
|
|
|
- TState& GetState(TComputationContext& ctx) const {
|
|
|
- auto& result = ctx.MutableValues[StateIndex_];
|
|
|
- if (!result.HasValue()) {
|
|
|
- result = ctx.HolderFactory.Create<TState>(ItemType_, ctx.Builder->GetPgBuilder());
|
|
|
- }
|
|
|
- return *static_cast<TState*>(result.AsBoxed().Get());
|
|
|
+ void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
|
|
|
+ state = ctx.HolderFactory.Create<TState>(ItemType_, ctx.Builder->GetPgBuilder());
|
|
|
+ }
|
|
|
+
|
|
|
+ TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
|
|
|
+ if (!state.HasValue())
|
|
|
+ MakeState(ctx, state);
|
|
|
+ return *static_cast<TState*>(state.AsBoxed().Get());
|
|
|
}
|
|
|
|
|
|
private:
|
|
|
IComputationNode* const Flow_;
|
|
|
TType* ItemType_;
|
|
|
- const ui32 StateIndex_;
|
|
|
};
|
|
|
|
|
|
class TWideFromBlocksWrapper : public TStatefulWideFlowCodegeneratorNode<TWideFromBlocksWrapper> {
|
|
@@ -294,12 +550,15 @@ public:
|
|
|
ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
|
|
|
auto& context = ctx.Codegen.GetContext();
|
|
|
|
|
|
+ const auto width = Types_.size();
|
|
|
const auto valueType = Type::getInt128Ty(context);
|
|
|
const auto ptrValueType = PointerType::getUnqual(valueType);
|
|
|
const auto statusType = Type::getInt32Ty(context);
|
|
|
const auto indexType = Type::getInt64Ty(context);
|
|
|
+ const auto arrayType = ArrayType::get(valueType, width);
|
|
|
+ const auto ptrValuesType = PointerType::getUnqual(ArrayType::get(valueType, width));
|
|
|
|
|
|
- TLLVMFieldsStructureState stateFields(context);
|
|
|
+ TLLVMFieldsStructureState stateFields(context, width);
|
|
|
const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
|
|
|
const auto statePtrType = PointerType::getUnqual(stateType);
|
|
|
|
|
@@ -320,7 +579,6 @@ public:
|
|
|
const auto main = BasicBlock::Create(context, "main", ctx.Func);
|
|
|
const auto more = BasicBlock::Create(context, "more", ctx.Func);
|
|
|
const auto good = BasicBlock::Create(context, "good", ctx.Func);
|
|
|
- const auto save = BasicBlock::Create(context, "save", ctx.Func);
|
|
|
const auto work = BasicBlock::Create(context, "work", ctx.Func);
|
|
|
const auto over = BasicBlock::Create(context, "over", ctx.Func);
|
|
|
|
|
@@ -353,14 +611,9 @@ public:
|
|
|
|
|
|
block = more;
|
|
|
|
|
|
- const auto pointerPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetPointer() }, "pointer_ptr", block);
|
|
|
- const auto pointer = new LoadInst(ptrValueType, pointerPtr, "pointer", block);
|
|
|
-
|
|
|
- std::vector<Value*> pointers(Types_.size());
|
|
|
- for (size_t idx = 0U; idx < pointers.size(); ++idx) {
|
|
|
- pointers[idx] = GetElementPtrInst::CreateInBounds(valueType, pointer, { ConstantInt::get(Type::getInt32Ty(context), idx) }, (TString("ptr_") += ToString(idx)).c_str(), block);
|
|
|
- SafeUnRefUnboxed(pointers[idx], ctx, block);
|
|
|
- }
|
|
|
+ const auto valuesPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetPointer() }, "values_ptr", block);
|
|
|
+ const auto values = new LoadInst(ptrValuesType, valuesPtr, "values", block);
|
|
|
+ SafeUnRefUnboxed(values, ctx, block);
|
|
|
|
|
|
const auto getres = GetNodeValues(Flow_, ctx, block);
|
|
|
|
|
@@ -375,24 +628,13 @@ public:
|
|
|
|
|
|
const auto countValue = getres.second.back()(ctx, block);
|
|
|
const auto height = CallInst::Create(getCount, { WrapArgumentForWindows(countValue, ctx, block) }, "height", block);
|
|
|
- CleanupBoxed(countValue, ctx, block);
|
|
|
|
|
|
new StoreInst(height, countPtr, block);
|
|
|
new StoreInst(ConstantInt::get(indexType, 0), indexPtr, block);
|
|
|
|
|
|
const auto empty = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, ConstantInt::get(indexType, 0), height, "empty", block);
|
|
|
|
|
|
- BranchInst::Create(more, save, empty, block);
|
|
|
-
|
|
|
- block = save;
|
|
|
-
|
|
|
- for (size_t idx = 0U; idx < pointers.size(); ++idx) {
|
|
|
- const auto value = getres.second[idx](ctx, block);
|
|
|
- AddRefBoxed(value, ctx, block);
|
|
|
- new StoreInst(value, pointers[idx], block);
|
|
|
- }
|
|
|
-
|
|
|
- BranchInst::Create(work, block);
|
|
|
+ BranchInst::Create(more, work, empty, block);
|
|
|
|
|
|
block = work;
|
|
|
|
|
@@ -409,11 +651,34 @@ public:
|
|
|
|
|
|
block = over;
|
|
|
|
|
|
- ICodegeneratorInlineWideNode::TGettersList getters(Types_.size());
|
|
|
+ ICodegeneratorInlineWideNode::TGettersList getters(width);
|
|
|
for (size_t idx = 0U; idx < getters.size(); ++idx) {
|
|
|
- getters[idx] = [idx, getType, getPtr, indexType, statePtrType, stateOnStack](const TCodegenContext& ctx, BasicBlock*& block) {
|
|
|
+ getters[idx] = [idx, width, getType, getPtr, valueType, indexType, arrayType, ptrValuesType, stateType, statePtrType, stateOnStack, getBlocks = getres.second](const TCodegenContext& ctx, BasicBlock*& block) {
|
|
|
+ auto& context = ctx.Codegen.GetContext();
|
|
|
+ const auto init = BasicBlock::Create(context, "init", ctx.Func);
|
|
|
+ const auto call = BasicBlock::Create(context, "call", ctx.Func);
|
|
|
+
|
|
|
+ TLLVMFieldsStructureState stateFields(context, width);
|
|
|
+
|
|
|
const auto stateArg = new LoadInst(statePtrType, stateOnStack, "state", block);
|
|
|
- return CallInst::Create(getType, getPtr, {stateArg, ctx.GetFactory(), ConstantInt::get(indexType, idx)}, "get", block);
|
|
|
+ const auto valuesPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetPointer() }, "values_ptr", block);
|
|
|
+ const auto values = new LoadInst(ptrValuesType, valuesPtr, "values", block);
|
|
|
+ const auto index = ConstantInt::get(indexType, idx);
|
|
|
+ const auto pointer = GetElementPtrInst::CreateInBounds(arrayType, values, { ConstantInt::get(indexType, 0), index }, "pointer", block);
|
|
|
+
|
|
|
+ BranchInst::Create(call, init, HasValue(pointer, block), block);
|
|
|
+
|
|
|
+ block = init;
|
|
|
+
|
|
|
+ const auto value = getBlocks[idx](ctx, block);
|
|
|
+ new StoreInst(value, pointer, block);
|
|
|
+ AddRefBoxed(value, ctx, block);
|
|
|
+
|
|
|
+ BranchInst::Create(call, block);
|
|
|
+
|
|
|
+ block = call;
|
|
|
+
|
|
|
+ return CallInst::Create(getType, getPtr, {stateArg, ctx.GetFactory(), index}, "get", block);
|
|
|
};
|
|
|
}
|
|
|
return {result, std::move(getters)};
|
|
@@ -493,12 +758,12 @@ private:
|
|
|
return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 3);
|
|
|
}
|
|
|
|
|
|
- TLLVMFieldsStructureState(llvm::LLVMContext& context)
|
|
|
+ TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width)
|
|
|
: TBase(context)
|
|
|
, CountType(Type::getInt64Ty(Context))
|
|
|
, IndexType(Type::getInt64Ty(Context))
|
|
|
, CurrentType(Type::getInt64Ty(Context))
|
|
|
- , PointerType(PointerType::getUnqual(Type::getInt128Ty(Context)))
|
|
|
+ , PointerType(PointerType::getUnqual(ArrayType::get(Type::getInt128Ty(Context), width)))
|
|
|
{}
|
|
|
};
|
|
|
#endif
|