mkql_blocks.cpp 62 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344
  1. #include "mkql_blocks.h"
  2. #include <yql/essentials/minikql/computation/mkql_block_reader.h>
  3. #include <yql/essentials/minikql/computation/mkql_block_builder.h>
  4. #include <yql/essentials/minikql/computation/mkql_block_impl.h>
  5. #include <yql/essentials/minikql/computation/mkql_block_impl_codegen.h> // Y_IGNORE
  6. #include <yql/essentials/minikql/arrow/arrow_defs.h>
  7. #include <yql/essentials/minikql/arrow/arrow_util.h>
  8. #include <yql/essentials/minikql/mkql_type_builder.h>
  9. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  10. #include <yql/essentials/minikql/mkql_node_builder.h>
  11. #include <yql/essentials/minikql/mkql_node_cast.h>
  12. #include <yql/essentials/parser/pg_wrapper/interface/arrow.h>
  13. #include <arrow/scalar.h>
  14. #include <arrow/array.h>
  15. #include <arrow/datum.h>
  16. namespace NKikimr {
  17. namespace NMiniKQL {
  18. namespace {
  19. class TToBlocksWrapper : public TStatelessFlowComputationNode<TToBlocksWrapper> {
  20. public:
  21. explicit TToBlocksWrapper(IComputationNode* flow, TType* itemType)
  22. : TStatelessFlowComputationNode(flow, EValueRepresentation::Boxed)
  23. , Flow_(flow)
  24. , ItemType_(itemType)
  25. {
  26. }
  27. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  28. const auto maxLen = CalcBlockLen(CalcMaxBlockItemSize(ItemType_));
  29. const auto builder = MakeArrayBuilder(TTypeInfoHelper(), ItemType_, ctx.ArrowMemoryPool, maxLen, &ctx.Builder->GetPgBuilder());
  30. for (size_t i = 0; i < builder->MaxLength(); ++i) {
  31. auto result = Flow_->GetValue(ctx);
  32. if (result.IsSpecial()) {
  33. if (i == 0) {
  34. return result.Release();
  35. }
  36. break;
  37. }
  38. builder->Add(result);
  39. }
  40. return ctx.HolderFactory.CreateArrowBlock(builder->Build(true));
  41. }
  42. private:
  43. void RegisterDependencies() const final {
  44. FlowDependsOn(Flow_);
  45. }
  46. private:
  47. IComputationNode* const Flow_;
  48. TType* ItemType_;
  49. };
  50. class TWideToBlocksWrapper : public TStatefulWideFlowCodegeneratorNode<TWideToBlocksWrapper> {
  51. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideToBlocksWrapper>;
  52. public:
  53. TWideToBlocksWrapper(TComputationMutables& mutables,
  54. IComputationWideFlowNode* flow,
  55. TVector<TType*>&& types)
  56. : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
  57. , Flow_(flow)
  58. , Types_(std::move(types))
  59. , MaxLength_(CalcBlockLen(std::accumulate(Types_.cbegin(), Types_.cend(), 0ULL, [](size_t max, const TType* type){ return std::max(max, CalcMaxBlockItemSize(type)); })))
  60. , Width_(Types_.size())
  61. , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(Width_))
  62. {
  63. }
  64. EFetchResult DoCalculate(NUdf::TUnboxedValue& state,
  65. TComputationContext& ctx,
  66. NUdf::TUnboxedValue*const* output) const {
  67. auto& s = GetState(state, ctx);
  68. const auto fields = ctx.WideFields.data() + WideFieldsIndex_;
  69. if (!s.Count) {
  70. if (!s.IsFinished_) do {
  71. switch (Flow_->FetchValues(ctx, fields)) {
  72. case EFetchResult::One:
  73. for (size_t i = 0; i < Types_.size(); ++i)
  74. s.Add(s.Values[i], i);
  75. continue;
  76. case EFetchResult::Yield:
  77. return EFetchResult::Yield;
  78. case EFetchResult::Finish:
  79. s.IsFinished_ = true;
  80. break;
  81. }
  82. break;
  83. } while (++s.Rows_ < MaxLength_ && s.BuilderAllocatedSize_ <= s.MaxBuilderAllocatedSize_);
  84. if (s.Rows_)
  85. s.MakeBlocks(ctx.HolderFactory);
  86. else
  87. return EFetchResult::Finish;
  88. }
  89. const auto sliceSize = s.Slice();
  90. for (size_t i = 0; i <= Types_.size(); ++i) {
  91. if (const auto out = output[i]) {
  92. *out = s.Get(sliceSize, ctx.HolderFactory, i);
  93. }
  94. }
  95. return EFetchResult::One;
  96. }
  97. #ifndef MKQL_DISABLE_CODEGEN
  98. ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  99. auto& context = ctx.Codegen.GetContext();
  100. const auto valueType = Type::getInt128Ty(context);
  101. const auto statusType = Type::getInt32Ty(context);
  102. const auto indexType = Type::getInt64Ty(context);
  103. TLLVMFieldsStructureState stateFields(context, Types_.size() + 1U);
  104. const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
  105. const auto statePtrType = PointerType::getUnqual(stateType);
  106. const auto atTop = &ctx.Func->getEntryBlock().back();
  107. const auto addFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Add));
  108. const auto addType = FunctionType::get(Type::getVoidTy(context), {statePtrType, valueType, indexType}, false);
  109. const auto addPtr = CastInst::Create(Instruction::IntToPtr, addFunc, PointerType::getUnqual(addType), "add", atTop);
  110. const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Get));
  111. const auto getType = FunctionType::get(valueType, {statePtrType, indexType, ctx.GetFactory()->getType(), indexType}, false);
  112. const auto getPtr = CastInst::Create(Instruction::IntToPtr, getFunc, PointerType::getUnqual(getType), "get", atTop);
  113. const auto heightPtr = new AllocaInst(indexType, 0U, "height_ptr", atTop);
  114. const auto stateOnStack = new AllocaInst(statePtrType, 0U, "state_on_stack", atTop);
  115. new StoreInst(ConstantInt::get(indexType, 0), heightPtr, atTop);
  116. new StoreInst(ConstantPointerNull::get(statePtrType), stateOnStack, atTop);
  117. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  118. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  119. const auto more = BasicBlock::Create(context, "more", ctx.Func);
  120. const auto skip = BasicBlock::Create(context, "skip", ctx.Func);
  121. const auto read = BasicBlock::Create(context, "read", ctx.Func);
  122. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  123. const auto stop = BasicBlock::Create(context, "stop", ctx.Func);
  124. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  125. const auto fill = BasicBlock::Create(context, "fill", ctx.Func);
  126. const auto over = BasicBlock::Create(context, "over", ctx.Func);
  127. const auto second_cond = BasicBlock::Create(context, "second_cond", ctx.Func);
  128. BranchInst::Create(make, main, IsInvalid(statePtr, block, context), block);
  129. block = make;
  130. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  131. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  132. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TWideToBlocksWrapper::MakeState));
  133. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false);
  134. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  135. CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
  136. BranchInst::Create(main, block);
  137. block = main;
  138. const auto state = new LoadInst(valueType, statePtr, "state", block);
  139. const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
  140. const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
  141. const auto countPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetCount() }, "count_ptr", block);
  142. const auto count = new LoadInst(indexType, countPtr, "count", block);
  143. const auto none = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, count, ConstantInt::get(indexType, 0), "none", block);
  144. BranchInst::Create(more, fill, none, block);
  145. block = more;
  146. const auto rowsPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetRows() }, "rows_ptr", block);
  147. const auto finishedPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetIsFinished() }, "is_finished_ptr", block);
  148. const auto allocatedSizePtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetBuilderAllocatedSize() }, "allocated_size_ptr", block);
  149. const auto maxAllocatedSizePtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetMaxBuilderAllocatedSize() }, "max_allocated_size_ptr", block);
  150. const auto finished = new LoadInst(Type::getInt1Ty(context), finishedPtr, "finished", block);
  151. BranchInst::Create(skip, read, finished, block);
  152. block = read;
  153. const auto getres = GetNodeValues(Flow_, ctx, block);
  154. const auto way = SwitchInst::Create(getres.first, good, 2U, block);
  155. way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Finish)), stop);
  156. way->addCase(ConstantInt::get(statusType, i32(EFetchResult::Yield)), over);
  157. const auto result = PHINode::Create(statusType, 3U, "result", over);
  158. result->addIncoming(getres.first, block);
  159. block = good;
  160. const auto read_rows = new LoadInst(indexType, rowsPtr, "read_rows", block);
  161. const auto increment = BinaryOperator::CreateAdd(read_rows, ConstantInt::get(indexType, 1), "increment", block);
  162. new StoreInst(increment, rowsPtr, block);
  163. for (size_t idx = 0U; idx < Types_.size(); ++idx) {
  164. const auto value = getres.second[idx](ctx, block);
  165. CallInst::Create(addType, addPtr, {stateArg, value, ConstantInt::get(indexType, idx)}, "", block);
  166. ValueCleanup(GetValueRepresentation(Types_[idx]), value, ctx, block);
  167. }
  168. const auto next = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_ULT, increment, ConstantInt::get(indexType, MaxLength_), "next", block);
  169. BranchInst::Create(second_cond, work, next, block);
  170. block = second_cond;
  171. const auto read_allocated_size = new LoadInst(indexType, allocatedSizePtr, "read_allocated_size", block);
  172. const auto read_max_allocated_size = new LoadInst(indexType, maxAllocatedSizePtr, "read_max_allocated_size", block);
  173. const auto next2 = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_ULE, read_allocated_size, read_max_allocated_size, "next2", block);
  174. BranchInst::Create(read, work, next2, block);
  175. block = stop;
  176. new StoreInst(ConstantInt::getTrue(context), finishedPtr, block);
  177. BranchInst::Create(skip, block);
  178. block = skip;
  179. const auto rows = new LoadInst(indexType, rowsPtr, "rows", block);
  180. const auto empty = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, rows, ConstantInt::get(indexType, 0), "empty", block);
  181. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block);
  182. BranchInst::Create(over, work, empty, block);
  183. block = work;
  184. const auto makeBlockFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::MakeBlocks));
  185. const auto makeBlockType = FunctionType::get(indexType, {statePtrType, ctx.GetFactory()->getType()}, false);
  186. const auto makeBlockPtr = CastInst::Create(Instruction::IntToPtr, makeBlockFunc, PointerType::getUnqual(makeBlockType), "make_blocks_func", block);
  187. CallInst::Create(makeBlockType, makeBlockPtr, {stateArg, ctx.GetFactory()}, "", block);
  188. BranchInst::Create(fill, block);
  189. block = fill;
  190. const auto sliceFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Slice));
  191. const auto sliceType = FunctionType::get(indexType, {statePtrType}, false);
  192. const auto slicePtr = CastInst::Create(Instruction::IntToPtr, sliceFunc, PointerType::getUnqual(sliceType), "slice_func", block);
  193. const auto slice = CallInst::Create(sliceType, slicePtr, {stateArg}, "slice", block);
  194. new StoreInst(slice, heightPtr, block);
  195. new StoreInst(stateArg, stateOnStack, block);
  196. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), block);
  197. BranchInst::Create(over, block);
  198. block = over;
  199. ICodegeneratorInlineWideNode::TGettersList getters(Types_.size() + 1U);
  200. for (size_t idx = 0U; idx < getters.size(); ++idx) {
  201. getters[idx] = [idx, getType, getPtr, heightPtr, indexType, statePtrType, stateOnStack](const TCodegenContext& ctx, BasicBlock*& block) {
  202. const auto stateArg = new LoadInst(statePtrType, stateOnStack, "state", block);
  203. const auto heightArg = new LoadInst(indexType, heightPtr, "height", block);
  204. return CallInst::Create(getType, getPtr, {stateArg, heightArg, ctx.GetFactory(), ConstantInt::get(indexType, idx)}, "get", block);
  205. };
  206. }
  207. return {result, std::move(getters)};
  208. }
  209. #endif
  210. private:
  211. struct TState : public TBlockState {
  212. size_t Rows_ = 0;
  213. bool IsFinished_ = false;
  214. size_t BuilderAllocatedSize_ = 0;
  215. size_t MaxBuilderAllocatedSize_ = 0;
  216. std::vector<std::unique_ptr<IArrayBuilder>> Builders_;
  217. static const size_t MaxAllocatedFactor_ = 4;
  218. TState(TMemoryUsageInfo* memInfo, TComputationContext& ctx, const TVector<TType*>& types, size_t maxLength, NUdf::TUnboxedValue**const fields)
  219. : TBlockState(memInfo, types.size() + 1U)
  220. , Builders_(types.size())
  221. {
  222. for (size_t i = 0; i < types.size(); ++i) {
  223. fields[i] = &Values[i];
  224. Builders_[i] = MakeArrayBuilder(TTypeInfoHelper(), types[i], ctx.ArrowMemoryPool, maxLength, &ctx.Builder->GetPgBuilder(), &BuilderAllocatedSize_);
  225. }
  226. MaxBuilderAllocatedSize_ = MaxAllocatedFactor_ * BuilderAllocatedSize_;
  227. }
  228. void Add(const NUdf::TUnboxedValuePod value, size_t idx) {
  229. Builders_[idx]->Add(value);
  230. }
  231. void MakeBlocks(const THolderFactory& holderFactory) {
  232. Values.back() = holderFactory.CreateArrowBlock(arrow::Datum(std::make_shared<arrow::UInt64Scalar>(Rows_)));
  233. Rows_ = 0;
  234. BuilderAllocatedSize_ = 0;
  235. for (size_t i = 0; i < Builders_.size(); ++i) {
  236. if (const auto builder = Builders_[i].get()) {
  237. Values[i] = holderFactory.CreateArrowBlock(builder->Build(IsFinished_));
  238. }
  239. }
  240. FillArrays();
  241. }
  242. };
  243. #ifndef MKQL_DISABLE_CODEGEN
  244. class TLLVMFieldsStructureState: public TLLVMFieldsStructureBlockState {
  245. private:
  246. using TBase = TLLVMFieldsStructureBlockState;
  247. llvm::IntegerType*const RowsType;
  248. llvm::IntegerType*const IsFinishedType;
  249. llvm::IntegerType*const BuilderAllocatedSizeType;
  250. llvm::IntegerType*const MaxBuilderAllocatedSizeType;
  251. protected:
  252. using TBase::Context;
  253. public:
  254. std::vector<llvm::Type*> GetFieldsArray() {
  255. std::vector<llvm::Type*> result = TBase::GetFieldsArray();
  256. result.emplace_back(RowsType);
  257. result.emplace_back(IsFinishedType);
  258. result.emplace_back(BuilderAllocatedSizeType);
  259. result.emplace_back(MaxBuilderAllocatedSizeType);
  260. return result;
  261. }
  262. llvm::Constant* GetRows() {
  263. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + BaseFields);
  264. }
  265. llvm::Constant* GetIsFinished() {
  266. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + BaseFields + 1);
  267. }
  268. llvm::Constant* GetBuilderAllocatedSize() {
  269. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + BaseFields + 2);
  270. }
  271. llvm::Constant* GetMaxBuilderAllocatedSize() {
  272. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + BaseFields + 3);
  273. }
  274. TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width)
  275. : TBase(context, width)
  276. , RowsType(Type::getInt64Ty(Context))
  277. , IsFinishedType(Type::getInt1Ty(Context))
  278. , BuilderAllocatedSizeType(Type::getInt64Ty(Context))
  279. , MaxBuilderAllocatedSizeType(Type::getInt64Ty(Context))
  280. {}
  281. };
  282. #endif
  283. void RegisterDependencies() const final {
  284. FlowDependsOn(Flow_);
  285. }
  286. void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
  287. state = ctx.HolderFactory.Create<TState>(ctx, Types_, MaxLength_, ctx.WideFields.data() + WideFieldsIndex_);
  288. }
  289. TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  290. if (state.IsInvalid())
  291. MakeState(ctx, state);
  292. return *static_cast<TState*>(state.AsBoxed().Get());
  293. }
  294. private:
  295. IComputationWideFlowNode* const Flow_;
  296. const TVector<TType*> Types_;
  297. const size_t MaxLength_;
  298. const size_t Width_;
  299. const size_t WideFieldsIndex_;
  300. };
  301. class TFromBlocksWrapper : public TStatefulFlowCodegeneratorNode<TFromBlocksWrapper> {
  302. using TBaseComputation = TStatefulFlowCodegeneratorNode<TFromBlocksWrapper>;
  303. public:
  304. TFromBlocksWrapper(TComputationMutables& mutables, IComputationNode* flow, TType* itemType)
  305. : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
  306. , Flow_(flow)
  307. , ItemType_(itemType)
  308. {
  309. }
  310. NUdf::TUnboxedValuePod DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  311. for (auto& s = GetState(state, ctx);;) {
  312. if (auto item = s.GetValue(ctx.HolderFactory); !item.IsInvalid())
  313. return item;
  314. if (const auto input = Flow_->GetValue(ctx); input.IsSpecial())
  315. return input;
  316. else
  317. s.Reset(input);
  318. }
  319. }
  320. #ifndef MKQL_DISABLE_CODEGEN
  321. Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  322. auto& context = ctx.Codegen.GetContext();
  323. const auto valueType = Type::getInt128Ty(context);
  324. const auto statePtrType = PointerType::getUnqual(StructType::get(context));
  325. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  326. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  327. const auto read = BasicBlock::Create(context, "read", ctx.Func);
  328. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  329. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  330. BranchInst::Create(make, work, IsInvalid(statePtr, block, context), block);
  331. block = make;
  332. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  333. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  334. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TFromBlocksWrapper::MakeState));
  335. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false);
  336. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  337. CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
  338. BranchInst::Create(work, block);
  339. block = work;
  340. const auto state = new LoadInst(valueType, statePtr, "state", block);
  341. const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
  342. const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
  343. const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::GetValue));
  344. const auto getType = FunctionType::get(valueType, {statePtrType, ctx.GetFactory()->getType()}, false);
  345. const auto getPtr = CastInst::Create(Instruction::IntToPtr, getFunc, PointerType::getUnqual(getType), "get", block);
  346. const auto value = CallInst::Create(getType, getPtr, {stateArg, ctx.GetFactory() }, "value", block);
  347. const auto result = PHINode::Create(valueType, 2U, "result", done);
  348. result->addIncoming(value, block);
  349. BranchInst::Create(read, done, IsInvalid(value, block, context), block);
  350. block = read;
  351. const auto input = GetNodeValue(Flow_, ctx, block);
  352. result->addIncoming(input, block);
  353. BranchInst::Create(done, init, IsSpecial(input, block, context), block);
  354. block = init;
  355. const auto setFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Reset));
  356. const auto setType = FunctionType::get(valueType, {statePtrType, valueType}, false);
  357. const auto setPtr = CastInst::Create(Instruction::IntToPtr, setFunc, PointerType::getUnqual(setType), "set", block);
  358. CallInst::Create(setType, setPtr, {stateArg, input }, "", block);
  359. ValueCleanup(EValueRepresentation::Any, input, ctx, block);
  360. BranchInst::Create(work, block);
  361. block = done;
  362. return result;
  363. }
  364. #endif
  365. private:
  366. struct TState : public TComputationValue<TState> {
  367. using TComputationValue::TComputationValue;
  368. TState(TMemoryUsageInfo* memInfo, TType* itemType, const NUdf::IPgBuilder& pgBuilder)
  369. : TComputationValue(memInfo)
  370. , Reader_(MakeBlockReader(TTypeInfoHelper(), itemType))
  371. , Converter_(MakeBlockItemConverter(TTypeInfoHelper(), itemType, pgBuilder))
  372. {
  373. }
  374. NUdf::TUnboxedValuePod GetValue(const THolderFactory& holderFactory) {
  375. for (;;) {
  376. if (Arrays_.empty()) {
  377. return NUdf::TUnboxedValuePod::Invalid();
  378. }
  379. if (Index_ < ui64(Arrays_.front()->length)) {
  380. break;
  381. }
  382. Index_ = 0;
  383. Arrays_.pop_front();
  384. }
  385. return Converter_->MakeValue(Reader_->GetItem(*Arrays_.front(), Index_++), holderFactory);
  386. }
  387. void Reset(const NUdf::TUnboxedValuePod block) {
  388. const auto& datum = TArrowBlock::From(block).GetDatum();
  389. MKQL_ENSURE(datum.is_arraylike(), "Expecting array as FromBlocks argument");
  390. MKQL_ENSURE(Arrays_.empty(), "Not all input is processed");
  391. if (datum.is_array()) {
  392. Arrays_.push_back(datum.array());
  393. } else {
  394. for (const auto& chunk : datum.chunks()) {
  395. Arrays_.push_back(chunk->data());
  396. }
  397. }
  398. Index_ = 0;
  399. }
  400. private:
  401. const std::unique_ptr<IBlockReader> Reader_;
  402. const std::unique_ptr<IBlockItemConverter> Converter_;
  403. TDeque<std::shared_ptr<arrow::ArrayData>> Arrays_;
  404. size_t Index_ = 0;
  405. };
  406. private:
  407. void RegisterDependencies() const final {
  408. FlowDependsOn(Flow_);
  409. }
  410. void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
  411. state = ctx.HolderFactory.Create<TState>(ItemType_, ctx.Builder->GetPgBuilder());
  412. }
  413. TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  414. if (state.IsInvalid())
  415. MakeState(ctx, state);
  416. return *static_cast<TState*>(state.AsBoxed().Get());
  417. }
  418. private:
  419. IComputationNode* const Flow_;
  420. TType* ItemType_;
  421. };
  422. struct TWideFromBlocksState : public TComputationValue<TWideFromBlocksState> {
  423. size_t Count_ = 0;
  424. size_t Index_ = 0;
  425. size_t Current_ = 0;
  426. NUdf::TUnboxedValue* Pointer_ = nullptr;
  427. TUnboxedValueVector Values_;
  428. std::vector<std::unique_ptr<IBlockReader>> Readers_;
  429. std::vector<std::unique_ptr<IBlockItemConverter>> Converters_;
  430. const std::vector<arrow::ValueDescr> ValuesDescr_;
  431. TWideFromBlocksState(TMemoryUsageInfo* memInfo, TComputationContext& ctx, const TVector<TType*>& types)
  432. : TComputationValue(memInfo)
  433. , Values_(types.size() + 1)
  434. , ValuesDescr_(ToValueDescr(types))
  435. {
  436. Pointer_ = Values_.data();
  437. const auto& pgBuilder = ctx.Builder->GetPgBuilder();
  438. for (size_t i = 0; i < types.size(); ++i) {
  439. const TType* blockItemType = AS_TYPE(TBlockType, types[i])->GetItemType();
  440. Readers_.push_back(MakeBlockReader(TTypeInfoHelper(), blockItemType));
  441. Converters_.push_back(MakeBlockItemConverter(TTypeInfoHelper(), blockItemType, pgBuilder));
  442. }
  443. }
  444. void ClearValues() {
  445. Values_.assign(Values_.size(), NUdf::TUnboxedValuePod());
  446. }
  447. NUdf::TUnboxedValuePod Get(const THolderFactory& holderFactory, size_t idx) const {
  448. TBlockItem item;
  449. const auto& datum = TArrowBlock::From(Values_[idx]).GetDatum();
  450. ARROW_DEBUG_CHECK_DATUM_TYPES(ValuesDescr_[idx], datum.descr());
  451. if (datum.is_scalar()) {
  452. item = Readers_[idx]->GetScalarItem(*datum.scalar());
  453. } else {
  454. MKQL_ENSURE(datum.is_array(), "Expecting array");
  455. item = Readers_[idx]->GetItem(*datum.array(), Current_);
  456. }
  457. return Converters_[idx]->MakeValue(item, holderFactory);
  458. }
  459. };
  460. class TWideFromBlocksFlowWrapper : public TStatefulWideFlowCodegeneratorNode<TWideFromBlocksFlowWrapper> {
  461. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TWideFromBlocksFlowWrapper>;
  462. using TState = TWideFromBlocksState;
  463. public:
  464. TWideFromBlocksFlowWrapper(TComputationMutables& mutables,
  465. IComputationWideFlowNode* flow,
  466. TVector<TType*>&& types)
  467. : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
  468. , Flow_(flow)
  469. , Types_(std::move(types))
  470. , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(Types_.size() + 1U))
  471. {}
  472. EFetchResult DoCalculate(NUdf::TUnboxedValue& state,
  473. TComputationContext& ctx,
  474. NUdf::TUnboxedValue*const* output) const
  475. {
  476. auto& s = GetState(state, ctx);
  477. const auto fields = ctx.WideFields.data() + WideFieldsIndex_;
  478. if (s.Index_ == s.Count_) do {
  479. if (const auto result = Flow_->FetchValues(ctx, fields); result != EFetchResult::One)
  480. return result;
  481. s.Index_ = 0;
  482. s.Count_ = GetBlockCount(s.Values_.back());
  483. } while (!s.Count_);
  484. s.Current_ = s.Index_;
  485. ++s.Index_;
  486. for (size_t i = 0; i < Types_.size(); ++i)
  487. if (const auto out = output[i])
  488. *out = s.Get(ctx.HolderFactory, i);
  489. return EFetchResult::One;
  490. }
  491. #ifndef MKQL_DISABLE_CODEGEN
  492. ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  493. auto& context = ctx.Codegen.GetContext();
  494. const auto width = Types_.size();
  495. const auto valueType = Type::getInt128Ty(context);
  496. const auto statusType = Type::getInt32Ty(context);
  497. const auto indexType = Type::getInt64Ty(context);
  498. const auto arrayType = ArrayType::get(valueType, width);
  499. const auto ptrValuesType = PointerType::getUnqual(ArrayType::get(valueType, width));
  500. TLLVMFieldsStructureState stateFields(context, width);
  501. const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
  502. const auto statePtrType = PointerType::getUnqual(stateType);
  503. const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::Get));
  504. const auto getType = FunctionType::get(valueType, {statePtrType, ctx.GetFactory()->getType(), indexType}, false);
  505. const auto getPtr = CastInst::Create(Instruction::IntToPtr, getFunc, PointerType::getUnqual(getType), "get", &ctx.Func->getEntryBlock().back());
  506. const auto stateOnStack = new AllocaInst(statePtrType, 0U, "state_on_stack", &ctx.Func->getEntryBlock().back());
  507. new StoreInst(ConstantPointerNull::get(statePtrType), stateOnStack, &ctx.Func->getEntryBlock().back());
  508. const auto name = "GetBlockCount";
  509. ctx.Codegen.AddGlobalMapping(name, reinterpret_cast<const void*>(&GetBlockCount));
  510. const auto getCountType = FunctionType::get(indexType, { valueType }, false);
  511. const auto getCount = ctx.Codegen.GetModule().getOrInsertFunction(name, getCountType);
  512. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  513. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  514. const auto more = BasicBlock::Create(context, "more", ctx.Func);
  515. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  516. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  517. const auto over = BasicBlock::Create(context, "over", ctx.Func);
  518. BranchInst::Create(make, main, IsInvalid(statePtr, block, context), block);
  519. block = make;
  520. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  521. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  522. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TWideFromBlocksFlowWrapper::MakeState));
  523. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false);
  524. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  525. CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
  526. BranchInst::Create(main, block);
  527. block = main;
  528. const auto state = new LoadInst(valueType, statePtr, "state", block);
  529. const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
  530. const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
  531. const auto countPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetCount() }, "count_ptr", block);
  532. const auto indexPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetIndex() }, "index_ptr", block);
  533. const auto count = new LoadInst(indexType, countPtr, "count", block);
  534. const auto index = new LoadInst(indexType, indexPtr, "index", block);
  535. const auto next = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, count, index, "next", block);
  536. BranchInst::Create(more, work, next, block);
  537. block = more;
  538. const auto clearFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::ClearValues));
  539. const auto clearType = FunctionType::get(Type::getVoidTy(context), {statePtrType}, false);
  540. const auto clearPtr = CastInst::Create(Instruction::IntToPtr, clearFunc, PointerType::getUnqual(clearType), "clear", block);
  541. CallInst::Create(clearType, clearPtr, {stateArg}, "", block);
  542. const auto getres = GetNodeValues(Flow_, ctx, block);
  543. const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, getres.first, ConstantInt::get(getres.first->getType(), static_cast<i32>(EFetchResult::Yield)), "special", block);
  544. const auto result = PHINode::Create(statusType, 2U, "result", over);
  545. result->addIncoming(getres.first, block);
  546. BranchInst::Create(over, good, special, block);
  547. block = good;
  548. const auto countValue = getres.second.back()(ctx, block);
  549. const auto height = CallInst::Create(getCount, { countValue }, "height", block);
  550. ValueCleanup(EValueRepresentation::Any, countValue, ctx, block);
  551. new StoreInst(height, countPtr, block);
  552. new StoreInst(ConstantInt::get(indexType, 0), indexPtr, block);
  553. const auto empty = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, ConstantInt::get(indexType, 0), height, "empty", block);
  554. BranchInst::Create(more, work, empty, block);
  555. block = work;
  556. const auto current = new LoadInst(indexType, indexPtr, "current", block);
  557. const auto currentPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetCurrent() }, "current_ptr", block);
  558. new StoreInst(current, currentPtr, block);
  559. const auto increment = BinaryOperator::CreateAdd(current, ConstantInt::get(indexType, 1), "increment", block);
  560. new StoreInst(increment, indexPtr, block);
  561. new StoreInst(stateArg, stateOnStack, block);
  562. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), block);
  563. BranchInst::Create(over, block);
  564. block = over;
  565. ICodegeneratorInlineWideNode::TGettersList getters(width);
  566. for (size_t idx = 0U; idx < getters.size(); ++idx) {
  567. getters[idx] = [idx, width, getType, getPtr, indexType, arrayType, ptrValuesType, stateType, statePtrType, stateOnStack, getBlocks = getres.second](const TCodegenContext& ctx, BasicBlock*& block) {
  568. auto& context = ctx.Codegen.GetContext();
  569. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  570. const auto call = BasicBlock::Create(context, "call", ctx.Func);
  571. TLLVMFieldsStructureState stateFields(context, width);
  572. const auto stateArg = new LoadInst(statePtrType, stateOnStack, "state", block);
  573. const auto valuesPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetPointer() }, "values_ptr", block);
  574. const auto values = new LoadInst(ptrValuesType, valuesPtr, "values", block);
  575. const auto index = ConstantInt::get(indexType, idx);
  576. const auto pointer = GetElementPtrInst::CreateInBounds(arrayType, values, { ConstantInt::get(indexType, 0), index }, "pointer", block);
  577. BranchInst::Create(call, init, HasValue(pointer, block, context), block);
  578. block = init;
  579. const auto value = getBlocks[idx](ctx, block);
  580. new StoreInst(value, pointer, block);
  581. AddRefBoxed(value, ctx, block);
  582. BranchInst::Create(call, block);
  583. block = call;
  584. return CallInst::Create(getType, getPtr, {stateArg, ctx.GetFactory(), index}, "get", block);
  585. };
  586. }
  587. return {result, std::move(getters)};
  588. }
  589. #endif
  590. private:
  591. #ifndef MKQL_DISABLE_CODEGEN
  592. class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TState>> {
  593. private:
  594. using TBase = TLLVMFieldsStructure<TComputationValue<TState>>;
  595. llvm::IntegerType*const CountType;
  596. llvm::IntegerType*const IndexType;
  597. llvm::IntegerType*const CurrentType;
  598. llvm::PointerType*const PointerType;
  599. protected:
  600. using TBase::Context;
  601. public:
  602. std::vector<llvm::Type*> GetFieldsArray() {
  603. std::vector<llvm::Type*> result = TBase::GetFields();
  604. result.emplace_back(CountType);
  605. result.emplace_back(IndexType);
  606. result.emplace_back(CurrentType);
  607. result.emplace_back(PointerType);
  608. return result;
  609. }
  610. llvm::Constant* GetCount() {
  611. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 0);
  612. }
  613. llvm::Constant* GetIndex() {
  614. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 1);
  615. }
  616. llvm::Constant* GetCurrent() {
  617. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 2);
  618. }
  619. llvm::Constant* GetPointer() {
  620. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount() + 3);
  621. }
  622. TLLVMFieldsStructureState(llvm::LLVMContext& context, size_t width)
  623. : TBase(context)
  624. , CountType(Type::getInt64Ty(Context))
  625. , IndexType(Type::getInt64Ty(Context))
  626. , CurrentType(Type::getInt64Ty(Context))
  627. , PointerType(PointerType::getUnqual(ArrayType::get(Type::getInt128Ty(Context), width)))
  628. {}
  629. };
  630. #endif
  631. void RegisterDependencies() const final {
  632. FlowDependsOn(Flow_);
  633. }
  634. void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
  635. state = ctx.HolderFactory.Create<TState>(ctx, Types_);
  636. }
  637. TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  638. if (state.IsInvalid()) {
  639. MakeState(ctx, state);
  640. const auto s = static_cast<TState*>(state.AsBoxed().Get());
  641. auto**const fields = ctx.WideFields.data() + WideFieldsIndex_;
  642. for (size_t i = 0; i <= Types_.size(); ++i) {
  643. fields[i] = &s->Values_[i];
  644. }
  645. return *s;
  646. }
  647. return *static_cast<TState*>(state.AsBoxed().Get());
  648. }
  649. IComputationWideFlowNode* const Flow_;
  650. const TVector<TType*> Types_;
  651. const size_t WideFieldsIndex_;
  652. };
  653. class TWideFromBlocksStreamWrapper : public TMutableComputationNode<TWideFromBlocksStreamWrapper>
  654. {
  655. using TBaseComputation = TMutableComputationNode<TWideFromBlocksStreamWrapper>;
  656. using TState = TWideFromBlocksState;
  657. public:
  658. TWideFromBlocksStreamWrapper(TComputationMutables& mutables,
  659. IComputationNode* stream,
  660. TVector<TType*>&& types)
  661. : TBaseComputation(mutables, EValueRepresentation::Boxed)
  662. , Stream_(stream)
  663. , Types_(std::move(types))
  664. {}
  665. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const
  666. {
  667. const auto state = ctx.HolderFactory.Create<TState>(ctx, Types_);
  668. return ctx.HolderFactory.Create<TStreamValue>(ctx.HolderFactory,
  669. std::move(state),
  670. std::move(Stream_->GetValue(ctx)));
  671. }
  672. private:
  673. class TStreamValue : public TComputationValue<TStreamValue> {
  674. using TBase = TComputationValue<TStreamValue>;
  675. public:
  676. TStreamValue(TMemoryUsageInfo* memInfo, const THolderFactory& holderFactory,
  677. NUdf::TUnboxedValue&& blockState, NUdf::TUnboxedValue&& stream)
  678. : TBase(memInfo)
  679. , BlockState_(blockState)
  680. , Stream_(stream)
  681. , HolderFactory_(holderFactory)
  682. {}
  683. private:
  684. NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* output, ui32 width) {
  685. auto& blockState = *static_cast<TState*>(BlockState_.AsBoxed().Get());
  686. auto* inputFields = blockState.Pointer_;
  687. const size_t inputWidth = blockState.Values_.size();
  688. if (blockState.Index_ == blockState.Count_) do {
  689. if (const auto result = Stream_.WideFetch(inputFields, inputWidth); result != NUdf::EFetchStatus::Ok)
  690. return result;
  691. blockState.Index_ = 0;
  692. blockState.Count_ = GetBlockCount(blockState.Values_.back());
  693. } while (!blockState.Count_);
  694. blockState.Current_ = blockState.Index_++;
  695. for (size_t i = 0; i < width; i++) {
  696. output[i] = blockState.Get(HolderFactory_, i);
  697. }
  698. return NUdf::EFetchStatus::Ok;
  699. }
  700. NUdf::TUnboxedValue BlockState_;
  701. NUdf::TUnboxedValue Stream_;
  702. const THolderFactory& HolderFactory_;
  703. };
  704. void RegisterDependencies() const final {
  705. this->DependsOn(Stream_);
  706. }
  707. IComputationNode* const Stream_;
  708. const TVector<TType*> Types_;
  709. };
  710. class TPrecomputedArrowNode : public IArrowKernelComputationNode {
  711. public:
  712. TPrecomputedArrowNode(const arrow::Datum& datum, TStringBuf kernelName)
  713. : Kernel_({}, datum.type(), [datum](arrow::compute::KernelContext*, const arrow::compute::ExecBatch&, arrow::Datum* res) {
  714. *res = datum;
  715. return arrow::Status::OK();
  716. })
  717. , KernelName_(kernelName)
  718. {
  719. Kernel_.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE;
  720. Kernel_.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE;
  721. }
  722. TStringBuf GetKernelName() const final {
  723. return KernelName_;
  724. }
  725. const arrow::compute::ScalarKernel& GetArrowKernel() const {
  726. return Kernel_;
  727. }
  728. const std::vector<arrow::ValueDescr>& GetArgsDesc() const {
  729. return EmptyDesc_;
  730. }
  731. const IComputationNode* GetArgument(ui32 index) const {
  732. Y_UNUSED(index);
  733. ythrow yexception() << "No input arguments";
  734. }
  735. private:
  736. arrow::compute::ScalarKernel Kernel_;
  737. const TStringBuf KernelName_;
  738. const std::vector<arrow::ValueDescr> EmptyDesc_;
  739. };
  740. class TAsScalarWrapper : public TMutableCodegeneratorNode<TAsScalarWrapper> {
  741. using TBaseComputation = TMutableCodegeneratorNode<TAsScalarWrapper>;
  742. public:
  743. TAsScalarWrapper(TComputationMutables& mutables, IComputationNode* arg, TType* type)
  744. : TBaseComputation(mutables, EValueRepresentation::Boxed)
  745. , Arg_(arg)
  746. , Type_(type)
  747. {
  748. std::shared_ptr<arrow::DataType> arrowType;
  749. MKQL_ENSURE(ConvertArrowType(Type_, arrowType), "Unsupported type of scalar");
  750. }
  751. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  752. return AsScalar(Arg_->GetValue(ctx).Release(), ctx);
  753. }
  754. #ifndef MKQL_DISABLE_CODEGEN
  755. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  756. auto& context = ctx.Codegen.GetContext();
  757. const auto value = GetNodeValue(Arg_, ctx, block);
  758. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  759. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  760. const auto asScalarFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TAsScalarWrapper::AsScalar));
  761. const auto asScalarType = FunctionType::get(Type::getInt128Ty(context), {self->getType(), value->getType(), ctx.Ctx->getType()}, false);
  762. const auto asScalarFuncPtr = CastInst::Create(Instruction::IntToPtr, asScalarFunc, PointerType::getUnqual(asScalarType), "function", block);
  763. return CallInst::Create(asScalarType, asScalarFuncPtr, {self, value, ctx.Ctx}, "scalar", block);
  764. }
  765. #endif
  766. private:
  767. std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const final {
  768. return std::make_unique<TPrecomputedArrowNode>(DoAsScalar(Arg_->GetValue(ctx).Release(), ctx), "AsScalar");
  769. }
  770. arrow::Datum DoAsScalar(const NUdf::TUnboxedValuePod value, TComputationContext& ctx) const {
  771. const NUdf::TUnboxedValue v(value);
  772. return ConvertScalar(Type_, v, ctx.ArrowMemoryPool);
  773. }
  774. NUdf::TUnboxedValuePod AsScalar(const NUdf::TUnboxedValuePod value, TComputationContext& ctx) const {
  775. return ctx.HolderFactory.CreateArrowBlock(DoAsScalar(value, ctx));
  776. }
  777. void RegisterDependencies() const final {
  778. DependsOn(Arg_);
  779. }
  780. IComputationNode* const Arg_;
  781. TType* Type_;
  782. };
  783. class TReplicateScalarWrapper : public TMutableCodegeneratorNode<TReplicateScalarWrapper> {
  784. using TBaseComputation = TMutableCodegeneratorNode<TReplicateScalarWrapper>;
  785. public:
  786. TReplicateScalarWrapper(TComputationMutables& mutables, IComputationNode* value, IComputationNode* count, TType* type)
  787. : TBaseComputation(mutables, EValueRepresentation::Boxed)
  788. , Value_(value)
  789. , Count_(count)
  790. , Type_(type)
  791. {
  792. std::shared_ptr<arrow::DataType> arrowType;
  793. MKQL_ENSURE(ConvertArrowType(Type_, arrowType), "Unsupported type of scalar");
  794. }
  795. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  796. const auto value = Value_->GetValue(ctx).Release();
  797. const auto count = Count_->GetValue(ctx).Release();
  798. return Replicate(value, count, ctx);
  799. }
  800. #ifndef MKQL_DISABLE_CODEGEN
  801. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  802. auto& context = ctx.Codegen.GetContext();
  803. const auto value = GetNodeValue(Value_, ctx, block);
  804. const auto count = GetNodeValue(Count_, ctx, block);
  805. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  806. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  807. const auto replicateFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TReplicateScalarWrapper::Replicate));
  808. const auto replicateType = FunctionType::get(Type::getInt128Ty(context), {self->getType(), value->getType(), count->getType(), ctx.Ctx->getType()}, false);
  809. const auto replicateFuncPtr = CastInst::Create(Instruction::IntToPtr, replicateFunc, PointerType::getUnqual(replicateType), "function", block);
  810. return CallInst::Create(replicateType, replicateFuncPtr, {self, value, count, ctx.Ctx}, "replicate", block);
  811. }
  812. #endif
  813. private:
  814. std::unique_ptr<IArrowKernelComputationNode> PrepareArrowKernelComputationNode(TComputationContext& ctx) const final {
  815. const auto value = Value_->GetValue(ctx).Release();
  816. const auto count = Count_->GetValue(ctx).Release();
  817. return std::make_unique<TPrecomputedArrowNode>(DoReplicate(value, count, ctx), "ReplicateScalar");
  818. }
  819. arrow::Datum DoReplicate(const NUdf::TUnboxedValuePod val, const NUdf::TUnboxedValuePod cnt, TComputationContext& ctx) const {
  820. const auto value = TArrowBlock::From(val).GetDatum().scalar();
  821. const ui64 count = TArrowBlock::From(cnt).GetDatum().scalar_as<arrow::UInt64Scalar>().value;
  822. const auto reader = MakeBlockReader(TTypeInfoHelper(), Type_);
  823. const auto builder = MakeArrayBuilder(TTypeInfoHelper(), Type_, ctx.ArrowMemoryPool, count, &ctx.Builder->GetPgBuilder());
  824. TBlockItem item = reader->GetScalarItem(*value);
  825. builder->Add(item, count);
  826. return builder->Build(true);
  827. }
  828. NUdf::TUnboxedValuePod Replicate(const NUdf::TUnboxedValuePod value, const NUdf::TUnboxedValuePod count, TComputationContext& ctx) const {
  829. return ctx.HolderFactory.CreateArrowBlock(DoReplicate(value, count, ctx));
  830. }
  831. void RegisterDependencies() const final {
  832. DependsOn(Value_);
  833. DependsOn(Count_);
  834. }
  835. IComputationNode* const Value_;
  836. IComputationNode* const Count_;
  837. TType* Type_;
  838. };
  839. class TBlockExpandChunkedWrapper : public TStatefulWideFlowCodegeneratorNode<TBlockExpandChunkedWrapper> {
  840. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TBlockExpandChunkedWrapper>;
  841. public:
  842. TBlockExpandChunkedWrapper(TComputationMutables& mutables, IComputationWideFlowNode* flow, size_t width)
  843. : TBaseComputation(mutables, flow, EValueRepresentation::Boxed)
  844. , Flow_(flow)
  845. , Width_(width)
  846. , WideFieldsIndex_(mutables.IncrementWideFieldsIndex(Width_))
  847. {
  848. }
  849. EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  850. auto& s = GetState(state, ctx);
  851. if (!s.Count) {
  852. const auto fields = ctx.WideFields.data() + WideFieldsIndex_;
  853. s.ClearValues();
  854. if (const auto result = Flow_->FetchValues(ctx, fields); result != EFetchResult::One)
  855. return result;
  856. s.FillArrays();
  857. }
  858. const auto sliceSize = s.Slice();
  859. for (size_t i = 0; i < Width_; ++i) {
  860. if (const auto out = output[i]) {
  861. *out = s.Get(sliceSize, ctx.HolderFactory, i);
  862. }
  863. }
  864. return EFetchResult::One;
  865. }
  866. #ifndef MKQL_DISABLE_CODEGEN
  867. ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  868. auto& context = ctx.Codegen.GetContext();
  869. const auto valueType = Type::getInt128Ty(context);
  870. const auto statusType = Type::getInt32Ty(context);
  871. const auto indexType = Type::getInt64Ty(context);
  872. const auto arrayType = ArrayType::get(valueType, Width_);
  873. const auto ptrValuesType = PointerType::getUnqual(arrayType);
  874. TLLVMFieldsStructureBlockState stateFields(context, Width_);
  875. const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
  876. const auto statePtrType = PointerType::getUnqual(stateType);
  877. const auto atTop = &ctx.Func->getEntryBlock().back();
  878. const auto getFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TBlockState::Get));
  879. const auto getType = FunctionType::get(valueType, {statePtrType, indexType, ctx.GetFactory()->getType(), indexType}, false);
  880. const auto getPtr = CastInst::Create(Instruction::IntToPtr, getFunc, PointerType::getUnqual(getType), "get", atTop);
  881. const auto heightPtr = new AllocaInst(indexType, 0U, "height_ptr", atTop);
  882. const auto stateOnStack = new AllocaInst(statePtrType, 0U, "state_on_stack", atTop);
  883. new StoreInst(ConstantInt::get(indexType, 0), heightPtr, atTop);
  884. new StoreInst(ConstantPointerNull::get(statePtrType), stateOnStack, atTop);
  885. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  886. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  887. const auto read = BasicBlock::Create(context, "read", ctx.Func);
  888. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  889. const auto fill = BasicBlock::Create(context, "fill", ctx.Func);
  890. const auto over = BasicBlock::Create(context, "over", ctx.Func);
  891. BranchInst::Create(make, main, IsInvalid(statePtr, block, context), block);
  892. block = make;
  893. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  894. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  895. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TBlockExpandChunkedWrapper::MakeState));
  896. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false);
  897. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  898. CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
  899. BranchInst::Create(main, block);
  900. block = main;
  901. const auto state = new LoadInst(valueType, statePtr, "state", block);
  902. const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
  903. const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
  904. const auto countPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetCount() }, "count_ptr", block);
  905. const auto count = new LoadInst(indexType, countPtr, "count", block);
  906. const auto next = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, count, ConstantInt::get(indexType, 0), "next", block);
  907. BranchInst::Create(read, fill, next, block);
  908. block = read;
  909. const auto clearFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TBlockState::ClearValues));
  910. const auto clearType = FunctionType::get(Type::getVoidTy(context), {statePtrType}, false);
  911. const auto clearPtr = CastInst::Create(Instruction::IntToPtr, clearFunc, PointerType::getUnqual(clearType), "clear", block);
  912. CallInst::Create(clearType, clearPtr, {stateArg}, "", block);
  913. const auto getres = GetNodeValues(Flow_, ctx, block);
  914. const auto special = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, getres.first, ConstantInt::get(getres.first->getType(), static_cast<i32>(EFetchResult::Yield)), "special", block);
  915. const auto result = PHINode::Create(statusType, 2U, "result", over);
  916. result->addIncoming(getres.first, block);
  917. BranchInst::Create(over, work, special, block);
  918. block = work;
  919. const auto valuesPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetPointer() }, "values_ptr", block);
  920. const auto values = new LoadInst(ptrValuesType, valuesPtr, "values", block);
  921. Value* array = UndefValue::get(arrayType);
  922. for (auto idx = 0U; idx < getres.second.size(); ++idx) {
  923. const auto value = getres.second[idx](ctx, block);
  924. AddRefBoxed(value, ctx, block);
  925. array = InsertValueInst::Create(array, value, {idx}, (TString("value_") += ToString(idx)).c_str(), block);
  926. }
  927. new StoreInst(array, values, block);
  928. const auto fillArraysFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TBlockState::FillArrays));
  929. const auto fillArraysType = FunctionType::get(Type::getVoidTy(context), {statePtrType}, false);
  930. const auto fillArraysPtr = CastInst::Create(Instruction::IntToPtr, fillArraysFunc, PointerType::getUnqual(fillArraysType), "fill_arrays_func", block);
  931. CallInst::Create(fillArraysType, fillArraysPtr, {stateArg}, "", block);
  932. BranchInst::Create(fill, block);
  933. block = fill;
  934. const auto sliceFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TBlockState::Slice));
  935. const auto sliceType = FunctionType::get(indexType, {statePtrType}, false);
  936. const auto slicePtr = CastInst::Create(Instruction::IntToPtr, sliceFunc, PointerType::getUnqual(sliceType), "slice_func", block);
  937. const auto slice = CallInst::Create(sliceType, slicePtr, {stateArg}, "slice", block);
  938. new StoreInst(slice, heightPtr, block);
  939. new StoreInst(stateArg, stateOnStack, block);
  940. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), block);
  941. BranchInst::Create(over, block);
  942. block = over;
  943. ICodegeneratorInlineWideNode::TGettersList getters(Width_);
  944. for (size_t idx = 0U; idx < getters.size(); ++idx) {
  945. getters[idx] = [idx, getType, getPtr, heightPtr, indexType, statePtrType, stateOnStack](const TCodegenContext& ctx, BasicBlock*& block) {
  946. const auto stateArg = new LoadInst(statePtrType, stateOnStack, "state", block);
  947. const auto heightArg = new LoadInst(indexType, heightPtr, "height", block);
  948. return CallInst::Create(getType, getPtr, {stateArg, heightArg, ctx.GetFactory(), ConstantInt::get(indexType, idx)}, "get", block);
  949. };
  950. }
  951. return {result, std::move(getters)};
  952. }
  953. #endif
  954. private:
  955. void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
  956. state = ctx.HolderFactory.Create<TBlockState>(Width_);
  957. }
  958. TBlockState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  959. if (state.IsInvalid()) {
  960. MakeState(ctx, state);
  961. auto& s = *static_cast<TBlockState*>(state.AsBoxed().Get());
  962. const auto fields = ctx.WideFields.data() + WideFieldsIndex_;
  963. for (size_t i = 0; i < Width_; ++i)
  964. fields[i] = &s.Values[i];
  965. return s;
  966. }
  967. return *static_cast<TBlockState*>(state.AsBoxed().Get());
  968. }
  969. void RegisterDependencies() const final {
  970. FlowDependsOn(Flow_);
  971. }
  972. IComputationWideFlowNode* const Flow_;
  973. const size_t Width_;
  974. const size_t WideFieldsIndex_;
  975. };
  976. class TBlockExpandChunkedStreamWrapper : public TMutableComputationNode<TBlockExpandChunkedStreamWrapper> {
  977. using TBaseComputation = TMutableComputationNode<TBlockExpandChunkedStreamWrapper>;
  978. class TExpanderState : public TComputationValue<TExpanderState> {
  979. using TBase = TComputationValue<TExpanderState>;
  980. public:
  981. TExpanderState(TMemoryUsageInfo* memInfo, TComputationContext& ctx, NUdf::TUnboxedValue&& stream, size_t width)
  982. : TBase(memInfo), HolderFactory_(ctx.HolderFactory), State_(ctx.HolderFactory.Create<TBlockState>(width)), Stream_(stream) {}
  983. NUdf::EFetchStatus WideFetch(NUdf::TUnboxedValue* output, ui32 width) {
  984. auto& s = *static_cast<TBlockState*>(State_.AsBoxed().Get());
  985. if (!s.Count) {
  986. s.ClearValues();
  987. auto result = Stream_.WideFetch(s.Values.data(), width);
  988. if (NUdf::EFetchStatus::Ok != result) {
  989. return result;
  990. }
  991. s.FillArrays();
  992. }
  993. const auto sliceSize = s.Slice();
  994. for (size_t i = 0; i < width; ++i) {
  995. output[i] = s.Get(sliceSize, HolderFactory_, i);
  996. }
  997. return NUdf::EFetchStatus::Ok;
  998. }
  999. private:
  1000. const THolderFactory& HolderFactory_;
  1001. NUdf::TUnboxedValue State_;
  1002. NUdf::TUnboxedValue Stream_;
  1003. };
  1004. public:
  1005. TBlockExpandChunkedStreamWrapper(TComputationMutables& mutables, IComputationNode* stream, size_t width)
  1006. : TBaseComputation(mutables, EValueRepresentation::Boxed)
  1007. , Stream_(stream)
  1008. , Width_(width) {}
  1009. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  1010. return ctx.HolderFactory.Create<TExpanderState>(ctx, std::move(Stream_->GetValue(ctx)), Width_);
  1011. }
  1012. void RegisterDependencies() const override {}
  1013. private:
  1014. IComputationNode* const Stream_;
  1015. const size_t Width_;
  1016. };
  1017. } // namespace
  1018. IComputationNode* WrapToBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1019. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args, got " << callable.GetInputsCount());
  1020. const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType());
  1021. return new TToBlocksWrapper(LocateNode(ctx.NodeLocator, callable, 0), flowType->GetItemType());
  1022. }
  1023. IComputationNode* WrapWideToBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1024. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args, got " << callable.GetInputsCount());
  1025. const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType());
  1026. const auto wideComponents = GetWideComponents(flowType);
  1027. TVector<TType*> items(wideComponents.begin(), wideComponents.end());
  1028. const auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0));
  1029. MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node");
  1030. return new TWideToBlocksWrapper(ctx.Mutables, wideFlow, std::move(items));
  1031. }
  1032. IComputationNode* WrapFromBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1033. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args, got " << callable.GetInputsCount());
  1034. const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType());
  1035. const auto blockType = AS_TYPE(TBlockType, flowType->GetItemType());
  1036. return new TFromBlocksWrapper(ctx.Mutables, LocateNode(ctx.NodeLocator, callable, 0), blockType->GetItemType());
  1037. }
  1038. IComputationNode* WrapWideFromBlocks(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1039. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args, got " << callable.GetInputsCount());
  1040. const auto inputType = callable.GetInput(0).GetStaticType();
  1041. MKQL_ENSURE(inputType->IsStream() || inputType->IsFlow(),
  1042. "Expected either WideStream or WideFlow as an input");
  1043. const auto yieldsStream = callable.GetType()->GetReturnType()->IsStream();
  1044. MKQL_ENSURE(yieldsStream == inputType->IsStream(),
  1045. "Expected both input and output have to be either WideStream or WideFlow");
  1046. const auto wideComponents = GetWideComponents(inputType);
  1047. MKQL_ENSURE(wideComponents.size() > 0, "Expected at least one column");
  1048. TVector<TType*> items;
  1049. for (ui32 i = 0; i < wideComponents.size() - 1; ++i) {
  1050. items.push_back(AS_TYPE(TBlockType, wideComponents[i]));
  1051. }
  1052. const auto wideFlowOrStream = LocateNode(ctx.NodeLocator, callable, 0);
  1053. if (yieldsStream) {
  1054. const auto wideStream = wideFlowOrStream;
  1055. return new TWideFromBlocksStreamWrapper(ctx.Mutables, wideStream, std::move(items));
  1056. }
  1057. // FIXME: Drop the branch below, when the time comes.
  1058. const auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(wideFlowOrStream);
  1059. MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node");
  1060. return new TWideFromBlocksFlowWrapper(ctx.Mutables, wideFlow, std::move(items));
  1061. }
  1062. IComputationNode* WrapAsScalar(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1063. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args, got " << callable.GetInputsCount());
  1064. return new TAsScalarWrapper(ctx.Mutables, LocateNode(ctx.NodeLocator, callable, 0), callable.GetInput(0).GetStaticType());
  1065. }
  1066. IComputationNode* WrapReplicateScalar(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1067. MKQL_ENSURE(callable.GetInputsCount() == 2, "Expected 2 args, got " << callable.GetInputsCount());
  1068. const auto valueType = AS_TYPE(TBlockType, callable.GetInput(0).GetStaticType());
  1069. MKQL_ENSURE(valueType->GetShape() == TBlockType::EShape::Scalar, "Expecting scalar as first arg");
  1070. const auto value = LocateNode(ctx.NodeLocator, callable, 0);
  1071. const auto count = LocateNode(ctx.NodeLocator, callable, 1);
  1072. return new TReplicateScalarWrapper(ctx.Mutables, value, count, valueType->GetItemType());
  1073. }
  1074. IComputationNode* WrapBlockExpandChunked(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  1075. MKQL_ENSURE(callable.GetInputsCount() == 1, "Expected 1 args, got " << callable.GetInputsCount());
  1076. if (callable.GetInput(0).GetStaticType()->IsStream()) {
  1077. const auto streamType = AS_TYPE(TStreamType, callable.GetInput(0).GetStaticType());
  1078. const auto wideComponents = GetWideComponents(streamType);
  1079. const auto computation = dynamic_cast<IComputationNode*>(LocateNode(ctx.NodeLocator, callable, 0));
  1080. MKQL_ENSURE(computation != nullptr, "Expected computation node");
  1081. return new TBlockExpandChunkedStreamWrapper(ctx.Mutables, computation, wideComponents.size());
  1082. } else {
  1083. const auto flowType = AS_TYPE(TFlowType, callable.GetInput(0).GetStaticType());
  1084. const auto wideComponents = GetWideComponents(flowType);
  1085. const auto wideFlow = dynamic_cast<IComputationWideFlowNode*>(LocateNode(ctx.NodeLocator, callable, 0));
  1086. MKQL_ENSURE(wideFlow != nullptr, "Expected wide flow node");
  1087. return new TBlockExpandChunkedWrapper(ctx.Mutables, wideFlow, wideComponents.size());
  1088. }
  1089. }
  1090. }
  1091. }