mkql_while.cpp 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698
  1. #include "mkql_while.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  4. #include <yql/essentials/minikql/mkql_node_cast.h>
  5. namespace NKikimr {
  6. namespace NMiniKQL {
  7. namespace {
  8. template <bool SkipOrTake, bool Inclusive>
  9. class TWhileFlowWrapper : public TStatefulFlowCodegeneratorNode<TWhileFlowWrapper<SkipOrTake, Inclusive>> {
  10. using TBaseComputation = TStatefulFlowCodegeneratorNode<TWhileFlowWrapper<SkipOrTake, Inclusive>>;
  11. public:
  12. TWhileFlowWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* flow, IComputationExternalNode* item, IComputationNode* predicate)
  13. : TBaseComputation(mutables, flow, kind, EValueRepresentation::Embedded), Flow(flow), Item(item), Predicate(predicate)
  14. {}
  15. NUdf::TUnboxedValue DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  16. if (state.HasValue() && state.Get<bool>()) {
  17. return SkipOrTake ? Flow->GetValue(ctx) : NUdf::TUnboxedValue(NUdf::TUnboxedValuePod::MakeFinish());
  18. }
  19. if constexpr (SkipOrTake) {
  20. do if (auto item = Flow->GetValue(ctx); item.IsSpecial())
  21. return item;
  22. else
  23. Item->SetValue(ctx, std::move(item));
  24. while (Predicate->GetValue(ctx).template Get<bool>());
  25. state = NUdf::TUnboxedValuePod(true);
  26. return Inclusive ? Flow->GetValue(ctx) : Item->GetValue(ctx);
  27. } else {
  28. if (auto item = Flow->GetValue(ctx); item.IsSpecial())
  29. return item;
  30. else
  31. Item->SetValue(ctx, std::move(item));
  32. if (Predicate->GetValue(ctx).template Get<bool>()) {
  33. return Item->GetValue(ctx);
  34. }
  35. state = NUdf::TUnboxedValuePod(true);
  36. return Inclusive ? Item->GetValue(ctx).Release() : NUdf::TUnboxedValuePod::MakeFinish();
  37. }
  38. }
  39. #ifndef MKQL_DISABLE_CODEGEN
  40. Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  41. auto& context = ctx.Codegen.GetContext();
  42. const auto codegenItem = dynamic_cast<ICodegeneratorExternalNode*>(Item);
  43. MKQL_ENSURE(codegenItem, "Item must be codegenerator node.");
  44. const auto valueType = Type::getInt128Ty(context);
  45. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  46. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  47. const auto stop = BasicBlock::Create(context, "stop", ctx.Func);
  48. const auto skip = BasicBlock::Create(context, "skip", ctx.Func);
  49. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  50. const auto result = PHINode::Create(valueType, SkipOrTake ? 3U : 4U, "result", done);
  51. const auto state = new LoadInst(valueType, statePtr, "state", block);
  52. const auto finished = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, state, GetTrue(context), "finished", block);
  53. BranchInst::Create(skip, work, finished, block);
  54. block = work;
  55. const auto item = GetNodeValue(Flow, ctx, block);
  56. result->addIncoming(item, block);
  57. BranchInst::Create(done, good, IsSpecial(item, block), block);
  58. block = good;
  59. codegenItem->CreateSetValue(ctx, block, item);
  60. const auto pred = GetNodeValue(Predicate, ctx, block);
  61. const auto bit = CastInst::Create(Instruction::Trunc, pred, Type::getInt1Ty(context), "bit", block);
  62. if constexpr (SkipOrTake) {
  63. BranchInst::Create(work, stop, bit, block);
  64. } else {
  65. result->addIncoming(item, block);
  66. BranchInst::Create(done, stop, bit, block);
  67. }
  68. block = stop;
  69. new StoreInst(GetTrue(context), statePtr, block);
  70. const auto last = Inclusive ?
  71. (SkipOrTake ? GetNodeValue(Flow, ctx, block) : item):
  72. (SkipOrTake ? item : GetFinish(context));
  73. result->addIncoming(last, block);
  74. BranchInst::Create(done, block);
  75. block = skip;
  76. const auto res = SkipOrTake ? GetNodeValue(Flow, ctx, block) : GetFinish(context);
  77. result->addIncoming(res, block);
  78. BranchInst::Create(done, block);
  79. block = done;
  80. return result;
  81. }
  82. #endif
  83. private:
  84. void RegisterDependencies() const final {
  85. if (const auto flow = this->FlowDependsOn(Flow)) {
  86. this->Own(flow, Item);
  87. this->DependsOn(flow, Predicate);
  88. }
  89. }
  90. IComputationNode* const Flow;
  91. IComputationExternalNode* const Item;
  92. IComputationNode* const Predicate;
  93. };
  94. template <bool SkipOrTake, bool Inclusive, bool IsStream>
  95. class TBaseWhileWrapper {
  96. protected:
  97. class TListValue : public TCustomListValue {
  98. public:
  99. class TIterator : public TComputationValue<TIterator> {
  100. public:
  101. TIterator(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, NUdf::TUnboxedValue&& iter, IComputationExternalNode* item, IComputationNode* predicate)
  102. : TComputationValue<TIterator>(memInfo)
  103. , CompCtx(compCtx)
  104. , Iter(std::move(iter))
  105. , Item(item)
  106. , Predicate(predicate)
  107. {}
  108. private:
  109. bool Next(NUdf::TUnboxedValue& value) override {
  110. if (FilterWorkFinished) {
  111. return SkipOrTake ? Iter.Next(value) : false;
  112. }
  113. if constexpr (SkipOrTake) {
  114. while (Iter.Next(Item->RefValue(CompCtx))) {
  115. if (!Predicate->GetValue(CompCtx).template Get<bool>()) {
  116. FilterWorkFinished = true;
  117. if constexpr (Inclusive) {
  118. return Iter.Next(value);
  119. } else {
  120. value = Item->GetValue(CompCtx);
  121. return true;
  122. }
  123. }
  124. }
  125. } else {
  126. if (Iter.Next(Item->RefValue(CompCtx))) {
  127. if (Predicate->GetValue(CompCtx).template Get<bool>()) {
  128. value = Item->GetValue(CompCtx);
  129. return true;
  130. } else {
  131. FilterWorkFinished = true;
  132. if constexpr (Inclusive) {
  133. value = Item->GetValue(CompCtx);
  134. return true;
  135. }
  136. }
  137. }
  138. }
  139. return false;
  140. }
  141. TComputationContext& CompCtx;
  142. const NUdf::TUnboxedValue Iter;
  143. IComputationExternalNode* const Item;
  144. IComputationNode* const Predicate;
  145. bool FilterWorkFinished = false;
  146. };
  147. TListValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, const NUdf::TUnboxedValue& list, IComputationExternalNode* item, IComputationNode* predicate)
  148. : TCustomListValue(memInfo)
  149. , CompCtx(compCtx)
  150. , List(list)
  151. , Item(item)
  152. , Predicate(predicate)
  153. {}
  154. private:
  155. NUdf::TUnboxedValue GetListIterator() const override {
  156. return CompCtx.HolderFactory.Create<TIterator>(CompCtx, List.GetListIterator(), Item, Predicate);
  157. }
  158. TComputationContext& CompCtx;
  159. const NUdf::TUnboxedValue List;
  160. IComputationExternalNode* const Item;
  161. IComputationNode* const Predicate;
  162. };
  163. class TStreamValue : public TComputationValue<TStreamValue> {
  164. public:
  165. using TBase = TComputationValue<TStreamValue>;
  166. TStreamValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, const NUdf::TUnboxedValue& stream, IComputationExternalNode* item, IComputationNode* predicate)
  167. : TBase(memInfo)
  168. , CompCtx(compCtx)
  169. , Stream(stream)
  170. , Item(item)
  171. , Predicate(predicate)
  172. {
  173. }
  174. private:
  175. ui32 GetTraverseCount() const override {
  176. return 1;
  177. }
  178. NUdf::TUnboxedValue GetTraverseItem(ui32 index) const override {
  179. Y_UNUSED(index);
  180. return Stream;
  181. }
  182. NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) override {
  183. if (FilterWorkFinished) {
  184. return SkipOrTake ? Stream.Fetch(result) : NUdf::EFetchStatus::Finish;
  185. }
  186. if constexpr (SkipOrTake) {
  187. for (;;) {
  188. if (const auto status = Stream.Fetch(Item->RefValue(CompCtx)); status != NUdf::EFetchStatus::Ok) {
  189. return status;
  190. }
  191. if (!Predicate->GetValue(CompCtx).template Get<bool>()) {
  192. FilterWorkFinished = true;
  193. if constexpr (Inclusive) {
  194. return Stream.Fetch(result);
  195. } else {
  196. result = Item->GetValue(CompCtx);
  197. return NUdf::EFetchStatus::Ok;
  198. }
  199. }
  200. }
  201. } else {
  202. switch (const auto status = Stream.Fetch(Item->RefValue(CompCtx))) {
  203. case NUdf::EFetchStatus::Yield:
  204. return status;
  205. case NUdf::EFetchStatus::Ok:
  206. if (Predicate->GetValue(CompCtx).template Get<bool>()) {
  207. result = Item->GetValue(CompCtx);
  208. return NUdf::EFetchStatus::Ok;
  209. }
  210. case NUdf::EFetchStatus::Finish:
  211. break;
  212. }
  213. FilterWorkFinished = true;
  214. if constexpr (Inclusive) {
  215. result = Item->GetValue(CompCtx);
  216. return NUdf::EFetchStatus::Ok;
  217. } else {
  218. return NUdf::EFetchStatus::Finish;
  219. }
  220. }
  221. }
  222. TComputationContext& CompCtx;
  223. const NUdf::TUnboxedValue Stream;
  224. IComputationExternalNode* const Item;
  225. IComputationNode* const Predicate;
  226. bool FilterWorkFinished = false;
  227. };
  228. #ifndef MKQL_DISABLE_CODEGEN
  229. class TStreamCodegenWhileValue : public TStreamCodegenStatefulValueT<> {
  230. public:
  231. TStreamCodegenWhileValue(TMemoryUsageInfo* memInfo, TFetchPtr fetch, TComputationContext* ctx, NUdf::TUnboxedValue&& stream)
  232. : TStreamCodegenStatefulValueT(memInfo, fetch, ctx, std::move(stream))
  233. {}
  234. private:
  235. NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) final {
  236. return State ?
  237. (SkipOrTake ? Stream.Fetch(result) : NUdf::EFetchStatus::Finish):
  238. TStreamCodegenStatefulValueT::Fetch(result);
  239. }
  240. };
  241. class TCodegenIteratorWhile : public TCodegenStatefulIterator<> {
  242. public:
  243. TCodegenIteratorWhile(TMemoryUsageInfo* memInfo, TNextPtr next, TComputationContext* ctx, NUdf::TUnboxedValue&& iterator, const NUdf::TUnboxedValue& init)
  244. : TCodegenStatefulIterator(memInfo, next, ctx, std::move(iterator), init)
  245. {}
  246. private:
  247. bool Next(NUdf::TUnboxedValue& value) final {
  248. return State ?
  249. (SkipOrTake ? Iterator.Next(value) : false):
  250. TCodegenStatefulIterator::Next(value);
  251. }
  252. };
  253. using TCustomListCodegenWhileValue = TCustomListCodegenStatefulValueT<TCodegenIteratorWhile>;
  254. #endif
  255. TBaseWhileWrapper(IComputationNode* list, IComputationExternalNode* item, IComputationNode* predicate)
  256. : List(list), Item(item), Predicate(predicate)
  257. {}
  258. #ifndef MKQL_DISABLE_CODEGEN
  259. Function* GenerateFilter(NYql::NCodegen::ICodegen& codegen, const TString& name) const {
  260. auto& module = codegen.GetModule();
  261. auto& context = codegen.GetContext();
  262. const auto codegenItem = dynamic_cast<ICodegeneratorExternalNode*>(Item);
  263. MKQL_ENSURE(codegenItem, "Item must be codegenerator node.");
  264. if (const auto f = module.getFunction(name.c_str()))
  265. return f;
  266. const auto valueType = Type::getInt128Ty(context);
  267. const auto containerType = codegen.GetEffectiveTarget() == NYql::NCodegen::ETarget::Windows ? static_cast<Type*>(PointerType::getUnqual(valueType)) : static_cast<Type*>(valueType);
  268. const auto contextType = GetCompContextType(context);
  269. const auto statusType = IsStream ? Type::getInt32Ty(context) : Type::getInt1Ty(context);
  270. const auto funcType = FunctionType::get(statusType, {PointerType::getUnqual(contextType), containerType, PointerType::getUnqual(valueType), PointerType::getUnqual(valueType)}, false);
  271. TCodegenContext ctx(codegen);
  272. ctx.Func = cast<Function>(module.getOrInsertFunction(name.c_str(), funcType).getCallee());
  273. DISubprogramAnnotator annotator(ctx, ctx.Func);
  274. auto args = ctx.Func->arg_begin();
  275. ctx.Ctx = &*args;
  276. const auto containerArg = &*++args;
  277. const auto statePtr = &*++args;
  278. const auto valuePtr = &*++args;
  279. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  280. auto block = main;
  281. const auto container = codegen.GetEffectiveTarget() == NYql::NCodegen::ETarget::Windows ?
  282. new LoadInst(valueType, containerArg, "load_container", false, block) : static_cast<Value*>(containerArg);
  283. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  284. const auto stop = BasicBlock::Create(context, "stop", ctx.Func);
  285. const auto pass = BasicBlock::Create(context, "pass", ctx.Func);
  286. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  287. const auto loop = SkipOrTake ? BasicBlock::Create(context, "loop", ctx.Func) : nullptr;
  288. if (loop) {
  289. BranchInst::Create(loop, block);
  290. block = loop;
  291. }
  292. const auto itemPtr = codegenItem->CreateRefValue(ctx, block);
  293. const auto status = IsStream ?
  294. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Fetch>(statusType, container, codegen, block, itemPtr):
  295. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Next>(statusType, container, codegen, block, itemPtr);
  296. const auto icmp = IsStream ?
  297. CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, status, ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Ok)), "cond", block) : status;
  298. BranchInst::Create(good, done, icmp, block);
  299. block = good;
  300. const auto item = new LoadInst(valueType, itemPtr, "item", block);
  301. const auto predicate = GetNodeValue(Predicate, ctx, block);
  302. const auto boolPred = CastInst::Create(Instruction::Trunc, predicate, Type::getInt1Ty(context), "bool", block);
  303. BranchInst::Create(SkipOrTake ? loop : pass, stop, boolPred, block);
  304. block = stop;
  305. new StoreInst(GetTrue(context), statePtr, block);
  306. if constexpr (SkipOrTake) {
  307. if constexpr (Inclusive) {
  308. const auto last = IsStream ?
  309. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Fetch>(statusType, container, codegen, block, valuePtr):
  310. CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Next>(statusType, container, codegen, block, valuePtr);
  311. ReturnInst::Create(context, last, block);
  312. } else {
  313. BranchInst::Create(pass, block);
  314. }
  315. } else {
  316. if constexpr (Inclusive) {
  317. BranchInst::Create(pass, block);
  318. } else {
  319. ReturnInst::Create(context, IsStream ? ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Finish)) : ConstantInt::getFalse(context), block);
  320. }
  321. }
  322. block = pass;
  323. SafeUnRefUnboxed(valuePtr, ctx, block);
  324. new StoreInst(item, valuePtr, block);
  325. ValueAddRef(Item->GetRepresentation(), valuePtr, ctx, block);
  326. BranchInst::Create(done, block);
  327. block = done;
  328. ReturnInst::Create(context, status, block);
  329. return ctx.Func;
  330. }
  331. using TFilterPtr = std::conditional_t<IsStream, typename TStreamCodegenWhileValue::TFetchPtr, typename TCustomListCodegenWhileValue::TNextPtr>;
  332. Function* FilterFunc = nullptr;
  333. TFilterPtr Filter = nullptr;
  334. #endif
  335. IComputationNode* const List;
  336. IComputationExternalNode* const Item;
  337. IComputationNode* const Predicate;
  338. };
  339. template <bool SkipOrTake, bool Inclusive>
  340. class TStreamWhileWrapper : public TCustomValueCodegeneratorNode<TStreamWhileWrapper<SkipOrTake, Inclusive>>,
  341. private TBaseWhileWrapper<SkipOrTake, Inclusive, true> {
  342. typedef TBaseWhileWrapper<SkipOrTake, Inclusive, true> TBaseWrapper;
  343. typedef TCustomValueCodegeneratorNode<TStreamWhileWrapper<SkipOrTake, Inclusive>> TBaseComputation;
  344. public:
  345. TStreamWhileWrapper(TComputationMutables& mutables, IComputationNode* list, IComputationExternalNode* item, IComputationNode* predicate)
  346. : TBaseComputation(mutables), TBaseWrapper(list, item, predicate)
  347. {}
  348. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  349. #ifndef MKQL_DISABLE_CODEGEN
  350. if (ctx.ExecuteLLVM && this->Filter)
  351. return ctx.HolderFactory.Create<typename TBaseWrapper::TStreamCodegenWhileValue>(this->Filter, &ctx, this->List->GetValue(ctx));
  352. #endif
  353. return ctx.HolderFactory.Create<typename TBaseWrapper::TStreamValue>(ctx, this->List->GetValue(ctx), this->Item, this->Predicate);
  354. }
  355. private:
  356. void RegisterDependencies() const final {
  357. this->DependsOn(this->List);
  358. this->Own(this->Item);
  359. this->DependsOn(this->Predicate);
  360. }
  361. #ifndef MKQL_DISABLE_CODEGEN
  362. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  363. this->FilterFunc = this->GenerateFilter(codegen, TBaseComputation::MakeName("Fetch"));
  364. codegen.ExportSymbol(this->FilterFunc);
  365. }
  366. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  367. if (this->FilterFunc)
  368. this->Filter = reinterpret_cast<typename TBaseWrapper::TFilterPtr>(codegen.GetPointerToFunction(this->FilterFunc));
  369. }
  370. #endif
  371. };
  372. template <bool SkipOrTake, bool Inclusive>
  373. class TListWhileWrapper : public TBothWaysCodegeneratorNode<TListWhileWrapper<SkipOrTake, Inclusive>>,
  374. private TBaseWhileWrapper<SkipOrTake, Inclusive, false> {
  375. typedef TBaseWhileWrapper<SkipOrTake, Inclusive, false> TBaseWrapper;
  376. typedef TBothWaysCodegeneratorNode<TListWhileWrapper<SkipOrTake, Inclusive>> TBaseComputation;
  377. public:
  378. TListWhileWrapper(TComputationMutables& mutables, IComputationNode* list, IComputationExternalNode* item, IComputationNode* predicate)
  379. : TBaseComputation(mutables), TBaseWrapper(list, item, predicate)
  380. {}
  381. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  382. auto list = this->List->GetValue(ctx);
  383. if (const auto elements = list.GetElements()) {
  384. const auto size = list.GetListLength();
  385. const ui64 init = Inclusive ? 1ULL : 0ULL;
  386. auto todo = size;
  387. for (auto e = elements; todo > init; --todo) {
  388. this->Item->SetValue(ctx, NUdf::TUnboxedValue(*e++));
  389. if (!this->Predicate->GetValue(ctx).template Get<bool>())
  390. break;
  391. }
  392. if (init && todo) {
  393. todo -= init;
  394. }
  395. const auto pass = size - todo;
  396. const auto copy = SkipOrTake ? todo : pass;
  397. if (copy == size) {
  398. return list.Release();
  399. }
  400. NUdf::TUnboxedValue* items = nullptr;
  401. const auto result = ctx.HolderFactory.CreateDirectArrayHolder(copy, items);
  402. const auto from = SkipOrTake ? elements + pass : elements;
  403. std::copy_n(from, copy, items);
  404. return result;
  405. }
  406. return ctx.HolderFactory.Create<typename TBaseWrapper::TListValue>(ctx, std::move(list), this->Item, this->Predicate);
  407. }
  408. #ifndef MKQL_DISABLE_CODEGEN
  409. NUdf::TUnboxedValuePod MakeLazyList(TComputationContext& ctx, const NUdf::TUnboxedValuePod value) const {
  410. return ctx.HolderFactory.Create<typename TBaseWrapper::TCustomListCodegenWhileValue>(this->Filter, &ctx, value);
  411. }
  412. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  413. auto& context = ctx.Codegen.GetContext();
  414. const auto codegenItem = dynamic_cast<ICodegeneratorExternalNode*>(this->Item);
  415. MKQL_ENSURE(codegenItem, "Item must be codegenerator node.");
  416. const auto list = GetNodeValue(this->List, ctx, block);
  417. const auto lazy = BasicBlock::Create(context, "lazy", ctx.Func);
  418. const auto hard = BasicBlock::Create(context, "hard", ctx.Func);
  419. const auto done = BasicBlock::Create(context, "done", ctx.Func);
  420. const auto out = PHINode::Create(list->getType(), 4U, "out", done);
  421. const auto elementsType = PointerType::getUnqual(list->getType());
  422. const auto elements = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetElements>(elementsType, list, ctx.Codegen, block);
  423. const auto fill = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, elements, ConstantPointerNull::get(elementsType), "fill", block);
  424. BranchInst::Create(hard, lazy, fill, block);
  425. {
  426. block = hard;
  427. const auto size = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::GetListLength>(Type::getInt64Ty(context), list, ctx.Codegen, block);
  428. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  429. const auto test = BasicBlock::Create(context, "test", ctx.Func);
  430. const auto stop = BasicBlock::Create(context, "stop", ctx.Func);
  431. const auto make = BasicBlock::Create(context, "make", ctx.Func);
  432. const auto index = PHINode::Create(size->getType(), 2U, "index", loop);
  433. const auto zero = ConstantInt::get(size->getType(), 0);
  434. index->addIncoming(zero, block);
  435. const auto none = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, zero, size, "none", block);
  436. out->addIncoming(list, block);
  437. BranchInst::Create(done, loop, none, block);
  438. block = loop;
  439. const auto ptr = GetElementPtrInst::CreateInBounds(list->getType(), elements, {index}, "ptr", block);
  440. const auto plus = BinaryOperator::CreateAdd(index, ConstantInt::get(size->getType(), 1), "plus", block);
  441. const auto more = CmpInst::Create(Instruction::ICmp, Inclusive ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_ULE, plus, size, "more", block);
  442. BranchInst::Create(test, stop, more, block);
  443. block = test;
  444. const auto item = new LoadInst(list->getType(), ptr, "item", block);
  445. codegenItem->CreateSetValue(ctx, block, item);
  446. const auto predicate = GetNodeValue(this->Predicate, ctx, block);
  447. const auto boolPred = CastInst::Create(Instruction::Trunc, predicate, Type::getInt1Ty(context), "bool", block);
  448. index->addIncoming(plus, block);
  449. BranchInst::Create(loop, stop, boolPred, block);
  450. block = stop;
  451. const auto pass = Inclusive ? static_cast<Value*>(plus) : static_cast<Value*>(index);
  452. const auto copy = SkipOrTake ? BinaryOperator::CreateSub(size, pass, "copy", block) : pass;
  453. const auto asis = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, size, copy, "asis", block);
  454. out->addIncoming(list, block);
  455. BranchInst::Create(done, make, asis, block);
  456. block = make;
  457. const auto itemsType = PointerType::getUnqual(list->getType());
  458. const auto itemsPtr = *this->Stateless || ctx.AlwaysInline ?
  459. new AllocaInst(itemsType, 0U, "items_ptr", &ctx.Func->getEntryBlock().back()):
  460. new AllocaInst(itemsType, 0U, "items_ptr", block);
  461. const auto array = GenNewArray(ctx, copy, itemsPtr, block);
  462. const auto items = new LoadInst(itemsType, itemsPtr, "items", block);
  463. const auto from = SkipOrTake ? GetElementPtrInst::CreateInBounds(list->getType(), elements, {pass}, "from", block) : elements;
  464. const auto move = BasicBlock::Create(context, "move", ctx.Func);
  465. const auto step = BasicBlock::Create(context, "step", ctx.Func);
  466. const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
  467. const auto idx = PHINode::Create(copy->getType(), 2U, "idx", move);
  468. idx->addIncoming(ConstantInt::get(copy->getType(), 0), block);
  469. BranchInst::Create(move, block);
  470. block = move;
  471. const auto finish = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGE, idx, copy, "finish", block);
  472. BranchInst::Create(exit, step, finish, block);
  473. block = step;
  474. const auto src = GetElementPtrInst::CreateInBounds(list->getType(), from, {idx}, "src", block);
  475. const auto itm = new LoadInst(list->getType(), src, "item", block);
  476. ValueAddRef(this->Item->GetRepresentation(), itm, ctx, block);
  477. const auto dst = GetElementPtrInst::CreateInBounds(list->getType(), items, {idx}, "dst", block);
  478. new StoreInst(itm, dst, block);
  479. const auto inc = BinaryOperator::CreateAdd(idx, ConstantInt::get(idx->getType(), 1), "inc", block);
  480. idx->addIncoming(inc, block);
  481. BranchInst::Create(move, block);
  482. block = exit;
  483. if (this->List->IsTemporaryValue()) {
  484. CleanupBoxed(list, ctx, block);
  485. }
  486. out->addIncoming(array, block);
  487. BranchInst::Create(done, block);
  488. }
  489. {
  490. block = lazy;
  491. const auto doFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TListWhileWrapper::MakeLazyList));
  492. const auto ptrType = PointerType::getUnqual(StructType::get(context));
  493. const auto self = CastInst::Create(Instruction::IntToPtr, ConstantInt::get(Type::getInt64Ty(context), uintptr_t(this)), ptrType, "self", block);
  494. if (NYql::NCodegen::ETarget::Windows != ctx.Codegen.GetEffectiveTarget()) {
  495. const auto funType = FunctionType::get(list->getType() , {self->getType(), ctx.Ctx->getType(), list->getType()}, false);
  496. const auto doFuncPtr = CastInst::Create(Instruction::IntToPtr, doFunc, PointerType::getUnqual(funType), "function", block);
  497. const auto value = CallInst::Create(funType, doFuncPtr, {self, ctx.Ctx, list}, "value", block);
  498. out->addIncoming(value, block);
  499. } else {
  500. const auto resultPtr = new AllocaInst(list->getType(), 0U, "return", block);
  501. new StoreInst(list, resultPtr, block);
  502. const auto funType = FunctionType::get(Type::getVoidTy(context), {self->getType(), resultPtr->getType(), ctx.Ctx->getType(), resultPtr->getType()}, false);
  503. const auto doFuncPtr = CastInst::Create(Instruction::IntToPtr, doFunc, PointerType::getUnqual(funType), "function", block);
  504. CallInst::Create(funType, doFuncPtr, {self, resultPtr, ctx.Ctx, resultPtr}, "", block);
  505. const auto value = new LoadInst(list->getType(), resultPtr, "value", block);
  506. out->addIncoming(value, block);
  507. }
  508. BranchInst::Create(done, block);
  509. }
  510. block = done;
  511. return out;
  512. }
  513. #endif
  514. private:
  515. void RegisterDependencies() const final {
  516. this->DependsOn(this->List);
  517. this->Own(this->Item);
  518. this->DependsOn(this->Predicate);
  519. }
  520. #ifndef MKQL_DISABLE_CODEGEN
  521. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  522. TMutableCodegeneratorRootNode<TListWhileWrapper<SkipOrTake, Inclusive>>::GenerateFunctions(codegen);
  523. this->FilterFunc = this->GenerateFilter(codegen, TBaseComputation::MakeName("Next"));
  524. codegen.ExportSymbol(this->FilterFunc);
  525. }
  526. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  527. TMutableCodegeneratorRootNode<TListWhileWrapper<SkipOrTake, Inclusive>>::FinalizeFunctions(codegen);
  528. if (this->FilterFunc)
  529. this->Filter = reinterpret_cast<typename TBaseWrapper::TFilterPtr>(codegen.GetPointerToFunction(this->FilterFunc));
  530. }
  531. #endif
  532. };
  533. template <bool SkipOrTake, bool Inclusive>
  534. IComputationNode* WrapFilterWhile(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  535. MKQL_ENSURE(callable.GetInputsCount() == 3, "Expected 3 args");
  536. const auto type = callable.GetType()->GetReturnType();
  537. const auto predicateType = AS_TYPE(TDataType, callable.GetInput(2));
  538. MKQL_ENSURE(predicateType->GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected bool");
  539. const auto flow = LocateNode(ctx.NodeLocator, callable, 0);
  540. const auto predicate = LocateNode(ctx.NodeLocator, callable, 2);
  541. const auto itemArg = LocateExternalNode(ctx.NodeLocator, callable, 1);
  542. if (type->IsFlow()) {
  543. return new TWhileFlowWrapper<SkipOrTake, Inclusive>(ctx.Mutables, GetValueRepresentation(type), flow, itemArg, predicate);
  544. } else if (type->IsStream()) {
  545. return new TStreamWhileWrapper<SkipOrTake, Inclusive>(ctx.Mutables, flow, itemArg, predicate);
  546. } else if (type->IsList()) {
  547. return new TListWhileWrapper<SkipOrTake, Inclusive>(ctx.Mutables, flow, itemArg, predicate);
  548. }
  549. THROW yexception() << "Expected flow, list or stream.";
  550. }
  551. }
  552. IComputationNode* WrapTakeWhile(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  553. return WrapFilterWhile<false, false>(callable, ctx);
  554. }
  555. IComputationNode* WrapSkipWhile(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  556. return WrapFilterWhile<true, false>(callable, ctx);
  557. }
  558. IComputationNode* WrapTakeWhileInclusive(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  559. return WrapFilterWhile<false, true>(callable, ctx);
  560. }
  561. IComputationNode* WrapSkipWhileInclusive(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  562. return WrapFilterWhile<true, true>(callable, ctx);
  563. }
  564. }
  565. }