mkql_mapnext.cpp 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. #include "mkql_mapnext.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  3. namespace NKikimr {
  4. namespace NMiniKQL {
  5. namespace {
  6. struct TState : public TComputationValue<TState> {
  7. using TComputationValue::TComputationValue;
  8. std::optional<NUdf::TUnboxedValue> Prev;
  9. bool Finish = false;
  10. };
  11. class TFlowMapNextWrapper : public TStatefulFlowComputationNode<TFlowMapNextWrapper> {
  12. typedef TStatefulFlowComputationNode<TFlowMapNextWrapper> TBaseComputation;
  13. public:
  14. TFlowMapNextWrapper(TComputationMutables& mutables, EValueRepresentation kind, IComputationNode* flow,
  15. IComputationExternalNode* item, IComputationExternalNode* nextItem, IComputationNode* newItem)
  16. : TBaseComputation(mutables, flow, kind, EValueRepresentation::Any)
  17. , Flow(flow)
  18. , Item(item)
  19. , NextItem(nextItem)
  20. , NewItem(newItem)
  21. {}
  22. NUdf::TUnboxedValue DoCalculate(NUdf::TUnboxedValue& stateValue, TComputationContext& ctx) const {
  23. if (!stateValue.HasValue()) {
  24. stateValue = ctx.HolderFactory.Create<TState>();
  25. }
  26. TState& state = *static_cast<TState*>(stateValue.AsBoxed().Get());
  27. NUdf::TUnboxedValue result;
  28. for (;;) {
  29. if (state.Finish) {
  30. if (!state.Prev) {
  31. return NUdf::TUnboxedValuePod::MakeFinish();
  32. }
  33. Item->SetValue(ctx, std::move(*state.Prev));
  34. state.Prev.reset();
  35. NextItem->SetValue(ctx, NUdf::TUnboxedValuePod());
  36. return NewItem->GetValue(ctx);
  37. }
  38. auto item = Flow->GetValue(ctx);
  39. if (item.IsYield()) {
  40. return item;
  41. }
  42. if (item.IsFinish()) {
  43. state.Finish = true;
  44. continue;
  45. }
  46. if (!state.Prev) {
  47. state.Prev = std::move(item);
  48. continue;
  49. }
  50. Item->SetValue(ctx, std::move(*state.Prev));
  51. state.Prev = item;
  52. NextItem->SetValue(ctx, std::move(item));
  53. result = NewItem->GetValue(ctx);
  54. break;
  55. }
  56. return result;
  57. }
  58. private:
  59. void RegisterDependencies() const final {
  60. if (const auto flow = FlowDependsOn(Flow)) {
  61. Own(flow, Item);
  62. Own(flow, NextItem);
  63. DependsOn(flow, NewItem);
  64. }
  65. }
  66. IComputationNode* const Flow;
  67. IComputationExternalNode* const Item;
  68. IComputationExternalNode* const NextItem;
  69. IComputationNode* const NewItem;
  70. };
  71. class TStreamMapNextWrapper : public TMutableComputationNode<TStreamMapNextWrapper> {
  72. typedef TMutableComputationNode<TStreamMapNextWrapper> TBaseComputation;
  73. public:
  74. TStreamMapNextWrapper(TComputationMutables& mutables, IComputationNode* stream,
  75. IComputationExternalNode* item, IComputationExternalNode* nextItem, IComputationNode* newItem)
  76. : TBaseComputation(mutables)
  77. , Stream(stream)
  78. , Item(item)
  79. , NextItem(nextItem)
  80. , NewItem(newItem)
  81. , StateIndex(mutables.CurValueIndex++)
  82. {}
  83. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  84. return ctx.HolderFactory.Create<TStreamValue>(ctx, Stream->GetValue(ctx), Item, NextItem, NewItem, StateIndex);
  85. }
  86. private:
  87. void RegisterDependencies() const final {
  88. DependsOn(Stream);
  89. Own(Item);
  90. Own(NextItem);
  91. DependsOn(NewItem);
  92. }
  93. class TStreamValue : public TComputationValue<TStreamValue> {
  94. public:
  95. using TBase = TComputationValue<TStreamValue>;
  96. TStreamValue(TMemoryUsageInfo* memInfo, TComputationContext& compCtx, NUdf::TUnboxedValue&& stream,
  97. IComputationExternalNode* item, IComputationExternalNode* nextItem, IComputationNode* newItem, ui32 stateIndex)
  98. : TBase(memInfo)
  99. , CompCtx(compCtx)
  100. , Stream(std::move(stream))
  101. , Item(item)
  102. , NextItem(nextItem)
  103. , NewItem(newItem)
  104. , StateIndex(stateIndex)
  105. {
  106. }
  107. private:
  108. ui32 GetTraverseCount() const final {
  109. return 1U;
  110. }
  111. NUdf::TUnboxedValue GetTraverseItem(ui32) const final {
  112. return Stream;
  113. }
  114. NUdf::TUnboxedValue Save() const final {
  115. return NUdf::TUnboxedValuePod::Zero();
  116. }
  117. void Load(const NUdf::TStringRef&) final {}
  118. NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) final {
  119. auto& state = GetState();
  120. for (;;) {
  121. if (state.Finish) {
  122. if (!state.Prev) {
  123. return NUdf::EFetchStatus::Finish;
  124. }
  125. Item->SetValue(CompCtx, std::move(*state.Prev));
  126. state.Prev.reset();
  127. NextItem->SetValue(CompCtx, NUdf::TUnboxedValuePod());
  128. result = NewItem->GetValue(CompCtx);
  129. return NUdf::EFetchStatus::Ok;
  130. }
  131. NUdf::TUnboxedValue item;
  132. const auto status = Stream.Fetch(item);
  133. if (status == NUdf::EFetchStatus::Yield) {
  134. return status;
  135. }
  136. if (status == NUdf::EFetchStatus::Finish) {
  137. state.Finish = true;
  138. continue;
  139. }
  140. if (!state.Prev) {
  141. state.Prev = std::move(item);
  142. continue;
  143. }
  144. Item->SetValue(CompCtx, std::move(*state.Prev));
  145. state.Prev = item;
  146. NextItem->SetValue(CompCtx, std::move(item));
  147. result = NewItem->GetValue(CompCtx);
  148. break;
  149. }
  150. return NUdf::EFetchStatus::Ok;
  151. }
  152. TState& GetState() const {
  153. auto& result = CompCtx.MutableValues[StateIndex];
  154. if (!result.HasValue()) {
  155. result = CompCtx.HolderFactory.Create<TState>();
  156. }
  157. return *static_cast<TState*>(result.AsBoxed().Get());
  158. }
  159. TComputationContext& CompCtx;
  160. const NUdf::TUnboxedValue Stream;
  161. IComputationExternalNode* const Item;
  162. IComputationExternalNode* const NextItem;
  163. IComputationNode* const NewItem;
  164. const ui32 StateIndex;
  165. };
  166. IComputationNode* const Stream;
  167. IComputationExternalNode* const Item;
  168. IComputationExternalNode* const NextItem;
  169. IComputationNode* const NewItem;
  170. const ui32 StateIndex;
  171. };
  172. }
  173. IComputationNode* WrapMapNext(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  174. MKQL_ENSURE(callable.GetInputsCount() == 4, "Expected 4 args, got " << callable.GetInputsCount());
  175. const auto type = callable.GetType()->GetReturnType();
  176. const auto input = LocateNode(ctx.NodeLocator, callable, 0);
  177. const auto itemArg = LocateExternalNode(ctx.NodeLocator, callable, 1);
  178. const auto nextItemArg = LocateExternalNode(ctx.NodeLocator, callable, 2);
  179. const auto newItem = LocateNode(ctx.NodeLocator, callable, 3);
  180. if (type->IsFlow()) {
  181. return new TFlowMapNextWrapper(ctx.Mutables, GetValueRepresentation(type), input, itemArg, nextItemArg, newItem);
  182. } else if (type->IsStream()) {
  183. return new TStreamMapNextWrapper(ctx.Mutables, input, itemArg, nextItemArg, newItem);
  184. }
  185. THROW yexception() << "Expected flow or stream.";
  186. }
  187. }
  188. }