mkql_multihopping_saveload_ut.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. #include "../mkql_multihopping.h"
  2. #include "mkql_computation_node_ut.h"
  3. #include <yql/essentials/minikql/mkql_node.h>
  4. #include <yql/essentials/minikql/mkql_node_cast.h>
  5. #include <yql/essentials/minikql/mkql_program_builder.h>
  6. #include <yql/essentials/minikql/mkql_function_registry.h>
  7. #include <yql/essentials/minikql/computation/mkql_computation_node.h>
  8. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  9. #include <yql/essentials/minikql/computation/mkql_computation_node_graph_saveload.h>
  10. #include <yql/essentials/minikql/invoke_builtins/mkql_builtins.h>
  11. #include <yql/essentials/minikql/comp_nodes/mkql_factories.h>
  12. #include <library/cpp/testing/unittest/registar.h>
  13. namespace NKikimr {
  14. namespace NMiniKQL {
  15. namespace {
  16. TComputationNodeFactory GetAuxCallableFactory(TWatermark& watermark) {
  17. return [&watermark](TCallable& callable, const TComputationNodeFactoryContext& ctx) -> IComputationNode* {
  18. if (callable.GetType()->GetName() == "OneYieldStream") {
  19. return new TExternalComputationNode(ctx.Mutables);
  20. } else if (callable.GetType()->GetName() == "MultiHoppingCore") {
  21. return WrapMultiHoppingCore(callable, ctx, watermark);
  22. }
  23. return GetBuiltinFactory()(callable, ctx);
  24. };
  25. }
  26. struct TStreamWithYield : public NUdf::TBoxedValue {
  27. TStreamWithYield(const TUnboxedValueVector& items, ui32 yieldPos, ui32 index)
  28. : Items(items)
  29. , YieldPos(yieldPos)
  30. , Index(index)
  31. {}
  32. private:
  33. TUnboxedValueVector Items;
  34. ui32 YieldPos;
  35. ui32 Index;
  36. ui32 GetTraverseCount() const override {
  37. return 0;
  38. }
  39. NUdf::TUnboxedValue Save() const override {
  40. return NUdf::TUnboxedValue::Zero();
  41. }
  42. bool Load2(const NUdf::TUnboxedValue& state) override {
  43. Y_UNUSED(state);
  44. return false;
  45. }
  46. NUdf::EFetchStatus Fetch(NUdf::TUnboxedValue& result) final {
  47. if (Index >= Items.size()) {
  48. return NUdf::EFetchStatus::Finish;
  49. }
  50. if (Index == YieldPos) {
  51. return NUdf::EFetchStatus::Yield;
  52. }
  53. result = Items[Index++];
  54. return NUdf::EFetchStatus::Ok;
  55. }
  56. };
  57. THolder<IComputationGraph> BuildGraph(TSetup<false>& setup, const std::vector<std::tuple<ui32, i64, ui32>> items,
  58. ui32 yieldPos, ui32 startIndex, bool dataWatermarks) {
  59. TProgramBuilder& pgmBuilder = *setup.PgmBuilder;
  60. auto structType = pgmBuilder.NewEmptyStructType();
  61. structType = pgmBuilder.NewStructType(structType, "key",
  62. pgmBuilder.NewDataType(NUdf::TDataType<ui32>::Id));
  63. structType = pgmBuilder.NewStructType(structType, "time",
  64. pgmBuilder.NewDataType(NUdf::TDataType<NUdf::TTimestamp>::Id));
  65. structType = pgmBuilder.NewStructType(structType, "sum",
  66. pgmBuilder.NewDataType(NUdf::TDataType<ui32>::Id));
  67. auto keyIndex = AS_TYPE(TStructType, structType)->GetMemberIndex("key");
  68. auto timeIndex = AS_TYPE(TStructType, structType)->GetMemberIndex("time");
  69. auto sumIndex = AS_TYPE(TStructType, structType)->GetMemberIndex("sum");
  70. auto inStreamType = pgmBuilder.NewStreamType(structType);
  71. TCallableBuilder inStream(pgmBuilder.GetTypeEnvironment(), "OneYieldStream", inStreamType);
  72. auto streamNode = inStream.Build();
  73. ui64 hop = 10, interval = 30, delay = 20;
  74. auto pgmReturn = pgmBuilder.MultiHoppingCore(
  75. TRuntimeNode(streamNode, false),
  76. [&](TRuntimeNode item) { // keyExtractor
  77. return pgmBuilder.Member(item, "key");
  78. },
  79. [&](TRuntimeNode item) { // timeExtractor
  80. return pgmBuilder.Member(item, "time");
  81. },
  82. [&](TRuntimeNode item) { // init
  83. std::vector<std::pair<std::string_view, TRuntimeNode>> members;
  84. members.emplace_back("sum", pgmBuilder.Member(item, "sum"));
  85. return pgmBuilder.NewStruct(members);
  86. },
  87. [&](TRuntimeNode item, TRuntimeNode state) { // update
  88. auto add = pgmBuilder.AggrAdd(
  89. pgmBuilder.Member(item, "sum"),
  90. pgmBuilder.Member(state, "sum"));
  91. std::vector<std::pair<std::string_view, TRuntimeNode>> members;
  92. members.emplace_back("sum", add);
  93. return pgmBuilder.NewStruct(members);
  94. },
  95. [&](TRuntimeNode state) { // save
  96. return pgmBuilder.Member(state, "sum");
  97. },
  98. [&](TRuntimeNode savedState) { // load
  99. std::vector<std::pair<std::string_view, TRuntimeNode>> members;
  100. members.emplace_back("sum", savedState);
  101. return pgmBuilder.NewStruct(members);
  102. },
  103. [&](TRuntimeNode state1, TRuntimeNode state2) { // merge
  104. auto add = pgmBuilder.AggrAdd(
  105. pgmBuilder.Member(state1, "sum"),
  106. pgmBuilder.Member(state2, "sum"));
  107. std::vector<std::pair<std::string_view, TRuntimeNode>> members;
  108. members.emplace_back("sum", add);
  109. return pgmBuilder.NewStruct(members);
  110. },
  111. [&](TRuntimeNode key, TRuntimeNode state, TRuntimeNode time) { // finish
  112. Y_UNUSED(time);
  113. std::vector<std::pair<std::string_view, TRuntimeNode>> members;
  114. members.emplace_back("key", key);
  115. members.emplace_back("sum", pgmBuilder.Member(state, "sum"));
  116. return pgmBuilder.NewStruct(members);
  117. },
  118. pgmBuilder.NewDataLiteral<NUdf::EDataSlot::Interval>(NUdf::TStringRef((const char*)&hop, sizeof(hop))), // hop
  119. pgmBuilder.NewDataLiteral<NUdf::EDataSlot::Interval>(NUdf::TStringRef((const char*)&interval, sizeof(interval))), // interval
  120. pgmBuilder.NewDataLiteral<NUdf::EDataSlot::Interval>(NUdf::TStringRef((const char*)&delay, sizeof(delay))), // delay
  121. pgmBuilder.NewDataLiteral<bool>(dataWatermarks), // dataWatermarks
  122. pgmBuilder.NewDataLiteral<bool>(false)
  123. );
  124. auto graph = setup.BuildGraph(pgmReturn, {streamNode});
  125. TUnboxedValueVector streamItems;
  126. for (size_t i = 0; i < items.size(); ++i) {
  127. NUdf::TUnboxedValue* itemsPtr;
  128. auto structValues = graph->GetHolderFactory().CreateDirectArrayHolder(3, itemsPtr);
  129. itemsPtr[keyIndex] = NUdf::TUnboxedValuePod(std::get<0>(items[i]));
  130. itemsPtr[timeIndex] = NUdf::TUnboxedValuePod(std::get<1>(items[i]));
  131. itemsPtr[sumIndex] = NUdf::TUnboxedValuePod(std::get<2>(items[i]));
  132. streamItems.push_back(std::move(structValues));
  133. }
  134. auto streamValue = NUdf::TUnboxedValuePod(new TStreamWithYield(streamItems, yieldPos, startIndex));
  135. graph->GetEntryPoint(0, true)->SetValue(graph->GetContext(), std::move(streamValue));
  136. return graph;
  137. }
  138. }
  139. Y_UNIT_TEST_SUITE(TMiniKQLMultiHoppingSaveLoadTest) {
  140. void TestWithSaveLoadImpl(
  141. const std::vector<std::tuple<ui32, i64, ui32>> input,
  142. const std::vector<std::tuple<ui32, ui32>> expected,
  143. bool withTraverse,
  144. bool dataWatermarks)
  145. {
  146. TWatermark watermark;
  147. for (ui32 yieldPos = 0; yieldPos < input.size(); ++yieldPos) {
  148. std::vector<std::tuple<ui32, ui32>> result;
  149. TSetup<false> setup1(GetAuxCallableFactory(watermark));
  150. auto graph1 = BuildGraph(setup1, input, yieldPos, 0, dataWatermarks);
  151. auto root1 = graph1->GetValue();
  152. NUdf::EFetchStatus status = NUdf::EFetchStatus::Ok;
  153. while (status == NUdf::EFetchStatus::Ok) {
  154. NUdf::TUnboxedValue val;
  155. status = root1.Fetch(val);
  156. if (status == NUdf::EFetchStatus::Ok) {
  157. result.emplace_back(val.GetElement(0).Get<ui32>(), val.GetElement(1).Get<ui32>());
  158. }
  159. }
  160. UNIT_ASSERT_EQUAL(status, NUdf::EFetchStatus::Yield);
  161. TString graphState;
  162. if (withTraverse) {
  163. SaveGraphState(&root1, 1, 0ULL, graphState);
  164. } else {
  165. graphState = graph1->SaveGraphState();
  166. }
  167. TSetup<false> setup2(GetAuxCallableFactory(watermark));
  168. auto graph2 = BuildGraph(setup2, input, -1, yieldPos, dataWatermarks);
  169. NUdf::TUnboxedValue root2;
  170. if (withTraverse) {
  171. root2 = graph2->GetValue();
  172. LoadGraphState(&root2, 1, 0ULL, graphState);
  173. } else {
  174. graph2->LoadGraphState(graphState);
  175. root2 = graph2->GetValue();
  176. }
  177. status = NUdf::EFetchStatus::Ok;
  178. while (status == NUdf::EFetchStatus::Ok) {
  179. NUdf::TUnboxedValue val;
  180. status = root2.Fetch(val);
  181. if (status == NUdf::EFetchStatus::Ok) {
  182. result.emplace_back(val.GetElement(0).Get<ui32>(), val.GetElement(1).Get<ui32>());
  183. }
  184. }
  185. UNIT_ASSERT_EQUAL(status, NUdf::EFetchStatus::Finish);
  186. auto sortedExpected = expected;
  187. std::sort(result.begin(), result.end());
  188. std::sort(sortedExpected.begin(), sortedExpected.end());
  189. UNIT_ASSERT_EQUAL(result, sortedExpected);
  190. }
  191. }
  192. const std::vector<std::tuple<ui32, i64, ui32>> input1 = {
  193. // Group; Time; Value
  194. {2, 1, 2},
  195. {1, 1, 2},
  196. {2, 2, 3},
  197. {1, 2, 3},
  198. {2, 15, 4},
  199. {1, 15, 4},
  200. {2, 23, 6},
  201. {1, 23, 6},
  202. {2, 24, 5},
  203. {1, 24, 5},
  204. {2, 25, 7},
  205. {1, 25, 7},
  206. {2, 40, 2},
  207. {1, 40, 2},
  208. {2, 47, 1},
  209. {1, 47, 1},
  210. {2, 51, 6},
  211. {1, 51, 6},
  212. {2, 59, 2},
  213. {1, 59, 2},
  214. {2, 85, 8},
  215. {1, 85, 8}
  216. };
  217. const std::vector<std::tuple<ui32, ui32>> expected = {
  218. {1, 8}, {1, 8}, {1, 8}, {1, 8},
  219. {1, 11}, {1, 11}, {1, 21}, {1, 22},
  220. {1, 27},
  221. {2, 8}, {2, 8}, {2, 8}, {2, 8},
  222. {2, 11}, {2, 11}, {2, 21},
  223. {2, 22}, {2, 27}};
  224. Y_UNIT_TEST(Test1) {
  225. TestWithSaveLoadImpl(input1, expected, true, false);
  226. }
  227. Y_UNIT_TEST(Test2) {
  228. TestWithSaveLoadImpl(input1, expected, false, false);
  229. }
  230. }
  231. } // namespace NMiniKQL
  232. } // namespace NKikimr