Browse Source

YQ-2549 Checkpointing in match_recognize (#1860)

Dmitry Kardymon 1 year ago
parent
commit
313d48609d

+ 196 - 23
ydb/library/yql/minikql/comp_nodes/mkql_match_recognize.cpp

@@ -2,6 +2,8 @@
 #include "mkql_match_recognize_matched_vars.h"
 #include "mkql_match_recognize_measure_arg.h"
 #include "mkql_match_recognize_nfa.h"
+#include "mkql_match_recognize_save_load.h"
+
 #include <ydb/library/yql/core/sql_types/match_recognize.h>
 #include <ydb/library/yql/minikql/computation/mkql_computation_node_impl.h>
 #include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h>
@@ -20,6 +22,8 @@ namespace NMatchRecognize {
 enum class EOutputColumnSource {PartitionKey, Measure};
 using TOutputColumnOrder = std::vector<std::pair<EOutputColumnSource, size_t>, TMKQLAllocator<std::pair<EOutputColumnSource, size_t>>>;
 
+constexpr ui32 StateVersion = 1;
+
 using namespace NYql::NMatchRecognize;
 
 struct TMatchRecognizeProcessorParameters {
@@ -43,7 +47,17 @@ class TBackTrackingMatchRecognize {
 public:
     //TODO(YQL-16486): create a tree for backtracking(replace var names with indexes)
 
-    struct TPatternConfiguration {};
+    struct TPatternConfiguration {
+        void Save(TOutputSerializer& /*serializer*/) const {
+        }
+
+        void Load(TInputSerializer& /*serializer*/) {
+        }
+
+        friend bool operator==(const TPatternConfiguration&, const TPatternConfiguration&) {
+            return true;
+        }
+    };
 
     struct TPatternConfigurationBuilder {
         using TPatternConfigurationPtr = std::shared_ptr<TPatternConfiguration>;
@@ -124,6 +138,15 @@ public:
         }
         return not Matches.empty();
     }
+
+    void Save(TOutputSerializer& /*serializer*/) const {
+        // Not used in not streaming mode.
+    }
+
+    void Load(TInputSerializer& /*serializer*/) {
+        // Not used in not streaming mode.
+    }
+
 private:
     const NUdf::TUnboxedValue PartitionKey;
     const TMatchRecognizeProcessorParameters& Parameters;
@@ -137,8 +160,8 @@ private:
 class TStreamingMatchRecognize {
     using TPartitionList = TSparseList;
     using TRange = TPartitionList::TRange;
-    using TMatchedVars = TMatchedVars<TRange>;
 public:
+    using TPatternConfiguration = TNfaTransitionGraph;
     using TPatternConfigurationBuilder = TNfaTransitionGraphBuilder;
     TStreamingMatchRecognize(
         NUdf::TUnboxedValue&& partitionKey,
@@ -157,12 +180,18 @@ public:
         Parameters.InputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TListValue<TSparseList>>(Rows));
         Parameters.CurrentRowIndexArg->SetValue(ctx, NUdf::TUnboxedValuePod(Rows.Size()));
         Nfa.ProcessRow(Rows.Append(std::move(row)), ctx);
+        return HasMatched();
+    }
+
+    bool HasMatched() const {
         return Nfa.HasMatched();
     }
+
     NUdf::TUnboxedValue GetOutputIfReady(TComputationContext& ctx) {
         auto match = Nfa.GetMatched();
-        if (!match.has_value())
+        if (!match.has_value()) {
             return NUdf::TUnboxedValue{};
+        }
         Parameters.MatchedVarsArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TRange>>(ctx.HolderFactory, match.value()));
         Parameters.MeasureInputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TMeasureInputDataValue>(
             ctx.HolderFactory.Create<TListValue<TSparseList>>(Rows),
@@ -189,6 +218,21 @@ public:
         Y_UNUSED(ctx);
         return false;
     }
+
+    void Save(TOutputSerializer& serializer) const {
+        // PartitionKey saved in TStateForInterleavedPartitions as key.
+        Rows.Save(serializer);
+        Nfa.Save(serializer);
+        serializer.Write(MatchNumber);
+    }
+
+    void Load(TInputSerializer& serializer) {
+        // PartitionKey passed in contructor.
+        Rows.Load(serializer);
+        Nfa.Load(serializer);
+        MatchNumber = serializer.Read<ui64>();
+    }
+
 private:
     const NUdf::TUnboxedValue PartitionKey;
     const TMatchRecognizeProcessorParameters& Parameters;
@@ -205,12 +249,15 @@ class TStateForNonInterleavedPartitions
     using TRowPatternConfigurationBuilder = typename Algo::TPatternConfigurationBuilder;
 public:
     TStateForNonInterleavedPartitions(
-            TMemoryUsageInfo* memInfo,
-            IComputationExternalNode* inputRowArg,
-            IComputationNode* partitionKey,
-            TType* partitionKeyType,
-            const TMatchRecognizeProcessorParameters& parameters,
-            const TContainerCacheOnContext& cache
+        TMemoryUsageInfo* memInfo,
+        IComputationExternalNode* inputRowArg,
+        IComputationNode* partitionKey,
+        TType* partitionKeyType,
+        const TMatchRecognizeProcessorParameters& parameters,
+        const TContainerCacheOnContext& cache,
+        TComputationContext &ctx,
+        TType* rowType,
+        const TMutableObjectOverBoxedValue<TValuePackerBoxed>& rowPacker
     )
     : TComputationValue<TStateForNonInterleavedPartitions>(memInfo)
     , InputRowArg(inputRowArg)
@@ -220,7 +267,54 @@ public:
     , RowPatternConfiguration(TRowPatternConfigurationBuilder::Create(parameters.Pattern, parameters.VarNamesLookup))
     , Cache(cache)
     , Terminating(false)
+    , SerializerContext(ctx, rowType, rowPacker)
     {}
+
+    NUdf::TUnboxedValue Save() const override {
+        TOutputSerializer serializer(SerializerContext);
+        serializer.Write(StateVersion);
+        serializer.Write(CurPartitionPackedKey);
+        bool isValid = static_cast<bool>(PartitionHandler);
+        serializer.Write(isValid);
+        if (isValid) {
+            PartitionHandler->Save(serializer);
+        }
+        isValid = static_cast<bool>(DelayedRow);
+        serializer.Write(isValid);
+        if (isValid) {
+            serializer.Write(DelayedRow);
+        }
+        RowPatternConfiguration->Save(serializer);
+        return serializer.MakeString();
+    }
+
+    void Load(const NUdf::TStringRef& state) override {
+        TInputSerializer serializer(SerializerContext, state);
+        const auto stateVersion = serializer.Read<decltype(StateVersion)>();
+        if (stateVersion == 1) {
+            serializer.Read(CurPartitionPackedKey);
+            bool validPartitionHandler = serializer.Read<bool>();
+            if (validPartitionHandler) {
+                NUdf::TUnboxedValue key = PartitionKeyPacker.Unpack(CurPartitionPackedKey, SerializerContext.Ctx.HolderFactory);
+                PartitionHandler.reset(new Algo(
+                    std::move(key),
+                    Parameters,
+                    RowPatternConfiguration,
+                    Cache
+                ));
+                PartitionHandler->Load(serializer);
+            }
+            bool validDelayedRow = serializer.Read<bool>();
+            if (validDelayedRow) {
+                DelayedRow = serializer.Read<NUdf::TUnboxedValue>();
+            }
+            auto restoredRowPatternConfiguration = std::make_shared<typename Algo::TPatternConfiguration>(); 
+            restoredRowPatternConfiguration->Load(serializer);
+            MKQL_ENSURE(*restoredRowPatternConfiguration == *RowPatternConfiguration, "Restored and current RowPatternConfiguration is different");
+        }
+        MKQL_ENSURE(serializer.Empty(), "State is corrupted");
+    }
+
     bool ProcessInputRow(NUdf::TUnboxedValue&& row, TComputationContext& ctx) {
         MKQL_ENSURE(not DelayedRow, "Internal logic error"); //we're finalizing previous partition
         InputRowArg->SetValue(ctx, NUdf::TUnboxedValue(row));
@@ -288,6 +382,7 @@ private:
     const TContainerCacheOnContext& Cache;
     NUdf::TUnboxedValue DelayedRow;
     bool Terminating;
+    TSerializerContext SerializerContext;
 };
 
 class TStateForInterleavedPartitions
@@ -302,7 +397,10 @@ public:
         IComputationNode* partitionKey,
         TType* partitionKeyType,
         const TMatchRecognizeProcessorParameters& parameters,
-        const TContainerCacheOnContext& cache
+        const TContainerCacheOnContext& cache,
+        TComputationContext &ctx,
+        TType* rowType,
+        const TMutableObjectOverBoxedValue<TValuePackerBoxed>& rowPacker
     )
     : TComputationValue<TStateForInterleavedPartitions>(memInfo)
     , InputRowArg(inputRowArg)
@@ -311,9 +409,59 @@ public:
     , Parameters(parameters)
     , NfaTransitionGraph(TNfaTransitionGraphBuilder::Create(parameters.Pattern, parameters.VarNamesLookup))
     , Cache(cache)
-{
-}
-   bool ProcessInputRow(NUdf::TUnboxedValue&& row, TComputationContext& ctx) {
+    , SerializerContext(ctx, rowType, rowPacker)
+    {}
+
+    NUdf::TUnboxedValue Save() const override {
+        TOutputSerializer serializer(SerializerContext);
+        serializer.Write(StateVersion);
+        serializer.Write(Partitions.size());
+
+        for (const auto& [key, state] : Partitions) {
+            serializer.Write(key);
+            state->Save(serializer);
+        }
+        // HasReadyOutput is not packed because when loading we can recalculate HasReadyOutput from Partitions.
+        serializer.Write(Terminating);
+        NfaTransitionGraph->Save(serializer);
+        return serializer.MakeString();
+    }
+
+    void Load(const NUdf::TStringRef& state) override {
+        TInputSerializer serializer(SerializerContext, state);
+        const auto stateVersion = serializer.Read<decltype(StateVersion)>();
+        if (stateVersion == 1) {
+            Partitions.clear();
+            auto partitionsCount = serializer.Read<TPartitionMap::size_type>();
+            Partitions.reserve(partitionsCount);
+            for (size_t i = 0; i < partitionsCount; ++i) {
+                auto packedKey = serializer.Read<TPartitionMap::key_type, std::string_view>();
+                NUdf::TUnboxedValue key = PartitionKeyPacker.Unpack(packedKey, SerializerContext.Ctx.HolderFactory);
+                auto pair = Partitions.emplace(
+                    packedKey,
+                    std::make_unique<TStreamingMatchRecognize>(
+                        std::move(key),
+                        Parameters,
+                        NfaTransitionGraph,
+                        Cache));
+                pair.first->second->Load(serializer);
+            }
+
+            for (auto it = Partitions.begin(); it != Partitions.end(); ++it) {
+                if (it->second->HasMatched()) {
+                    HasReadyOutput.push(it);
+                }
+            }
+            serializer.Read(Terminating);
+            auto restoredTransitionGraph = std::make_shared<TNfaTransitionGraph>();
+            restoredTransitionGraph->Load(serializer);
+            MKQL_ENSURE(NfaTransitionGraph, "Empty NfaTransitionGraph");
+            MKQL_ENSURE(*restoredTransitionGraph == *NfaTransitionGraph, "Restored and current NfaTransitionGraph is different");
+        }
+        MKQL_ENSURE(serializer.Empty(), "State is corrupted");
+    }
+
+    bool ProcessInputRow(NUdf::TUnboxedValue&& row, TComputationContext& ctx) {
         auto partition = GetPartitionHandler(row, ctx);
         if (partition->second->ProcessInputRow(std::move(row), ctx)) {
             HasReadyOutput.push(partition);
@@ -375,17 +523,19 @@ private:
     const TMatchRecognizeProcessorParameters& Parameters;
     const TNfaTransitionGraph::TPtr NfaTransitionGraph;
     const TContainerCacheOnContext& Cache;
+    TSerializerContext SerializerContext;
 };
 
 template<class State>
-class TMatchRecognizeWrapper : public TStatefulFlowComputationNode<TMatchRecognizeWrapper<State>> {
-    using TBaseComputation = TStatefulFlowComputationNode<TMatchRecognizeWrapper<State>>;
+class TMatchRecognizeWrapper : public TStatefulFlowComputationNode<TMatchRecognizeWrapper<State>, true> {
+    using TBaseComputation = TStatefulFlowComputationNode<TMatchRecognizeWrapper<State>, true>;
 public:
     TMatchRecognizeWrapper(TComputationMutables &mutables, EValueRepresentation kind, IComputationNode *inputFlow,
        IComputationExternalNode *inputRowArg,
        IComputationNode *partitionKey,
        TType* partitionKeyType,
-       const TMatchRecognizeProcessorParameters& parameters
+       const TMatchRecognizeProcessorParameters& parameters,
+       TType* rowType
     )
     :TBaseComputation(mutables, inputFlow, kind, EValueRepresentation::Embedded)
     , InputFlow(inputFlow)
@@ -394,17 +544,36 @@ public:
     , PartitionKeyType(partitionKeyType)
     , Parameters(parameters)
     , Cache(mutables)
+    , RowType(rowType)
+    , RowPacker(mutables)
     {}
 
     NUdf::TUnboxedValue DoCalculate(NUdf::TUnboxedValue &stateValue, TComputationContext &ctx) const {
         if (stateValue.IsInvalid()) {
             stateValue = ctx.HolderFactory.Create<State>(
-                    InputRowArg,
-                    PartitionKey,
-                    PartitionKeyType,
-                    Parameters,
-                    Cache
+                InputRowArg,
+                PartitionKey,
+                PartitionKeyType,
+                Parameters,
+                Cache,
+                ctx,
+                RowType,
+                RowPacker
+            );
+        } else if (stateValue.HasValue() && !stateValue.IsBoxed()) {
+            // Load from saved state.
+            NUdf::TUnboxedValue state = ctx.HolderFactory.Create<State>(
+                InputRowArg,
+                PartitionKey,
+                PartitionKeyType,
+                Parameters,
+                Cache,
+                ctx,
+                RowType,
+                RowPacker
             );
+            state.Load(stateValue.AsStringRef());
+            stateValue = state;
         }
         auto state = static_cast<State*>(stateValue.AsBoxed().Get());
         while (true) {
@@ -446,8 +615,9 @@ private:
     IComputationNode* const PartitionKey;
     TType* const PartitionKeyType;
     const TMatchRecognizeProcessorParameters Parameters;
-    TNfaTransitionGraph::TPtr NfaTransitionGraph;
     const TContainerCacheOnContext Cache;
+    TType* const RowType;
+    TMutableObjectOverBoxedValue<TValuePackerBoxed> RowPacker;
 };
 
 TOutputColumnOrder GetOutputColumnOrder(TRuntimeNode partitionKyeColumnsIndexes, TRuntimeNode measureColumnsIndexes) {
@@ -550,7 +720,6 @@ std::pair<TUnboxedValueVector, THashMap<TString, size_t>> ConvertListOfStrings(c
 } //namespace NMatchRecognize
 
 
-
 IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
     using namespace NMatchRecognize;
     size_t inputIndex = 0;
@@ -579,6 +748,7 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation
     MKQL_ENSURE(callable.GetInputsCount() == inputIndex, "Wrong input count");
 
     const auto& [vars, varsLookup] = ConvertListOfStrings(varNames);
+    auto* rowType = AS_TYPE(TStructType, AS_TYPE(TFlowType, inputFlow.GetStaticType())->GetItemType());
 
     const auto parameters = TMatchRecognizeProcessorParameters {
         static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputDataArg.GetNode()))
@@ -604,6 +774,7 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation
             , LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode())
             , partitionKeySelector.GetStaticType()
             , std::move(parameters)
+            , rowType
         );
     } else {
         const bool useNfaForTables = true; //TODO(YQL-16486) get this flag from an optimizer
@@ -615,6 +786,7 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation
                 , LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode())
                 , partitionKeySelector.GetStaticType()
                 , std::move(parameters)
+                , rowType
             );
         } else {
             return new TMatchRecognizeWrapper<TStateForNonInterleavedPartitions<TBackTrackingMatchRecognize>>(ctx.Mutables
@@ -624,6 +796,7 @@ IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputation
                 , LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode())
                 , partitionKeySelector.GetStaticType()
                 , std::move(parameters)
+                , rowType
             );
         }
     }

