mkql_match_recognize.cpp 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691
  1. #include "mkql_match_recognize_list.h"
  2. #include "mkql_match_recognize_measure_arg.h"
  3. #include "mkql_match_recognize_matched_vars.h"
  4. #include "mkql_match_recognize_nfa.h"
  5. #include "mkql_match_recognize_rows_formatter.h"
  6. #include "mkql_match_recognize_save_load.h"
  7. #include "mkql_match_recognize_version.h"
  8. #include <yql/essentials/core/sql_types/match_recognize.h>
  9. #include <yql/essentials/minikql/computation/mkql_computation_node_holders.h>
  10. #include <yql/essentials/minikql/computation/mkql_computation_node_holders_codegen.h>
  11. #include <yql/essentials/minikql/computation/mkql_computation_node_impl.h>
  12. #include <yql/essentials/minikql/computation/mkql_computation_node_pack.h>
  13. #include <yql/essentials/minikql/mkql_node.h>
  14. #include <yql/essentials/minikql/mkql_node_cast.h>
  15. #include <yql/essentials/minikql/mkql_string_util.h>
  16. #include <deque>
  17. namespace NKikimr::NMiniKQL {
  18. namespace NMatchRecognize {
  19. struct TMatchRecognizeProcessorParameters {
  20. IComputationExternalNode* InputDataArg;
  21. TRowPattern Pattern;
  22. TUnboxedValueVector VarNames;
  23. THashMap<TString, size_t> VarNamesLookup;
  24. IComputationExternalNode* MatchedVarsArg;
  25. IComputationExternalNode* CurrentRowIndexArg;
  26. TComputationNodePtrVector Defines;
  27. IComputationExternalNode* MeasureInputDataArg;
  28. TMeasureInputColumnOrder MeasureInputColumnOrder;
  29. TAfterMatchSkipTo SkipTo;
  30. };
  31. class TStreamingMatchRecognize {
  32. public:
  33. TStreamingMatchRecognize(
  34. NUdf::TUnboxedValue&& partitionKey,
  35. const TMatchRecognizeProcessorParameters& parameters,
  36. const IRowsFormatter::TState& rowsFormatterState,
  37. TNfaTransitionGraph::TPtr nfaTransitions)
  38. : PartitionKey(std::move(partitionKey))
  39. , Parameters(parameters)
  40. , RowsFormatter_(IRowsFormatter::Create(rowsFormatterState))
  41. , Nfa(nfaTransitions, parameters.MatchedVarsArg, parameters.Defines, parameters.SkipTo)
  42. {}
  43. bool ProcessInputRow(NUdf::TUnboxedValue&& row, TComputationContext& ctx) {
  44. Parameters.InputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TListValue>(Rows));
  45. Parameters.CurrentRowIndexArg->SetValue(ctx, NUdf::TUnboxedValuePod(Rows.LastRowIndex()));
  46. Nfa.ProcessRow(Rows.Append(std::move(row)), ctx);
  47. return HasMatched();
  48. }
  49. bool HasMatched() const {
  50. return Nfa.HasMatched();
  51. }
  52. NUdf::TUnboxedValue GetOutputIfReady(TComputationContext& ctx) {
  53. if (auto result = RowsFormatter_->GetOtherMatchRow(ctx, Rows, PartitionKey, Nfa.GetTransitionGraph())) {
  54. return result;
  55. }
  56. auto match = Nfa.GetMatched();
  57. if (!match) {
  58. return NUdf::TUnboxedValue{};
  59. }
  60. Parameters.MatchedVarsArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TSparseList::TRange>>(ctx.HolderFactory, match->Vars));
  61. Parameters.MeasureInputDataArg->SetValue(ctx, ctx.HolderFactory.Create<TMeasureInputDataValue>(
  62. ctx.HolderFactory.Create<TListValue>(Rows),
  63. Parameters.MeasureInputColumnOrder,
  64. Parameters.MatchedVarsArg->GetValue(ctx),
  65. Parameters.VarNames,
  66. MatchNumber
  67. ));
  68. auto result = RowsFormatter_->GetFirstMatchRow(ctx, Rows, PartitionKey, Nfa.GetTransitionGraph(), *match);
  69. Nfa.AfterMatchSkip(*match);
  70. return result;
  71. }
  72. bool ProcessEndOfData(TComputationContext& ctx) {
  73. return Nfa.ProcessEndOfData(ctx);
  74. }
  75. void Save(TMrOutputSerializer& serializer) const {
  76. // PartitionKey saved in TStateForInterleavedPartitions as key.
  77. Rows.Save(serializer);
  78. Nfa.Save(serializer);
  79. serializer.Write(MatchNumber);
  80. RowsFormatter_->Save(serializer);
  81. }
  82. void Load(TMrInputSerializer& serializer) {
  83. // PartitionKey passed in contructor.
  84. Rows.Load(serializer);
  85. Nfa.Load(serializer);
  86. MatchNumber = serializer.Read<ui64>();
  87. if (serializer.GetStateVersion() >= 2U) {
  88. RowsFormatter_->Load(serializer);
  89. }
  90. }
  91. private:
  92. NUdf::TUnboxedValue PartitionKey;
  93. const TMatchRecognizeProcessorParameters& Parameters;
  94. std::unique_ptr<IRowsFormatter> RowsFormatter_;
  95. TSparseList Rows;
  96. TNfa Nfa;
  97. ui64 MatchNumber = 0;
  98. };
  99. class TStateForNonInterleavedPartitions
  100. : public TComputationValue<TStateForNonInterleavedPartitions>
  101. {
  102. public:
  103. TStateForNonInterleavedPartitions(
  104. TMemoryUsageInfo* memInfo,
  105. IComputationExternalNode* inputRowArg,
  106. IComputationNode* partitionKey,
  107. TType* partitionKeyType,
  108. const TMatchRecognizeProcessorParameters& parameters,
  109. const IRowsFormatter::TState& rowsFormatterState,
  110. TComputationContext &ctx,
  111. TType* rowType,
  112. const TMutableObjectOverBoxedValue<TValuePackerBoxed>& rowPacker
  113. )
  114. : TComputationValue<TStateForNonInterleavedPartitions>(memInfo)
  115. , InputRowArg(inputRowArg)
  116. , PartitionKey(partitionKey)
  117. , PartitionKeyPacker(true, partitionKeyType)
  118. , Parameters(parameters)
  119. , RowsFormatterState(rowsFormatterState)
  120. , RowPatternConfiguration(TNfaTransitionGraphBuilder::Create(parameters.Pattern, parameters.VarNamesLookup))
  121. , Terminating(false)
  122. , SerializerContext(ctx, rowType, rowPacker)
  123. , Ctx(ctx)
  124. {}
  125. NUdf::TUnboxedValue Save() const override {
  126. TMrOutputSerializer out(SerializerContext, EMkqlStateType::SIMPLE_BLOB, StateVersion, Ctx);
  127. out.Write(CurPartitionPackedKey);
  128. bool isValid = static_cast<bool>(PartitionHandler);
  129. out.Write(isValid);
  130. if (isValid) {
  131. PartitionHandler->Save(out);
  132. }
  133. isValid = static_cast<bool>(DelayedRow);
  134. out.Write(isValid);
  135. if (isValid) {
  136. out.Write(DelayedRow);
  137. }
  138. return out.MakeState();
  139. }
  140. bool Load2(const NUdf::TUnboxedValue& state) override {
  141. TMrInputSerializer in(SerializerContext, state);
  142. in.Read(CurPartitionPackedKey);
  143. bool validPartitionHandler = in.Read<bool>();
  144. if (validPartitionHandler) {
  145. NUdf::TUnboxedValue key = PartitionKeyPacker.Unpack(CurPartitionPackedKey, SerializerContext.Ctx.HolderFactory);
  146. PartitionHandler.reset(new TStreamingMatchRecognize(
  147. std::move(key),
  148. Parameters,
  149. RowsFormatterState,
  150. RowPatternConfiguration
  151. ));
  152. PartitionHandler->Load(in);
  153. }
  154. bool validDelayedRow = in.Read<bool>();
  155. if (validDelayedRow) {
  156. in(DelayedRow);
  157. }
  158. if (in.GetStateVersion() < 2U) {
  159. auto restoredRowPatternConfiguration = std::make_shared<TNfaTransitionGraph>();
  160. restoredRowPatternConfiguration->Load(in);
  161. MKQL_ENSURE(*restoredRowPatternConfiguration == *RowPatternConfiguration, "Restored and current RowPatternConfiguration is different");
  162. }
  163. MKQL_ENSURE(in.Empty(), "State is corrupted");
  164. return true;
  165. }
  166. bool HasListItems() const override {
  167. return false;
  168. }
  169. bool ProcessInputRow(NUdf::TUnboxedValue&& row, TComputationContext& ctx) {
  170. MKQL_ENSURE(not DelayedRow, "Internal logic error"); //we're finalizing previous partition
  171. InputRowArg->SetValue(ctx, NUdf::TUnboxedValue(row));
  172. auto partitionKey = PartitionKey->GetValue(ctx);
  173. const auto packedKey = PartitionKeyPacker.Pack(partitionKey);
  174. //TODO switch to tuple compare for comparable types
  175. if (packedKey == CurPartitionPackedKey) { //continue in the same partition
  176. MKQL_ENSURE(PartitionHandler, "Internal logic error");
  177. return PartitionHandler->ProcessInputRow(std::move(row), ctx);
  178. }
  179. //either the first or next partition
  180. DelayedRow = std::move(row);
  181. if (PartitionHandler) {
  182. return PartitionHandler->ProcessEndOfData(ctx);
  183. }
  184. //be aware that the very first partition is created in the same manner as subsequent
  185. return false;
  186. }
  187. bool ProcessEndOfData(TComputationContext& ctx) {
  188. if (Terminating)
  189. return false;
  190. Terminating = true;
  191. if (PartitionHandler) {
  192. return PartitionHandler->ProcessEndOfData(ctx);
  193. }
  194. return false;
  195. }
  196. NUdf::TUnboxedValue GetOutputIfReady(TComputationContext& ctx) {
  197. if (PartitionHandler) {
  198. auto result = PartitionHandler->GetOutputIfReady(ctx);
  199. if (result) {
  200. return result;
  201. }
  202. }
  203. if (DelayedRow) {
  204. //either the first partition or
  205. //we're finalizing a partition and expect no more output from this partition
  206. NUdf::TUnboxedValue temp;
  207. std::swap(temp, DelayedRow);
  208. InputRowArg->SetValue(ctx, NUdf::TUnboxedValue(temp));
  209. auto partitionKey = PartitionKey->GetValue(ctx);
  210. CurPartitionPackedKey = PartitionKeyPacker.Pack(partitionKey);
  211. PartitionHandler.reset(new TStreamingMatchRecognize(
  212. std::move(partitionKey),
  213. Parameters,
  214. RowsFormatterState,
  215. RowPatternConfiguration
  216. ));
  217. PartitionHandler->ProcessInputRow(std::move(temp), ctx);
  218. }
  219. if (Terminating) {
  220. return NUdf::TUnboxedValue::MakeFinish();
  221. }
  222. return NUdf::TUnboxedValue{};
  223. }
  224. private:
  225. TString CurPartitionPackedKey;
  226. std::unique_ptr<TStreamingMatchRecognize> PartitionHandler;
  227. IComputationExternalNode* InputRowArg;
  228. IComputationNode* PartitionKey;
  229. TValuePackerGeneric<false> PartitionKeyPacker;
  230. const TMatchRecognizeProcessorParameters& Parameters;
  231. const IRowsFormatter::TState& RowsFormatterState;
  232. const TNfaTransitionGraph::TPtr RowPatternConfiguration;
  233. NUdf::TUnboxedValue DelayedRow;
  234. bool Terminating;
  235. TSerializerContext SerializerContext;
  236. TComputationContext& Ctx;
  237. };
  238. class TStateForInterleavedPartitions
  239. : public TComputationValue<TStateForInterleavedPartitions>
  240. {
  241. using TPartitionMapValue = std::unique_ptr<TStreamingMatchRecognize>;
  242. using TPartitionMap = std::unordered_map<TString, TPartitionMapValue, std::hash<TString>, std:: equal_to<TString>, TMKQLAllocator<std::pair<const TString, TPartitionMapValue>>>;
  243. public:
  244. TStateForInterleavedPartitions(
  245. TMemoryUsageInfo* memInfo,
  246. IComputationExternalNode* inputRowArg,
  247. IComputationNode* partitionKey,
  248. TType* partitionKeyType,
  249. const TMatchRecognizeProcessorParameters& parameters,
  250. const IRowsFormatter::TState& rowsFormatterState,
  251. TComputationContext &ctx,
  252. TType* rowType,
  253. const TMutableObjectOverBoxedValue<TValuePackerBoxed>& rowPacker
  254. )
  255. : TComputationValue<TStateForInterleavedPartitions>(memInfo)
  256. , InputRowArg(inputRowArg)
  257. , PartitionKey(partitionKey)
  258. , PartitionKeyPacker(true, partitionKeyType)
  259. , Parameters(parameters)
  260. , RowsFormatterState(rowsFormatterState)
  261. , NfaTransitionGraph(TNfaTransitionGraphBuilder::Create(parameters.Pattern, parameters.VarNamesLookup))
  262. , SerializerContext(ctx, rowType, rowPacker)
  263. , Ctx(ctx)
  264. {}
  265. NUdf::TUnboxedValue Save() const override {
  266. TMrOutputSerializer serializer(SerializerContext, EMkqlStateType::SIMPLE_BLOB, StateVersion, Ctx);
  267. serializer.Write(Partitions.size());
  268. for (const auto& [key, state] : Partitions) {
  269. serializer.Write(key);
  270. state->Save(serializer);
  271. }
  272. // HasReadyOutput is not packed because when loading we can recalculate HasReadyOutput from Partitions.
  273. serializer.Write(Terminating);
  274. return serializer.MakeState();
  275. }
  276. bool Load2(const NUdf::TUnboxedValue& state) override {
  277. TMrInputSerializer in(SerializerContext, state);
  278. Partitions.clear();
  279. auto partitionsCount = in.Read<TPartitionMap::size_type>();
  280. Partitions.reserve(partitionsCount);
  281. for (size_t i = 0; i < partitionsCount; ++i) {
  282. auto packedKey = in.Read<TPartitionMap::key_type, std::string_view>();
  283. NUdf::TUnboxedValue key = PartitionKeyPacker.Unpack(packedKey, SerializerContext.Ctx.HolderFactory);
  284. auto pair = Partitions.emplace(
  285. packedKey,
  286. std::make_unique<TStreamingMatchRecognize>(
  287. std::move(key),
  288. Parameters,
  289. RowsFormatterState,
  290. NfaTransitionGraph
  291. )
  292. );
  293. pair.first->second->Load(in);
  294. }
  295. for (auto it = Partitions.begin(); it != Partitions.end(); ++it) {
  296. if (it->second->HasMatched()) {
  297. HasReadyOutput.push(it);
  298. }
  299. }
  300. in.Read(Terminating);
  301. if (in.GetStateVersion() < 2U) {
  302. auto restoredTransitionGraph = std::make_shared<TNfaTransitionGraph>();
  303. restoredTransitionGraph->Load(in);
  304. MKQL_ENSURE(NfaTransitionGraph, "Empty NfaTransitionGraph");
  305. MKQL_ENSURE(*restoredTransitionGraph == *NfaTransitionGraph, "Restored and current NfaTransitionGraph is different");
  306. }
  307. MKQL_ENSURE(in.Empty(), "State is corrupted");
  308. return true;
  309. }
  310. bool HasListItems() const override {
  311. return false;
  312. }
  313. bool ProcessInputRow(NUdf::TUnboxedValue&& row, TComputationContext& ctx) {
  314. auto partition = GetPartitionHandler(row, ctx);
  315. if (partition->second->ProcessInputRow(std::move(row), ctx)) {
  316. HasReadyOutput.push(partition);
  317. }
  318. return !HasReadyOutput.empty();
  319. }
  320. bool ProcessEndOfData(TComputationContext& ctx) {
  321. for (auto it = Partitions.begin(); it != Partitions.end(); ++it) {
  322. auto b = it->second->ProcessEndOfData(ctx);
  323. if (b) {
  324. HasReadyOutput.push(it);
  325. }
  326. }
  327. Terminating = true;
  328. return !HasReadyOutput.empty();
  329. }
  330. NUdf::TUnboxedValue GetOutputIfReady(TComputationContext& ctx) {
  331. while (!HasReadyOutput.empty()) {
  332. auto r = HasReadyOutput.top()->second->GetOutputIfReady(ctx);
  333. if (not r) {
  334. //dried up
  335. HasReadyOutput.pop();
  336. continue;
  337. } else {
  338. return r;
  339. }
  340. }
  341. return Terminating ? NUdf::TUnboxedValue(NUdf::TUnboxedValue::MakeFinish()) : NUdf::TUnboxedValue{};
  342. }
  343. private:
  344. TPartitionMap::iterator GetPartitionHandler(const NUdf::TUnboxedValue& row, TComputationContext &ctx) {
  345. InputRowArg->SetValue(ctx, NUdf::TUnboxedValue(row));
  346. auto partitionKey = PartitionKey->GetValue(ctx);
  347. const auto packedKey = PartitionKeyPacker.Pack(partitionKey);
  348. if (const auto it = Partitions.find(TString(packedKey)); it != Partitions.end()) {
  349. return it;
  350. } else {
  351. return Partitions.emplace_hint(it, TString(packedKey), std::make_unique<TStreamingMatchRecognize>(
  352. std::move(partitionKey),
  353. Parameters,
  354. RowsFormatterState,
  355. NfaTransitionGraph
  356. ));
  357. }
  358. }
  359. private:
  360. TPartitionMap Partitions;
  361. std::stack<TPartitionMap::iterator, std::deque<TPartitionMap::iterator, TMKQLAllocator<TPartitionMap::iterator>>> HasReadyOutput;
  362. bool Terminating = false;
  363. IComputationExternalNode* InputRowArg;
  364. IComputationNode* PartitionKey;
  365. //TODO switch to tuple compare
  366. TValuePackerGeneric<false> PartitionKeyPacker;
  367. const TMatchRecognizeProcessorParameters& Parameters;
  368. const IRowsFormatter::TState& RowsFormatterState;
  369. const TNfaTransitionGraph::TPtr NfaTransitionGraph;
  370. TSerializerContext SerializerContext;
  371. TComputationContext& Ctx;
  372. };
  373. template<class State>
  374. class TMatchRecognizeWrapper : public TStatefulFlowComputationNode<TMatchRecognizeWrapper<State>, true> {
  375. using TBaseComputation = TStatefulFlowComputationNode<TMatchRecognizeWrapper<State>, true>;
  376. public:
  377. TMatchRecognizeWrapper(
  378. TComputationMutables& mutables,
  379. EValueRepresentation kind,
  380. IComputationNode *inputFlow,
  381. IComputationExternalNode *inputRowArg,
  382. IComputationNode *partitionKey,
  383. TType* partitionKeyType,
  384. TMatchRecognizeProcessorParameters&& parameters,
  385. IRowsFormatter::TState&& rowsFormatterState,
  386. TType* rowType)
  387. : TBaseComputation(mutables, inputFlow, kind, EValueRepresentation::Embedded)
  388. , InputFlow(inputFlow)
  389. , InputRowArg(inputRowArg)
  390. , PartitionKey(partitionKey)
  391. , PartitionKeyType(partitionKeyType)
  392. , Parameters(std::move(parameters))
  393. , RowsFormatterState(std::move(rowsFormatterState))
  394. , RowType(rowType)
  395. , RowPacker(mutables)
  396. {}
  397. NUdf::TUnboxedValue DoCalculate(NUdf::TUnboxedValue &stateValue, TComputationContext &ctx) const {
  398. if (stateValue.IsInvalid()) {
  399. stateValue = ctx.HolderFactory.Create<State>(
  400. InputRowArg,
  401. PartitionKey,
  402. PartitionKeyType,
  403. Parameters,
  404. RowsFormatterState,
  405. ctx,
  406. RowType,
  407. RowPacker
  408. );
  409. } else if (stateValue.HasValue()) {
  410. MKQL_ENSURE(stateValue.IsBoxed(), "Expected boxed value");
  411. bool isStateToLoad = stateValue.HasListItems();
  412. if (isStateToLoad) {
  413. // Load from saved state.
  414. NUdf::TUnboxedValue state = ctx.HolderFactory.Create<State>(
  415. InputRowArg,
  416. PartitionKey,
  417. PartitionKeyType,
  418. Parameters,
  419. RowsFormatterState,
  420. ctx,
  421. RowType,
  422. RowPacker
  423. );
  424. state.Load2(stateValue);
  425. stateValue = state;
  426. }
  427. }
  428. auto state = static_cast<State*>(stateValue.AsBoxed().Get());
  429. while (true) {
  430. if (auto output = state->GetOutputIfReady(ctx); output) {
  431. return output;
  432. }
  433. auto item = InputFlow->GetValue(ctx);
  434. if (item.IsFinish()) {
  435. state->ProcessEndOfData(ctx);
  436. continue;
  437. } else if (item.IsSpecial()) {
  438. return item;
  439. }
  440. state->ProcessInputRow(std::move(item), ctx);
  441. }
  442. }
  443. private:
  444. using TBaseComputation::Own;
  445. using TBaseComputation::DependsOn;
  446. void RegisterDependencies() const final {
  447. if (const auto flow = TBaseComputation::FlowDependsOn(InputFlow)) {
  448. Own(flow, InputRowArg);
  449. Own(flow, Parameters.InputDataArg);
  450. Own(flow, Parameters.MatchedVarsArg);
  451. Own(flow, Parameters.CurrentRowIndexArg);
  452. Own(flow, Parameters.MeasureInputDataArg);
  453. DependsOn(flow, PartitionKey);
  454. for (auto& m: RowsFormatterState.Measures) {
  455. DependsOn(flow, m);
  456. }
  457. for (auto& d: Parameters.Defines) {
  458. DependsOn(flow, d);
  459. }
  460. }
  461. }
  462. IComputationNode* const InputFlow;
  463. IComputationExternalNode* const InputRowArg;
  464. IComputationNode* const PartitionKey;
  465. TType* const PartitionKeyType;
  466. TMatchRecognizeProcessorParameters Parameters;
  467. IRowsFormatter::TState RowsFormatterState;
  468. TType* const RowType;
  469. TMutableObjectOverBoxedValue<TValuePackerBoxed> RowPacker;
  470. };
  471. TOutputColumnOrder GetOutputColumnOrder(TRuntimeNode partitionKyeColumnsIndexes, TRuntimeNode measureColumnsIndexes) {
  472. std::unordered_map<size_t, TOutputColumnEntry, std::hash<size_t>, std::equal_to<size_t>, TMKQLAllocator<std::pair<const size_t, TOutputColumnEntry>, EMemorySubPool::Temporary>> temp;
  473. {
  474. auto list = AS_VALUE(TListLiteral, partitionKyeColumnsIndexes);
  475. for (ui32 i = 0; i != list->GetItemsCount(); ++i) {
  476. auto index = AS_VALUE(TDataLiteral, list->GetItems()[i])->AsValue().Get<ui32>();
  477. temp[index] = {i, EOutputColumnSource::PartitionKey};
  478. }
  479. }
  480. {
  481. auto list = AS_VALUE(TListLiteral, measureColumnsIndexes);
  482. for (ui32 i = 0; i != list->GetItemsCount(); ++i) {
  483. auto index = AS_VALUE(TDataLiteral, list->GetItems()[i])->AsValue().Get<ui32>();
  484. temp[index] = {i, EOutputColumnSource::Measure};
  485. }
  486. }
  487. if (temp.empty())
  488. return {};
  489. auto outputSize = std::ranges::max_element(temp, {}, &std::pair<const size_t, TOutputColumnEntry>::first)->first + 1;
  490. TOutputColumnOrder result(outputSize);
  491. for (const auto& [i, v]: temp) {
  492. result[i] = v;
  493. }
  494. return result;
  495. }
  496. TRowPattern ConvertPattern(const TRuntimeNode& pattern) {
  497. TVector<TRowPatternTerm> result;
  498. const auto& inputPattern = AS_VALUE(TTupleLiteral, pattern);
  499. for (ui32 i = 0; i != inputPattern->GetValuesCount(); ++i) {
  500. const auto& inputTerm = AS_VALUE(TTupleLiteral, inputPattern->GetValue(i));
  501. TVector<TRowPatternFactor> term;
  502. for (ui32 j = 0; j != inputTerm->GetValuesCount(); ++j) {
  503. const auto& inputFactor = AS_VALUE(TTupleLiteral, inputTerm->GetValue(j));
  504. MKQL_ENSURE(inputFactor->GetValuesCount() == 6, "Internal logic error");
  505. const auto& primary = inputFactor->GetValue(0);
  506. term.push_back(TRowPatternFactor{
  507. primary.GetRuntimeType()->IsData() ?
  508. TRowPatternPrimary(TString(AS_VALUE(TDataLiteral, primary)->AsValue().AsStringRef())) :
  509. ConvertPattern(primary),
  510. AS_VALUE(TDataLiteral, inputFactor->GetValue(1))->AsValue().Get<ui64>(),
  511. AS_VALUE(TDataLiteral, inputFactor->GetValue(2))->AsValue().Get<ui64>(),
  512. AS_VALUE(TDataLiteral, inputFactor->GetValue(3))->AsValue().Get<bool>(),
  513. AS_VALUE(TDataLiteral, inputFactor->GetValue(4))->AsValue().Get<bool>(),
  514. AS_VALUE(TDataLiteral, inputFactor->GetValue(5))->AsValue().Get<bool>()
  515. });
  516. }
  517. result.push_back(std::move(term));
  518. }
  519. return result;
  520. }
  521. TMeasureInputColumnOrder GetMeasureColumnOrder(const TListLiteral& specialColumnIndexes, ui32 inputRowColumnCount) {
  522. //Use Last enum value to denote that c colum comes from the input table
  523. TMeasureInputColumnOrder result(inputRowColumnCount + specialColumnIndexes.GetItemsCount(), std::make_pair(EMeasureInputDataSpecialColumns::Last, 0));
  524. if (specialColumnIndexes.GetItemsCount() != 0) {
  525. MKQL_ENSURE(specialColumnIndexes.GetItemsCount() == static_cast<size_t>(EMeasureInputDataSpecialColumns::Last),
  526. "Internal logic error");
  527. for (size_t i = 0; i != specialColumnIndexes.GetItemsCount(); ++i) {
  528. auto ind = AS_VALUE(TDataLiteral, specialColumnIndexes.GetItems()[i])->AsValue().Get<ui32>();
  529. result[ind] = std::make_pair(static_cast<EMeasureInputDataSpecialColumns>(i), 0);
  530. }
  531. }
  532. //update indexes for input table columns
  533. ui32 inputIdx = 0;
  534. for (auto& [t, i]: result) {
  535. if (EMeasureInputDataSpecialColumns::Last == t) {
  536. i = inputIdx++;
  537. }
  538. }
  539. return result;
  540. }
  541. TComputationNodePtrVector ConvertVectorOfCallables(const TRuntimeNode::TList& v, const TComputationNodeFactoryContext& ctx) {
  542. TComputationNodePtrVector result;
  543. result.reserve(v.size());
  544. for (auto& c: v) {
  545. result.push_back(LocateNode(ctx.NodeLocator, *c.GetNode()));
  546. }
  547. return result;
  548. }
  549. std::pair<TUnboxedValueVector, THashMap<TString, size_t>> ConvertListOfStrings(const TRuntimeNode& l) {
  550. TUnboxedValueVector vec;
  551. THashMap<TString, size_t> lookup;
  552. const auto& list = AS_VALUE(TListLiteral, l);
  553. vec.reserve(list->GetItemsCount());
  554. for (ui32 i = 0; i != list->GetItemsCount(); ++i) {
  555. const auto& varName = AS_VALUE(TDataLiteral, list->GetItems()[i])->AsValue().AsStringRef();
  556. vec.push_back(MakeString(varName));
  557. lookup[TString(varName)] = i;
  558. }
  559. return {vec, lookup};
  560. }
  561. } //namespace NMatchRecognize
  562. IComputationNode* WrapMatchRecognizeCore(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  563. using namespace NMatchRecognize;
  564. size_t inputIndex = 0;
  565. const auto& inputFlow = callable.GetInput(inputIndex++);
  566. const auto& inputRowArg = callable.GetInput(inputIndex++);
  567. const auto& partitionKeySelector = callable.GetInput(inputIndex++);
  568. const auto& partitionColumnIndexes = callable.GetInput(inputIndex++);
  569. const auto& measureInputDataArg = callable.GetInput(inputIndex++);
  570. const auto& measureSpecialColumnIndexes = callable.GetInput(inputIndex++);
  571. const auto& inputRowColumnCount = callable.GetInput(inputIndex++);
  572. const auto& matchedVarsArg = callable.GetInput(inputIndex++);
  573. const auto& measureColumnIndexes = callable.GetInput(inputIndex++);
  574. TRuntimeNode::TList measures;
  575. for (size_t i = 0; i != AS_VALUE(TListLiteral, measureColumnIndexes)->GetItemsCount(); ++i) {
  576. measures.push_back(callable.GetInput(inputIndex++));
  577. }
  578. const auto& pattern = callable.GetInput(inputIndex++);
  579. const auto& currentRowIndexArg = callable.GetInput(inputIndex++);
  580. const auto& inputDataArg = callable.GetInput(inputIndex++);
  581. const auto& defineNames = callable.GetInput(inputIndex++);
  582. TRuntimeNode::TList defines;
  583. for (size_t i = 0; i != AS_VALUE(TListLiteral, defineNames)->GetItemsCount(); ++i) {
  584. defines.push_back(callable.GetInput(inputIndex++));
  585. }
  586. const auto& streamingMode = callable.GetInput(inputIndex++);
  587. NYql::NMatchRecognize::TAfterMatchSkipTo skipTo = {NYql::NMatchRecognize::EAfterMatchSkipTo::NextRow, ""};
  588. if (inputIndex + 2 <= callable.GetInputsCount()) {
  589. skipTo.To = static_cast<EAfterMatchSkipTo>(AS_VALUE(TDataLiteral, callable.GetInput(inputIndex++))->AsValue().Get<i32>());
  590. skipTo.Var = AS_VALUE(TDataLiteral, callable.GetInput(inputIndex++))->AsValue().AsStringRef();
  591. }
  592. NYql::NMatchRecognize::ERowsPerMatch rowsPerMatch = NYql::NMatchRecognize::ERowsPerMatch::OneRow;
  593. TOutputColumnOrder outputColumnOrder;
  594. if (inputIndex + 2 <= callable.GetInputsCount()) {
  595. rowsPerMatch = static_cast<ERowsPerMatch>(AS_VALUE(TDataLiteral, callable.GetInput(inputIndex++))->AsValue().Get<i32>());
  596. outputColumnOrder = IRowsFormatter::GetOutputColumnOrder(callable.GetInput(inputIndex++));
  597. } else {
  598. outputColumnOrder = GetOutputColumnOrder(partitionColumnIndexes, measureColumnIndexes);
  599. }
  600. MKQL_ENSURE(callable.GetInputsCount() == inputIndex, "Wrong input count");
  601. const auto& [varNames, varNamesLookup] = ConvertListOfStrings(defineNames);
  602. auto* rowType = AS_TYPE(TStructType, AS_TYPE(TFlowType, inputFlow.GetStaticType())->GetItemType());
  603. auto parameters = TMatchRecognizeProcessorParameters {
  604. static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputDataArg.GetNode())),
  605. ConvertPattern(pattern),
  606. varNames,
  607. varNamesLookup,
  608. static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *matchedVarsArg.GetNode())),
  609. static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *currentRowIndexArg.GetNode())),
  610. ConvertVectorOfCallables(defines, ctx),
  611. static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *measureInputDataArg.GetNode())),
  612. GetMeasureColumnOrder(
  613. *AS_VALUE(TListLiteral, measureSpecialColumnIndexes),
  614. AS_VALUE(TDataLiteral, inputRowColumnCount)->AsValue().Get<ui32>()
  615. ),
  616. skipTo
  617. };
  618. IRowsFormatter::TState rowsFormatterState(ctx, outputColumnOrder, ConvertVectorOfCallables(measures, ctx), rowsPerMatch);
  619. if (AS_VALUE(TDataLiteral, streamingMode)->AsValue().Get<bool>()) {
  620. return new TMatchRecognizeWrapper<TStateForInterleavedPartitions>(
  621. ctx.Mutables,
  622. GetValueRepresentation(inputFlow.GetStaticType()),
  623. LocateNode(ctx.NodeLocator, *inputFlow.GetNode()),
  624. static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode())),
  625. LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode()),
  626. partitionKeySelector.GetStaticType(),
  627. std::move(parameters),
  628. std::move(rowsFormatterState),
  629. rowType
  630. );
  631. } else {
  632. return new TMatchRecognizeWrapper<TStateForNonInterleavedPartitions>(
  633. ctx.Mutables,
  634. GetValueRepresentation(inputFlow.GetStaticType()),
  635. LocateNode(ctx.NodeLocator, *inputFlow.GetNode()),
  636. static_cast<IComputationExternalNode*>(LocateNode(ctx.NodeLocator, *inputRowArg.GetNode())),
  637. LocateNode(ctx.NodeLocator, *partitionKeySelector.GetNode()),
  638. partitionKeySelector.GetStaticType(),
  639. std::move(parameters),
  640. std::move(rowsFormatterState),
  641. rowType
  642. );
  643. }
  644. }
  645. } //namespace NKikimr::NMiniKQL