mkql_match_recognize_ut.cpp 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. #include "../mkql_time_order_recover.h"
  2. #include <yql/essentials/minikql/mkql_node.h>
  3. #include <yql/essentials/minikql/mkql_node_cast.h>
  4. #include <yql/essentials/minikql/mkql_program_builder.h>
  5. #include <yql/essentials/minikql/mkql_function_registry.h>
  6. #include <yql/essentials/minikql/computation/mkql_computation_node.h>
  7. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  8. #include <yql/essentials/minikql/computation/mkql_computation_node_graph_saveload.h>
  9. #include <yql/essentials/minikql/invoke_builtins/mkql_builtins.h>
  10. #include <yql/essentials/minikql/comp_nodes/mkql_factories.h>
  11. #include <library/cpp/testing/unittest/registar.h>
  12. namespace NKikimr::NMiniKQL {
  13. namespace {
  14. TIntrusivePtr<IRandomProvider> CreateRandomProvider() {
  15. return CreateDeterministicRandomProvider(1);
  16. }
  17. TIntrusivePtr<ITimeProvider> CreateTimeProvider() {
  18. return CreateDeterministicTimeProvider(10000000);
  19. }
  20. struct TSetup {
  21. TSetup(TScopedAlloc& alloc)
  22. : Alloc(alloc)
  23. {
  24. FunctionRegistry = CreateFunctionRegistry(CreateBuiltinRegistry());
  25. RandomProvider = CreateRandomProvider();
  26. TimeProvider = CreateTimeProvider();
  27. Env.Reset(new TTypeEnvironment(Alloc));
  28. PgmBuilder.Reset(new TProgramBuilder(*Env, *FunctionRegistry));
  29. }
  30. THolder<IComputationGraph> BuildGraph(TRuntimeNode pgm, const std::vector<TNode*>& entryPoints = std::vector<TNode*>()) {
  31. Explorer.Walk(pgm.GetNode(), *Env);
  32. TComputationPatternOpts opts(
  33. Alloc.Ref(),
  34. *Env, GetBuiltinFactory(),
  35. FunctionRegistry.Get(),
  36. NUdf::EValidateMode::None,
  37. NUdf::EValidatePolicy::Fail, "OFF", EGraphPerProcess::Multi);
  38. Pattern = MakeComputationPattern(Explorer, pgm, entryPoints, opts);
  39. TComputationOptsFull compOpts = opts.ToComputationOptions(*RandomProvider, *TimeProvider);
  40. return Pattern->Clone(compOpts);
  41. }
  42. TIntrusivePtr<IFunctionRegistry> FunctionRegistry;
  43. TIntrusivePtr<IRandomProvider> RandomProvider;
  44. TIntrusivePtr<ITimeProvider> TimeProvider;
  45. TScopedAlloc& Alloc;
  46. THolder<TTypeEnvironment> Env;
  47. THolder<TProgramBuilder> PgmBuilder;
  48. TExploringNodeVisitor Explorer;
  49. IComputationPattern::TPtr Pattern;
  50. };
  51. using TTestInputData = std::vector<std::tuple<i64, std::string, ui32, std::string>>;
  52. THolder<IComputationGraph> BuildGraph(
  53. TSetup& setup,
  54. bool streamingMode,
  55. const TTestInputData& input) {
  56. TProgramBuilder& pgmBuilder = *setup.PgmBuilder;
  57. const auto structType = pgmBuilder.NewStructType({
  58. {"time", pgmBuilder.NewDataType(NUdf::EDataSlot::Int64)},
  59. {"key", pgmBuilder.NewDataType(NUdf::EDataSlot::String)},
  60. {"sum", pgmBuilder.NewDataType(NUdf::EDataSlot::Uint32)},
  61. {"part", pgmBuilder.NewDataType(NUdf::EDataSlot::String)}
  62. });
  63. TVector<TRuntimeNode> items;
  64. for (size_t i = 0; i < input.size(); ++i) {
  65. const auto& [time, key, sum, part] = input[i];
  66. items.push_back(pgmBuilder.NewStruct({
  67. {"time", pgmBuilder.NewDataLiteral(time)},
  68. {"key", pgmBuilder.NewDataLiteral<NUdf::EDataSlot::String>(key)},
  69. {"sum", pgmBuilder.NewDataLiteral(sum)},
  70. {"part", pgmBuilder.NewDataLiteral<NUdf::EDataSlot::String>(part)},
  71. }));
  72. }
  73. const auto list = pgmBuilder.NewList(structType, std::move(items));
  74. auto inputFlow = pgmBuilder.ToFlow(list);
  75. auto pgmReturn = pgmBuilder.MatchRecognizeCore(
  76. inputFlow,
  77. [&](TRuntimeNode item) {
  78. return pgmBuilder.NewTuple({pgmBuilder.Member(item, "part")});
  79. },
  80. {},
  81. {"key"sv},
  82. {[&](TRuntimeNode /*measureInputDataArg*/, TRuntimeNode /*matchedVarsArg*/) {
  83. return pgmBuilder.NewDataLiteral<ui32>(56);
  84. }},
  85. {
  86. {NYql::NMatchRecognize::TRowPatternFactor{"A", 3, 3, false, false, false}}
  87. },
  88. {"A"sv},
  89. {[&](TRuntimeNode /*inputDataArg*/, TRuntimeNode /*matchedVarsArg*/, TRuntimeNode /*currentRowIndexArg*/) {
  90. return pgmBuilder.NewDataLiteral<bool>(true);
  91. }},
  92. streamingMode,
  93. {NYql::NMatchRecognize::EAfterMatchSkipTo::NextRow, ""},
  94. NYql::NMatchRecognize::ERowsPerMatch::OneRow
  95. );
  96. auto graph = setup.BuildGraph(pgmReturn);
  97. return graph;
  98. }
  99. }
  100. Y_UNIT_TEST_SUITE(MatchRecognizeSaveLoadTest) {
  101. void TestWithSaveLoadImpl(bool streamingMode) {
  102. TScopedAlloc alloc(__LOCATION__);
  103. std::vector<std::tuple<ui32, i64, ui32>> result;
  104. TSetup setup1(alloc);
  105. const TTestInputData input = {
  106. {1000, "A", 101, "P"},
  107. {1001, "B", 102, "P"},
  108. {1002, "C", 103, "P"}, // <- match end
  109. {1003, "D", 103, "P"}}; // <- not processed
  110. auto graph1 = BuildGraph(setup1,streamingMode, input);
  111. auto value = graph1->GetValue();
  112. UNIT_ASSERT(!value.IsFinish() && value);
  113. auto v = value.GetElement(0).Get<ui32>();
  114. TString graphState = graph1->SaveGraphState();
  115. graph1.Reset();
  116. TSetup setup2(alloc);
  117. auto graph2 = BuildGraph(setup2, streamingMode, TTestInputData{{1003, "D", 103, "P"}});
  118. graph2->LoadGraphState(graphState);
  119. value = graph2->GetValue();
  120. UNIT_ASSERT(!value.IsFinish() && value);
  121. v = value.GetElement(0).Get<ui32>();
  122. UNIT_ASSERT_VALUES_EQUAL(56, v);
  123. }
  124. Y_UNIT_TEST(StreamingMode) {
  125. TestWithSaveLoadImpl(true);
  126. }
  127. Y_UNIT_TEST(NotStreamingMode) {
  128. TestWithSaveLoadImpl(false);
  129. }
  130. }
  131. } // namespace NKikimr::NMiniKQL