+ 44 - 2
ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_list.h

@@ -1,7 +1,11 @@
 #pragma once
+
+#include "mkql_match_recognize_save_load.h"
+
 #include <ydb/library/yql/minikql/defs.h>
 #include <ydb/library/yql/minikql/computation/mkql_computation_node_impl.h>
 #include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h>
+#include <ydb/library/yql/minikql/comp_nodes/mkql_saveload.h>
 #include <ydb/library/yql/public/udf/udf_value.h>
 #include <unordered_map>
 
@@ -131,15 +135,37 @@ class TSparseList {
             }
         }
 
+        void Save(TOutputSerializer& serializer) const {
+            serializer(Storage.size());
+            for (const auto& [key, item]: Storage) {
+                serializer(key, item.Value, item.LockCount);
+            }
+        }
+
+        void Load(TInputSerializer& serializer) {
+            auto size = serializer.Read<TStorage::size_type>();
+            Storage.reserve(size);
+            for (size_t i = 0; i < size; ++i) {
+                TStorage::key_type key;
+                NUdf::TUnboxedValue row;
+                decltype(TItem::LockCount) lockCount;
+                serializer(key, row, lockCount);
+                Storage.emplace(key, TItem{row, lockCount});
+            }
+        }
+
     private:
         //TODO consider to replace hash table with contiguous chunks
         using TAllocator = TMKQLAllocator<std::pair<const size_t, TItem>, EMemorySubPool::Temporary>;
