mkql_condense1.cpp 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602
  1. #include "mkql_condense1.h"
  2. #include "mkql_squeeze_state.h"
  3. #include <yql/essentials/minikql/mkql_node_cast.h>
  4. #include <yql/essentials/minikql/mkql_node_builder.h>
  5. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  6. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  7. namespace NKikimr {
  8. namespace NMiniKQL {
  9. namespace {
  10. template <bool Interruptable, bool UseCtx>
  11. class TCondense1FlowWrapper : public TStatefulFlowCodegeneratorNode<TCondense1FlowWrapper<Interruptable, UseCtx>> {
  12. typedef TStatefulFlowCodegeneratorNode<TCondense1FlowWrapper<Interruptable, UseCtx>> TBaseComputation;
  13. public:
  14. TCondense1FlowWrapper(
  15. TComputationMutables& mutables,
  16. EValueRepresentation kind,
  17. IComputationNode* flow,
  18. IComputationExternalNode* item,
  19. IComputationExternalNode* state,
  20. IComputationNode* outSwitch,
  21. IComputationNode* initState,
  22. IComputationNode* updateState)
  23. : TBaseComputation(mutables, flow, kind, EValueRepresentation::Embedded),
  24. Flow(flow), Item(item), State(state), Switch(outSwitch), InitState(initState), UpdateState(updateState)
  25. {}
  26. NUdf::TUnboxedValuePod DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  27. if (state.IsFinish()) {
  28. return static_cast<const NUdf::TUnboxedValuePod&>(state);
  29. } else if (state.HasValue()) {
  30. if constexpr (UseCtx) {
  31. CleanupCurrentContext();
  32. }
  33. state = NUdf::TUnboxedValuePod();
  34. State->SetValue(ctx, InitState->GetValue(ctx));
  35. }
  36. while (true) {
  37. auto item = Flow->GetValue(ctx);
  38. if (item.IsYield()) {
  39. return item.Release();
  40. }
  41. if (item.IsFinish()) {
  42. break;
  43. }
  44. Item->SetValue(ctx, std::move(item));
  45. if (state.IsInvalid()) {
  46. state = NUdf::TUnboxedValuePod();
  47. State->SetValue(ctx, InitState->GetValue(ctx));
  48. } else {
  49. if (Switch) {
  50. const auto& reset = Switch->GetValue(ctx);
  51. if (Interruptable && !reset) {
  52. break;
  53. }
  54. if (reset.template Get<bool>()) {
  55. state = NUdf::TUnboxedValuePod::Zero();
  56. return State->GetValue(ctx).Release();
  57. }
  58. }
  59. State->SetValue(ctx, UpdateState->GetValue(ctx));
  60. }
  61. }
  62. const bool empty = state.IsInvalid();
  63. state = NUdf::TUnboxedValuePod::MakeFinish();
  64. return empty ? NUdf::TUnboxedValuePod::MakeFinish() : State->GetValue(ctx).Release();
  65. }
  66. #ifndef MKQL_DISABLE_CODEGEN
  67. Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  68. auto& context = ctx.Codegen.GetContext();
  69. const auto codegenItem = dynamic_cast<ICodegeneratorExternalNode*>(Item);
  70. MKQL_ENSURE(codegenItem, "Item must be codegenerator node.");
  71. const auto codegenState = dynamic_cast<ICodegeneratorExternalNode*>(State);
  72. MKQL_ENSURE(codegenState, "State must be codegenerator node.");
  73. const auto step = BasicBlock::Create(context, "step", ctx.Func);
  74. const auto frst = BasicBlock::Create(context, "frst", ctx.Func);
  75. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  76. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  77. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  78. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  79. const auto stop = BasicBlock::Create(context, "stop", ctx.Func);
  80. const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
  81. const auto valueType = Type::getInt128Ty(context);
  82. const auto state = new LoadInst(valueType, statePtr, "state", block);
  83. const auto result = PHINode::Create(valueType, Switch ? 4U : 3U, "result", exit);
  84. result->addIncoming(state, block);
  85. const auto way = SwitchInst::Create(state, frst, 2U, block);
  86. way->addCase(GetFalse(context), step);
  87. way->addCase(GetFinish(context), exit);
  88. block = step;
  89. if constexpr (UseCtx) {
  90. const auto cleanup = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&CleanupCurrentContext));
  91. const auto cleanupType = FunctionType::get(Type::getVoidTy(context), {}, false);
  92. const auto cleanupPtr = CastInst::Create(Instruction::IntToPtr, cleanup, PointerType::getUnqual(cleanupType), "cleanup_ctx", block);
  93. CallInst::Create(cleanupType, cleanupPtr, {}, "", block);
  94. }
  95. new StoreInst(GetEmpty(context), statePtr, block);
  96. codegenState->CreateSetValue(ctx, block, GetNodeValue(InitState, ctx, block));
  97. BranchInst::Create(frst, block);
  98. block = frst;
  99. const auto invalid = IsInvalid(state, block, context);
  100. const auto empty = PHINode::Create(invalid->getType(), 3U, "empty", work);
  101. empty->addIncoming(invalid, block);
  102. BranchInst::Create(work, block);
  103. block = work;
  104. const auto item = GetNodeValue(Flow, ctx, block);
  105. result->addIncoming(item, block);
  106. const auto action = SwitchInst::Create(item, good, 2U, block);
  107. action->addCase(GetFinish(context), stop);
  108. action->addCase(GetYield(context), exit);
  109. block = good;
  110. codegenItem->CreateSetValue(ctx, block, item);
  111. BranchInst::Create(init, next, empty, block);
  112. block = init;
  113. new StoreInst(GetEmpty(context), statePtr, block);
  114. codegenState->CreateSetValue(ctx, block, GetNodeValue(InitState, ctx, block));
  115. empty->addIncoming(ConstantInt::getFalse(context), block);
  116. BranchInst::Create(work, block);
  117. block = next;
  118. if (Switch) {
  119. const auto swap = BasicBlock::Create(context, "swap", ctx.Func);
  120. const auto skip = BasicBlock::Create(context, "skip", ctx.Func);
  121. const auto reset = GetNodeValue(Switch, ctx, block);
  122. if constexpr (Interruptable) {
  123. const auto pass = BasicBlock::Create(context, "pass", ctx.Func);
  124. BranchInst::Create(stop, pass, IsEmpty(reset, block, context), block);
  125. block = pass;
  126. }
  127. const auto cast = CastInst::Create(Instruction::Trunc, reset, Type::getInt1Ty(context), "bool", block);
  128. BranchInst::Create(swap, skip, cast, block);
  129. block = swap;
  130. new StoreInst(GetFalse(context), statePtr, block);
  131. const auto output = codegenState->CreateGetValue(ctx, block);
  132. result->addIncoming(output, block);
  133. BranchInst::Create(exit, block);
  134. block = skip;
  135. }
  136. codegenState->CreateSetValue(ctx, block, GetNodeValue(UpdateState, ctx, block));
  137. empty->addIncoming(ConstantInt::getFalse(context), block);
  138. BranchInst::Create(work, block);
  139. block = stop;
  140. new StoreInst(GetFinish(context), statePtr, block);
  141. const auto output = codegenState->CreateGetValue(ctx, block);
  142. const auto select = SelectInst::Create(empty, GetFinish(context), output, "output", block);
  143. result->addIncoming(select, block);
  144. BranchInst::Create(exit, block);
  145. block = exit;
  146. return result;
  147. }
  148. #endif
  149. private:
  150. void RegisterDependencies() const final {
  151. if (const auto flow = this->FlowDependsOn(Flow)) {
  152. this->Own(flow, Item);
  153. this->Own(flow, State);
  154. this->DependsOn(flow, Switch);
  155. this->DependsOn(flow, InitState);
  156. this->DependsOn(flow, UpdateState);
  157. }
  158. }
  159. IComputationNode* const Flow;
  160. IComputationExternalNode* const Item;
  161. IComputationExternalNode* const State;
  162. IComputationNode* const Switch;
  163. IComputationNode* const InitState;
  164. IComputationNode* const UpdateState;
  165. };
  166. template <bool Interruptable, bool UseCtx>
  167. class TCondense1Wrapper : public TCustomValueCodegeneratorNode<TCondense1Wrapper<Interruptable, UseCtx>> {
  168. typedef TCustomValueCodegeneratorNode<TCondense1Wrapper<Interruptable, UseCtx>> TBaseComputation;
  169. public:
  170. class TValue : public TComputationValue<TValue> {
  171. public:
  172. using TBase = TComputationValue<TValue>;
  173. TValue(
  174. TMemoryUsageInfo* memInfo,
  175. NUdf::TUnboxedValue&& stream,
  176. const TSqueezeState& state,
  177. TComputationContext& ctx)
  178. : TBase(memInfo)
  179. , Stream(std::move(stream))
  180. , Ctx(ctx)
  181. , State(state)
  182. {}
  183. private:
  184. ui32 GetTraverseCount() const final {
  185. return 1;
  186. }
  187. NUdf::TUnboxedValue GetTraverseItem(ui32 index) const final {
  188. Y_UNUSED(index);
  189. return Stream;
  190. }
  191. NUdf::TUnboxedValue Save() const final {
  192. return State.Save(Ctx);
  193. }
  194. void Load(const NUdf::TStringRef& state) final {
  195. State.Load(Ctx, state);
  196. }
  197. NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) final {
  198. switch (State.Stage) {
  199. case ESqueezeState::Finished:
  200. return NUdf::EFetchStatus::Finish;
  201. case ESqueezeState::NeedInit:
  202. if constexpr (UseCtx) {
  203. CleanupCurrentContext();
  204. }
  205. State.Stage = ESqueezeState::Work;
  206. State.State->SetValue(Ctx, State.InitState->GetValue(Ctx));
  207. break;
  208. default:
  209. break;
  210. }
  211. while (true) {
  212. const auto status = Stream.Fetch(State.Item->RefValue(Ctx));
  213. if (status == NUdf::EFetchStatus::Yield) {
  214. return status;
  215. }
  216. if (status == NUdf::EFetchStatus::Finish) {
  217. break;
  218. }
  219. if (ESqueezeState::Idle == State.Stage) {
  220. State.Stage = ESqueezeState::Work;
  221. State.State->SetValue(Ctx, State.InitState->GetValue(Ctx));
  222. } else {
  223. if (State.Switch) {
  224. const auto& reset = State.Switch->GetValue(Ctx);
  225. if (Interruptable && !reset) {
  226. break;
  227. }
  228. if (reset.template Get<bool>()) {
  229. State.Stage = ESqueezeState::NeedInit;
  230. result = State.State->GetValue(Ctx);
  231. return NUdf::EFetchStatus::Ok;
  232. }
  233. }
  234. State.State->SetValue(Ctx, State.UpdateState->GetValue(Ctx));
  235. }
  236. }
  237. if (ESqueezeState::Idle == State.Stage) {
  238. State.Stage = ESqueezeState::Finished;
  239. return NUdf::EFetchStatus::Finish;
  240. }
  241. result = State.State->GetValue(Ctx);
  242. State.Stage = ESqueezeState::Finished;
  243. return NUdf::EFetchStatus::Ok;
  244. }
  245. const NUdf::TUnboxedValue Stream;
  246. TComputationContext& Ctx;
  247. TSqueezeState State;
  248. };
  249. TCondense1Wrapper(
  250. TComputationMutables& mutables,
  251. IComputationNode* stream,
  252. IComputationExternalNode* item,
  253. IComputationExternalNode* state,
  254. IComputationNode* outSwitch,
  255. IComputationNode* initState,
  256. IComputationNode* updateState,
  257. IComputationExternalNode* inSave = nullptr,
  258. IComputationNode* outSave = nullptr,
  259. IComputationExternalNode* inLoad = nullptr,
  260. IComputationNode* outLoad = nullptr,
  261. TType* stateType = nullptr)
  262. : TBaseComputation(mutables)
  263. , Stream(stream)
  264. , State(item, state, outSwitch, initState, updateState, inSave, outSave, inLoad, outLoad, stateType)
  265. {
  266. this->Stateless = false;
  267. }
  268. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  269. #ifndef MKQL_DISABLE_CODEGEN
  270. if (ctx.ExecuteLLVM && Fetch)
  271. return ctx.HolderFactory.Create<TSqueezeCodegenValue>(State, Fetch, ctx, Stream->GetValue(ctx));
  272. #endif
  273. return ctx.HolderFactory.Create<TValue>(Stream->GetValue(ctx), State, ctx);
  274. }
  275. private:
  276. void RegisterDependencies() const final {
  277. this->DependsOn(Stream);
  278. this->Own(State.Item);
  279. this->Own(State.State);
  280. this->DependsOn(State.Switch);
  281. this->DependsOn(State.InitState);
  282. this->DependsOn(State.UpdateState);
  283. this->Own(State.InSave);
  284. this->DependsOn(State.OutSave);
  285. this->Own(State.InLoad);
  286. this->DependsOn(State.OutLoad);
  287. }
  288. #ifndef MKQL_DISABLE_CODEGEN
  289. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  290. FetchFunc = GenerateFetch(codegen);
  291. codegen.ExportSymbol(FetchFunc);
  292. }
  293. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  294. if (FetchFunc) {
  295. Fetch = reinterpret_cast<TFetchPtr>(codegen.GetPointerToFunction(FetchFunc));
  296. }
  297. }
  298. Function* GenerateFetch(NYql::NCodegen::ICodegen& codegen) const {
  299. auto& module = codegen.GetModule();
  300. auto& context = codegen.GetContext();
  301. const auto codegenItemArg = dynamic_cast<ICodegeneratorExternalNode*>(State.Item);
  302. const auto codegenStateArg = dynamic_cast<ICodegeneratorExternalNode*>(State.State);
  303. MKQL_ENSURE(codegenItemArg, "Item arg must be codegenerator node.");
  304. MKQL_ENSURE(codegenStateArg, "State arg must be codegenerator node.");
  305. const auto& name = TBaseComputation::MakeName("Fetch");
  306. if (const auto f = module.getFunction(name.c_str()))
  307. return f;
  308. const auto valueType = Type::getInt128Ty(context);
  309. const auto containerType = static_cast<Type*>(valueType);
  310. const auto contextType = GetCompContextType(context);
  311. const auto statusType = Type::getInt32Ty(context);
  312. const auto stateType = Type::getInt8Ty(context);
  313. const auto funcType = FunctionType::get(statusType, {PointerType::getUnqual(contextType), containerType, PointerType::getUnqual(valueType), PointerType::getUnqual(stateType)}, false);
  314. TCodegenContext ctx(codegen);
  315. ctx.Func = cast<Function>(module.getOrInsertFunction(name.c_str(), funcType).getCallee());
  316. DISubprogramAnnotator annotator(ctx, ctx.Func);
  317. auto args = ctx.Func->arg_begin();
  318. ctx.Ctx = &*args;
  319. const auto containerArg = &*++args;
  320. const auto valuePtr = &*++args;
  321. const auto statePtr = &*++args;
  322. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  323. auto block = main;
  324. const auto container = static_cast<Value*>(containerArg);
  325. const auto step = BasicBlock::Create(context, "step", ctx.Func);
  326. const auto none = BasicBlock::Create(context, "none", ctx.Func);
  327. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  328. const auto state0 = new LoadInst(stateType, statePtr, "state0", block);
  329. const auto select = SwitchInst::Create(state0, loop, 2U, block);
  330. select->addCase(ConstantInt::get(stateType, static_cast<ui8>(ESqueezeState::Finished)), none);
  331. select->addCase(ConstantInt::get(stateType, static_cast<ui8>(ESqueezeState::NeedInit)), step);
  332. block = none;
  333. ReturnInst::Create(context, ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Finish)), block);
  334. block = step;
  335. if constexpr (UseCtx) {
  336. const auto cleanup = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&CleanupCurrentContext));
  337. const auto cleanupType = FunctionType::get(Type::getVoidTy(context), {}, false);
  338. const auto cleanupPtr = CastInst::Create(Instruction::IntToPtr, cleanup, PointerType::getUnqual(cleanupType), "cleanup_ctx", block);
  339. CallInst::Create(cleanupType, cleanupPtr, {}, "", block);
  340. }
  341. new StoreInst(ConstantInt::get(state0->getType(), static_cast<ui8>(ESqueezeState::Work)), statePtr, block);
  342. codegenStateArg->CreateSetValue(ctx, block, GetNodeValue(State.InitState, ctx, block));
  343. BranchInst::Create(loop, block);
  344. block = loop;
  345. const auto itemPtr = codegenItemArg->CreateRefValue(ctx, block);
  346. const auto status = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Fetch>(statusType, container, codegen, block, itemPtr);
  347. const auto ychk = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, status, ConstantInt::get(status->getType(), static_cast<ui32>(NUdf::EFetchStatus::Yield)), "ychk", block);
  348. const auto yies = BasicBlock::Create(context, "yies", ctx.Func);
  349. const auto nope = BasicBlock::Create(context, "nope", ctx.Func);
  350. BranchInst::Create(yies, nope, ychk, block);
  351. block = yies;
  352. ReturnInst::Create(context, status, block);
  353. block = nope;
  354. const auto icmp = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, status, ConstantInt::get(status->getType(), static_cast<ui32>(NUdf::EFetchStatus::Finish)), "cond", block);
  355. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  356. const auto stop = BasicBlock::Create(context, "stop", ctx.Func);
  357. BranchInst::Create(stop, good, icmp, block);
  358. block = good;
  359. const auto state1 = new LoadInst(stateType, statePtr, "state1", block);
  360. const auto one = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, state1, ConstantInt::get(state1->getType(), static_cast<ui8>(ESqueezeState::Idle)), "one", block);
  361. const auto wait = BasicBlock::Create(context, "wait", ctx.Func);
  362. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  363. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  364. const auto phi = PHINode::Create(valueType, 2, "phi", next);
  365. BranchInst::Create(wait, work, one, block);
  366. block = wait;
  367. new StoreInst(ConstantInt::get(state1->getType(), static_cast<ui8>(ESqueezeState::Work)), statePtr, block);
  368. const auto reset = GetNodeValue(State.InitState, ctx, block);
  369. phi->addIncoming(reset, block);
  370. BranchInst::Create(next, block);
  371. block = work;
  372. if (State.Switch) {
  373. const auto swap = BasicBlock::Create(context, "swap", ctx.Func);
  374. const auto skip = BasicBlock::Create(context, "skip", ctx.Func);
  375. const auto reset = GetNodeValue(State.Switch, ctx, block);
  376. if constexpr (Interruptable) {
  377. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  378. const auto done = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, reset, ConstantInt::get(reset->getType(), 0), "done", block);
  379. BranchInst::Create(stop, next, done, block);
  380. block = next;
  381. }
  382. const auto cast = CastInst::Create(Instruction::Trunc, reset, Type::getInt1Ty(context), "bool", block);
  383. BranchInst::Create(swap, skip, cast, block);
  384. block = swap;
  385. new StoreInst(ConstantInt::get(state1->getType(), static_cast<ui8>(ESqueezeState::NeedInit)), statePtr, block);
  386. SafeUnRefUnboxedOne(valuePtr, ctx, block);
  387. const auto state = codegenStateArg->CreateGetValue(ctx, block);
  388. new StoreInst(state, valuePtr, block);
  389. ValueAddRef(State.State->GetRepresentation(), valuePtr, ctx, block);
  390. ReturnInst::Create(context, ConstantInt::get(status->getType(), static_cast<ui32>(NUdf::EFetchStatus::Ok)), block);
  391. block = skip;
  392. }
  393. const auto state = GetNodeValue(State.UpdateState, ctx, block);
  394. phi->addIncoming(state, block);
  395. BranchInst::Create(next, block);
  396. block = next;
  397. codegenStateArg->CreateSetValue(ctx, block, phi);
  398. BranchInst::Create(loop, block);
  399. block = stop;
  400. const auto state2 = new LoadInst(stateType, statePtr, "state2", block);
  401. const auto full = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, state2, ConstantInt::get(state2->getType(), static_cast<ui8>(ESqueezeState::Work)), "full", block);
  402. new StoreInst(ConstantInt::get(state2->getType(), static_cast<ui8>(ESqueezeState::Finished)), statePtr, block);
  403. const auto fill = BasicBlock::Create(context, "fill", ctx.Func);
  404. BranchInst::Create(fill, yies, full, block);
  405. block = fill;
  406. SafeUnRefUnboxedOne(valuePtr, ctx, block);
  407. const auto result = codegenStateArg->CreateGetValue(ctx, block);
  408. new StoreInst(result, valuePtr, block);
  409. ValueAddRef(State.State->GetRepresentation(), valuePtr, ctx, block);
  410. ReturnInst::Create(context, ConstantInt::get(status->getType(), static_cast<ui32>(NUdf::EFetchStatus::Ok)), block);
  411. return ctx.Func;
  412. }
  413. using TFetchPtr = TSqueezeCodegenValue::TFetchPtr;
  414. Function* FetchFunc = nullptr;
  415. TFetchPtr Fetch = nullptr;
  416. #endif
  417. IComputationNode* const Stream;
  418. TSqueezeState State;
  419. };
  420. }
  421. template <bool UseCtx>
  422. IComputationNode* WrapCondense1Impl(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  423. const auto stream = LocateNode(ctx.NodeLocator, callable, 0);
  424. const auto initState = LocateNode(ctx.NodeLocator, callable, 2);
  425. const auto outSwitch = LocateNode(ctx.NodeLocator, callable, 4);
  426. const auto updateState = LocateNode(ctx.NodeLocator, callable, 5);
  427. const auto item = LocateExternalNode(ctx.NodeLocator, callable, 1);
  428. const auto state = LocateExternalNode(ctx.NodeLocator, callable, 3);
  429. bool isOptional;
  430. const auto dataType = UnpackOptionalData(callable.GetInput(4), isOptional);
  431. MKQL_ENSURE(dataType->GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected bool.");
  432. const auto type = callable.GetType()->GetReturnType();
  433. if (type->IsFlow()) {
  434. const auto kind = GetValueRepresentation(AS_TYPE(TFlowType, type)->GetItemType());
  435. if (isOptional) {
  436. return new TCondense1FlowWrapper<true, UseCtx>(ctx.Mutables, kind, stream, item, state, outSwitch, initState, updateState);
  437. } else {
  438. return new TCondense1FlowWrapper<false, UseCtx>(ctx.Mutables, kind, stream, item, state, outSwitch, initState, updateState);
  439. }
  440. } else {
  441. if (isOptional) {
  442. return new TCondense1Wrapper<true, UseCtx>(ctx.Mutables, stream, item, state, outSwitch, initState, updateState);
  443. } else {
  444. return new TCondense1Wrapper<false, UseCtx>(ctx.Mutables, stream, item, state, outSwitch, initState, updateState);
  445. }
  446. }
  447. }
  448. IComputationNode* WrapCondense1(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  449. MKQL_ENSURE(callable.GetInputsCount() == 6 || callable.GetInputsCount() == 7, "Expected 6 or 7 args");
  450. bool useCtx = false;
  451. if (callable.GetInputsCount() >= 7) {
  452. useCtx = AS_VALUE(TDataLiteral, callable.GetInput(6))->AsValue().Get<bool>();
  453. }
  454. if (useCtx) {
  455. return WrapCondense1Impl<true>(callable, ctx);
  456. } else {
  457. return WrapCondense1Impl<false>(callable, ctx);
  458. }
  459. }
  460. IComputationNode* WrapSqueeze1(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  461. MKQL_ENSURE(callable.GetInputsCount() == 9, "Expected 9 args");
  462. const auto stream = LocateNode(ctx.NodeLocator, callable, 0);
  463. const auto initState = LocateNode(ctx.NodeLocator, callable, 2);
  464. const auto updateState = LocateNode(ctx.NodeLocator, callable, 4);
  465. const auto item = LocateExternalNode(ctx.NodeLocator, callable, 1);
  466. const auto state = LocateExternalNode(ctx.NodeLocator, callable, 3);
  467. IComputationExternalNode* inSave = nullptr;
  468. IComputationNode* outSave = nullptr;
  469. IComputationExternalNode* inLoad = nullptr;
  470. IComputationNode* outLoad = nullptr;
  471. const auto hasSaveLoad = !callable.GetInput(6).GetStaticType()->IsVoid();
  472. if (hasSaveLoad) {
  473. outSave = LocateNode(ctx.NodeLocator, callable, 6);
  474. outLoad = LocateNode(ctx.NodeLocator, callable, 8);
  475. inSave = LocateExternalNode(ctx.NodeLocator, callable, 5);
  476. inLoad = LocateExternalNode(ctx.NodeLocator, callable, 7);
  477. }
  478. const auto stateType = hasSaveLoad ? callable.GetInput(6).GetStaticType() : nullptr;
  479. return new TCondense1Wrapper<false, false>(ctx.Mutables, stream, item, state, nullptr, initState, updateState, inSave, outSave, inLoad, outLoad, stateType);
  480. }
  481. }
  482. }