mkql_extend.cpp 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632
  1. #include "mkql_extend.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  4. #include <yql/essentials/minikql/computation/mkql_llvm_base.h> // Y_IGNORE
  5. #include <yql/essentials/minikql/computation/mkql_custom_list.h>
  6. #include <yql/essentials/minikql/mkql_node_cast.h>
  7. #include <util/string/cast.h>
  8. #include <queue>
  9. namespace NKikimr {
  10. namespace NMiniKQL {
  11. namespace {
  12. class TState : public TComputationValue<TState> {
  13. public:
  14. ssize_t Index;
  15. std::queue<ssize_t> Queue;
  16. TState(TMemoryUsageInfo* memInfo, ssize_t count)
  17. : TComputationValue<TState>(memInfo)
  18. {
  19. while (count)
  20. Queue.push(--count);
  21. Index = Queue.front();
  22. }
  23. void NextFlow() {
  24. Queue.push(Queue.front());
  25. Queue.pop();
  26. Index = Queue.front();
  27. }
  28. void FlowOver() {
  29. Queue.pop();
  30. Index = Queue.empty() ? -1LL : Queue.front();
  31. }
  32. };
  33. #ifndef MKQL_DISABLE_CODEGEN
  34. class TLLVMFieldsStructureState: public TLLVMFieldsStructure<TComputationValue<TState>> {
  35. private:
  36. using TBase = TLLVMFieldsStructure<TComputationValue<TState>>;
  37. llvm::IntegerType*const IndexType;
  38. protected:
  39. using TBase::Context;
  40. public:
  41. std::vector<llvm::Type*> GetFieldsArray() {
  42. auto result = TBase::GetFields();
  43. result.emplace_back(IndexType);
  44. return result;
  45. }
  46. llvm::Constant* GetIndex() {
  47. return ConstantInt::get(Type::getInt32Ty(Context), TBase::GetFieldsCount());
  48. }
  49. TLLVMFieldsStructureState(llvm::LLVMContext& context)
  50. : TBase(context), IndexType(Type::getInt64Ty(Context))
  51. {}
  52. };
  53. #endif
  54. class TExtendWideFlowWrapper : public TStatefulWideFlowCodegeneratorNode<TExtendWideFlowWrapper> {
  55. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TExtendWideFlowWrapper>;
  56. public:
  57. TExtendWideFlowWrapper(TComputationMutables& mutables, TComputationWideFlowNodePtrVector&& flows, size_t width)
  58. : TBaseComputation(mutables, this, EValueRepresentation::Boxed)
  59. , Flows_(std::move(flows)), Width_(width)
  60. {
  61. #ifdef MKQL_DISABLE_CODEGEN
  62. Y_UNUSED(Width_);
  63. #endif
  64. }
  65. EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  66. auto& s = GetState(state, ctx);
  67. while (s.Index >= 0) {
  68. switch (Flows_[s.Index]->FetchValues(ctx, output)) {
  69. case EFetchResult::One:
  70. return EFetchResult::One;
  71. case EFetchResult::Yield:
  72. s.NextFlow();
  73. return EFetchResult::Yield;
  74. case EFetchResult::Finish:
  75. s.FlowOver();
  76. break;
  77. }
  78. }
  79. return EFetchResult::Finish;
  80. }
  81. #ifndef MKQL_DISABLE_CODEGEN
  82. ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  83. auto& context = ctx.Codegen.GetContext();
  84. const auto valueType = Type::getInt128Ty(context);
  85. const auto indexType = Type::getInt64Ty(context);
  86. const auto statusType = Type::getInt32Ty(context);
  87. const auto arrayType = ArrayType::get(valueType, Width_);
  88. const auto arrayPtr = new AllocaInst(arrayType, 0, "array_ptr", &ctx.Func->getEntryBlock().back());
  89. TLLVMFieldsStructureState stateFields(context);
  90. const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
  91. const auto statePtrType = PointerType::getUnqual(stateType);
  92. const auto funcType = FunctionType::get(Type::getVoidTy(context), {statePtrType}, false);
  93. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  94. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  95. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  96. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  97. const auto over = BasicBlock::Create(context, "over", ctx.Func);
  98. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  99. BranchInst::Create(make, main, IsInvalid(statePtr, block, context), block);
  100. block = make;
  101. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  102. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  103. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TExtendWideFlowWrapper::MakeState));
  104. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false);
  105. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  106. CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
  107. BranchInst::Create(main, block);
  108. block = main;
  109. const auto state = new LoadInst(valueType, statePtr, "state", block);
  110. const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
  111. const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
  112. const auto indexPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetIndex() }, "index_ptr", block);
  113. BranchInst::Create(loop, main);
  114. block = loop;
  115. const auto index = new LoadInst(indexType, indexPtr, "index", block);
  116. const auto result = PHINode::Create(statusType, Flows_.size() + 2U, "result", done);
  117. const auto select = SwitchInst::Create(index, done, Flows_.size(), block);
  118. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block);
  119. for (auto i = 0U; i < Flows_.size(); ++i) {
  120. const auto flow = BasicBlock::Create(context, (TString("flow_") += ToString(i)).c_str(), ctx.Func);
  121. const auto save = BasicBlock::Create(context, (TString("save_") += ToString(i)).c_str(), ctx.Func);
  122. select->addCase(ConstantInt::get(indexType, i), flow);
  123. block = flow;
  124. const auto getres = GetNodeValues(Flows_[i], ctx, block);
  125. const auto way = SwitchInst::Create(getres.first, save, 2U, block);
  126. way->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), over);
  127. way->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), next);
  128. block = save;
  129. Value* values = UndefValue::get(arrayType);
  130. for (auto idx = 0U; idx < Width_; ++idx) {
  131. const auto value = getres.second[idx](ctx, block);
  132. values = InsertValueInst::Create(values, value, {idx}, (TString("value_") += ToString(idx)).c_str(), block);
  133. }
  134. new StoreInst(values, arrayPtr, block);
  135. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), block);
  136. BranchInst::Create(done, block);
  137. }
  138. block = next;
  139. const auto nextFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::NextFlow));
  140. const auto nextPtr = CastInst::Create(Instruction::IntToPtr, nextFunc, PointerType::getUnqual(funcType), "next_ptr", block);
  141. CallInst::Create(funcType, nextPtr, {stateArg}, "", block);
  142. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block);
  143. BranchInst::Create(done, block);
  144. block = over;
  145. const auto overFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::FlowOver));
  146. const auto overPtr = CastInst::Create(Instruction::IntToPtr, overFunc, PointerType::getUnqual(funcType), "over_ptr", block);
  147. CallInst::Create(funcType, overPtr, {stateArg}, "", block);
  148. BranchInst::Create(loop, block);
  149. block = done;
  150. ICodegeneratorInlineWideNode::TGettersList getters(Width_);
  151. for (size_t idx = 0U; idx < getters.size(); ++idx) {
  152. getters[idx] = [idx, valueType, arrayType, arrayPtr, indexType](const TCodegenContext& ctx, BasicBlock*& block) {
  153. Y_UNUSED(ctx);
  154. const auto valuePtr = GetElementPtrInst::CreateInBounds(arrayType, arrayPtr, { ConstantInt::get(indexType, 0), ConstantInt::get(indexType, idx)}, "value_ptr", block);
  155. return new LoadInst(valueType, valuePtr, "value", block);
  156. };
  157. }
  158. return {result, std::move(getters)};
  159. }
  160. #endif
  161. private:
  162. void RegisterDependencies() const final {
  163. for (auto& flow : Flows_) {
  164. FlowDependsOn(flow);
  165. }
  166. }
  167. void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
  168. state = ctx.HolderFactory.Create<TState>(Flows_.size());
  169. }
  170. TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  171. if (state.IsInvalid())
  172. MakeState(ctx, state);
  173. return *static_cast<TState*>(state.AsBoxed().Get());
  174. }
  175. const TComputationWideFlowNodePtrVector Flows_;
  176. const size_t Width_;
  177. };
  178. class TExtendFlowWrapper : public TStatefulFlowCodegeneratorNode<TExtendFlowWrapper> {
  179. typedef TStatefulFlowCodegeneratorNode<TExtendFlowWrapper> TBaseComputation;
  180. public:
  181. TExtendFlowWrapper(TComputationMutables& mutables, EValueRepresentation kind, TComputationNodePtrVector&& flows)
  182. : TBaseComputation(mutables, this, kind, EValueRepresentation::Boxed), Flows(flows)
  183. {}
  184. NUdf::TUnboxedValuePod DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  185. auto& s = GetState(state, ctx);
  186. while (s.Index >= 0) {
  187. auto item = Flows[s.Index]->GetValue(ctx);
  188. if (item.IsYield())
  189. s.NextFlow();
  190. if (item.IsFinish())
  191. s.FlowOver();
  192. else
  193. return item.Release();
  194. }
  195. return NUdf::TUnboxedValuePod::MakeFinish();
  196. }
  197. #ifndef MKQL_DISABLE_CODEGEN
  198. Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  199. auto& context = ctx.Codegen.GetContext();
  200. const auto valueType = Type::getInt128Ty(context);
  201. const auto indexType = Type::getInt64Ty(context);
  202. TLLVMFieldsStructureState stateFields(context);
  203. const auto stateType = StructType::get(context, stateFields.GetFieldsArray());
  204. const auto statePtrType = PointerType::getUnqual(stateType);
  205. const auto funcType = FunctionType::get(Type::getVoidTy(context), {statePtrType}, false);
  206. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  207. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  208. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  209. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  210. const auto over = BasicBlock::Create(context, "over", ctx.Func);
  211. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  212. BranchInst::Create(make, main, IsInvalid(statePtr, block, context), block);
  213. block = make;
  214. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  215. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  216. const auto makeFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TExtendFlowWrapper::MakeState));
  217. const auto makeType = FunctionType::get(Type::getVoidTy(context), {self->getType(), ctx.Ctx->getType(), statePtr->getType()}, false);
  218. const auto makeFuncPtr = CastInst::Create(Instruction::IntToPtr, makeFunc, PointerType::getUnqual(makeType), "function", block);
  219. CallInst::Create(makeType, makeFuncPtr, {self, ctx.Ctx, statePtr}, "", block);
  220. BranchInst::Create(main, block);
  221. block = main;
  222. const auto state = new LoadInst(valueType, statePtr, "state", block);
  223. const auto half = CastInst::Create(Instruction::Trunc, state, Type::getInt64Ty(context), "half", block);
  224. const auto stateArg = CastInst::Create(Instruction::IntToPtr, half, statePtrType, "state_arg", block);
  225. const auto indexPtr = GetElementPtrInst::CreateInBounds(stateType, stateArg, { stateFields.This(), stateFields.GetIndex() }, "index_ptr", block);
  226. BranchInst::Create(loop, main);
  227. block = loop;
  228. const auto index = new LoadInst(indexType, indexPtr, "index", block);
  229. const auto result = PHINode::Create(valueType, Flows.size() + 2U, "result", done);
  230. const auto select = SwitchInst::Create(index, done, Flows.size(), block);
  231. result->addIncoming(GetFinish(context), block);
  232. for (auto i = 0U; i < Flows.size(); ++i) {
  233. const auto flow = BasicBlock::Create(context, (TString("flow_") += ToString(i)).c_str(), ctx.Func);
  234. select->addCase(ConstantInt::get(indexType, i), flow);
  235. block = flow;
  236. const auto item = GetNodeValue(Flows[i], ctx, block);
  237. result->addIncoming(item, block);
  238. const auto way = SwitchInst::Create(item, done, 2U, block);
  239. way->addCase(GetFinish(context), over);
  240. way->addCase(GetYield(context), next);
  241. }
  242. block = next;
  243. const auto nextFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::NextFlow));
  244. const auto nextPtr = CastInst::Create(Instruction::IntToPtr, nextFunc, PointerType::getUnqual(funcType), "next_ptr", block);
  245. CallInst::Create(funcType, nextPtr, {stateArg}, "", block);
  246. result->addIncoming(GetYield(context), block);
  247. BranchInst::Create(done, block);
  248. block = over;
  249. const auto overFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TState::FlowOver));
  250. const auto overPtr = CastInst::Create(Instruction::IntToPtr, overFunc, PointerType::getUnqual(funcType), "over_ptr", block);
  251. CallInst::Create(funcType, overPtr, {stateArg}, "", block);
  252. BranchInst::Create(loop, block);
  253. block = done;
  254. return result;
  255. }
  256. #endif
  257. private:
  258. void MakeState(TComputationContext& ctx, NUdf::TUnboxedValue& state) const {
  259. state = ctx.HolderFactory.Create<TState>(Flows.size());
  260. }
  261. TState& GetState(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  262. if (state.IsInvalid())
  263. MakeState(ctx, state);
  264. return *static_cast<TState*>(state.AsBoxed().Get());
  265. }
  266. void RegisterDependencies() const final {
  267. std::for_each(Flows.cbegin(), Flows.cend(), std::bind(&TExtendFlowWrapper::FlowDependsOn, this, std::placeholders::_1));
  268. }
  269. const TComputationNodePtrVector Flows;
  270. };
  271. class TOrderedExtendWideFlowWrapper : public TStatefulWideFlowCodegeneratorNode<TOrderedExtendWideFlowWrapper> {
  272. using TBaseComputation = TStatefulWideFlowCodegeneratorNode<TOrderedExtendWideFlowWrapper>;
  273. public:
  274. TOrderedExtendWideFlowWrapper(TComputationMutables& mutables, TComputationWideFlowNodePtrVector&& flows, size_t width)
  275. : TBaseComputation(mutables, this, EValueRepresentation::Embedded)
  276. , Flows_(std::move(flows)), Width_(width)
  277. {
  278. #ifdef MKQL_DISABLE_CODEGEN
  279. Y_UNUSED(Width_);
  280. #endif
  281. }
  282. EFetchResult DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx, NUdf::TUnboxedValue*const* output) const {
  283. for (ui64 index = state.IsInvalid() ? 0ULL : state.Get<ui64>(); index < Flows_.size(); ++index) {
  284. if (const auto result = Flows_[index]->FetchValues(ctx, output); EFetchResult::Finish != result) {
  285. state = NUdf::TUnboxedValuePod(index);
  286. return result;
  287. }
  288. }
  289. return EFetchResult::Finish;
  290. }
  291. #ifndef MKQL_DISABLE_CODEGEN
  292. ICodegeneratorInlineWideNode::TGenerateResult DoGenGetValues(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  293. auto& context = ctx.Codegen.GetContext();
  294. const auto valueType = Type::getInt128Ty(context);
  295. const auto indexType = Type::getInt64Ty(context);
  296. const auto statusType = Type::getInt32Ty(context);
  297. const auto arrayType = ArrayType::get(valueType, Width_);
  298. const auto arrayPtr = new AllocaInst(arrayType, 0, "array_ptr", &ctx.Func->getEntryBlock().back());
  299. TLLVMFieldsStructureState stateFields(context);
  300. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  301. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  302. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  303. const auto load = new LoadInst(valueType, statePtr, "load", block);
  304. const auto start = SelectInst::Create(IsInvalid(load, block, context), ConstantInt::get(indexType, 0ULL), GetterFor<ui64>(load, context, block), "start", block);
  305. const auto index = PHINode::Create(indexType, 2U, "index", main);
  306. index->addIncoming(start, block);
  307. BranchInst::Create(main, block);
  308. block = main;
  309. const auto result = PHINode::Create(statusType, Flows_.size() + 2U, "result", done);
  310. const auto select = SwitchInst::Create(index, done, Flows_.size(), block);
  311. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), block);
  312. for (auto i = 0U; i < Flows_.size(); ++i) {
  313. const auto flow = BasicBlock::Create(context, (TString("flow_") += ToString(i)).c_str(), ctx.Func);
  314. const auto save = BasicBlock::Create(context, (TString("save_") += ToString(i)).c_str(), ctx.Func);
  315. select->addCase(ConstantInt::get(indexType, i), flow);
  316. block = flow;
  317. const auto getres = GetNodeValues(Flows_[i], ctx, block);
  318. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), block);
  319. const auto way = SwitchInst::Create(getres.first, save, 2U, block);
  320. way->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Finish)), next);
  321. way->addCase(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::Yield)), done);
  322. block = save;
  323. Value* values = UndefValue::get(arrayType);
  324. for (auto idx = 0U; idx < Width_; ++idx) {
  325. const auto value = getres.second[idx](ctx, block);
  326. values = InsertValueInst::Create(values, value, {idx}, (TString("value_") += ToString(idx)).c_str(), block);
  327. }
  328. new StoreInst(values, arrayPtr, block);
  329. result->addIncoming(ConstantInt::get(statusType, static_cast<i32>(EFetchResult::One)), block);
  330. BranchInst::Create(done, block);
  331. }
  332. block = next;
  333. const auto plus = BinaryOperator::CreateAdd(index, ConstantInt::get(indexType, 1ULL), "plus", block);
  334. index->addIncoming(plus, block);
  335. BranchInst::Create(main, block);
  336. block = done;
  337. new StoreInst(SetterFor<ui64>(index, context, block), statePtr, block);
  338. ICodegeneratorInlineWideNode::TGettersList getters(Width_);
  339. for (size_t idx = 0U; idx < getters.size(); ++idx) {
  340. getters[idx] = [idx, valueType, arrayType, arrayPtr, indexType](const TCodegenContext& ctx, BasicBlock*& block) {
  341. Y_UNUSED(ctx);
  342. const auto valuePtr = GetElementPtrInst::CreateInBounds(arrayType, arrayPtr, { ConstantInt::get(indexType, 0), ConstantInt::get(indexType, idx)}, "value_ptr", block);
  343. return new LoadInst(valueType, valuePtr, "value", block);
  344. };
  345. }
  346. return {result, std::move(getters)};
  347. }
  348. #endif
  349. private:
  350. void RegisterDependencies() const final {
  351. for (auto& flow : Flows_) {
  352. FlowDependsOn(flow);
  353. }
  354. }
  355. const TComputationWideFlowNodePtrVector Flows_;
  356. const size_t Width_;
  357. };
  358. class TOrderedExtendFlowWrapper : public TStatefulFlowCodegeneratorNode<TOrderedExtendFlowWrapper> {
  359. using TBaseComputation = TStatefulFlowCodegeneratorNode<TOrderedExtendFlowWrapper>;
  360. public:
  361. TOrderedExtendFlowWrapper(TComputationMutables& mutables, EValueRepresentation kind, TComputationNodePtrVector&& flows)
  362. : TBaseComputation(mutables, this, kind, EValueRepresentation::Embedded), Flows_(flows)
  363. {}
  364. NUdf::TUnboxedValue DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  365. for (ui64 index = state.IsInvalid() ? 0ULL : state.Get<ui64>(); index < Flows_.size(); ++index) {
  366. const auto item = Flows_[index]->GetValue(ctx);
  367. if (!item.IsFinish()) {
  368. state = NUdf::TUnboxedValuePod(index);
  369. return item;
  370. }
  371. }
  372. return NUdf::TUnboxedValuePod::MakeFinish();
  373. }
  374. #ifndef MKQL_DISABLE_CODEGEN
  375. Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  376. auto& context = ctx.Codegen.GetContext();
  377. const auto valueType = Type::getInt128Ty(context);
  378. const auto indexType = Type::getInt64Ty(context);
  379. const auto load = new LoadInst(valueType, statePtr, "load", block);
  380. const auto state = SelectInst::Create(IsInvalid(load, block, context), ConstantInt::get(indexType, 0ULL), GetterFor<ui64>(load, context, block), "index", block);
  381. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  382. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  383. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  384. const auto result = PHINode::Create(valueType, Flows_.size() + 1U, "result", done);
  385. const auto index = PHINode::Create(indexType, 2U, "index", main);
  386. index->addIncoming(state, block);
  387. BranchInst::Create(main, block);
  388. block = main;
  389. const auto select = SwitchInst::Create(index, done, Flows_.size(), block);
  390. result->addIncoming(GetFinish(context), block);
  391. for (auto i = 0U; i < Flows_.size(); ++i) {
  392. const auto flow = BasicBlock::Create(context, "flow", ctx.Func);
  393. select->addCase(ConstantInt::get(indexType, i), flow);
  394. block = flow;
  395. const auto item = GetNodeValue(Flows_[i], ctx, block);
  396. result->addIncoming(item, block);
  397. BranchInst::Create(next, done, IsFinish(item, block, context), block);
  398. }
  399. block = next;
  400. const auto plus = BinaryOperator::CreateAdd(index, ConstantInt::get(indexType, 1ULL), "plus", block);
  401. index->addIncoming(plus, block);
  402. BranchInst::Create(main, block);
  403. block = done;
  404. new StoreInst(SetterFor<ui64>(index, context, block), statePtr, block);
  405. return result;
  406. }
  407. #endif
  408. private:
  409. void RegisterDependencies() const final {
  410. std::for_each(Flows_.cbegin(), Flows_.cend(), std::bind(&TOrderedExtendFlowWrapper::FlowDependsOn, this, std::placeholders::_1));
  411. }
  412. const TComputationNodePtrVector Flows_;
  413. };
  414. template <bool IsStream>
  415. class TOrderedExtendWrapper : public TMutableCodegeneratorNode<TOrderedExtendWrapper<IsStream>> {
  416. using TBaseComputation = TMutableCodegeneratorNode<TOrderedExtendWrapper<IsStream>>;
  417. public:
  418. TOrderedExtendWrapper(TComputationMutables& mutables, TComputationNodePtrVector&& lists)
  419. : TBaseComputation(mutables, EValueRepresentation::Boxed)
  420. , Lists(std::move(lists))
  421. {
  422. }
  423. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  424. TUnboxedValueVector values;
  425. values.reserve(Lists.size());
  426. std::transform(Lists.cbegin(), Lists.cend(), std::back_inserter(values),
  427. std::bind(&IComputationNode::GetValue, std::placeholders::_1, std::ref(ctx))
  428. );
  429. return IsStream ?
  430. ctx.HolderFactory.ExtendStream(values.data(), values.size()):
  431. ctx.HolderFactory.ExtendList<false>(values.data(), values.size());
  432. }
  433. #ifndef MKQL_DISABLE_CODEGEN
  434. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  435. auto& context = ctx.Codegen.GetContext();
  436. const auto valueType = Type::getInt128Ty(context);
  437. const auto sizeType = Type::getInt64Ty(context);
  438. const auto size = ConstantInt::get(sizeType, Lists.size());
  439. const auto arrayType = ArrayType::get(valueType, Lists.size());
  440. const auto array = *this->Stateless || ctx.AlwaysInline ?
  441. new AllocaInst(arrayType, 0U, "array", &ctx.Func->getEntryBlock().back()):
  442. new AllocaInst(arrayType, 0U, "array", block);
  443. for (size_t i = 0U; i < Lists.size(); ++i) {
  444. const auto ptr = GetElementPtrInst::CreateInBounds(arrayType, array, {ConstantInt::get(sizeType, 0), ConstantInt::get(sizeType, i)}, (TString("ptr_") += ToString(i)).c_str(), block);
  445. GetNodeValue(ptr, Lists[i], ctx, block);
  446. }
  447. const auto factory = ctx.GetFactory();
  448. const auto func = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(IsStream ? &THolderFactory::ExtendStream : &THolderFactory::ExtendList<false>));
  449. const auto funType = FunctionType::get(valueType, {factory->getType(), array->getType(), size->getType()}, false);
  450. const auto funcPtr = CastInst::Create(Instruction::IntToPtr, func, PointerType::getUnqual(funType), "function", block);
  451. const auto res = CallInst::Create(funType, funcPtr, {factory, array, size}, "res", block);
  452. return res;
  453. }
  454. #endif
  455. private:
  456. void RegisterDependencies() const final {
  457. std::for_each(Lists.cbegin(), Lists.cend(), std::bind(&TOrderedExtendWrapper::DependsOn, this, std::placeholders::_1));
  458. }
  459. const TComputationNodePtrVector Lists;
  460. };
  461. template<bool Ordered>
  462. IComputationNode* WrapExtendT(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  463. MKQL_ENSURE(callable.GetInputsCount() >= 1, "Expected at least 1 list");
  464. const auto type = callable.GetType()->GetReturnType();
  465. TComputationNodePtrVector flows;
  466. flows.reserve(callable.GetInputsCount());
  467. for (ui32 i = 0; i < callable.GetInputsCount(); ++i) {
  468. flows.emplace_back(LocateNode(ctx.NodeLocator, callable, i));
  469. }
  470. if (type->IsFlow()) {
  471. if (dynamic_cast<IComputationWideFlowNode*>(flows.front())) {
  472. const auto width = GetWideComponentsCount(AS_TYPE(TFlowType, callable.GetType()->GetReturnType()));
  473. TComputationWideFlowNodePtrVector wideFlows;
  474. wideFlows.reserve(callable.GetInputsCount());
  475. for (ui32 i = 0; i < callable.GetInputsCount(); ++i) {
  476. wideFlows.emplace_back(dynamic_cast<IComputationWideFlowNode*>(flows[i]));
  477. MKQL_ENSURE_S(wideFlows.back());
  478. }
  479. if constexpr (Ordered)
  480. return new TOrderedExtendWideFlowWrapper(ctx.Mutables, std::move(wideFlows), width);
  481. else
  482. return new TExtendWideFlowWrapper(ctx.Mutables, std::move(wideFlows), width);
  483. }
  484. if constexpr (Ordered)
  485. return new TOrderedExtendFlowWrapper(ctx.Mutables, GetValueRepresentation(AS_TYPE(TFlowType, type)->GetItemType()), std::move(flows));
  486. else
  487. return new TExtendFlowWrapper(ctx.Mutables, GetValueRepresentation(AS_TYPE(TFlowType, type)->GetItemType()), std::move(flows));
  488. } else if (type->IsStream()) {
  489. return new TOrderedExtendWrapper<true>(ctx.Mutables, std::move(flows));
  490. } else if (type->IsList()) {
  491. return new TOrderedExtendWrapper<false>(ctx.Mutables, std::move(flows));
  492. }
  493. THROW yexception() << "Expected either flow, list or stream.";
  494. }
  495. }
  496. IComputationNode* WrapExtend(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  497. return WrapExtendT<false>(callable, ctx);
  498. }
  499. IComputationNode* WrapOrderedExtend(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  500. return WrapExtendT<true>(callable, ctx);
  501. }
  502. }
  503. }