mkql_condense.cpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569
  1. #include "mkql_condense.h"
  2. #include "mkql_squeeze_state.h"
  3. #include <yql/essentials/minikql/mkql_node_cast.h>
  4. #include <yql/essentials/minikql/mkql_node_builder.h>
  5. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  6. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  7. namespace NKikimr {
  8. namespace NMiniKQL {
  9. namespace {
  10. template <bool Interruptable, bool UseCtx>
  11. class TCondenseFlowWrapper : public TStatefulFlowCodegeneratorNode<TCondenseFlowWrapper<Interruptable, UseCtx>> {
  12. typedef TStatefulFlowCodegeneratorNode<TCondenseFlowWrapper<Interruptable, UseCtx>> TBaseComputation;
  13. public:
  14. TCondenseFlowWrapper(
  15. TComputationMutables& mutables,
  16. EValueRepresentation kind,
  17. IComputationNode* flow,
  18. IComputationExternalNode* item,
  19. IComputationExternalNode* state,
  20. IComputationNode* outSwitch,
  21. IComputationNode* initState,
  22. IComputationNode* updateState)
  23. : TBaseComputation(mutables, flow, kind, EValueRepresentation::Embedded),
  24. Flow(flow), Item(item), State(state), Switch(outSwitch), InitState(initState), UpdateState(updateState)
  25. {}
  26. NUdf::TUnboxedValuePod DoCalculate(NUdf::TUnboxedValue& state, TComputationContext& ctx) const {
  27. if (state.IsFinish()) {
  28. return static_cast<const NUdf::TUnboxedValuePod&>(state);
  29. }
  30. if (state.IsInvalid()) {
  31. state = NUdf::TUnboxedValuePod();
  32. State->SetValue(ctx, InitState->GetValue(ctx));
  33. } else if (state.HasValue()) {
  34. if constexpr (UseCtx) {
  35. CleanupCurrentContext();
  36. }
  37. state = NUdf::TUnboxedValuePod();
  38. State->SetValue(ctx, InitState->GetValue(ctx));
  39. State->SetValue(ctx, UpdateState->GetValue(ctx));
  40. }
  41. while (true) {
  42. auto item = Flow->GetValue(ctx);
  43. if (item.IsYield()) {
  44. return item.Release();
  45. }
  46. if (item.IsFinish()) {
  47. break;
  48. }
  49. Item->SetValue(ctx, std::move(item));
  50. if (Switch) {
  51. const auto& reset = Switch->GetValue(ctx);
  52. if (Interruptable && !reset) {
  53. break;
  54. }
  55. if (reset.template Get<bool>()) {
  56. state = NUdf::TUnboxedValuePod::Zero();
  57. return State->GetValue(ctx).Release();
  58. }
  59. }
  60. State->SetValue(ctx, UpdateState->GetValue(ctx));
  61. }
  62. state = NUdf::TUnboxedValue::MakeFinish();
  63. return State->GetValue(ctx).Release();
  64. }
  65. #ifndef MKQL_DISABLE_CODEGEN
  66. Value* DoGenerateGetValue(const TCodegenContext& ctx, Value* statePtr, BasicBlock*& block) const {
  67. auto& context = ctx.Codegen.GetContext();
  68. const auto codegenItem = dynamic_cast<ICodegeneratorExternalNode*>(Item);
  69. MKQL_ENSURE(codegenItem, "Item must be codegenerator node.");
  70. const auto codegenState = dynamic_cast<ICodegeneratorExternalNode*>(State);
  71. MKQL_ENSURE(codegenState, "State must be codegenerator node.");
  72. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  73. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  74. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  75. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  76. const auto stop = BasicBlock::Create(context, "stop", ctx.Func);
  77. const auto exit = BasicBlock::Create(context, "exit", ctx.Func);
  78. const auto valueType = Type::getInt128Ty(context);
  79. const auto state = new LoadInst(valueType, statePtr, "state", block);
  80. const auto result = PHINode::Create(valueType, Switch ? 4U : 3U, "result", exit);
  81. result->addIncoming(state, block);
  82. const auto select = SwitchInst::Create(state, work, 3U, block);
  83. select->addCase(GetFinish(context), exit);
  84. select->addCase(GetInvalid(context), init);
  85. select->addCase(GetFalse(context), next);
  86. block = init;
  87. new StoreInst(GetEmpty(context), statePtr, block);
  88. codegenState->CreateSetValue(ctx, block, GetNodeValue(InitState, ctx, block));
  89. BranchInst::Create(work, block);
  90. block = next;
  91. if constexpr (UseCtx) {
  92. const auto cleanup = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&CleanupCurrentContext));
  93. const auto cleanupType = FunctionType::get(Type::getVoidTy(context), {}, false);
  94. const auto cleanupPtr = CastInst::Create(Instruction::IntToPtr, cleanup, PointerType::getUnqual(cleanupType), "cleanup_ctx", block);
  95. CallInst::Create(cleanupType, cleanupPtr, {}, "", block);
  96. }
  97. new StoreInst(GetEmpty(context), statePtr, block);
  98. codegenState->CreateSetValue(ctx, block, GetNodeValue(InitState, ctx, block));
  99. codegenState->CreateSetValue(ctx, block, GetNodeValue(UpdateState, ctx, block));
  100. BranchInst::Create(work, block);
  101. block = work;
  102. const auto item = GetNodeValue(Flow, ctx, block);
  103. result->addIncoming(item, block);
  104. const auto action = SwitchInst::Create(item, good, 2U, block);
  105. action->addCase(GetFinish(context), stop);
  106. action->addCase(GetYield(context), exit);
  107. block = good;
  108. codegenItem->CreateSetValue(ctx, block, item);
  109. if (Switch) {
  110. const auto swap = BasicBlock::Create(context, "swap", ctx.Func);
  111. const auto skip = BasicBlock::Create(context, "skip", ctx.Func);
  112. const auto reset = GetNodeValue(Switch, ctx, block);
  113. if constexpr (Interruptable) {
  114. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  115. BranchInst::Create(stop, next, IsEmpty(reset, block, context), block);
  116. block = next;
  117. }
  118. const auto cast = CastInst::Create(Instruction::Trunc, reset, Type::getInt1Ty(context), "bool", block);
  119. BranchInst::Create(swap, skip, cast, block);
  120. block = swap;
  121. new StoreInst(GetFalse(context), statePtr, block);
  122. result->addIncoming(GetNodeValue(State, ctx, block), block);
  123. BranchInst::Create(exit, block);
  124. block = skip;
  125. }
  126. codegenState->CreateSetValue(ctx, block, GetNodeValue(UpdateState, ctx, block));
  127. BranchInst::Create(work, block);
  128. block = stop;
  129. new StoreInst(GetFinish(context), statePtr, block);
  130. const auto output = codegenState->CreateGetValue(ctx, block);
  131. result->addIncoming(output, block);
  132. BranchInst::Create(exit, block);
  133. block = exit;
  134. return result;
  135. }
  136. #endif
  137. private:
  138. void RegisterDependencies() const final {
  139. if (const auto flow = this->FlowDependsOn(Flow)) {
  140. this->Own(flow, Item);
  141. this->Own(flow, State);
  142. this->DependsOn(flow, InitState);
  143. this->DependsOn(flow, Switch);
  144. this->DependsOn(flow, UpdateState);
  145. }
  146. }
  147. IComputationNode* const Flow;
  148. IComputationExternalNode* const Item;
  149. IComputationExternalNode* const State;
  150. IComputationNode* const Switch;
  151. IComputationNode* const InitState;
  152. IComputationNode* const UpdateState;
  153. };
  154. template <bool Interruptable, bool UseCtx>
  155. class TCondenseWrapper : public TCustomValueCodegeneratorNode<TCondenseWrapper<Interruptable, UseCtx>> {
  156. typedef TCustomValueCodegeneratorNode<TCondenseWrapper<Interruptable, UseCtx>> TBaseComputation;
  157. public:
  158. class TValue : public TComputationValue<TValue> {
  159. public:
  160. using TBase = TComputationValue<TValue>;
  161. TValue(
  162. TMemoryUsageInfo* memInfo,
  163. NUdf::TUnboxedValue&& stream,
  164. const TSqueezeState& state,
  165. TComputationContext& ctx)
  166. : TBase(memInfo)
  167. , Stream(std::move(stream))
  168. , Ctx(ctx)
  169. , State(state)
  170. {}
  171. private:
  172. ui32 GetTraverseCount() const final {
  173. return 1;
  174. }
  175. NUdf::TUnboxedValue GetTraverseItem(ui32 index) const final {
  176. Y_UNUSED(index);
  177. return Stream;
  178. }
  179. NUdf::TUnboxedValue Save() const final {
  180. return State.Save(Ctx);
  181. }
  182. void Load(const NUdf::TStringRef& state) final {
  183. State.Load(Ctx, state);
  184. }
  185. NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) override {
  186. switch (State.Stage) {
  187. case ESqueezeState::Finished:
  188. return NUdf::EFetchStatus::Finish;
  189. case ESqueezeState::Idle:
  190. State.Stage = ESqueezeState::Work;
  191. State.State->SetValue(Ctx, State.InitState->GetValue(Ctx));
  192. break;
  193. case ESqueezeState::NeedInit:
  194. if constexpr (UseCtx) {
  195. CleanupCurrentContext();
  196. }
  197. State.Stage = ESqueezeState::Work;
  198. State.State->SetValue(Ctx, State.InitState->GetValue(Ctx));
  199. State.State->SetValue(Ctx, State.UpdateState->GetValue(Ctx));
  200. break;
  201. default:
  202. break;
  203. }
  204. while (true) {
  205. const auto status = Stream.Fetch(State.Item->RefValue(Ctx));
  206. if (status == NUdf::EFetchStatus::Yield) {
  207. return status;
  208. }
  209. if (status == NUdf::EFetchStatus::Finish) {
  210. break;
  211. }
  212. if (State.Switch) {
  213. const auto& reset = State.Switch->GetValue(Ctx);
  214. if (Interruptable && !reset) {
  215. break;
  216. }
  217. if (reset.template Get<bool>()) {
  218. State.Stage = ESqueezeState::NeedInit;
  219. result = State.State->GetValue(Ctx);
  220. return NUdf::EFetchStatus::Ok;
  221. }
  222. }
  223. State.State->SetValue(Ctx, State.UpdateState->GetValue(Ctx));
  224. }
  225. State.Stage = ESqueezeState::Finished;
  226. result = State.State->GetValue(Ctx);
  227. return NUdf::EFetchStatus::Ok;
  228. }
  229. const NUdf::TUnboxedValue Stream;
  230. TComputationContext& Ctx;
  231. TSqueezeState State;
  232. };
  233. TCondenseWrapper(
  234. TComputationMutables& mutables,
  235. IComputationNode* stream,
  236. IComputationExternalNode* item,
  237. IComputationExternalNode* state,
  238. IComputationNode* outSwitch,
  239. IComputationNode* initState,
  240. IComputationNode* updateState,
  241. IComputationExternalNode* inSave = nullptr,
  242. IComputationNode* outSave = nullptr,
  243. IComputationExternalNode* inLoad = nullptr,
  244. IComputationNode* outLoad = nullptr,
  245. TType* stateType = nullptr)
  246. : TBaseComputation(mutables)
  247. , Stream(stream)
  248. , State(item, state, outSwitch, initState, updateState, inSave, outSave, inLoad, outLoad, stateType)
  249. {
  250. this->Stateless = false;
  251. }
  252. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  253. #ifndef MKQL_DISABLE_CODEGEN
  254. if (ctx.ExecuteLLVM && Fetch)
  255. return ctx.HolderFactory.Create<TSqueezeCodegenValue>(State, Fetch, ctx, Stream->GetValue(ctx));
  256. #endif
  257. return ctx.HolderFactory.Create<TValue>(Stream->GetValue(ctx), State, ctx);
  258. }
  259. private:
  260. void RegisterDependencies() const final {
  261. this->DependsOn(Stream);
  262. this->Own(State.Item);
  263. this->Own(State.State);
  264. this->DependsOn(State.Switch);
  265. this->DependsOn(State.InitState);
  266. this->DependsOn(State.UpdateState);
  267. this->Own(State.InSave);
  268. this->DependsOn(State.OutSave);
  269. this->Own(State.InLoad);
  270. this->DependsOn(State.OutLoad);
  271. }
  272. #ifndef MKQL_DISABLE_CODEGEN
  273. void GenerateFunctions(NYql::NCodegen::ICodegen& codegen) final {
  274. FetchFunc = GenerateFetch(codegen);
  275. codegen.ExportSymbol(FetchFunc);
  276. }
  277. void FinalizeFunctions(NYql::NCodegen::ICodegen& codegen) final {
  278. if (FetchFunc) {
  279. Fetch = reinterpret_cast<TFetchPtr>(codegen.GetPointerToFunction(FetchFunc));
  280. }
  281. }
  282. Function* GenerateFetch(NYql::NCodegen::ICodegen& codegen) const {
  283. auto& module = codegen.GetModule();
  284. auto& context = codegen.GetContext();
  285. const auto codegenItemArg = dynamic_cast<ICodegeneratorExternalNode*>(State.Item);
  286. const auto codegenStateArg = dynamic_cast<ICodegeneratorExternalNode*>(State.State);
  287. MKQL_ENSURE(codegenItemArg, "Item arg must be codegenerator node.");
  288. MKQL_ENSURE(codegenStateArg, "State arg must be codegenerator node.");
  289. const auto& name = TBaseComputation::MakeName("Fetch");
  290. if (const auto f = module.getFunction(name.c_str()))
  291. return f;
  292. const auto valueType = Type::getInt128Ty(context);
  293. const auto containerType = static_cast<Type*>(valueType);
  294. const auto contextType = GetCompContextType(context);
  295. const auto statusType = Type::getInt32Ty(context);
  296. const auto stateType = Type::getInt8Ty(context);
  297. const auto funcType = FunctionType::get(statusType, {PointerType::getUnqual(contextType), containerType, PointerType::getUnqual(valueType), PointerType::getUnqual(stateType)}, false);
  298. TCodegenContext ctx(codegen);
  299. ctx.Func = cast<Function>(module.getOrInsertFunction(name.c_str(), funcType).getCallee());
  300. DISubprogramAnnotator annotator(ctx, ctx.Func);
  301. auto args = ctx.Func->arg_begin();
  302. ctx.Ctx = &*args;
  303. const auto containerArg = &*++args;
  304. const auto valuePtr = &*++args;
  305. const auto statePtr = &*++args;
  306. const auto main = BasicBlock::Create(context, "main", ctx.Func);
  307. auto block = main;
  308. const auto container = static_cast<Value*>(containerArg);
  309. const auto state = new LoadInst(stateType, statePtr, "state", block);
  310. const auto init = BasicBlock::Create(context, "init", ctx.Func);
  311. const auto next = BasicBlock::Create(context, "next", ctx.Func);
  312. const auto work = BasicBlock::Create(context, "work", ctx.Func);
  313. const auto none = BasicBlock::Create(context, "none", ctx.Func);
  314. const auto select = SwitchInst::Create(state, work, 3U, block);
  315. select->addCase(ConstantInt::get(stateType, static_cast<ui8>(ESqueezeState::Finished)), none);
  316. select->addCase(ConstantInt::get(stateType, static_cast<ui8>(ESqueezeState::Idle)), init);
  317. select->addCase(ConstantInt::get(stateType, static_cast<ui8>(ESqueezeState::NeedInit)), next);
  318. block = none;
  319. ReturnInst::Create(context, ConstantInt::get(statusType, static_cast<ui32>(NUdf::EFetchStatus::Finish)), block);
  320. block = init;
  321. new StoreInst(ConstantInt::get(state->getType(), static_cast<ui8>(ESqueezeState::Work)), statePtr, block);
  322. codegenStateArg->CreateSetValue(ctx, block, GetNodeValue(State.InitState, ctx, block));
  323. BranchInst::Create(work, block);
  324. block = next;
  325. if constexpr (UseCtx) {
  326. const auto cleanup = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&CleanupCurrentContext));
  327. const auto cleanupType = FunctionType::get(Type::getVoidTy(context), {}, false);
  328. const auto cleanupPtr = CastInst::Create(Instruction::IntToPtr, cleanup, PointerType::getUnqual(cleanupType), "cleanup_ctx", block);
  329. CallInst::Create(cleanupType, cleanupPtr, {}, "", block);
  330. }
  331. new StoreInst(ConstantInt::get(state->getType(), static_cast<ui8>(ESqueezeState::Work)), statePtr, block);
  332. codegenStateArg->CreateSetValue(ctx, block, GetNodeValue(State.InitState, ctx, block));
  333. codegenStateArg->CreateSetValue(ctx, block, GetNodeValue(State.UpdateState, ctx, block));
  334. BranchInst::Create(work, block);
  335. block = work;
  336. const auto loop = BasicBlock::Create(context, "loop", ctx.Func);
  337. BranchInst::Create(loop, block);
  338. block = loop;
  339. const auto itemPtr = codegenItemArg->CreateRefValue(ctx, block);
  340. const auto status = CallBoxedValueVirtualMethod<NUdf::TBoxedValueAccessor::EMethod::Fetch>(statusType, container, codegen, block, itemPtr);
  341. const auto ychk = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, status, ConstantInt::get(status->getType(), static_cast<ui32>(NUdf::EFetchStatus::Yield)), "ychk", block);
  342. const auto yies = BasicBlock::Create(context, "yies", ctx.Func);
  343. const auto nope = BasicBlock::Create(context, "nope", ctx.Func);
  344. BranchInst::Create(yies, nope, ychk, block);
  345. block = yies;
  346. ReturnInst::Create(context, status, block);
  347. block = nope;
  348. const auto icmp = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, status, ConstantInt::get(status->getType(), static_cast<ui32>(NUdf::EFetchStatus::Finish)), "cond", block);
  349. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  350. const auto stop = BasicBlock::Create(context, "stop", ctx.Func);
  351. BranchInst::Create(stop, good, icmp, block);
  352. block = good;
  353. if (State.Switch) {
  354. const auto swap = BasicBlock::Create(context, "swap", ctx.Func);
  355. const auto skip = BasicBlock::Create(context, "skip", ctx.Func);
  356. const auto reset = GetNodeValue(State.Switch, ctx, block);
  357. if constexpr (Interruptable) {
  358. const auto pass = BasicBlock::Create(context, "pass", ctx.Func);
  359. const auto done = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, reset, ConstantInt::get(reset->getType(), 0), "done", block);
  360. BranchInst::Create(stop, pass, done, block);
  361. block = pass;
  362. }
  363. const auto cast = CastInst::Create(Instruction::Trunc, reset, Type::getInt1Ty(context), "bool", block);
  364. BranchInst::Create(swap, skip, cast, block);
  365. block = swap;
  366. new StoreInst(ConstantInt::get(state->getType(), static_cast<ui8>(ESqueezeState::NeedInit)), statePtr, block);
  367. SafeUnRefUnboxedOne(valuePtr, ctx, block);
  368. const auto state = codegenStateArg->CreateGetValue(ctx, block);
  369. new StoreInst(state, valuePtr, block);
  370. ValueAddRef(State.State->GetRepresentation(), valuePtr, ctx, block);
  371. ReturnInst::Create(context, ConstantInt::get(status->getType(), static_cast<ui32>(NUdf::EFetchStatus::Ok)), block);
  372. block = skip;
  373. }
  374. codegenStateArg->CreateSetValue(ctx, block, GetNodeValue(State.UpdateState, ctx, block));
  375. BranchInst::Create(loop, block);
  376. block = stop;
  377. new StoreInst(ConstantInt::get(state->getType(), static_cast<ui8>(ESqueezeState::Finished)), statePtr, block);
  378. SafeUnRefUnboxedOne(valuePtr, ctx, block);
  379. const auto result = codegenStateArg->CreateGetValue(ctx, block);
  380. new StoreInst(result, valuePtr, block);
  381. ValueAddRef(State.State->GetRepresentation(), valuePtr, ctx, block);
  382. ReturnInst::Create(context, ConstantInt::get(status->getType(), static_cast<ui32>(NUdf::EFetchStatus::Ok)), block);
  383. return ctx.Func;
  384. }
  385. using TFetchPtr = TSqueezeCodegenValue::TFetchPtr;
  386. Function* FetchFunc = nullptr;
  387. TFetchPtr Fetch = nullptr;
  388. #endif
  389. IComputationNode* const Stream;
  390. TSqueezeState State;
  391. };
  392. }
  393. template <bool UseCtx>
  394. IComputationNode* WrapCondenseImpl(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  395. const auto stream = LocateNode(ctx.NodeLocator, callable, 0);
  396. const auto initState = LocateNode(ctx.NodeLocator, callable, 1);
  397. const auto outSwitch = LocateNode(ctx.NodeLocator, callable, 4);
  398. const auto updateState = LocateNode(ctx.NodeLocator, callable, 5);
  399. const auto item = LocateExternalNode(ctx.NodeLocator, callable, 2);
  400. const auto state = LocateExternalNode(ctx.NodeLocator, callable, 3);
  401. bool isOptional;
  402. const auto dataType = UnpackOptionalData(callable.GetInput(4), isOptional);
  403. MKQL_ENSURE(dataType->GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected bool.");
  404. const auto type = callable.GetType()->GetReturnType();
  405. if (type->IsFlow()) {
  406. const auto kind = GetValueRepresentation(AS_TYPE(TFlowType, type)->GetItemType());
  407. if (isOptional) {
  408. return new TCondenseFlowWrapper<true, UseCtx>(ctx.Mutables, kind, stream, item, state, outSwitch, initState, updateState);
  409. } else {
  410. return new TCondenseFlowWrapper<false, UseCtx>(ctx.Mutables, kind, stream, item, state, outSwitch, initState, updateState);
  411. }
  412. } else {
  413. if (isOptional) {
  414. return new TCondenseWrapper<true, UseCtx>(ctx.Mutables, stream, item, state, outSwitch, initState, updateState);
  415. } else {
  416. return new TCondenseWrapper<false, UseCtx>(ctx.Mutables, stream, item, state, outSwitch, initState, updateState);
  417. }
  418. }
  419. }
  420. IComputationNode* WrapCondense(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  421. MKQL_ENSURE(callable.GetInputsCount() == 6 || callable.GetInputsCount() == 7, "Expected 6 or 7 args");
  422. bool useCtx = false;
  423. if (callable.GetInputsCount() >= 7) {
  424. useCtx = AS_VALUE(TDataLiteral, callable.GetInput(6))->AsValue().Get<bool>();
  425. }
  426. if (useCtx) {
  427. return WrapCondenseImpl<true>(callable, ctx);
  428. } else {
  429. return WrapCondenseImpl<false>(callable, ctx);
  430. }
  431. }
  432. IComputationNode* WrapSqueeze(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  433. MKQL_ENSURE(callable.GetInputsCount() == 9, "Expected 9 args");
  434. const auto stream = LocateNode(ctx.NodeLocator, callable, 0);
  435. const auto initState = LocateNode(ctx.NodeLocator, callable, 1);
  436. const auto updateState = LocateNode(ctx.NodeLocator, callable, 4);
  437. const auto item = LocateExternalNode(ctx.NodeLocator, callable, 2);
  438. const auto state = LocateExternalNode(ctx.NodeLocator, callable, 3);
  439. IComputationExternalNode* inSave = nullptr;
  440. IComputationNode* outSave = nullptr;
  441. IComputationExternalNode* inLoad = nullptr;
  442. IComputationNode* outLoad = nullptr;
  443. const auto hasSaveLoad = !callable.GetInput(6).GetStaticType()->IsVoid();
  444. if (hasSaveLoad) {
  445. outSave = LocateNode(ctx.NodeLocator, callable, 6);
  446. outLoad = LocateNode(ctx.NodeLocator, callable, 8);
  447. inSave = LocateExternalNode(ctx.NodeLocator, callable, 5);
  448. inLoad = LocateExternalNode(ctx.NodeLocator, callable, 7);
  449. }
  450. const auto stateType = hasSaveLoad ? callable.GetInput(6).GetStaticType() : nullptr;
  451. return new TCondenseWrapper<false, false>(ctx.Mutables, stream, item, state, nullptr, initState, updateState, inSave, outSave, inLoad, outLoad, stateType);
  452. }
  453. }
  454. }