mkql_blocks.cpp 66 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449
  1. #include "mkql_blocks.h"
  2. #include <yql/essentials/minikql/computation/mkql_block_reader.h>
  3. #include <yql/essentials/minikql/computation/mkql_block_builder.h>
  4. #include <yql/essentials/minikql/computation/mkql_block_impl.h>
  5. #include <yql/essentials/minikql/computation/mkql_block_impl_codegen.h> // Y_IGNORE
  6. #include <yql/essentials/minikql/arrow/arrow_defs.h>
  7. #include <yql/essentials/minikql/arrow/arrow_util.h>
  8. #include <yql/essentials/minikql/mkql_type_builder.h>
  9. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  10. #include <yql/essentials/minikql/mkql_node_builder.h>
  11. #include <yql/essentials/minikql/mkql_node_cast.h>
  12. #include <yql/essentials/parser/pg_wrapper/interface/arrow.h>
  13. #include <arrow/scalar.h>
  14. #include <arrow/array.h>
  15. #include <arrow/datum.h>
  16. namespace NKikimr {
  17. namespace NMiniKQL {
  18. namespace {
  19. class TToBlocksWrapper : public TStatelessFlowComputationNode<TToBlocksWrapper> {
  20. public:
  21. explicit TToBlocksWrapper(IComputationNode* flow, TType* itemType)
  22. : TStatelessFlowComputationNode(flow, EValueRepresentation::Boxed)
  23. , Flow_(flow)
  24. , ItemType_(itemType)
  25. {
  26. }
  27. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  28. const auto maxLen = CalcBlockLen(CalcMaxBlockItemSize(ItemType_));
  29. const auto builder = MakeArrayBuilder(TTypeInfoHelper(), ItemType_, ctx.ArrowMemoryPool, maxLen, &ctx.Builder->GetPgBuilder());
  30. for (size_t i = 0; i < builder->MaxLength(); ++i) {
  31. auto result = Flow_->GetValue(ctx);
  32. if (result.IsSpecial()) {
  33. if (i == 0) {
  34. return result.Release();
  35. }
  36. break;
  37. }
  38. builder->Add(result);
  39. }
  40. return ctx.HolderFactory.CreateArrowBlock(builder->Build(true));
  41. }
  42. private:
  43. void RegisterDependencies() const final {
  44. FlowDependsOn(Flow_);
  45. }
  46. private:
  47. IComputationNode* const Flow_;
  48. TType* ItemType_;
  49. };
  50. struct TWideToBlocksState : public TBlockState {
  51. size_t Rows_ = 0;
  52. bool IsFinished_ = false;
  53. size_t BuilderAllocatedSize_ = 0;
  54. size_t MaxBuilderAllocatedSize_ = 0;
  55. std::vector<std::unique_ptr<IArrayBuilder>> Builders_;
  56. static const size_t MaxAllocatedFactor_ = 4;
  57. TWideToBlocksState(TMemoryUsageInfo* memInfo, TComputationContext& ctx, const TVector<TType*>& types, size_t maxLength)
  58. : TBlockState(memInfo, types.size() + 1U)
  59. , Builders_(types.size())
  60. {
  61. for (size_t i = 0; i < types.size(); ++i) {
  62. Builders_[i] = MakeArrayBuilder(TTypeInfoHelper(), types[i], ctx.ArrowMemoryPool, maxLength, &ctx.Builder->GetPgBuilder(), &BuilderAllocatedSize_);
  63. }
  64. MaxBuilderAllocatedSize_ = MaxAllocatedFactor_ * BuilderAllocatedSize_;
  65. }
  66. void Add(const NUdf::TUnboxedValuePod value, size_t idx) {
  67. Builders_[idx]->Add(value);
  68. }
  69. void MakeBlocks(const THolderFactory& holderFactory) {
  70. Values.back() = holderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(Rows_)));
  71. Rows_ = 0;
  72. BuilderAllocatedSize_ = 0;
  73. for (size_t i = 0; i < Builders_.size(); ++i) {
  74. if (const auto builder = Builders_[i].get()) {
  75. Values[i] = holderFactory.CreateArrowBlock(builder->Build(IsFinished_));
  76. }
  77. }
  78. FillArrays();
  79. }
  80. };
  81. class TWideToBlocksFlowWrapper : public TStatefulWideFlowCodegeneratorNode<TWideToBlocksFlowWrapper> {
  82. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideToBlocksFlowWrapper>;
  83. using TState = TWideToBlocksState;
  84. public:
  85. TWideToBlocksFlowWrapper(TComputationMutables& mutables,
  86. IComputationWideFlowNode* flow,
  87. TVector<TType*>&& types)
  88. : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
  89. , Flow_(flow)
  90. , Types_(std::move(types))
  91. , MaxLength_(CalcBlockLen(std::accumulate(Types_.cbegin(), Types_.cend(), 0ULL, [](size_t max, const TType* type){ return std::max(max, CalcMaxBlockItemSize(type)); })))
  92. , Width_(Types_.size())
  93. , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(Width_))
  94. {
  95. }
  96. EFetchResult DoCalculate(NUdf::TUnboxedValue& state,
  97. TComputationContext& ctx,
  98. NUdf::TUnboxedValue*const* output) const {
  99. auto& s = GetState(state, ctx);
  100. const auto fields = ctx.WideFields.data() + WideFieldsIndex_;
  101. if (!s.Count) {
  102. if (!s.IsFinished_) do {
  103. switch (Flow_->FetchValues(ctx, fields)) {
  104. case EFetchResult::One:
  105. for (size_t i = 0; i < Types_.size(); ++i)
  106. s.Add(s.Values[i], i);
  107. continue;
  108. case EFetchResult::Yield:
  109. return EFetchResult::Yield;
  110. case EFetchResult::Finish:
  111. s.IsFinished_ = true;
  112. break;
  113. }
  114. break;
  115. } while (++s.Rows_ < MaxLength_ && s.BuilderAllocatedSize_ <= s.MaxBuilderAllocatedSize_);
  116. if (s.Rows_)
  117. s.MakeBlocks(ctx.HolderFactory);
  118. else
  119. return EFetchResult::Finish;
  120. }
  121. const auto sliceSize = s.Slice();
  122. for (size_t i = 0; i <= Types_.size(); ++i) {
  123. if (const auto out = output[i]) {
  124. *out = s.Get(sliceSize, ctx.HolderFactory, i);
  125. }
  126. }
  127. return EFetchResult::One;
  128. }
  129. #ifndef MKQL_DISABLE_CODEGEN
  130. ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  131. auto& context = ctx.Codegen.GetContext();
  132. const auto valueType = Type::getInt128Ty(context);
  133. const auto statusType = Type::getInt32Ty(context);
  134. const auto indexType = Type::getInt64Ty(context);
  135. TLLVMFieldsStructureState stateFields(context, Types_.size() + 1U);
  136. const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
  137. const auto statePtrType = PointerType::getUnqual(stateType);
  138. const auto atTop = &ctx.Func->getEntryBlock().back();
  139. const auto addFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Add));
  140. const auto addType = FunctionType::get(Type::getVoidTy(context), {statePtrType, valueType, indexType}, false);
  141. const auto addPtr = CastInst::Create(Instruction::IntToPtr, addFunc, PointerType::getUnqual(addType), "add", atTop);
  142. const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Get));
  143. const auto getType = FunctionType::get(valueType, {statePtrType, indexType, ctx.GetFactory()->getType(), indexType}, false);
  144. const auto getPtr = CastInst::Create(Instruction::IntToPtr, getFunc, PointerType::getUnqual(getType), "get", atTop);
  145. const auto heightPtr = new AllocaInst(indexType, 0U, "height_ptr", atTop);
  146. const auto stateOnStack = new AllocaInst(statePtrType, 0U, "state_on_stack", atTop);
  147. new StoreInst(ConstantInt::get(indexType, 0), heightPtr, atTop);
  148. new StoreInst(ConstantPointerNull::get(statePtrType), stateOnStack, atTop);
  149. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  150. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  151. const auto more = BasicBlock::Create(context, "more", ctx.Func);
  152. const auto skip = BasicBlock::Create(context, "skip", ctx.Func);
  153. const auto read = BasicBlock::Create(context, "read", ctx.Func);
  154. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  155. const auto stop = BasicBlock::Create(context, "stop", ctx.Func);
  156. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  157. const auto fill = BasicBlock::Create(context, "fill", ctx.Func);
  158. const auto over = BasicBlock::Create(context, "over", ctx.Func);
  159. const auto second_cond = BasicBlock::Create(context, "second_cond", ctx.Func);
  160. BranchInst::Create(make, main, IsInvalid(statePtr, block, context), block);
  161. block = make;
  162. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  163. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  164. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TWideToBlocksFlowWrapper::MakeState));
  165. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false);
  166. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  167. CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
  168. BranchInst::Create(main, block);
  169. block = main;
  170. const auto state = new LoadInst(valueType, statePtr, "state", block);
  171. const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
  172. const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
  173. const auto countPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetCount() }, "count_ptr", block);
  174. const auto count = new LoadInst(indexType, countPtr, "count", block);
  175. const auto none = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, count, ConstantInt::get(indexType, 0), "none", block);
  176. BranchInst::Create(more, fill, none, block);
  177. block = more;
  178. const auto rowsPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetRows() }, "rows_ptr", block);
  179. const auto finishedPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetIsFinished() }, "is_finished_ptr", block);
  180. const auto allocatedSizePtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetBuilderAllocatedSize() }, "allocated_size_ptr", block);
  181. const auto maxAllocatedSizePtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetMaxBuilderAllocatedSize() }, "max_allocated_size_ptr", block);
  182. const auto finished = new LoadInst(Type::getInt1Ty(context), finishedPtr, "finished", block);
  183. BranchInst::Create(skip, read, finished, block);
  184. block = read;
  185. const auto getres = GetNodeValues(Flow_, ctx, block);
  186. const auto way = SwitchInst::Create(getres.first, good, 2U, block);
  187. way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Finish)), stop);
  188. way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Yield)), over);
  189. const auto result = PHINode::Create(statusType, 3U, "result", over);
  190. result->addIncoming(getres.first, block);
  191. block = good;
  192. const auto read_rows = new LoadInst(indexType, rowsPtr, "read_rows", block);
  193. const auto increment = BinaryOperator::CreateAdd(read_rows, ConstantInt::get(indexType, 1), "increment", block);
  194. new StoreInst(increment, rowsPtr, block);
  195. for (size_t idx = 0U; idx < Types_.size(); ++idx) {
  196. const auto value = getres.second[idx](ctx, block);
  197. CallInst::Create(addType, addPtr, {stateArg, value, ConstantInt::get(indexType, idx)}, "", block);
  198. ValueCleanup(GetValueRepresentation(Types_[idx]), value, ctx, block);
  199. }
  200. const auto next = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_ULT, increment, ConstantInt::get(indexType, MaxLength_), "next", block);
  201. BranchInst::Create(second_cond, work, next, block);
  202. block = second_cond;
  203. const auto read_allocated_size = new LoadInst(indexType, allocatedSizePtr, "read_allocated_size", block);
  204. const auto read_max_allocated_size = new LoadInst(indexType, maxAllocatedSizePtr, "read_max_allocated_size", block);
  205. const auto next2 = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_ULE, read_allocated_size, read_max_allocated_size, "next2", block);
  206. BranchInst::Create(read, work, next2, block);
  207. block = stop;
  208. new StoreInst(ConstantInt::getTrue(context), finishedPtr, block);
  209. BranchInst::Create(skip, block);
  210. block = skip;
  211. const auto rows = new LoadInst(indexType, rowsPtr, "rows", block);
  212. const auto empty = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, rows, ConstantInt::get(indexType, 0), "empty", block);
  213. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block);
  214. BranchInst::Create(over, work, empty, block);
  215. block = work;
  216. const auto makeBlockFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::MakeBlocks));
  217. const auto makeBlockType = FunctionType::get(indexType, {statePtrType, ctx.GetFactory()->getType()}, false);
  218. const auto makeBlockPtr = CastInst::Create(Instruction::IntToPtr, makeBlockFunc, PointerType::getUnqual(makeBlockType), "make_blocks_func", block);
  219. CallInst::Create(makeBlockType, makeBlockPtr, {stateArg, ctx.GetFactory()}, "", block);
  220. BranchInst::Create(fill, block);
  221. block = fill;
  222. const auto sliceFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Slice));
  223. const auto sliceType = FunctionType::get(indexType, {statePtrType}, false);
  224. const auto slicePtr = CastInst::Create(Instruction::IntToPtr, sliceFunc, PointerType::getUnqual(sliceType), "slice_func", block);
  225. const auto slice = CallInst::Create(sliceType, slicePtr, {stateArg}, "slice", block);
  226. new StoreInst(slice, heightPtr, block);
  227. new StoreInst(stateArg, stateOnStack, block);
  228. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), block);
  229. BranchInst::Create(over, block);
  230. block = over;
  231. ICodegeneratorInlineWideNode::TGettersList getters(Types_.size() + 1U);
  232. for (size_t idx = 0U; idx < getters.size(); ++idx) {
  233. getters[idx] = [idx, getType, getPtr, heightPtr, indexType, statePtrType, stateOnStack](const TCodegenContext& ctx, BasicBlock*& block) {
  234. const auto stateArg = new LoadInst(statePtrType, stateOnStack, "state", block);
  235. const auto heightArg = new LoadInst(indexType, heightPtr, "height", block);
  236. return CallInst::Create(getType, getPtr, {stateArg, heightArg, ctx.GetFactory(), ConstantInt::get(indexType, idx)}, "get", block);
  237. };
  238. }
  239. return {result, std::move(getters)};
  240. }
  241. #endif
  242. private:
  243. #ifndef MKQL_DISABLE_CODEGEN
  244. class TLLVMFieldsStructureState: public TLLVMFieldsStructureBlockState {
  245. private:
  246. using TBase = TLLVMFieldsStructureBlockState;
  247. llvm::IntegerType*const RowsType;
  248. llvm::IntegerType*const IsFinishedType;
  249. llvm::IntegerType*const BuilderAllocatedSizeType;
  250. llvm::IntegerType*const MaxBuilderAllocatedSizeType;
  251. protected:
  252. using TBase::Context;
  253. public:
  254. std::vector<llvm::Type*> GetFieldsArray() {
  255. std::vector<llvm::Type*> result = TBase::GetFieldsArray();
  256. result.emplace_back(RowsType);
  257. result.emplace_back(IsFinishedType);
  258. result.emplace_back(BuilderAllocatedSizeType);
  259. result.emplace_back(MaxBuilderAllocatedSizeType);
  260. return result;
  261. }
  262. llvm::Constant* GetRows() {
  263. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + BaseFields);
  264. }
  265. llvm::Constant* GetIsFinished() {
  266. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + BaseFields + 1);
  267. }
  268. llvm::Constant* GetBuilderAllocatedSize() {
  269. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + BaseFields + 2);
  270. }
  271. llvm::Constant* GetMaxBuilderAllocatedSize() {
  272. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + BaseFields + 3);
  273. }
  274. TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width)
  275. : TBase(context, width)
  276. , RowsType(Type::getInt64Ty(Context))
  277. , IsFinishedType(Type::getInt1Ty(Context))
  278. , BuilderAllocatedSizeType(Type::getInt64Ty(Context))
  279. , MaxBuilderAllocatedSizeType(Type::getInt64Ty(Context))
  280. {}
  281. };
  282. #endif
  283. void RegisterDependencies() const final {
  284. FlowDependsOn(Flow_);
  285. }
  286. void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
  287. state = ctx.HolderFactory.Create<TState>(ctx, Types_, MaxLength_);
  288. }
  289. TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  290. if (state.IsInvalid()) {
  291. MakeState(ctx, state);
  292. auto& s = *static_cast<TState*>(state.AsBoxed().Get());
  293. const auto fields = ctx.WideFields.data() + WideFieldsIndex_;
  294. for (size_t i = 0; i < Width_; ++i)
  295. fields[i] = &s.Values[i];
  296. return s;
  297. }
  298. return *static_cast<TState*>(state.AsBoxed().Get());
  299. }
  300. private:
  301. IComputationWideFlowNode* const Flow_;
  302. const TVector<TType*> Types_;
  303. const size_t MaxLength_;
  304. const size_t Width_;
  305. const size_t WideFieldsIndex_;
  306. };
  307. class TWideToBlocksStreamWrapper : public TMutableComputationNode<TWideToBlocksStreamWrapper>
  308. {
  309. using TBaseComputation = TMutableComputationNode<TWideToBlocksStreamWrapper>;
  310. using TState = TWideToBlocksState;
  311. public:
  312. TWideToBlocksStreamWrapper(TComputationMutables& mutables,
  313. IComputationNode* stream,
  314. TVector<TType*>&& types)
  315. : TBaseComputation(mutables, EValueRepresentation::Boxed)
  316. , Stream_(stream)
  317. , Types_(std::move(types))
  318. , MaxLength_(CalcBlockLen(std::accumulate(Types_.cbegin(), Types_.cend(), 0ULL, [](size_t max, const TType* type){ return std::max(max, CalcMaxBlockItemSize(type)); })))
  319. {}
  320. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const
  321. {
  322. const auto state = ctx.HolderFactory.Create<TState>(ctx, Types_, MaxLength_);
  323. return ctx.HolderFactory.Create<TStreamValue>(ctx.HolderFactory,
  324. std::move(state),
  325. std::move(Stream_->GetValue(ctx)),
  326. MaxLength_);
  327. }
  328. private:
  329. class TStreamValue : public TComputationValue<TStreamValue> {
  330. using TBase = TComputationValue<TStreamValue>;
  331. public:
  332. TStreamValue(TMemoryUsageInfo* memInfo, const THolderFactory& holderFactory,
  333. NUdf::TUnboxedValue&& blockState, NUdf::TUnboxedValue&& stream,
  334. const size_t maxLength)
  335. : TBase(memInfo)
  336. , BlockState_(blockState)
  337. , Stream_(stream)
  338. , MaxLength_(maxLength)
  339. , HolderFactory_(holderFactory)
  340. {}
  341. private:
  342. NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* output, ui32 width) {
  343. auto& blockState = *static_cast<TState*>(BlockState_.AsBoxed().Get());
  344. auto* inputFields = blockState.Pointer_;
  345. const size_t inputWidth = blockState.Values.size() - 1;
  346. if (!blockState.Count) {
  347. if (!blockState.IsFinished_) do {
  348. switch (Stream_.WideFetch(inputFields, inputWidth)) {
  349. case NUdf::EFetchStatus::Ok:
  350. for (size_t i = 0; i < inputWidth; i++)
  351. blockState.Add(blockState.Values[i], i);
  352. continue;
  353. case NUdf::EFetchStatus::Yield:
  354. return NUdf::EFetchStatus::Yield;
  355. case NUdf::EFetchStatus::Finish:
  356. blockState.IsFinished_ = true;
  357. break;
  358. }
  359. break;
  360. } while (++blockState.Rows_ < MaxLength_ && blockState.BuilderAllocatedSize_ <= blockState.MaxBuilderAllocatedSize_);
  361. if (blockState.Rows_)
  362. blockState.MakeBlocks(HolderFactory_);
  363. else
  364. return NUdf::EFetchStatus::Finish;
  365. }
  366. const auto sliceSize = blockState.Slice();
  367. for (size_t i = 0; i < width; i++) {
  368. output[i] = blockState.Get(sliceSize, HolderFactory_, i);
  369. }
  370. return NUdf::EFetchStatus::Ok;
  371. }
  372. NUdf::TUnboxedValue BlockState_;
  373. NUdf::TUnboxedValue Stream_;
  374. const size_t MaxLength_;
  375. const THolderFactory& HolderFactory_;
  376. };
  377. void RegisterDependencies() const final {
  378. this->DependsOn(Stream_);
  379. }
  380. IComputationNode* const Stream_;
  381. const TVector<TType*> Types_;
  382. const size_t MaxLength_;
  383. };
  384. class TFromBlocksWrapper : public TStatefulFlowCodegeneratorNode<TFromBlocksWrapper> {
  385. using TBaseComputation = TStatefulFlowCodegeneratorNode<TFromBlocksWrapper>;
  386. public:
  387. TFromBlocksWrapper(TComputationMutables& mutables, IComputationNode* flow, TType* itemType)
  388. : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
  389. , Flow_(flow)
  390. , ItemType_(itemType)
  391. {
  392. }
  393. NUdf::TUnboxedValuePod DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  394. for (auto& s = GetState(state, ctx);;) {
  395. if (auto item = s.GetValue(ctx.HolderFactory); !item.IsInvalid())
  396. return item;
  397. if (const auto input = Flow_->GetValue(ctx); input.IsSpecial())
  398. return input;
  399. else
  400. s.Reset(input);
  401. }
  402. }
  403. #ifndef MKQL_DISABLE_CODEGEN
  404. Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  405. auto& context = ctx.Codegen.GetContext();
  406. const auto valueType = Type::getInt128Ty(context);
  407. const auto statePtrType = PointerType::getUnqual(StructType::get(context));
  408. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  409. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  410. const auto read = BasicBlock::Create(context, "read", ctx.Func);
  411. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  412. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  413. BranchInst::Create(make, work, IsInvalid(statePtr, block, context), block);
  414. block = make;
  415. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  416. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  417. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TFromBlocksWrapper::MakeState));
  418. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false);
  419. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  420. CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
  421. BranchInst::Create(work, block);
  422. block = work;
  423. const auto state = new LoadInst(valueType, statePtr, "state", block);
  424. const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
  425. const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
  426. const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::GetValue));
  427. const auto getType = FunctionType::get(valueType, {statePtrType, ctx.GetFactory()->getType()}, false);
  428. const auto getPtr = CastInst::Create(Instruction::IntToPtr, getFunc, PointerType::getUnqual(getType), "get", block);
  429. const auto value = CallInst::Create(getType, getPtr, {stateArg, ctx.GetFactory() }, "value", block);
  430. const auto result = PHINode::Create(valueType, 2U, "result", done);
  431. result->addIncoming(value, block);
  432. BranchInst::Create(read, done, IsInvalid(value, block, context), block);
  433. block = read;
  434. const auto input = GetNodeValue(Flow_, ctx, block);
  435. result->addIncoming(input, block);
  436. BranchInst::Create(done, init, IsSpecial(input, block, context), block);
  437. block = init;
  438. const auto setFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Reset));
  439. const auto setType = FunctionType::get(valueType, {statePtrType, valueType}, false);
  440. const auto setPtr = CastInst::Create(Instruction::IntToPtr, setFunc, PointerType::getUnqual(setType), "set", block);
  441. CallInst::Create(setType, setPtr, {stateArg, input }, "", block);
  442. ValueCleanup(EValueRepresentation::Any, input, ctx, block);
  443. BranchInst::Create(work, block);
  444. block = done;
  445. return result;
  446. }
  447. #endif
  448. private:
  449. struct TState : public TComputationValue<TState> {
  450. using TComputationValue::TComputationValue;
  451. TState(TMemoryUsageInfo* memInfo, TType* itemType, const NUdf::IPgBuilder& pgBuilder)
  452. : TComputationValue(memInfo)
  453. , Reader_(MakeBlockReader(TTypeInfoHelper(), itemType))
  454. , Converter_(MakeBlockItemConverter(TTypeInfoHelper(), itemType, pgBuilder))
  455. {
  456. }
  457. NUdf::TUnboxedValuePod GetValue(const THolderFactory& holderFactory) {
  458. for (;;) {
  459. if (Arrays_.empty()) {
  460. return NUdf::TUnboxedValuePod::Invalid();
  461. }
  462. if (Index_ < ui64(Arrays_.front()->length)) {
  463. break;
  464. }
  465. Index_ = 0;
  466. Arrays_.pop_front();
  467. }
  468. return Converter_->MakeValue(Reader_->GetItem(*Arrays_.front(), Index_++), holderFactory);
  469. }
  470. void Reset(const NUdf::TUnboxedValuePod block) {
  471. const auto& datum = TArrowBlock::From(block).GetDatum();
  472. MKQL_ENSURE(datum.is_arraylike(), "Expecting array as FromBlocks argument");
  473. MKQL_ENSURE(Arrays_.empty(), "Not all input is processed");
  474. if (datum.is_array()) {
  475. Arrays_.push_back(datum.array());
  476. } else {
  477. for (const auto& chunk : datum.chunks()) {
  478. Arrays_.push_back(chunk->data());
  479. }
  480. }
  481. Index_ = 0;
  482. }
  483. private:
  484. const std::unique_ptr<IBlockReader> Reader_;
  485. const std::unique_ptr<IBlockItemConverter> Converter_;
  486. TDeque<std::shared_ptr<arrow::ArrayData>> Arrays_;
  487. size_t Index_ = 0;
  488. };
  489. private:
  490. void RegisterDependencies() const final {
  491. FlowDependsOn(Flow_);
  492. }
  493. void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
  494. state = ctx.HolderFactory.Create<TState>(ItemType_, ctx.Builder->GetPgBuilder());
  495. }
  496. TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  497. if (state.IsInvalid())
  498. MakeState(ctx, state);
  499. return *static_cast<TState*>(state.AsBoxed().Get());
  500. }
  501. private:
  502. IComputationNode* const Flow_;
  503. TType* ItemType_;
  504. };
  505. struct TWideFromBlocksState : public TComputationValue<TWideFromBlocksState> {
  506. size_t Count_ = 0;
  507. size_t Index_ = 0;
  508. size_t Current_ = 0;
  509. NUdf::TUnboxedValue* Pointer_ = nullptr;
  510. TUnboxedValueVector Values_;
  511. std::vector<std::unique_ptr<IBlockReader>> Readers_;
  512. std::vector<std::unique_ptr<IBlockItemConverter>> Converters_;
  513. const std::vector<arrow::ValueDescr> ValuesDescr_;
  514. TWideFromBlocksState(TMemoryUsageInfo* memInfo, TComputationContext& ctx, const TVector<TType*>& types)
  515. : TComputationValue(memInfo)
  516. , Values_(types.size() + 1)
  517. , ValuesDescr_(ToValueDescr(types))
  518. {
  519. Pointer_ = Values_.data();
  520. const auto& pgBuilder = ctx.Builder->GetPgBuilder();
  521. for (size_t i = 0; i < types.size(); ++i) {
  522. const TType* blockItemType = AS_TYPE(TBlockType, types[i])->GetItemType();
  523. Readers_.push_back(MakeBlockReader(TTypeInfoHelper(), blockItemType));
  524. Converters_.push_back(MakeBlockItemConverter(TTypeInfoHelper(), blockItemType, pgBuilder));
  525. }
  526. }
  527. void ClearValues() {
  528. Values_.assign(Values_.size(), NUdf::TUnboxedValuePod());
  529. }
  530. NUdf::TUnboxedValuePod Get(const THolderFactory& holderFactory, size_t idx) const {
  531. TBlockItem item;
  532. const auto& datum = TArrowBlock::From(Values_[idx]).GetDatum();
  533. ARROW_DEBUG_CHECK_DATUM_TYPES(ValuesDescr_[idx], datum.descr());
  534. if (datum.is_scalar()) {
  535. item = Readers_[idx]->GetScalarItem(*datum.scalar());
  536. } else {
  537. MKQL_ENSURE(datum.is_array(), "Expecting array");
  538. item = Readers_[idx]->GetItem(*datum.array(), Current_);
  539. }
  540. return Converters_[idx]->MakeValue(item, holderFactory);
  541. }
  542. };
  543. class TWideFromBlocksFlowWrapper : public TStatefulWideFlowCodegeneratorNode<TWideFromBlocksFlowWrapper> {
  544. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideFromBlocksFlowWrapper>;
  545. using TState = TWideFromBlocksState;
  546. public:
  547. TWideFromBlocksFlowWrapper(TComputationMutables& mutables,
  548. IComputationWideFlowNode* flow,
  549. TVector<TType*>&& types)
  550. : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
  551. , Flow_(flow)
  552. , Types_(std::move(types))
  553. , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(Types_.size() + 1U))
  554. {}
  555. EFetchResult DoCalculate(NUdf::TUnboxedValue& state,
  556. TComputationContext& ctx,
  557. NUdf::TUnboxedValue*const* output) const
  558. {
  559. auto& s = GetState(state, ctx);
  560. const auto fields = ctx.WideFields.data() + WideFieldsIndex_;
  561. if (s.Index_ == s.Count_) do {
  562. if (const auto result = Flow_->FetchValues(ctx, fields); result != EFetchResult::One)
  563. return result;
  564. s.Index_ = 0;
  565. s.Count_ = GetBlockCount(s.Values_.back());
  566. } while (!s.Count_);
  567. s.Current_ = s.Index_;
  568. ++s.Index_;
  569. for (size_t i = 0; i < Types_.size(); ++i)
  570. if (const auto out = output[i])
  571. *out = s.Get(ctx.HolderFactory, i);
  572. return EFetchResult::One;
  573. }
  574. #ifndef MKQL_DISABLE_CODEGEN
  575. ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  576. auto& context = ctx.Codegen.GetContext();
  577. const auto width = Types_.size();
  578. const auto valueType = Type::getInt128Ty(context);
  579. const auto statusType = Type::getInt32Ty(context);
  580. const auto indexType = Type::getInt64Ty(context);
  581. const auto arrayType = ArrayType::get(valueType, width);
  582. const auto ptrValuesType = PointerType::getUnqual(ArrayType::get(valueType, width));
  583. TLLVMFieldsStructureState stateFields(context, width);
  584. const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
  585. const auto statePtrType = PointerType::getUnqual(stateType);
  586. const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Get));
  587. const auto getType = FunctionType::get(valueType, {statePtrType, ctx.GetFactory()->getType(), indexType}, false);
  588. const auto getPtr = CastInst::Create(Instruction::IntToPtr, getFunc, PointerType::getUnqual(getType), "get", &ctx.Func->getEntryBlock().back());
  589. const auto stateOnStack = new AllocaInst(statePtrType, 0U, "state_on_stack", &ctx.Func->getEntryBlock().back());
  590. new StoreInst(ConstantPointerNull::get(statePtrType), stateOnStack, &ctx.Func->getEntryBlock().back());
  591. const auto name = "GetBlockCount";
  592. ctx.Codegen.AddGlobalMapping(name, reinterpret_cast<const void*>(&GetBlockCount));
  593. const auto getCountType = FunctionType::get(indexType, { valueType }, false);
  594. const auto getCount = ctx.Codegen.GetModule().getOrInsertFunction(name, getCountType);
  595. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  596. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  597. const auto more = BasicBlock::Create(context, "more", ctx.Func);
  598. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  599. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  600. const auto over = BasicBlock::Create(context, "over", ctx.Func);
  601. BranchInst::Create(make, main, IsInvalid(statePtr, block, context), block);
  602. block = make;
  603. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  604. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  605. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TWideFromBlocksFlowWrapper::MakeState));
  606. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false);
  607. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  608. CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
  609. BranchInst::Create(main, block);
  610. block = main;
  611. const auto state = new LoadInst(valueType, statePtr, "state", block);
  612. const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
  613. const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
  614. const auto countPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetCount() }, "count_ptr", block);
  615. const auto indexPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetIndex() }, "index_ptr", block);
  616. const auto count = new LoadInst(indexType, countPtr, "count", block);
  617. const auto index = new LoadInst(indexType, indexPtr, "index", block);
  618. const auto next = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, count, index, "next", block);
  619. BranchInst::Create(more, work, next, block);
  620. block = more;
  621. const auto clearFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::ClearValues));
  622. const auto clearType = FunctionType::get(Type::getVoidTy(context), {statePtrType}, false);
  623. const auto clearPtr = CastInst::Create(Instruction::IntToPtr, clearFunc, PointerType::getUnqual(clearType), "clear", block);
  624. CallInst::Create(clearType, clearPtr, {stateArg}, "", block);
  625. const auto getres = GetNodeValues(Flow_, ctx, block);
  626. const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, getres.first, ConstantInt::get(getres.first->getType(), static_cast<i32>(EFetchResult::Yield)), "special", block);
  627. const auto result = PHINode::Create(statusType, 2U, "result", over);
  628. result->addIncoming(getres.first, block);
  629. BranchInst::Create(over, good, special, block);
  630. block = good;
  631. const auto countValue = getres.second.back()(ctx, block);
  632. const auto height = CallInst::Create(getCount, { countValue }, "height", block);
  633. ValueCleanup(EValueRepresentation::Any, countValue, ctx, block);
  634. new StoreInst(height, countPtr, block);
  635. new StoreInst(ConstantInt::get(indexType, 0), indexPtr, block);
  636. const auto empty = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, ConstantInt::get(indexType, 0), height, "empty", block);
  637. BranchInst::Create(more, work, empty, block);
  638. block = work;
  639. const auto current = new LoadInst(indexType, indexPtr, "current", block);
  640. const auto currentPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetCurrent() }, "current_ptr", block);
  641. new StoreInst(current, currentPtr, block);
  642. const auto increment = BinaryOperator::CreateAdd(current, ConstantInt::get(indexType, 1), "increment", block);
  643. new StoreInst(increment, indexPtr, block);
  644. new StoreInst(stateArg, stateOnStack, block);
  645. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), block);
  646. BranchInst::Create(over, block);
  647. block = over;
  648. ICodegeneratorInlineWideNode::TGettersList getters(width);
  649. for (size_t idx = 0U; idx < getters.size(); ++idx) {
  650. getters[idx] = [idx, width, getType, getPtr, indexType, arrayType, ptrValuesType, stateType, statePtrType, stateOnStack, getBlocks = getres.second](const TCodegenContext& ctx, BasicBlock*& block) {
  651. auto& context = ctx.Codegen.GetContext();
  652. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  653. const auto call = BasicBlock::Create(context, "call", ctx.Func);
  654. TLLVMFieldsStructureState stateFields(context, width);
  655. const auto stateArg = new LoadInst(statePtrType, stateOnStack, "state", block);
  656. const auto valuesPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetPointer() }, "values_ptr", block);
  657. const auto values = new LoadInst(ptrValuesType, valuesPtr, "values", block);
  658. const auto index = ConstantInt::get(indexType, idx);
  659. const auto pointer = GetElementPtrInst::CreateInBounds(arrayType, values, { ConstantInt::get(indexType, 0), index }, "pointer", block);
  660. BranchInst::Create(call, init, HasValue(pointer, block, context), block);
  661. block = init;
  662. const auto value = getBlocks[idx](ctx, block);
  663. new StoreInst(value, pointer, block);
  664. AddRefBoxed(value, ctx, block);
  665. BranchInst::Create(call, block);
  666. block = call;
  667. return CallInst::Create(getType, getPtr, {stateArg, ctx.GetFactory(), index}, "get", block);
  668. };
  669. }
  670. return {result, std::move(getters)};
  671. }
  672. #endif
  673. private:
  674. #ifndef MKQL_DISABLE_CODEGEN
  675. class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TState>> {
  676. private:
  677. using TBase = TLLVMFieldsStructure<TComputationValue<TState>>;
  678. llvm::IntegerType*const CountType;
  679. llvm::IntegerType*const IndexType;
  680. llvm::IntegerType*const CurrentType;
  681. llvm::PointerType*const PointerType;
  682. protected:
  683. using TBase::Context;
  684. public:
  685. std::vector<llvm::Type*> GetFieldsArray() {
  686. std::vector<llvm::Type*> result = TBase::GetFields();
  687. result.emplace_back(CountType);
  688. result.emplace_back(IndexType);
  689. result.emplace_back(CurrentType);
  690. result.emplace_back(PointerType);
  691. return result;
  692. }
  693. llvm::Constant* GetCount() {
  694. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 0);
  695. }
  696. llvm::Constant* GetIndex() {
  697. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 1);
  698. }
  699. llvm::Constant* GetCurrent() {
  700. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 2);
  701. }
  702. llvm::Constant* GetPointer() {
  703. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 3);
  704. }
  705. TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width)
  706. : TBase(context)
  707. , CountType(Type::getInt64Ty(Context))
  708. , IndexType(Type::getInt64Ty(Context))
  709. , CurrentType(Type::getInt64Ty(Context))
  710. , PointerType(PointerType::getUnqual(ArrayType::get(Type::getInt128Ty(Context), width)))
  711. {}
  712. };
  713. #endif
  714. void RegisterDependencies() const final {
  715. FlowDependsOn(Flow_);
  716. }
  717. void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
  718. state = ctx.HolderFactory.Create<TState>(ctx, Types_);
  719. }
  720. TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  721. if (state.IsInvalid()) {
  722. MakeState(ctx, state);
  723. const auto s = static_cast<TState*>(state.AsBoxed().Get());
  724. auto**const fields = ctx.WideFields.data() + WideFieldsIndex_;
  725. for (size_t i = 0; i <= Types_.size(); ++i) {
  726. fields[i] = &s->Values_[i];
  727. }
  728. return *s;
  729. }
  730. return *static_cast<TState*>(state.AsBoxed().Get());
  731. }
  732. IComputationWideFlowNode* const Flow_;
  733. const TVector<TType*> Types_;
  734. const size_t WideFieldsIndex_;
  735. };
  736. class TWideFromBlocksStreamWrapper : public TMutableComputationNode<TWideFromBlocksStreamWrapper>
  737. {
  738. using TBaseComputation = TMutableComputationNode<TWideFromBlocksStreamWrapper>;
  739. using TState = TWideFromBlocksState;
  740. public:
  741. TWideFromBlocksStreamWrapper(TComputationMutables& mutables,
  742. IComputationNode* stream,
  743. TVector<TType*>&& types)
  744. : TBaseComputation(mutables, EValueRepresentation::Boxed)
  745. , Stream_(stream)
  746. , Types_(std::move(types))
  747. {}
  748. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const
  749. {
  750. const auto state = ctx.HolderFactory.Create<TState>(ctx, Types_);
  751. return ctx.HolderFactory.Create<TStreamValue>(ctx.HolderFactory,
  752. std::move(state),
  753. std::move(Stream_->GetValue(ctx)));
  754. }
  755. private:
  756. class TStreamValue : public TComputationValue<TStreamValue> {
  757. using TBase = TComputationValue<TStreamValue>;
  758. public:
  759. TStreamValue(TMemoryUsageInfo* memInfo, const THolderFactory& holderFactory,
  760. NUdf::TUnboxedValue&& blockState, NUdf::TUnboxedValue&& stream)
  761. : TBase(memInfo)
  762. , BlockState_(blockState)
  763. , Stream_(stream)
  764. , HolderFactory_(holderFactory)
  765. {}
  766. private:
  767. NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* output, ui32 width) {
  768. auto& blockState = *static_cast<TState*>(BlockState_.AsBoxed().Get());
  769. auto* inputFields = blockState.Pointer_;
  770. const size_t inputWidth = blockState.Values_.size();
  771. if (blockState.Index_ == blockState.Count_) do {
  772. if (const auto result = Stream_.WideFetch(inputFields, inputWidth); result != NUdf::EFetchStatus::Ok)
  773. return result;
  774. blockState.Index_ = 0;
  775. blockState.Count_ = GetBlockCount(blockState.Values_.back());
  776. } while (!blockState.Count_);
  777. blockState.Current_ = blockState.Index_++;
  778. for (size_t i = 0; i < width; i++) {
  779. output[i] = blockState.Get(HolderFactory_, i);
  780. }
  781. return NUdf::EFetchStatus::Ok;
  782. }
  783. NUdf::TUnboxedValue BlockState_;
  784. NUdf::TUnboxedValue Stream_;
  785. const THolderFactory& HolderFactory_;
  786. };
  787. void RegisterDependencies() const final {
  788. this->DependsOn(Stream_);
  789. }
  790. IComputationNode* const Stream_;
  791. const TVector<TType*> Types_;
  792. };
  793. class TPrecomputedArrowNode : public IArrowKernelComputationNode {
  794. public:
  795. TPrecomputedArrowNode(const arrow::Datum& datum, TStringBuf kernelName)
  796. : Kernel_({}, datum.type(), [datum](arrow::compute::KernelContext*, const arrow::compute::ExecBatch&, arrow::Datum* res) {
  797. *res = datum;
  798. return arrow::Status::OK();
  799. })
  800. , KernelName_(kernelName)
  801. {
  802. Kernel_.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE;
  803. Kernel_.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE;
  804. }
  805. TStringBuf GetKernelName() const final {
  806. return KernelName_;
  807. }
  808. const arrow::compute::ScalarKernel& GetArrowKernel() const {
  809. return Kernel_;
  810. }
  811. const std::vector<arrow::ValueDescr>& GetArgsDesc() const {
  812. return EmptyDesc_;
  813. }
  814. const IComputationNode* GetArgument(ui32 index) const {
  815. Y_UNUSED(index);
  816. ythrow yexception() << "No input arguments";
  817. }
  818. private:
  819. arrow::compute::ScalarKernel Kernel_;
  820. const TStringBuf KernelName_;
  821. const std::vector<arrow::ValueDescr> EmptyDesc_;
  822. };
  823. class TAsScalarWrapper : public TMutableCodegeneratorNode<TAsScalarWrapper> {
  824. using TBaseComputation = TMutableCodegeneratorNode<TAsScalarWrapper>;
  825. public:
  826. TAsScalarWrapper(TComputationMutables& mutables, IComputationNode* arg, TType* type)
  827. : TBaseComputation(mutables, EValueRepresentation::Boxed)
  828. , Arg_(arg)
  829. , Type_(type)
  830. {
  831. std::shared_ptr<arrow::DataType> arrowType;
  832. MKQL_ENSURE(ConvertArrowType(Type_, arrowType), "Unsupported type of scalar");
  833. }
  834. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  835. return AsScalar(Arg_->GetValue(ctx).Release(), ctx);
  836. }
  837. #ifndef MKQL_DISABLE_CODEGEN
  838. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  839. auto& context = ctx.Codegen.GetContext();
  840. const auto value = GetNodeValue(Arg_, ctx, block);
  841. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  842. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  843. const auto asScalarFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TAsScalarWrapper::AsScalar));
  844. const auto asScalarType = FunctionType::get(Type::getInt128Ty(context), {self->getType(), value->getType(), ctx.Ctx->getType()}, false);
  845. const auto asScalarFuncPtr = CastInst::Create(Instruction::IntToPtr, asScalarFunc, PointerType::getUnqual(asScalarType), "function", block);
  846. return CallInst::Create(asScalarType, asScalarFuncPtr, {self, value, ctx.Ctx}, "scalar", block);
  847. }
  848. #endif
  849. private:
  850. std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const final {
  851. return std::make_unique<TPrecomputedArrowNode>(DoAsScalar(Arg_->GetValue(ctx).Release(), ctx), "AsScalar");
  852. }
  853. arrow::Datum DoAsScalar(const NUdf::TUnboxedValuePod value, TComputationContext& ctx) const {
  854. const NUdf::TUnboxedValue v(value);
  855. return ConvertScalar(Type_, v, ctx.ArrowMemoryPool);
  856. }
  857. NUdf::TUnboxedValuePod AsScalar(const NUdf::TUnboxedValuePod value, TComputationContext& ctx) const {
  858. return ctx.HolderFactory.CreateArrowBlock(DoAsScalar(value, ctx));
  859. }
  860. void RegisterDependencies() const final {
  861. DependsOn(Arg_);
  862. }
  863. IComputationNode* const Arg_;
  864. TType* Type_;
  865. };
  866. class TReplicateScalarWrapper : public TMutableCodegeneratorNode<TReplicateScalarWrapper> {
  867. using TBaseComputation = TMutableCodegeneratorNode<TReplicateScalarWrapper>;
  868. public:
  869. TReplicateScalarWrapper(TComputationMutables& mutables, IComputationNode* value, IComputationNode* count, TType* type)
  870. : TBaseComputation(mutables, EValueRepresentation::Boxed)
  871. , Value_(value)
  872. , Count_(count)
  873. , Type_(type)
  874. {
  875. std::shared_ptr<arrow::DataType> arrowType;
  876. MKQL_ENSURE(ConvertArrowType(Type_, arrowType), "Unsupported type of scalar");
  877. }
  878. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  879. const auto value = Value_->GetValue(ctx).Release();
  880. const auto count = Count_->GetValue(ctx).Release();
  881. return Replicate(value, count, ctx);
  882. }
  883. #ifndef MKQL_DISABLE_CODEGEN
  884. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  885. auto& context = ctx.Codegen.GetContext();
  886. const auto value = GetNodeValue(Value_, ctx, block);
  887. const auto count = GetNodeValue(Count_, ctx, block);
  888. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  889. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  890. const auto replicateFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TReplicateScalarWrapper::Replicate));
  891. const auto replicateType = FunctionType::get(Type::getInt128Ty(context), {self->getType(), value->getType(), count->getType(), ctx.Ctx->getType()}, false);
  892. const auto replicateFuncPtr = CastInst::Create(Instruction::IntToPtr, replicateFunc, PointerType::getUnqual(replicateType), "function", block);
  893. return CallInst::Create(replicateType, replicateFuncPtr, {self, value, count, ctx.Ctx}, "replicate", block);
  894. }
  895. #endif
  896. private:
  897. std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const final {
  898. const auto value = Value_->GetValue(ctx).Release();
  899. const auto count = Count_->GetValue(ctx).Release();
  900. return std::make_unique<TPrecomputedArrowNode>(DoReplicate(value, count, ctx), "ReplicateScalar");
  901. }
  902. arrow::Datum DoReplicate(const NUdf::TUnboxedValuePod val, const NUdf::TUnboxedValuePod cnt, TComputationContext& ctx) const {
  903. const auto value = TArrowBlock::From(val).GetDatum().scalar();
  904. const ui64 count = TArrowBlock::From(cnt).GetDatum().scalar_as<arrow::UInt64Scalar>().value;
  905. const auto reader = MakeBlockReader(TTypeInfoHelper(), Type_);
  906. const auto builder = MakeArrayBuilder(TTypeInfoHelper(), Type_, ctx.ArrowMemoryPool, count, &ctx.Builder->GetPgBuilder());
  907. TBlockItem item = reader->GetScalarItem(*value);
  908. builder->Add(item, count);
  909. return builder->Build(true);
  910. }
  911. NUdf::TUnboxedValuePod Replicate(const NUdf::TUnboxedValuePod value, const NUdf::TUnboxedValuePod count, TComputationContext& ctx) const {
  912. return ctx.HolderFactory.CreateArrowBlock(DoReplicate(value, count, ctx));
  913. }
  914. void RegisterDependencies() const final {
  915. DependsOn(Value_);
  916. DependsOn(Count_);
  917. }
  918. IComputationNode* const Value_;
  919. IComputationNode* const Count_;
  920. TType* Type_;
  921. };
  922. class TBlockExpandChunkedWrapper : public TStatefulWideFlowCodegeneratorNode<TBlockExpandChunkedWrapper> {
  923. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TBlockExpandChunkedWrapper>;
  924. public:
  925. TBlockExpandChunkedWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, size_t width)
  926. : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
  927. , Flow_(flow)
  928. , Width_(width)
  929. , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(Width_))
  930. {
  931. }
  932. EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  933. auto& s = GetState(state, ctx);
  934. if (!s.Count) {
  935. const auto fields = ctx.WideFields.data() + WideFieldsIndex_;
  936. s.ClearValues();
  937. if (const auto result = Flow_->FetchValues(ctx, fields); result != EFetchResult::One)
  938. return result;
  939. s.FillArrays();
  940. }
  941. const auto sliceSize = s.Slice();
  942. for (size_t i = 0; i < Width_; ++i) {
  943. if (const auto out = output[i]) {
  944. *out = s.Get(sliceSize, ctx.HolderFactory, i);
  945. }
  946. }
  947. return EFetchResult::One;
  948. }
  949. #ifndef MKQL_DISABLE_CODEGEN
  950. ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  951. auto& context = ctx.Codegen.GetContext();
  952. const auto valueType = Type::getInt128Ty(context);
  953. const auto statusType = Type::getInt32Ty(context);
  954. const auto indexType = Type::getInt64Ty(context);
  955. const auto arrayType = ArrayType::get(valueType, Width_);
  956. const auto ptrValuesType = PointerType::getUnqual(arrayType);
  957. TLLVMFieldsStructureBlockState stateFields(context, Width_);
  958. const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
  959. const auto statePtrType = PointerType::getUnqual(stateType);
  960. const auto atTop = &ctx.Func->getEntryBlock().back();
  961. const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TBlockState::Get));
  962. const auto getType = FunctionType::get(valueType, {statePtrType, indexType, ctx.GetFactory()->getType(), indexType}, false);
  963. const auto getPtr = CastInst::Create(Instruction::IntToPtr, getFunc, PointerType::getUnqual(getType), "get", atTop);
  964. const auto heightPtr = new AllocaInst(indexType, 0U, "height_ptr", atTop);
  965. const auto stateOnStack = new AllocaInst(statePtrType, 0U, "state_on_stack", atTop);
  966. new StoreInst(ConstantInt::get(indexType, 0), heightPtr, atTop);
  967. new StoreInst(ConstantPointerNull::get(statePtrType), stateOnStack, atTop);
  968. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  969. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  970. const auto read = BasicBlock::Create(context, "read", ctx.Func);
  971. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  972. const auto fill = BasicBlock::Create(context, "fill", ctx.Func);
  973. const auto over = BasicBlock::Create(context, "over", ctx.Func);
  974. BranchInst::Create(make, main, IsInvalid(statePtr, block, context), block);
  975. block = make;
  976. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  977. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  978. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TBlockExpandChunkedWrapper::MakeState));
  979. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false);
  980. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  981. CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
  982. BranchInst::Create(main, block);
  983. block = main;
  984. const auto state = new LoadInst(valueType, statePtr, "state", block);
  985. const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
  986. const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
  987. const auto countPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetCount() }, "count_ptr", block);
  988. const auto count = new LoadInst(indexType, countPtr, "count", block);
  989. const auto next = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, count, ConstantInt::get(indexType, 0), "next", block);
  990. BranchInst::Create(read, fill, next, block);
  991. block = read;
  992. const auto clearFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TBlockState::ClearValues));
  993. const auto clearType = FunctionType::get(Type::getVoidTy(context), {statePtrType}, false);
  994. const auto clearPtr = CastInst::Create(Instruction::IntToPtr, clearFunc, PointerType::getUnqual(clearType), "clear", block);
  995. CallInst::Create(clearType, clearPtr, {stateArg}, "", block);
  996. const auto getres = GetNodeValues(Flow_, ctx, block);
  997. const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, getres.first, ConstantInt::get(getres.first->getType(), static_cast<i32>(EFetchResult::Yield)), "special", block);
  998. const auto result = PHINode::Create(statusType, 2U, "result", over);
  999. result->addIncoming(getres.first, block);
  1000. BranchInst::Create(over, work, special, block);
  1001. block = work;
  1002. const auto valuesPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetPointer() }, "values_ptr", block);
  1003. const auto values = new LoadInst(ptrValuesType, valuesPtr, "values", block);
  1004. Value* array = UndefValue::get(arrayType);
  1005. for (auto idx = 0U; idx < getres.second.size(); ++idx) {
  1006. const auto value = getres.second[idx](ctx, block);
  1007. AddRefBoxed(value, ctx, block);
  1008. array = InsertValueInst::Create(array, value, {idx}, (TString("value_") += ToString(idx)).c_str(), block);
  1009. }
  1010. new StoreInst(array, values, block);
  1011. const auto fillArraysFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TBlockState::FillArrays));
  1012. const auto fillArraysType = FunctionType::get(Type::getVoidTy(context), {statePtrType}, false);
  1013. const auto fillArraysPtr = CastInst::Create(Instruction::IntToPtr, fillArraysFunc, PointerType::getUnqual(fillArraysType), "fill_arrays_func", block);
  1014. CallInst::Create(fillArraysType, fillArraysPtr, {stateArg}, "", block);
  1015. BranchInst::Create(fill, block);
  1016. block = fill;
  1017. const auto sliceFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TBlockState::Slice));
  1018. const auto sliceType = FunctionType::get(indexType, {statePtrType}, false);
  1019. const auto slicePtr = CastInst::Create(Instruction::IntToPtr, sliceFunc, PointerType::getUnqual(sliceType), "slice_func", block);
  1020. const auto slice = CallInst::Create(sliceType, slicePtr, {stateArg}, "slice", block);
  1021. new StoreInst(slice, heightPtr, block);
  1022. new StoreInst(stateArg, stateOnStack, block);
  1023. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), block);
  1024. BranchInst::Create(over, block);
  1025. block = over;
  1026. ICodegeneratorInlineWideNode::TGettersList getters(Width_);
  1027. for (size_t idx = 0U; idx < getters.size(); ++idx) {
  1028. getters[idx] = [idx, getType, getPtr, heightPtr, indexType, statePtrType, stateOnStack](const TCodegenContext& ctx, BasicBlock*& block) {
  1029. const auto stateArg = new LoadInst(statePtrType, stateOnStack, "state", block);
  1030. const auto heightArg = new LoadInst(indexType, heightPtr, "height", block);
  1031. return CallInst::Create(getType, getPtr, {stateArg, heightArg, ctx.GetFactory(), ConstantInt::get(indexType, idx)}, "get", block);
  1032. };
  1033. }
  1034. return {result, std::move(getters)};
  1035. }
  1036. #endif
  1037. private:
  1038. void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
  1039. state = ctx.HolderFactory.Create<TBlockState>(Width_);
  1040. }
  1041. TBlockState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  1042. if (state.IsInvalid()) {
  1043. MakeState(ctx, state);
  1044. auto& s = *static_cast<TBlockState*>(state.AsBoxed().Get());
  1045. const auto fields = ctx.WideFields.data() + WideFieldsIndex_;
  1046. for (size_t i = 0; i < Width_; ++i)
  1047. fields[i] = &s.Values[i];
  1048. return s;
  1049. }
  1050. return *static_cast<TBlockState*>(state.AsBoxed().Get());
  1051. }
  1052. void RegisterDependencies() const final {
  1053. FlowDependsOn(Flow_);
  1054. }
  1055. IComputationWideFlowNode* const Flow_;
  1056. const size_t Width_;
  1057. const size_t WideFieldsIndex_;
  1058. };
  1059. class TBlockExpandChunkedStreamWrapper : public TMutableComputationNode<TBlockExpandChunkedStreamWrapper> {
  1060. using TBaseComputation = TMutableComputationNode<TBlockExpandChunkedStreamWrapper>;
  1061. class TExpanderState : public TComputationValue<TExpanderState> {
  1062. using TBase = TComputationValue<TExpanderState>;
  1063. public:
  1064. TExpanderState(TMemoryUsageInfo* memInfo, TComputationContext& ctx, NUdf::TUnboxedValue&& stream, size_t width)
  1065. : TBase(memInfo), HolderFactory_(ctx.HolderFactory), State_(ctx.HolderFactory.Create<TBlockState>(width)), Stream_(stream) {}
  1066. NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* output, ui32 width) {
  1067. auto& s = *static_cast<TBlockState*>(State_.AsBoxed().Get());
  1068. if (!s.Count) {
  1069. s.ClearValues();
  1070. auto result = Stream_.WideFetch(s.Values.data(), width);
  1071. if (NUdf::EFetchStatus::Ok != result) {
  1072. return result;
  1073. }
  1074. s.FillArrays();
  1075. }
  1076. const auto sliceSize = s.Slice();
  1077. for (size_t i = 0; i < width; ++i) {
  1078. output[i] = s.Get(sliceSize, HolderFactory_, i);
  1079. }
  1080. return NUdf::EFetchStatus::Ok;
  1081. }
  1082. private:
  1083. const THolderFactory& HolderFactory_;
  1084. NUdf::TUnboxedValue State_;
  1085. NUdf::TUnboxedValue Stream_;
  1086. };
  1087. public:
  1088. TBlockExpandChunkedStreamWrapper(TComputationMutables& mutables, IComputationNode* stream, size_t width)
  1089. : TBaseComputation(mutables, EValueRepresentation::Boxed)
  1090. , Stream_(stream)
  1091. , Width_(width) {}
  1092. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  1093. return ctx.HolderFactory.Create<TExpanderState>(ctx, std::move(Stream_->GetValue(ctx)), Width_);
  1094. }
  1095. void RegisterDependencies() const override {}
  1096. private:
  1097. IComputationNode* const Stream_;
  1098. const size_t Width_;
  1099. };
  1100. } // namespace
  1101. IComputationNode* WrapToBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1102. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args, got " << callable.GetInputsCount());
  1103. const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType());
  1104. return new TToBlocksWrapper(LocateNode(ctx.NodeLocator, callable, 0), flowType->GetItemType());
  1105. }
  1106. IComputationNode* WrapWideToBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1107. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args, got " << callable.GetInputsCount());
  1108. const auto inputType = callable.GetInput(0).GetStaticType();
  1109. MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(),
  1110. "Expected either WideStream or WideFlow as an input");
  1111. const auto yieldsStream = callable.GetType()->GetReturnType()->IsStream();
  1112. MKQL_ENSURE(yieldsStream == inputType->IsStream(),
  1113. "Expected both input and output have to be either WideStream or WideFlow");
  1114. const auto wideComponents = GetWideComponents(inputType);
  1115. TVector<TType*> items(wideComponents.begin(), wideComponents.end());
  1116. const auto wideFlowOrStream = LocateNode(ctx.NodeLocator, callable, 0);
  1117. if (yieldsStream) {
  1118. const auto wideStream = wideFlowOrStream;
  1119. return new TWideToBlocksStreamWrapper(ctx.Mutables, wideStream, std::move(items));
  1120. }
  1121. // FIXME: Drop the branch below, when the time comes.
  1122. const auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(wideFlowOrStream);
  1123. MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node");
  1124. return new TWideToBlocksFlowWrapper(ctx.Mutables, wideFlow, std::move(items));
  1125. }
  1126. IComputationNode* WrapFromBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1127. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args, got " << callable.GetInputsCount());
  1128. const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType());
  1129. const auto blockType = AS_TYPE(TBlockType, flowType->GetItemType());
  1130. return new TFromBlocksWrapper(ctx.Mutables, LocateNode(ctx.NodeLocator, callable, 0), blockType->GetItemType());
  1131. }
  1132. IComputationNode* WrapWideFromBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1133. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args, got " << callable.GetInputsCount());
  1134. const auto inputType = callable.GetInput(0).GetStaticType();
  1135. MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(),
  1136. "Expected either WideStream or WideFlow as an input");
  1137. const auto yieldsStream = callable.GetType()->GetReturnType()->IsStream();
  1138. MKQL_ENSURE(yieldsStream == inputType->IsStream(),
  1139. "Expected both input and output have to be either WideStream or WideFlow");
  1140. const auto wideComponents = GetWideComponents(inputType);
  1141. MKQL_ENSURE(wideComponents.size() > 0, "Expected at least one column");
  1142. TVector<TType*> items;
  1143. for (ui32 i = 0; i < wideComponents.size() - 1; ++i) {
  1144. items.push_back(AS_TYPE(TBlockType, wideComponents[i]));
  1145. }
  1146. const auto wideFlowOrStream = LocateNode(ctx.NodeLocator, callable, 0);
  1147. if (yieldsStream) {
  1148. const auto wideStream = wideFlowOrStream;
  1149. return new TWideFromBlocksStreamWrapper(ctx.Mutables, wideStream, std::move(items));
  1150. }
  1151. // FIXME: Drop the branch below, when the time comes.
  1152. const auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(wideFlowOrStream);
  1153. MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node");
  1154. return new TWideFromBlocksFlowWrapper(ctx.Mutables, wideFlow, std::move(items));
  1155. }
  1156. IComputationNode* WrapAsScalar(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1157. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args, got " << callable.GetInputsCount());
  1158. return new TAsScalarWrapper(ctx.Mutables, LocateNode(ctx.NodeLocator, callable, 0), callable.GetInput(0).GetStaticType());
  1159. }
  1160. IComputationNode* WrapReplicateScalar(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1161. MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args, got " << callable.GetInputsCount());
  1162. const auto valueType = AS_TYPE(TBlockType, callable.GetInput(0).GetStaticType());
  1163. MKQL_ENSURE(valueType->GetShape() == TBlockType::EShape::Scalar, "Expecting scalar as first arg");
  1164. const auto value = LocateNode(ctx.NodeLocator, callable, 0);
  1165. const auto count = LocateNode(ctx.NodeLocator, callable, 1);
  1166. return new TReplicateScalarWrapper(ctx.Mutables, value, count, valueType->GetItemType());
  1167. }
  1168. IComputationNode* WrapBlockExpandChunked(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1169. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args, got " << callable.GetInputsCount());
  1170. if (callable.GetInput(0).GetStaticType()->IsStream()) {
  1171. const auto streamType = AS_TYPE(TStreamType, callable.GetInput(0).GetStaticType());
  1172. const auto wideComponents = GetWideComponents(streamType);
  1173. const auto computation = dynamic_cast<IComputationNode*>(LocateNode(ctx.NodeLocator, callable, 0));
  1174. MKQL_ENSURE(computation != nullptr, "Expected computation node");
  1175. return new TBlockExpandChunkedStreamWrapper(ctx.Mutables, computation, wideComponents.size());
  1176. } else {
  1177. const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType());
  1178. const auto wideComponents = GetWideComponents(flowType);
  1179. const auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0));
  1180. MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node");
  1181. return new TBlockExpandChunkedWrapper(ctx.Mutables, wideFlow, wideComponents.size());
  1182. }
  1183. }
  1184. }
  1185. }