mkql_match_recognize_nfa.h 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687
  1. #pragma once
  2. #include "mkql_match_recognize_matched_vars.h"
  3. #include "mkql_match_recognize_save_load.h"
  4. #include "../computation/mkql_computation_node_holders.h"
  5. #include "util/generic/overloaded.h"
  6. #include <yql/essentials/core/sql_types/match_recognize.h>
  7. #include <util/generic/hash_table.h>
  8. #include <util/generic/string.h>
  9. namespace NKikimr::NMiniKQL::NMatchRecognize {
  10. using namespace NYql::NMatchRecognize;
  11. struct TVoidTransition {
  12. friend constexpr bool operator==(const TVoidTransition&, const TVoidTransition&) = default;
  13. };
  14. struct TEpsilonTransitions {
  15. std::vector<size_t, TMKQLAllocator<size_t>> To;
  16. friend constexpr bool operator==(const TEpsilonTransitions&, const TEpsilonTransitions&) = default;
  17. };
  18. struct TMatchedVarTransition {
  19. size_t To;
  20. ui32 VarIndex;
  21. bool SaveState;
  22. bool ExcludeFromOutput;
  23. friend constexpr bool operator==(const TMatchedVarTransition&, const TMatchedVarTransition&) = default;
  24. };
  25. struct TQuantityEnterTransition {
  26. size_t To;
  27. friend constexpr bool operator==(const TQuantityEnterTransition&, const TQuantityEnterTransition&) = default;
  28. };
  29. struct TQuantityExitTransition {
  30. ui64 QuantityMin;
  31. ui64 QuantityMax;
  32. size_t ToFindMore;
  33. size_t ToMatched;
  34. friend constexpr bool operator==(const TQuantityExitTransition&, const TQuantityExitTransition&) = default;
  35. };
  36. template <typename... Ts>
  37. struct TVariantHelper {
  38. using TVariant = std::variant<Ts...>;
  39. using TTuple = std::tuple<Ts...>;
  40. static std::variant<Ts...> GetVariantByIndex(size_t i) {
  41. MKQL_ENSURE(i < sizeof...(Ts), "Wrong variant index");
  42. static std::variant<Ts...> table[] = { Ts{ }... };
  43. return table[i];
  44. }
  45. };
  46. using TNfaTransitionHelper = TVariantHelper<
  47. TVoidTransition,
  48. TMatchedVarTransition,
  49. TEpsilonTransitions,
  50. TQuantityEnterTransition,
  51. TQuantityExitTransition
  52. >;
  53. using TNfaTransition = TNfaTransitionHelper::TVariant;
  54. struct TNfaTransitionDestinationVisitor {
  55. std::function<size_t(size_t)> Callback;
  56. template<typename Callback>
  57. explicit TNfaTransitionDestinationVisitor(Callback callback)
  58. : Callback(std::move(callback)) {}
  59. TNfaTransition operator()(TVoidTransition tr) const {
  60. return tr;
  61. }
  62. TNfaTransition operator()(TMatchedVarTransition tr) const {
  63. tr.To = Callback(tr.To);
  64. return tr;
  65. }
  66. TNfaTransition operator()(TEpsilonTransitions tr) const {
  67. for (size_t& toNode: tr.To) {
  68. toNode = Callback(toNode);
  69. }
  70. return tr;
  71. }
  72. TNfaTransition operator()(TQuantityEnterTransition tr) const {
  73. tr.To = Callback(tr.To);
  74. return tr;
  75. }
  76. TNfaTransition operator()(TQuantityExitTransition tr) const {
  77. tr.ToFindMore = Callback(tr.ToFindMore);
  78. tr.ToMatched = Callback(tr.ToMatched);
  79. return tr;
  80. }
  81. };
  82. struct TNfaTransitionGraph {
  83. using TTransitions = std::vector<TNfaTransition, TMKQLAllocator<TNfaTransition>>;
  84. TTransitions Transitions;
  85. size_t Input;
  86. size_t Output;
  87. using TPtr = std::shared_ptr<TNfaTransitionGraph>;
  88. template<class>
  89. inline constexpr static bool always_false_v = false;
  90. void Save(TMrOutputSerializer& serializer) const {
  91. serializer(Transitions.size());
  92. for (ui64 i = 0; i < Transitions.size(); ++i) {
  93. serializer.Write(Transitions[i].index());
  94. std::visit(TOverloaded{
  95. [&](const TVoidTransition&) {},
  96. [&](const TEpsilonTransitions& tr) {
  97. serializer(tr.To);
  98. },
  99. [&](const TMatchedVarTransition& tr) {
  100. serializer(tr.VarIndex, tr.SaveState, tr.To);
  101. },
  102. [&](const TQuantityEnterTransition& tr) {
  103. serializer(tr.To);
  104. },
  105. [&](const TQuantityExitTransition& tr) {
  106. serializer(tr.QuantityMin, tr.QuantityMax, tr.ToFindMore, tr.ToMatched);
  107. },
  108. }, Transitions[i]);
  109. }
  110. serializer(Input, Output);
  111. }
  112. void Load(TMrInputSerializer& serializer) {
  113. ui64 transitionSize = serializer.Read<TTransitions::size_type>();
  114. Transitions.resize(transitionSize);
  115. for (ui64 i = 0; i < transitionSize; ++i) {
  116. size_t index = serializer.Read<std::size_t>();
  117. Transitions[i] = TNfaTransitionHelper::GetVariantByIndex(index);
  118. std::visit(TOverloaded{
  119. [&](TVoidTransition&) {},
  120. [&](TEpsilonTransitions& tr) {
  121. serializer(tr.To);
  122. },
  123. [&](TMatchedVarTransition& tr) {
  124. serializer(tr.VarIndex, tr.SaveState, tr.To);
  125. },
  126. [&](TQuantityEnterTransition& tr) {
  127. serializer(tr.To);
  128. },
  129. [&](TQuantityExitTransition& tr) {
  130. serializer(tr.QuantityMin, tr.QuantityMax, tr.ToFindMore, tr.ToMatched);
  131. },
  132. }, Transitions[i]);
  133. }
  134. serializer(Input, Output);
  135. }
  136. bool operator==(const TNfaTransitionGraph& other) {
  137. return Transitions == other.Transitions
  138. && Input == other.Input
  139. && Output == other.Output;
  140. }
  141. };
  142. class TNfaTransitionGraphOptimizer {
  143. public:
  144. TNfaTransitionGraphOptimizer(TNfaTransitionGraph::TPtr graph)
  145. : Graph(graph) {}
  146. void DoOptimizations() {
  147. EliminateEpsilonChains();
  148. EliminateSingleEpsilons();
  149. CollectGarbage();
  150. }
  151. private:
  152. void EliminateEpsilonChains() {
  153. for (size_t node = 0; node != Graph->Transitions.size(); node++) {
  154. if (auto* ts = std::get_if<TEpsilonTransitions>(&Graph->Transitions[node])) {
  155. // new vector of eps transitions,
  156. // contains refs to all nodes which are reachable from oldNode via eps transitions
  157. TEpsilonTransitions optimizedTs;
  158. auto dfsStack = ts->To;
  159. while (!dfsStack.empty()) {
  160. auto curNode = dfsStack.back();
  161. dfsStack.pop_back();
  162. if (auto* curTs = std::get_if<TEpsilonTransitions>(&Graph->Transitions[curNode])) {
  163. std::copy(curTs->To.begin(), curTs->To.end(), std::back_inserter(dfsStack));
  164. } else {
  165. optimizedTs.To.push_back(curNode);
  166. }
  167. }
  168. *ts = optimizedTs;
  169. }
  170. }
  171. }
  172. void EliminateSingleEpsilons() {
  173. for (size_t node = 0; node != Graph->Transitions.size(); node++) {
  174. if (std::holds_alternative<TEpsilonTransitions>(Graph->Transitions[node])) {
  175. continue;
  176. }
  177. Graph->Transitions[node] = std::visit(TNfaTransitionDestinationVisitor([&](size_t toNode) -> size_t {
  178. if (auto *tr = std::get_if<TEpsilonTransitions>(&Graph->Transitions[toNode])) {
  179. if (tr->To.size() == 1) {
  180. return tr->To[0];
  181. }
  182. }
  183. return toNode;
  184. }), Graph->Transitions[node]);
  185. }
  186. }
  187. void CollectGarbage() {
  188. auto oldInput = Graph->Input;
  189. auto oldOutput = Graph->Output;
  190. decltype(Graph->Transitions) oldTransitions;
  191. Graph->Transitions.swap(oldTransitions);
  192. // Scan for reachable nodes and map old node ids to new node ids
  193. std::vector<std::optional<size_t>> mapping(oldTransitions.size(), std::nullopt);
  194. std::vector<size_t> dfsStack = {oldInput};
  195. mapping[oldInput] = 0;
  196. Graph->Transitions.emplace_back();
  197. while (!dfsStack.empty()) {
  198. auto oldNode = dfsStack.back();
  199. dfsStack.pop_back();
  200. std::visit(TNfaTransitionDestinationVisitor([&](size_t oldToNode) {
  201. if (!mapping[oldToNode]) {
  202. mapping[oldToNode] = Graph->Transitions.size();
  203. Graph->Transitions.emplace_back();
  204. dfsStack.push_back(oldToNode);
  205. }
  206. return 0;
  207. }), oldTransitions[oldNode]);
  208. }
  209. // Rebuild transition vector
  210. for (size_t oldNode = 0; oldNode != oldTransitions.size(); oldNode++) {
  211. if (!mapping[oldNode]) {
  212. continue;
  213. }
  214. auto node = mapping[oldNode].value();
  215. if (oldNode == oldInput) {
  216. Graph->Input = node;
  217. }
  218. if (oldNode == oldOutput) {
  219. Graph->Output = node;
  220. }
  221. Graph->Transitions[node] = oldTransitions[oldNode];
  222. Graph->Transitions[node] = std::visit(TNfaTransitionDestinationVisitor([&](size_t oldToNode) {
  223. return mapping[oldToNode].value();
  224. }), Graph->Transitions[node]);
  225. }
  226. }
  227. TNfaTransitionGraph::TPtr Graph;
  228. };
  229. class TNfaTransitionGraphBuilder {
  230. private:
  231. struct TNfaItem {
  232. size_t Input;
  233. size_t Output;
  234. };
  235. TNfaTransitionGraphBuilder(TNfaTransitionGraph::TPtr graph)
  236. : Graph(graph) {}
  237. size_t AddNode() {
  238. Graph->Transitions.emplace_back();
  239. return Graph->Transitions.size() - 1;
  240. }
  241. TNfaItem BuildTerms(const TVector<TRowPatternTerm>& terms, const THashMap<TString, size_t>& varNameToIndex) {
  242. auto input = AddNode();
  243. auto output = AddNode();
  244. TEpsilonTransitions fromInput;
  245. for (const auto& t: terms) {
  246. auto a = BuildTerm(t, varNameToIndex);
  247. fromInput.To.push_back(a.Input);
  248. Graph->Transitions[a.Output] = TEpsilonTransitions({output});
  249. }
  250. Graph->Transitions[input] = std::move(fromInput);
  251. return {input, output};
  252. }
  253. TNfaItem BuildTerm(const TRowPatternTerm& term, const THashMap<TString, size_t>& varNameToIndex) {
  254. auto input = AddNode();
  255. auto output = AddNode();
  256. std::vector<TNfaItem, TMKQLAllocator<TNfaItem>> automata;
  257. for (const auto& f: term) {
  258. automata.push_back(BuildFactor(f, varNameToIndex));
  259. }
  260. for (size_t i = 0; i != automata.size() - 1; ++i) {
  261. Graph->Transitions[automata[i].Output] = TEpsilonTransitions({automata[i + 1].Input});
  262. }
  263. Graph->Transitions[input] = TEpsilonTransitions({automata.front().Input});
  264. Graph->Transitions[automata.back().Output] = TEpsilonTransitions({output});
  265. return {input, output};
  266. }
  267. TNfaItem BuildFactor(const TRowPatternFactor& factor, const THashMap<TString, size_t>& varNameToIndex) {
  268. auto input = AddNode();
  269. auto output = AddNode();
  270. auto item = factor.Primary.index() == 0 ?
  271. BuildVar(varNameToIndex.at(std::get<0>(factor.Primary)), !factor.Unused, !factor.Output) :
  272. BuildTerms(std::get<1>(factor.Primary), varNameToIndex);
  273. if (1 == factor.QuantityMin && 1 == factor.QuantityMax) { //simple linear case
  274. Graph->Transitions[input] = TEpsilonTransitions{{item.Input}};
  275. Graph->Transitions[item.Output] = TEpsilonTransitions{{output}};
  276. } else {
  277. auto interim = AddNode();
  278. auto fromInput = TEpsilonTransitions{{interim}};
  279. if (factor.QuantityMin == 0) {
  280. fromInput.To.push_back(output);
  281. }
  282. Graph->Transitions[input] = fromInput;
  283. Graph->Transitions[interim] = TQuantityEnterTransition{item.Input};
  284. Graph->Transitions[item.Output] = TQuantityExitTransition{
  285. factor.QuantityMin,
  286. factor.QuantityMax,
  287. item.Input,
  288. output,
  289. };
  290. }
  291. return {input, output};
  292. }
  293. TNfaItem BuildVar(ui32 varIndex, bool isUsed, bool excludeFromOutput) {
  294. auto input = AddNode();
  295. auto matchVar = AddNode();
  296. auto output = AddNode();
  297. Graph->Transitions[input] = TEpsilonTransitions({matchVar});
  298. Graph->Transitions[matchVar] = TMatchedVarTransition{
  299. output,
  300. varIndex,
  301. isUsed,
  302. excludeFromOutput,
  303. };
  304. return {input, output};
  305. }
  306. public:
  307. static TNfaTransitionGraph::TPtr Create(const TRowPattern& pattern, const THashMap<TString, size_t>& varNameToIndex) {
  308. auto result = std::make_shared<TNfaTransitionGraph>();
  309. TNfaTransitionGraphBuilder builder(result);
  310. auto item = builder.BuildTerms(pattern, varNameToIndex);
  311. result->Input = item.Input;
  312. result->Output = item.Output;
  313. TNfaTransitionGraphOptimizer optimizer(result);
  314. optimizer.DoOptimizations();
  315. return result;
  316. }
  317. private:
  318. TNfaTransitionGraph::TPtr Graph;
  319. };
  320. class TNfa {
  321. using TRange = TSparseList::TRange;
  322. using TMatchedVars = TMatchedVars<TRange>;
  323. public:
  324. struct TMatch {
  325. size_t BeginIndex;
  326. size_t EndIndex;
  327. TMatchedVars Vars;
  328. void Save(TMrOutputSerializer& serializer) const {
  329. serializer(BeginIndex, EndIndex);
  330. serializer.Write(Vars.size());
  331. for (const auto& vector : Vars) {
  332. serializer.Write(vector.size());
  333. for (const auto& range : vector) {
  334. range.Save(serializer);
  335. }
  336. }
  337. }
  338. void Load(TMrInputSerializer& serializer) {
  339. serializer(BeginIndex, EndIndex);
  340. auto varsSize = serializer.Read<size_t>();
  341. Vars.clear();
  342. Vars.resize(varsSize);
  343. for (auto& subvec: Vars) {
  344. ui64 vectorSize = serializer.Read<ui64>();
  345. subvec.resize(vectorSize);
  346. for (auto& item : subvec) {
  347. item.Load(serializer);
  348. }
  349. }
  350. }
  351. };
  352. private:
  353. struct TState {
  354. size_t Index;
  355. TMatch Match;
  356. std::deque<ui64, TMKQLAllocator<ui64>> Quantifiers;
  357. void Save(TMrOutputSerializer& serializer) const {
  358. serializer.Write(Index);
  359. Match.Save(serializer);
  360. serializer.Write(Quantifiers.size());
  361. for (ui64 qnt : Quantifiers) {
  362. serializer.Write(qnt);
  363. }
  364. }
  365. void Load(TMrInputSerializer& serializer) {
  366. serializer.Read(Index);
  367. Match.Load(serializer);
  368. Quantifiers.clear();
  369. auto quantifiersSize = serializer.Read<ui64>();
  370. for (size_t i = 0; i < quantifiersSize; ++i) {
  371. ui64 qnt = serializer.Read<ui64>();
  372. Quantifiers.push_back(qnt);
  373. }
  374. }
  375. friend inline bool operator<(const TState& lhs, const TState& rhs) {
  376. auto lhsMatchEndIndex = -static_cast<i64>(lhs.Match.EndIndex);
  377. auto rhsMatchEndIndex = -static_cast<i64>(rhs.Match.EndIndex);
  378. return std::tie(lhs.Match.BeginIndex, lhsMatchEndIndex, lhs.Index, lhs.Match.Vars, lhs.Quantifiers) < std::tie(rhs.Match.BeginIndex, rhsMatchEndIndex, rhs.Index, rhs.Match.Vars, rhs.Quantifiers);
  379. }
  380. friend inline bool operator==(const TState& lhs, const TState& rhs) {
  381. return std::tie(lhs.Match.BeginIndex, lhs.Match.EndIndex, lhs.Index, lhs.Match.Vars, lhs.Quantifiers) == std::tie(rhs.Match.BeginIndex, rhs.Match.EndIndex, rhs.Index, rhs.Match.Vars, rhs.Quantifiers);
  382. }
  383. };
  384. public:
  385. TNfa(
  386. TNfaTransitionGraph::TPtr transitionGraph,
  387. IComputationExternalNode* matchedRangesArg,
  388. const TComputationNodePtrVector& defines,
  389. TAfterMatchSkipTo skipTo)
  390. : TransitionGraph(transitionGraph)
  391. , MatchedRangesArg(matchedRangesArg)
  392. , Defines(defines)
  393. , SkipTo_(skipTo)
  394. {}
  395. void ProcessRow(TSparseList::TRange&& currentRowLock, TComputationContext& ctx) {
  396. TState state(TransitionGraph->Input, TMatch{currentRowLock.From(), currentRowLock.To(), TMatchedVars(Defines.size())}, std::deque<ui64, TMKQLAllocator<ui64>>{});
  397. Insert(std::move(state));
  398. MakeEpsilonTransitions();
  399. TStateSet newStates;
  400. TStateSet deletedStates;
  401. for (const auto& state : ActiveStates) {
  402. //Here we handle only transitions of TMatchedVarTransition type,
  403. //all other transitions are handled in MakeEpsilonTransitions
  404. if (const auto* matchedVarTransition = std::get_if<TMatchedVarTransition>(&TransitionGraph->Transitions[state.Index])) {
  405. MatchedRangesArg->SetValue(ctx, ctx.HolderFactory.Create<TMatchedVarsValue<TRange>>(ctx.HolderFactory, state.Match.Vars));
  406. const auto varIndex = matchedVarTransition->VarIndex;
  407. const auto& v = Defines[varIndex]->GetValue(ctx);
  408. if (v && v.Get<bool>()) {
  409. if (matchedVarTransition->SaveState) {
  410. auto vars = state.Match.Vars; //TODO get rid of this copy
  411. auto& matchedVar = vars[varIndex];
  412. currentRowLock.NfaIndex(state.Index);
  413. Extend(matchedVar, currentRowLock);
  414. newStates.emplace(matchedVarTransition->To, TMatch{state.Match.BeginIndex, currentRowLock.To(), std::move(vars)}, state.Quantifiers);
  415. } else {
  416. newStates.emplace(matchedVarTransition->To, TMatch{state.Match.BeginIndex, currentRowLock.To(), state.Match.Vars}, state.Quantifiers);
  417. }
  418. }
  419. deletedStates.insert(state);
  420. }
  421. }
  422. for (auto& state : deletedStates) {
  423. Erase(std::move(state));
  424. }
  425. for (auto& state : newStates) {
  426. Insert(std::move(state));
  427. }
  428. MakeEpsilonTransitions();
  429. }
  430. bool HasMatched() const {
  431. for (auto& state: ActiveStates) {
  432. if (auto activeStateIter = ActiveStateCounters.find(state.Match.BeginIndex),
  433. finishedStateIter = FinishedStateCounters.find(state.Match.BeginIndex);
  434. ((activeStateIter != ActiveStateCounters.end() &&
  435. finishedStateIter != FinishedStateCounters.end() &&
  436. activeStateIter->second == finishedStateIter->second) ||
  437. EndOfData) &&
  438. state.Index == TransitionGraph->Output) {
  439. return true;
  440. }
  441. }
  442. return false;
  443. }
  444. std::optional<TMatch> GetMatched() {
  445. for (auto& state: ActiveStates) {
  446. if (auto activeStateIter = ActiveStateCounters.find(state.Match.BeginIndex),
  447. finishedStateIter = FinishedStateCounters.find(state.Match.BeginIndex);
  448. ((activeStateIter != ActiveStateCounters.end() &&
  449. finishedStateIter != FinishedStateCounters.end() &&
  450. activeStateIter->second == finishedStateIter->second) ||
  451. EndOfData) &&
  452. state.Index == TransitionGraph->Output) {
  453. auto result = state.Match;
  454. Erase(std::move(state));
  455. return result;
  456. }
  457. }
  458. return std::nullopt;
  459. }
  460. size_t GetActiveStatesCount() const {
  461. return ActiveStates.size();
  462. }
  463. void Save(TMrOutputSerializer& serializer) const {
  464. // TransitionGraph is not saved/loaded, passed in constructor.
  465. serializer.Write(ActiveStates.size());
  466. for (const auto& state : ActiveStates) {
  467. state.Save(serializer);
  468. }
  469. serializer.Write(ActiveStateCounters.size());
  470. for (const auto& counter : ActiveStateCounters) {
  471. serializer(counter);
  472. }
  473. serializer.Write(FinishedStateCounters.size());
  474. for (const auto& counter : FinishedStateCounters) {
  475. serializer(counter);
  476. }
  477. }
  478. void Load(TMrInputSerializer& serializer) {
  479. {
  480. ActiveStates.clear();
  481. auto activeStatesSize = serializer.Read<ui64>();
  482. for (size_t i = 0; i < activeStatesSize; ++i) {
  483. TState state;
  484. state.Load(serializer);
  485. ActiveStates.emplace(state);
  486. }
  487. }
  488. {
  489. ActiveStateCounters.clear();
  490. auto activeStateCountersSize = serializer.Read<ui64>();
  491. for (size_t i = 0; i < activeStateCountersSize; ++i) {
  492. using map_type = decltype(ActiveStateCounters);
  493. auto matchBeginIndex = serializer.Read<map_type::key_type>();
  494. auto counter = serializer.Read<map_type::mapped_type>();
  495. ActiveStateCounters.emplace(matchBeginIndex, counter);
  496. }
  497. }
  498. {
  499. FinishedStateCounters.clear();
  500. auto finishedStateCountersSize = serializer.Read<ui64>();
  501. for (size_t i = 0; i < finishedStateCountersSize; ++i) {
  502. using map_type = decltype(FinishedStateCounters);
  503. auto matchBeginIndex = serializer.Read<map_type::key_type>();
  504. auto counter = serializer.Read<map_type::mapped_type>();
  505. FinishedStateCounters.emplace(matchBeginIndex, counter);
  506. }
  507. }
  508. }
  509. bool ProcessEndOfData(const TComputationContext& /* ctx */) {
  510. EndOfData = true;
  511. return HasMatched();
  512. }
  513. void AfterMatchSkip(const TMatch& match) {
  514. const auto skipToRowIndex = [&]() {
  515. switch (SkipTo_.To) {
  516. case EAfterMatchSkipTo::NextRow:
  517. return match.BeginIndex + 1;
  518. case EAfterMatchSkipTo::PastLastRow:
  519. return match.EndIndex + 1;
  520. case EAfterMatchSkipTo::ToFirst:
  521. MKQL_ENSURE(false, "AFTER MATCH SKIP TO FIRST is not implemented yet");
  522. case EAfterMatchSkipTo::ToLast:
  523. [[fallthrough]];
  524. case EAfterMatchSkipTo::To:
  525. MKQL_ENSURE(false, "AFTER MATCH SKIP TO LAST is not implemented yet");
  526. }
  527. }();
  528. TStateSet deletedStates;
  529. for (const auto& state : ActiveStates) {
  530. if (state.Match.BeginIndex < skipToRowIndex) {
  531. deletedStates.insert(state);
  532. }
  533. }
  534. for (auto& state : deletedStates) {
  535. Erase(std::move(state));
  536. }
  537. }
  538. const TNfaTransitionGraph& GetTransitionGraph() const {
  539. return *TransitionGraph;
  540. }
  541. private:
  542. //TODO (zverevgeny): Consider to change to std::vector for the sake of perf
  543. using TStateSet = std::set<TState, std::less<TState>, TMKQLAllocator<TState>>;
  544. bool MakeEpsilonTransitionsImpl() {
  545. TStateSet newStates;
  546. TStateSet deletedStates;
  547. for (const auto& state: ActiveStates) {
  548. std::visit(TOverloaded {
  549. [&](const TVoidTransition&) {
  550. //Do nothing for void
  551. },
  552. [&](const TMatchedVarTransition&) {
  553. //Transitions of TMatchedVarTransition type are handled in ProcessRow method
  554. },
  555. [&](const TEpsilonTransitions& epsilonTransitions) {
  556. deletedStates.insert(state);
  557. for (const auto& i : epsilonTransitions.To) {
  558. newStates.emplace(i, state.Match, state.Quantifiers);
  559. }
  560. },
  561. [&](const TQuantityEnterTransition& quantityEnterTransition) {
  562. deletedStates.insert(state);
  563. auto quantifiers = state.Quantifiers; //TODO get rid of this copy
  564. quantifiers.push_back(0);
  565. newStates.emplace(quantityEnterTransition.To, state.Match, std::move(quantifiers));
  566. },
  567. [&](const TQuantityExitTransition& quantityExitTransition) {
  568. deletedStates.insert(state);
  569. auto [quantityMin, quantityMax, toFindMore, toMatched] = quantityExitTransition;
  570. if (state.Quantifiers.back() + 1 < quantityMax) {
  571. auto q = state.Quantifiers;
  572. q.back()++;
  573. newStates.emplace(toFindMore, state.Match, std::move(q));
  574. }
  575. if (quantityMin <= state.Quantifiers.back() + 1 && state.Quantifiers.back() + 1 <= quantityMax) {
  576. auto q = state.Quantifiers;
  577. q.pop_back();
  578. newStates.emplace(toMatched, state.Match, std::move(q));
  579. }
  580. },
  581. }, TransitionGraph->Transitions[state.Index]);
  582. }
  583. bool result = newStates != deletedStates;
  584. for (auto& state : deletedStates) {
  585. Erase(std::move(state));
  586. }
  587. for (auto& state : newStates) {
  588. Insert(std::move(state));
  589. }
  590. return result;
  591. }
  592. void MakeEpsilonTransitions() {
  593. while (MakeEpsilonTransitionsImpl());
  594. }
  595. static void Add(THashMap<size_t, i64>& counters, size_t index, i64 value) {
  596. auto countersIter = counters.try_emplace(index, 0).first;
  597. MKQL_ENSURE(countersIter != counters.end(), "Internal logic error");
  598. countersIter->second += value;
  599. if (countersIter->second == 0) {
  600. counters.erase(countersIter);
  601. }
  602. }
  603. void Insert(TState state) {
  604. auto matchBeginIndex = state.Match.BeginIndex;
  605. const auto& transition = TransitionGraph->Transitions[state.Index];
  606. auto diff = static_cast<i64>(ActiveStates.insert(std::move(state)).second);
  607. Add(ActiveStateCounters, matchBeginIndex, diff);
  608. if (std::holds_alternative<TVoidTransition>(transition)) {
  609. Add(FinishedStateCounters, matchBeginIndex, diff);
  610. }
  611. }
  612. void Erase(TState state) {
  613. auto matchBeginIndex = state.Match.BeginIndex;
  614. const auto& transition = TransitionGraph->Transitions[state.Index];
  615. auto diff = -static_cast<i64>(ActiveStates.erase(std::move(state)));
  616. Add(ActiveStateCounters, matchBeginIndex, diff);
  617. if (std::holds_alternative<TVoidTransition>(transition)) {
  618. Add(FinishedStateCounters, matchBeginIndex, diff);
  619. }
  620. }
  621. TNfaTransitionGraph::TPtr TransitionGraph;
  622. IComputationExternalNode* const MatchedRangesArg;
  623. const TComputationNodePtrVector Defines;
  624. TStateSet ActiveStates; //NFA state
  625. THashMap<size_t, i64> ActiveStateCounters;
  626. THashMap<size_t, i64> FinishedStateCounters;
  627. bool EndOfData = false;
  628. TAfterMatchSkipTo SkipTo_;
  629. };
  630. }//namespace NKikimr::NMiniKQL::NMatchRecognize