-        std::unordered_map<
+
+        using TStorage = std::unordered_map<
             size_t,
             TItem,
             std::hash<size_t>,
             std::equal_to<size_t>,
-            TAllocator> Storage;
+            TAllocator>;
+
+        TStorage Storage;
     };
     using TContainerPtr = TContainer::TPtr;
 
@@ -242,6 +268,14 @@ public:
             ToIndex = -1;
         }
 
+        void Save(TOutputSerializer& serializer) const {
+            serializer(Container, FromIndex, ToIndex);
+       }
+
+        void Load(TInputSerializer& serializer) {
+            serializer(Container, FromIndex, ToIndex);
+        }
+
     private:
         TRange(TContainerPtr container, size_t index)
             : Container(container)
@@ -297,6 +331,14 @@ public:
         return Size() == 0;
     }
 
+    void Save(TOutputSerializer& serializer) const {
+        serializer(Container, ListSize);
+    }
+
+    void Load(TInputSerializer& serializer) {
+        serializer(Container, ListSize);
+    }
+
 private:
     TContainerPtr Container = MakeIntrusive<TContainer>();
     size_t ListSize = 0; //impl: max index ever stored + 1

+ 2 - 2
ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_matched_vars.h

@@ -8,6 +8,7 @@ namespace NKikimr::NMiniKQL::NMatchRecognize {
 
 template<class R>
 using TMatchedVar = std::vector<R, TMKQLAllocator<R>>;
+
 template<class R>
 void Extend(TMatchedVar<R>& var, const R& r) {
     if (var.empty()) {
@@ -110,8 +111,7 @@ public:
         : TComputationValue<TMatchedVarsValue>(memInfo)
         , HolderFactory(holderFactory)
         , Vars(vars)
-    {
-    }
+    {}
 
     NUdf::TUnboxedValue GetElement(ui32 index) const override {
         return HolderFactory.Create<TRangeList>(HolderFactory, Vars[index]);

+ 1 - 0
ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_measure_arg.h

@@ -32,6 +32,7 @@ public:
             , VarNames(varNames)
             , MatchNumber(matchNumber)
     {}
+
     NUdf::TUnboxedValue GetElement(ui32 index) const override {
         switch(ColumnOrder[index].first) {
             case EMeasureInputDataSpecialColumns::Classifier: {

+ 132 - 5
ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_nfa.h

@@ -1,6 +1,7 @@
 #pragma once
 
 #include "mkql_match_recognize_matched_vars.h"
+#include "mkql_match_recognize_save_load.h"
 #include "../computation/mkql_computation_node_holders.h"
 #include "../computation/mkql_computation_node_impl.h"
 #include <ydb/library/yql/core/sql_types/match_recognize.h>
@@ -12,13 +13,29 @@ namespace NKikimr::NMiniKQL::NMatchRecognize {
 using namespace NYql::NMatchRecognize;
 
 struct TVoidTransition {
+    friend bool operator==(const TVoidTransition&, const TVoidTransition&) {
+        return true;
+    }
 };
 using TEpsilonTransition = size_t; //to
 using TEpsilonTransitions = std::vector<TEpsilonTransition, TMKQLAllocator<TEpsilonTransition>>;
 using TMatchedVarTransition = std::pair<std::pair<ui32, bool>, size_t>; //{{varIndex, saveState}, to}
 using TQuantityEnterTransition = size_t; //to
 using TQuantityExitTransition = std::pair<std::pair<ui64, ui64>, std::pair<size_t, size_t>>; //{{min, max}, {foFindMore, toMatched}}
-using TNfaTransition = std::variant<
+
+template <typename... Ts>
+struct TVariantHelper {
+    using TVariant =  std::variant<Ts...>;
+    using TTuple =  std::tuple<Ts...>;
+
+    static std::variant<Ts...> getVariantByIndex(size_t i) {
+        MKQL_ENSURE(i < sizeof...(Ts), "Wrong variant index");
+        static std::variant<Ts...> table[] = { Ts{ }... };
+        return table[i];
+    }
+};
+
+using TNfaTransitionHelper = TVariantHelper<
     TVoidTransition,
     TMatchedVarTransition,
     TEpsilonTransitions,
@@ -26,6 +43,8 @@ using TNfaTransition = std::variant<
     TQuantityExitTransition
 >;
 
+using TNfaTransition = TNfaTransitionHelper::TVariant;
+
 struct TNfaTransitionDestinationVisitor {
     std::function<size_t(size_t)> callback;
 
@@ -61,11 +80,42 @@ struct TNfaTransitionDestinationVisitor {
 };
 
 struct TNfaTransitionGraph {
-    std::vector<TNfaTransition, TMKQLAllocator<TNfaTransition>> Transitions;
+    using TTransitions = std::vector<TNfaTransition, TMKQLAllocator<TNfaTransition>>;
+
+    TTransitions Transitions;
     size_t Input;
     size_t Output;
 
     using TPtr = std::shared_ptr<TNfaTransitionGraph>;
+
+    template<class>
+    inline constexpr static bool always_false_v = false;
+
+    void Save(TOutputSerializer& serializer) const {
+        serializer(Transitions.size());
+        for (ui64 i = 0; i < Transitions.size(); ++i) {
+            serializer.Write(Transitions[i].index());
+            std::visit(serializer, Transitions[i]);
+        }
+        serializer(Input, Output);
+    }
+
+    void Load(TInputSerializer& serializer) {
+        ui64 transitionSize = serializer.Read<TTransitions::size_type>();
+        Transitions.resize(transitionSize);
+        for (ui64 i = 0; i < transitionSize; ++i) {
+            size_t index = serializer.Read<std::size_t>();
+            Transitions[i] = TNfaTransitionHelper::getVariantByIndex(index);
+            std::visit(serializer, Transitions[i]);
+        }
+        serializer(Input, Output);
+    }
+
+    bool operator==(const TNfaTransitionGraph& other) {
+        return Transitions == other.Transitions
+            && Input == other.Input
+            && Output == other.Output;
+    }
 };
 
 class TNfaTransitionGraphOptimizer {
@@ -78,6 +128,7 @@ public:
         EliminateSingleEpsilons();
         CollectGarbage();
     }
+
 private:
     void EliminateEpsilonChains() {
         for (size_t node = 0; node != Graph->Transitions.size(); node++) {
@@ -250,14 +301,69 @@ private:
 class TNfa {
     using TRange = TSparseList::TRange;
     using TMatchedVars = TMatchedVars<TRange>;
+
+
     struct TState {
+        
+        TState() {}
+
         TState(size_t index, const TMatchedVars& vars, std::stack<ui64, std::deque<ui64, TMKQLAllocator<ui64>>>&& quantifiers)
             : Index(index)
             , Vars(vars)
             , Quantifiers(quantifiers) {}
-        const size_t Index;
+        size_t Index;
         TMatchedVars Vars;
-        std::stack<ui64, std::deque<ui64, TMKQLAllocator<ui64>>> Quantifiers; //get rid of this
+        
+        using TQuantifiersStdStack = std::stack<
+            ui64,
+            std::deque<ui64, TMKQLAllocator<ui64>>>; //get rid of this
+
+        struct TQuantifiersStack: public TQuantifiersStdStack {
+            template<typename...TArgs>
+            TQuantifiersStack(TArgs... args) : TQuantifiersStdStack(args...) {}
+            
+            auto begin() const { return c.begin(); }
+            auto end() const { return c.end(); }
+            auto clear() { return c.clear(); }
+        };
+
+        TQuantifiersStack Quantifiers;
+
+        void Save(TOutputSerializer& serializer) const {
+            serializer.Write(Index);
+            serializer.Write(Vars.size());
+            for (const auto& vector : Vars) {
+                serializer.Write(vector.size());
+                for (const auto& range : vector) {
+                    range.Save(serializer);
+                }
+            }
+            serializer.Write(Quantifiers.size());
+            for (ui64 qnt : Quantifiers) {
+                serializer.Write(qnt);
+            }
+        }
+
+        void Load(TInputSerializer& serializer) {
+            serializer.Read(Index);
+
+            auto varsSize = serializer.Read<TMatchedVars::size_type>();
+            Vars.clear();
+            Vars.resize(varsSize);
+            for (auto& subvec: Vars) {
+                ui64 vectorSize = serializer.Read<ui64>();
+                subvec.resize(vectorSize);
+                for (auto& item : subvec) {
+                    item.Load(serializer);
+                }
+            }
+            Quantifiers.clear();
+            auto quantifiersSize = serializer.Read<ui64>();
+            for (size_t i = 0; i < quantifiersSize; ++i) {
+                ui64 qnt = serializer.Read<ui64>();
+                Quantifiers.push(qnt);
+            }
+        }
 
         friend inline bool operator<(const TState& lhs, const TState& rhs) {
             return std::tie(lhs.Index, lhs.Quantifiers, lhs.Vars) < std::tie(rhs.Index, rhs.Quantifiers, rhs.Vars);
@@ -267,13 +373,14 @@ class TNfa {
         }
     };
 public:
+
     TNfa(TNfaTransitionGraph::TPtr transitionGraph, IComputationExternalNode* matchedRangesArg, const TComputationNodePtrVector& defines)
         : TransitionGraph(transitionGraph)
         , MatchedRangesArg(matchedRangesArg)
         , Defines(defines) {
     }
 
-    void ProcessRow(TSparseList::TRange&& currentRowLock, TComputationContext& ctx) {
+    void ProcessRow(TSparseList::TRange&& currentRowLock, TComputationContext& ctx) {        
         ActiveStates.emplace(TransitionGraph->Input, TMatchedVars(Defines.size()), std::stack<ui64, std::deque<ui64, TMKQLAllocator<ui64>>>{});
         MakeEpsilonTransitions();
         std::set<TState, std::less<TState>, TMKQLAllocator<TState>> newStates;
@@ -329,6 +436,25 @@ public:
         return ActiveStates.size();
     }
 
+    void Save(TOutputSerializer& serializer) const {
+        // TransitionGraph is not saved/loaded, passed in constructor.
+        serializer.Write(ActiveStates.size());
+        for (const auto& state : ActiveStates) {
+            state.Save(serializer);
+        }
+        serializer.Write(EpsilonTransitionsLastRow);
+    }
+
+    void Load(TInputSerializer& serializer) {
+        auto stateSize = serializer.Read<ui64>();
+        for (size_t i = 0; i < stateSize; ++i) {
+            TState state;
+            state.Load(serializer);
+            ActiveStates.emplace(state);
+        }
+        serializer.Read(EpsilonTransitionsLastRow);
+    }
+
 private:
     //TODO (zverevgeny): Consider to change to std::vector for the sake of perf
     using TStateSet = std::set<TState, std::less<TState>, TMKQLAllocator<TState>>;
@@ -376,6 +502,7 @@ private:
         TStateSet& NewStates;
         TStateSet& DeletedStates;
     };
+
     bool MakeEpsilonTransitionsImpl() {
         TStateSet newStates;
         TStateSet deletedStates;

+ 217 - 0
ydb/library/yql/minikql/comp_nodes/mkql_match_recognize_save_load.h

@@ -0,0 +1,217 @@
+#pragma once
+
+#include <ydb/library/yql/minikql/computation/mkql_computation_node.h>
+#include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h>
+#include <ydb/library/yql/minikql/comp_nodes/mkql_saveload.h>
+#include <ydb/library/yql/minikql/mkql_string_util.h>
+
+namespace NKikimr::NMiniKQL::NMatchRecognize {
+
+struct TSerializerContext {
+
+    TComputationContext&    Ctx;
+    TType*                  RowType;
+    const TMutableObjectOverBoxedValue<TValuePackerBoxed>& RowPacker;
+};
+
+template<class>
+inline constexpr bool always_false_v = false;
+
+struct TOutputSerializer {
+private:
+    enum class TPtrStateMode {
+        Saved = 0,
+        FromCache = 1
+    };
+
+public:
+    TOutputSerializer(const TSerializerContext& context)
+        : Context(context)
+    {} 
+
+    template <typename... Ts>
+    void operator()(Ts&&... args) {
+        (Write(std::forward<Ts>(args)), ...);
+    }
+
+    template<typename Type>
+    void Write(const Type& value ) {
+        if constexpr (std::is_same_v<std::remove_cv_t<Type>, TString>) {
+            WriteString(Buf, value);
+        } else if constexpr (std::is_same_v<std::remove_cv_t<Type>, ui64>) {
+            WriteUi64(Buf, value);
+        } else if constexpr (std::is_same_v<std::remove_cv_t<Type>, bool>) {
+            WriteBool(Buf, value);
+        } else if constexpr (std::is_same_v<std::remove_cv_t<Type>, ui8>) {
+            WriteByte(Buf, value);
+        } else if constexpr (std::is_same_v<std::remove_cv_t<Type>, ui32>) {
+            WriteUi32(Buf, value);
+        } else if constexpr (std::is_same_v<std::remove_cv_t<Type>, NUdf::TUnboxedValue>) {     // Only Row type (StateType) supported !
+            WriteUnboxedValue(Buf, Context.RowPacker.RefMutableObject(Context.Ctx, false, Context.RowType), value);
+        } else if constexpr (std::is_empty_v<Type>){
+            // Empty struct is not saved/loaded.
+        } else {
+            static_assert(always_false_v<Type>, "Not supported type / not implemented");
+        }
+    }
+
+    template<class Type>
+    void Write(const TIntrusivePtr<Type>& ptr) {
+        bool isValid = static_cast<bool>(ptr);
+        WriteBool(Buf, isValid);
+        if (!isValid) {
+            return;
+        }
+        auto addr = reinterpret_cast<std::uintptr_t>(ptr.Get());
+        WriteUi64(Buf, addr);
+
+        auto it = Cache.find(addr);
+        if (it != Cache.end()) {
+            WriteByte(Buf, static_cast<ui8>(TPtrStateMode::FromCache));
+            return;
+        }
+        WriteByte(Buf, static_cast<ui8>(TPtrStateMode::Saved));
+        ptr->Save(*this);
+        Cache[addr] = addr;
+    }
+
+    template<class Type1, class Type2>
+    void Write(const std::pair<Type1, Type2>& value) {
+        Write(value.first);
+        Write(value.second);
+    }
+
+    template<class Type, class Allocator>
+    void Write(const std::vector<Type, Allocator>& value) {
+        Write(value.size());
+        for (size_t i = 0; i < value.size(); ++i) {
+            Write(value[i]);
+        }
+    }
+
+    NUdf::TUnboxedValuePod MakeString() {
+        auto strRef = NUdf::TStringRef(Buf.data(), Buf.size());
+        return NKikimr::NMiniKQL::MakeString(strRef);
+    }
+
+private:
+    const TSerializerContext& Context;
+    TString Buf;
+    mutable std::map<std::uintptr_t, std::uintptr_t> Cache;
+};
+
+struct TInputSerializer {
+private:
+    enum class TPtrStateMode {
+        Saved = 0,
+        FromCache = 1
+    };
+
+public:
+    TInputSerializer(TSerializerContext& context, const NUdf::TStringRef& state)
+        : Context(context)
+        , Buf(state.Data(), state.Size())
+    {}
+
+    template <typename... Ts>
+    void operator()(Ts&... args) {
+        (Read(args), ...);
+    }
+
+    template<typename Type, typename ReturnType = Type>
+    ReturnType Read() {
+        if constexpr (std::is_same_v<std::remove_cv_t<Type>, TString>) {
+            return ReadString(Buf);
+        } else if constexpr (std::is_same_v<std::remove_cv_t<Type>, ui64>) {
+            return ReadUi64(Buf);
+        } else if constexpr (std::is_same_v<std::remove_cv_t<Type>, bool>) {
+            return ReadBool(Buf);
+        } else if constexpr (std::is_same_v<std::remove_cv_t<Type>, ui8>) {
+            return ReadByte(Buf);
+        } else if constexpr (std::is_same_v<std::remove_cv_t<Type>, ui32>) {
+            return ReadUi32(Buf);
+        } else if constexpr (std::is_same_v<std::remove_cv_t<Type>, NUdf::TUnboxedValue>) {
+            return ReadUnboxedValue(Buf, Context.RowPacker.RefMutableObject(Context.Ctx, false, Context.RowType), Context.Ctx);
+        } else if constexpr (std::is_empty_v<Type>){
+            // Empty struct is not saved/loaded.
+        } else {
+            static_assert(always_false_v<Type>, "Not supported type / not implemented");
+        }
+    }
+
+    template<typename Type>
+    void Read(Type& value) {
+        if constexpr (std::is_same_v<std::remove_cv_t<Type>, TString>) {
+            value = ReadString(Buf);
+        } else if constexpr (std::is_same_v<std::remove_cv_t<Type>, ui64>) {
+            value = ReadUi64(Buf);
+        } else if constexpr (std::is_same_v<std::remove_cv_t<Type>, bool>) {
+            value = ReadBool(Buf);
+        } else if constexpr (std::is_same_v<std::remove_cv_t<Type>, ui8>) {
+            value = ReadByte(Buf);
+        } else if constexpr (std::is_same_v<std::remove_cv_t<Type>, ui32>) {
+            value = ReadUi32(Buf);
+        } else if constexpr (std::is_same_v<std::remove_cv_t<Type>, NUdf::TUnboxedValue>) {
+            value = ReadUnboxedValue(Buf, Context.RowPacker.RefMutableObject(Context.Ctx, false, Context.RowType), Context.Ctx);
+        } else if constexpr (std::is_empty_v<Type>){
+            // Empty struct is not saved/loaded.
+        } else {
+            static_assert(always_false_v<Type>, "Not supported type / not implemented");
+        }
+    }
+
+    template<class Type>
+    void Read(TIntrusivePtr<Type>& ptr) {
+        bool isValid = Read<bool>();
+        if (!isValid) {
+            ptr.Reset();
+            return;
+        }
+        ui64 addr = Read<ui64>();
+        TPtrStateMode mode = static_cast<TPtrStateMode>(Read<ui8>());
+        if (mode == TPtrStateMode::Saved) {
+            ptr = MakeIntrusive<Type>();
+            ptr->Load(*this);
+            Cache[addr] = ptr.Get();
+            return;
+        }
+        auto it = Cache.find(addr);
+        MKQL_ENSURE(it != Cache.end(), "Internal error");
+        auto* cachePtr = static_cast<Type*>(it->second);
+        ptr = TIntrusivePtr<Type>(cachePtr);
+    }
+
+    template<class Type1, class Type2>
+    void Read(std::pair<Type1, Type2>& value) {
+        Read(value.first);
+        Read(value.second);
+    }
+
+    template<class Type, class Allocator>
+    void Read(std::vector<Type, Allocator>& value) {
+        using TVector = std::vector<Type, Allocator>;
+        auto size = Read<typename TVector::size_type>();
+        //auto size = Read<TVector::size_type>();
+        value.clear();
+        value.resize(size);
+        for (size_t i = 0; i < size; ++i) {
+            Read(value[i]);
+        }
+    }
+
+    NUdf::TUnboxedValuePod MakeString() {
+        auto strRef = NUdf::TStringRef(Buf.data(), Buf.size());
+        return NKikimr::NMiniKQL::MakeString(strRef);
+    }
+
+    bool Empty() const {
+        return Buf.empty();
+    }
+
+private:
+    TSerializerContext& Context;
+    TStringBuf Buf;
+    mutable std::map<std::uintptr_t, void *> Cache;
+};
+
+} //namespace NKikimr::NMiniKQL::NMatchRecognize 

+ 1 - 1
ydb/library/yql/minikql/comp_nodes/mkql_time_order_recover.cpp

@@ -60,7 +60,7 @@ public:
             auto begin() const { return c.begin(); }
             auto end() const { return c.end(); }
             auto clear() { return c.clear(); }
-    };
+        };
 
     public:
 

+ 171 - 0
ydb/library/yql/minikql/comp_nodes/ut/mkql_match_recognize_ut.cpp

@@ -0,0 +1,171 @@
+#include "../mkql_time_order_recover.h"
+#include <ydb/library/yql/minikql/mkql_node.h>
+#include <ydb/library/yql/minikql/mkql_node_cast.h>
+#include <ydb/library/yql/minikql/mkql_program_builder.h>
+#include <ydb/library/yql/minikql/mkql_function_registry.h>
+#include <ydb/library/yql/minikql/computation/mkql_computation_node.h>
+#include <ydb/library/yql/minikql/computation/mkql_computation_node_holders.h>
+#include <ydb/library/yql/minikql/computation/mkql_computation_node_graph_saveload.h>
+#include <ydb/library/yql/minikql/invoke_builtins/mkql_builtins.h>
+#include <ydb/library/yql/minikql/comp_nodes/mkql_factories.h>
+
+#include <library/cpp/testing/unittest/registar.h>
+
+namespace NKikimr {
+    namespace NMiniKQL {
+
+        namespace {
+            TIntrusivePtr<IRandomProvider> CreateRandomProvider() {
+                return CreateDeterministicRandomProvider(1);
+            }
+
+            TIntrusivePtr<ITimeProvider> CreateTimeProvider() {
+                return CreateDeterministicTimeProvider(10000000);
+            }
+
+            struct TSetup {
+                TSetup(TScopedAlloc& alloc)
+                    : Alloc(alloc)
+                {
+                    FunctionRegistry = CreateFunctionRegistry(CreateBuiltinRegistry());
+                    RandomProvider = CreateRandomProvider();
+                    TimeProvider = CreateTimeProvider();
+
+                    Env.Reset(new TTypeEnvironment(Alloc));
+                    PgmBuilder.Reset(new TProgramBuilder(*Env, *FunctionRegistry));
+                }
+
+                THolder<IComputationGraph> BuildGraph(TRuntimeNode pgm, const std::vector<TNode*>& entryPoints = std::vector<TNode*>()) {
+                    Explorer.Walk(pgm.GetNode(), *Env);
+                    TComputationPatternOpts opts(
+                    Alloc.Ref(),
+                    *Env, GetBuiltinFactory(),
+                    FunctionRegistry.Get(),
+                    NUdf::EValidateMode::None,
+                    NUdf::EValidatePolicy::Fail, "OFF", EGraphPerProcess::Multi);
+                    Pattern = MakeComputationPattern(Explorer, pgm, entryPoints, opts);
+                    TComputationOptsFull compOpts = opts.ToComputationOptions(*RandomProvider, *TimeProvider);
+                    return Pattern->Clone(compOpts);
+                }
+
+                TIntrusivePtr<IFunctionRegistry> FunctionRegistry;
+                TIntrusivePtr<IRandomProvider> RandomProvider;
+                TIntrusivePtr<ITimeProvider> TimeProvider;
+                TScopedAlloc& Alloc;
+                THolder<TTypeEnvironment> Env;
+                THolder<TProgramBuilder> PgmBuilder;
+                TExploringNodeVisitor Explorer;
+                IComputationPattern::TPtr Pattern;
+            };
+
+            using TTestInputData = std::vector<std::tuple<i64, std::string, ui32, std::string>>;
+
+            THolder<IComputationGraph> BuildGraph(
+                TSetup& setup,
+                bool streamingMode,
+                const TTestInputData& input) {
+                TProgramBuilder& pgmBuilder = *setup.PgmBuilder;
+
+                auto structType = pgmBuilder.NewStructType({
+                    {"time", pgmBuilder.NewDataType(NUdf::TDataType<i64>::Id)},
+                    {"key", pgmBuilder.NewDataType(NUdf::TDataType<char*>::Id)},
+                    {"sum", pgmBuilder.NewDataType(NUdf::TDataType<ui32>::Id)},
+                    {"part", pgmBuilder.NewDataType(NUdf::TDataType<char*>::Id)}});
+
+                TVector<TRuntimeNode> items;
+                for (size_t i = 0; i < input.size(); ++i)
+                {
+                    auto time = pgmBuilder.NewDataLiteral<i64>(std::get<0>(input[i]));
+                    auto key = pgmBuilder.NewDataLiteral<NUdf::EDataSlot::String>(NUdf::TStringRef(std::get<1>(input[i])));
+                    auto sum = pgmBuilder.NewDataLiteral<ui32>(std::get<2>(input[i]));
+                    auto part = pgmBuilder.NewDataLiteral<NUdf::EDataSlot::String>(NUdf::TStringRef(std::get<3>(input[i])));
+
+                    auto item = pgmBuilder.NewStruct(structType,
+                        {{"time", time}, {"key", key}, {"sum", sum},  {"part", part}});
+                    items.push_back(std::move(item));
+                }
+                
+                const auto list = pgmBuilder.NewList(structType, std::move(items));
+                auto inputFlow = pgmBuilder.ToFlow(list);
+
+                TVector<TStringBuf> partitionColumns;
+                TVector<std::pair<TStringBuf, TProgramBuilder::TBinaryLambda>> getMeasures = {{
+                    std::make_pair(
+                        TStringBuf("key"),
+                        [&](TRuntimeNode /*measureInputDataArg*/, TRuntimeNode /*matchedVarsArg*/) {
+                            return pgmBuilder.NewDataLiteral<ui32>(56);
+                        }
+                )}};
+                TVector<std::pair<TStringBuf, TProgramBuilder::TTernaryLambda>> getDefines = {{
+                    std::make_pair(
+                        TStringBuf("A"),
+                        [&](TRuntimeNode /*inputDataArg*/, TRuntimeNode /*matchedVarsArg*/, TRuntimeNode /*currentRowIndexArg*/) {
+                            return pgmBuilder.NewDataLiteral<bool>(true);
+                        }
+                )}};
+
+                auto pgmReturn = pgmBuilder.MatchRecognizeCore(
+                    inputFlow,
+                    [&](TRuntimeNode item) {
+                        return pgmBuilder.Member(item, "part");
+                    },
+                    partitionColumns,
+                    getMeasures,
+                    {
+                        {NYql::NMatchRecognize::TRowPatternFactor{"A", 3, 3, false, false, false}}
+                    },
+                    getDefines,
+                    streamingMode);
+
+                auto graph = setup.BuildGraph(pgmReturn);
+                return graph;
+            }
+        }
+
+        Y_UNIT_TEST_SUITE(TMiniKQLMatchRecognizeSaveLoadTest) {
+            void TestWithSaveLoadImpl(
+                bool streamingMode)
+            {
+                TScopedAlloc alloc(__LOCATION__);
+                std::vector<std::tuple<ui32, i64, ui32>> result;
+                TSetup setup1(alloc);
+
+                const TTestInputData input = {
+                    {1000, "A", 101, "P"},
+                    {1001, "B", 102, "P"},
+                    {1002, "C", 103, "P"},      // <- match end
+                    {1003, "D", 103, "P"}};     // <- not processed
+                    
+                auto graph1 = BuildGraph(setup1,streamingMode, input);
+
+                auto value = graph1->GetValue();
+
+                UNIT_ASSERT(!value.IsFinish() && value);
+                auto v = value.GetElement(0).Get<ui32>();
+
+                TString graphState = graph1->SaveGraphState();
+
+                graph1.Reset();
+
+                TSetup setup2(alloc);
+
+                auto graph2 = BuildGraph(setup2, streamingMode, TTestInputData{{1003, "D", 103, "P"}});
+                graph2->LoadGraphState(graphState);
+
+                value = graph2->GetValue();
+                UNIT_ASSERT(!value.IsFinish() && value);
+                v = value.GetElement(0).Get<ui32>();
+                UNIT_ASSERT_VALUES_EQUAL(56, v);
+            }
+
+            Y_UNIT_TEST(StreamingMode) {
+                TestWithSaveLoadImpl(true);
+            }
+
+            Y_UNIT_TEST(NotStreamingMode) {
+                TestWithSaveLoadImpl(false);
+            }
+        }
+
+    } // namespace NMiniKQL
+} // namespace NKikimr

+ 1 - 0
ydb/library/yql/minikql/comp_nodes/ut/ya.make.inc

@@ -52,6 +52,7 @@ SET(ORIG_SOURCES
     mkql_match_recognize_matched_vars_ut.cpp
     mkql_match_recognize_list_ut.cpp
     mkql_match_recognize_nfa_ut.cpp
+    mkql_match_recognize_ut.cpp
     mkql_safe_circular_buffer_ut.cpp
     mkql_sort_ut.cpp
     mkql_switch_ut.cpp

+ 121 - 67
ydb/tests/fq/yds/test_recovery_match_recognize.py

@@ -9,6 +9,7 @@ import time
 import ydb.tests.library.common.yatest_common as yatest_common
 from ydb.tests.tools.fq_runner.kikimr_runner import StreamingOverKikimr
 from ydb.tests.tools.fq_runner.kikimr_runner import StreamingOverKikimrConfig
+from ydb.tests.tools.fq_runner.kikimr_runner import TenantConfig
 import library.python.retry as retry
 from ydb.tests.tools.fq_runner.kikimr_utils import yq_v1
 from ydb.tests.tools.datastreams_helpers.test_yds_base import TestYdsBase
@@ -17,7 +18,7 @@ import ydb.public.api.protos.draft.fq_pb2 as fq
 
 @pytest.fixture
 def kikimr(request):
-    kikimr_conf = StreamingOverKikimrConfig(cloud_mode=True, node_count=2)
+    kikimr_conf = StreamingOverKikimrConfig(cloud_mode=True, node_count={"/cp": TenantConfig(1), "/compute": TenantConfig(1)})
     kikimr = StreamingOverKikimr(kikimr_conf)
     kikimr.start_mvp_mock_server()
     kikimr.start()
@@ -33,13 +34,6 @@ class TestRecoveryMatchRecognize(TestYdsBase):
         # for retry
         cls.retry_conf = retry.RetryConf().upto(seconds=30).waiting(0.1)
 
-    @retry.retry_intrusive
-    def get_graph_master_node_id(self, kikimr, query_id):
-        for node_index in kikimr.control_plane.kikimr_cluster.nodes:
-            if kikimr.control_plane.get_task_count(node_index, query_id) > 0:
-                return node_index
-        assert False, "No active graphs found"
-
     def get_ca_count(self, kikimr, node_index):
         result = kikimr.control_plane.get_sensors(node_index, "utils").find_sensor(
             {"activity": "DQ_COMPUTE_ACTOR", "sensor": "ActorsAliveByActivity", "execpool": "User"}
@@ -70,11 +64,59 @@ class TestRecoveryMatchRecognize(TestYdsBase):
                     logging.debug("Node {}, workers {}, ca {}".format(s, w, c))
                 assert False, "Workers={} and CAs={}, but {} and {} expected".format(wcs, ccs, worker_count, ca_count)
 
+    def restart_node(self, kikimr, query_id):
+        # restart node with CA
+
+        node_to_restart = None
+
+        for node_index in kikimr.compute_plane.kikimr_cluster.nodes:
+            wc = kikimr.compute_plane.get_worker_count(node_index)
+            if wc is not None:
+                if wc > 0 and node_to_restart is None:
+                    node_to_restart = node_index
+        assert node_to_restart is not None, "Can't find any task on node"
+
+        logging.debug("Restart compute node {}".format(node_to_restart))
+
+        kikimr.compute_plane.kikimr_cluster.nodes[node_to_restart].stop()
+        kikimr.compute_plane.kikimr_cluster.nodes[node_to_restart].start()
+        kikimr.compute_plane.wait_bootstrap(node_to_restart)
+
+    def recovery_impl(self, kikimr, client, yq_version, sql_template, test_name, messages_before_restart, messages_after_restart, expected):
+
+        self.init_topics(f"{test_name}_{yq_version}")
+
+        sql = sql_template.format(self.input_topic, self.output_topic);
+
+        client.create_yds_connection("myyds", os.getenv("YDB_DATABASE"), os.getenv("YDB_ENDPOINT"))
+        query_id = client.create_query("simple", sql, type=fq.QueryContent.QueryType.STREAMING).result.query_id
+        client.wait_query_status(query_id, fq.QueryMeta.RUNNING)
+        kikimr.compute_plane.wait_zero_checkpoint(query_id)
+
+        self.write_stream(messages_before_restart)
+
+        logging.debug("get_completed_checkpoints {}".format(kikimr.compute_plane.get_completed_checkpoints(query_id)))
+        kikimr.compute_plane.wait_completed_checkpoints(
+            query_id, kikimr.compute_plane.get_completed_checkpoints(query_id) + 1
+        )
+
+        self.restart_node(kikimr, query_id)
+        self.write_stream(messages_after_restart)
+
+        assert client.get_query_status(query_id) == fq.QueryMeta.RUNNING
+
+        read_data = self.read_stream(len(expected))
+        logging.info("Data was read: {}".format(read_data))
+
+        assert read_data == expected
+
+        client.abort_query(query_id)
+        client.wait_query(query_id)
+        self.dump_workers(kikimr, 0, 0)
+
     @yq_v1
     @pytest.mark.parametrize("kikimr", [(None, None, None)], indirect=["kikimr"])
-    def test_program_state_recovery(self, kikimr, client, yq_version):
-
-        self.init_topics(f"pq_kikimr_streaming_{yq_version}")
+    def test_time_order_recoverer(self, kikimr, client, yq_version, request):
 
         sql = R'''
             PRAGMA dq.MaxTasksPerStage="2";
@@ -83,9 +125,9 @@ class TestRecoveryMatchRecognize(TestYdsBase):
             pragma config.flags("TimeOrderRecoverDelay", "-1000000");
             pragma config.flags("TimeOrderRecoverAhead", "1000000");
 
-            INSERT INTO myyds.`{output_topic}`
+            INSERT INTO myyds.`{1}`
             SELECT ToBytes(Unwrap(Json::SerializeJson(Yson::From(TableRow()))))
-            FROM (SELECT * FROM myyds.`{input_topic}`
+            FROM (SELECT * FROM myyds.`{0}`
                 WITH (
                     format=json_each_row,
                     SCHEMA
@@ -99,62 +141,74 @@ class TestRecoveryMatchRecognize(TestYdsBase):
                 ONE ROW PER MATCH
                 PATTERN ( ALL_TRUE )
                 DEFINE
-                    ALL_TRUE as True)''' \
-            .format(
-            input_topic=self.input_topic,
-            output_topic=self.output_topic,
-        )
-
-        client.create_yds_connection("myyds", os.getenv("YDB_DATABASE"), os.getenv("YDB_ENDPOINT"))
-        query_id = client.create_query("simple", sql, type=fq.QueryContent.QueryType.STREAMING).result.query_id
-        client.wait_query_status(query_id, fq.QueryMeta.RUNNING)
-        kikimr.compute_plane.wait_zero_checkpoint(query_id)
-
-        master_node_index = self.get_graph_master_node_id(kikimr, query_id)
-        logging.debug("Master node {}".format(master_node_index))
-
-        messages1 = ['{"dt": 1696849942400002}', '{"dt": 1696849942000001}']
-        self.write_stream(messages1)
+                    ALL_TRUE as True)'''
+
+        messages_before_restart = [
+            '{"dt":1696849942400002}',
+            '{"dt":1696849942000001}']
+        messages_after_restart = [
+            '{"dt":1696849942800000}',
+            '{"dt":1696849943200003}',
+            '{"dt":1696849943300003}',
+            '{"dt":1696849943600003}',
+            '{"dt":1696849943900003}']
+        expected = [
+            '{"dt":1696849942000001}',
+            '{"dt":1696849942400002}',
+            '{"dt":1696849942800000}']
+
+        self.recovery_impl(kikimr, client, yq_version, sql, request.node.name, messages_before_restart, messages_after_restart, expected)
 
-        logging.debug("get_completed_checkpoints {}".format(kikimr.compute_plane.get_completed_checkpoints(query_id)))
-        kikimr.compute_plane.wait_completed_checkpoints(
-            query_id, kikimr.compute_plane.get_completed_checkpoints(query_id) + 1
-        )
-
-        # restart node with CA
-        node_to_restart = None
-        for node_index in kikimr.control_plane.kikimr_cluster.nodes:
-            wc = kikimr.control_plane.get_worker_count(node_index)
-            if wc is not None:
-                if wc > 0 and node_index != master_node_index and node_to_restart is None:
-                    node_to_restart = node_index
-        assert node_to_restart is not None, "Can't find any task on non master node"
-
-        logging.debug("Restart non-master node {}".format(node_to_restart))
-
-        kikimr.control_plane.kikimr_cluster.nodes[node_to_restart].stop()
-        kikimr.control_plane.kikimr_cluster.nodes[node_to_restart].start()
-        kikimr.control_plane.wait_bootstrap(node_to_restart)
-
-        messages2 = [
-            '{"dt": 1696849942800000}',
-            '{"dt": 1696849943200003}',
-            '{"dt": 1696849943300003}',
-            '{"dt": 1696849943600003}',
-            '{"dt": 1696849943900003}'
-        ]
-        self.write_stream(messages2)
-
-        assert client.get_query_status(query_id) == fq.QueryMeta.RUNNING
-
-        expected = ['{"dt":1696849942000001}', '{"dt":1696849942400002}', '{"dt":1696849942800000}']
+    @yq_v1
+    @pytest.mark.parametrize("kikimr", [(None, None, None)], indirect=["kikimr"])
+    def test_match_recognize(self, kikimr, client, yq_version, request):
 
-        read_data = self.read_stream(len(expected))
-        logging.info("Data was read: {}".format(read_data))
+        sql = R'''
+            PRAGMA dq.MaxTasksPerStage="2";
 
-        assert read_data == expected
+            pragma FeatureR010="prototype";
+            pragma config.flags("TimeOrderRecoverDelay", "-1000000");
+            pragma config.flags("TimeOrderRecoverAhead", "1000000");
+            pragma config.flags("MatchRecognizeStream", "auto");
 
-        client.abort_query(query_id)
-        client.wait_query(query_id)
+            INSERT INTO myyds.`{1}`
+            SELECT ToBytes(Unwrap(Json::SerializeJson(Yson::From(TableRow()))))
+            FROM (SELECT * FROM myyds.`{0}`
+                WITH (
+                    format=json_each_row,
+                    SCHEMA
+                    (
+                        dt UINT64,
+                        str STRING
+                    )))
+            MATCH_RECOGNIZE(
+                ORDER BY CAST(dt as Timestamp)
+                MEASURES
+                   LAST(A.dt) as dt_begin,
+                   LAST(C.dt) as dt_end,
+                   LAST(A.str) as a_str,
+                   LAST(B.str) as b_str,
+                   LAST(C.str) as c_str
+                ONE ROW PER MATCH
+                PATTERN ( A B C )
+                DEFINE
+                    A as A.str='A',
+                    B as B.str='B',
+                    C as C.str='C')'''
+
+        messages_before_restart = [
+            '{"dt": 1696849942000001, "str": "A" }',
+            '{"dt": 1696849942500001, "str": "B" }',
+            '{"dt": 1696849943000001, "str": "C" }',
+            '{"dt": 1696849943600001, "str": "D" }']       # push A+B from TimeOrderRecoverer to MatchRecognize 
+
+        # Before restart:
+        #    A + B : in MatchRecognize
+        #    C + D : in TimeOrderRecoverer
+
+        messages_after_restart = [
+            '{"dt": 1696849944100001, "str": "E" }']
+        expected = [
+            '{"a_str":"A","b_str":"B","c_str":"C","dt_begin":1696849942000001,"dt_end":1696849943000001}']
+        self.recovery_impl(kikimr, client, yq_version, sql, request.node.name, messages_before_restart, messages_after_restart, expected)
 
-        self.dump_workers(kikimr, 0, 0)