sql_match_recognize.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. #include "sql_match_recognize.h"
  2. #include "node.h"
  3. #include "sql_expression.h"
  4. #include <yql/essentials/core/sql_types/match_recognize.h>
  5. #include <algorithm>
  6. namespace NSQLTranslationV1 {
  7. using namespace NSQLv1Generated;
  8. namespace {
  9. TPosition TokenPosition(const TToken& token){
  10. return TPosition{token.GetColumn(), token.GetLine()};
  11. }
  12. TString PatternVar(const TRule_row_pattern_variable_name& node, TSqlMatchRecognizeClause& ctx){
  13. return Id(node.GetRule_identifier1(), ctx);
  14. }
  15. } //namespace
  16. TMatchRecognizeBuilderPtr TSqlMatchRecognizeClause::CreateBuilder(const NSQLv1Generated::TRule_row_pattern_recognition_clause &matchRecognizeClause) {
  17. TPosition pos(matchRecognizeClause.GetToken1().GetColumn(), matchRecognizeClause.GetToken1().GetLine());
  18. if (!Ctx.FeatureR010) {
  19. Ctx.Error(pos, TIssuesIds::CORE) << "Unexpected MATCH_RECOGNIZE";
  20. return {};
  21. }
  22. TVector<TNamedFunction> partitioners;
  23. TPosition partitionsPos = pos;
  24. if (matchRecognizeClause.HasBlock3()) {
  25. const auto& partitionClause = matchRecognizeClause.GetBlock3().GetRule_window_partition_clause1();
  26. partitionsPos = TokenPosition(partitionClause.GetToken1());
  27. partitioners = ParsePartitionBy(partitionClause);
  28. if (!partitioners)
  29. return {};
  30. }
  31. TVector<TSortSpecificationPtr> sortSpecs;
  32. TPosition orderByPos = pos;
  33. if (matchRecognizeClause.HasBlock4()) {
  34. const auto& orderByClause = matchRecognizeClause.GetBlock4().GetRule_order_by_clause1();
  35. orderByPos = TokenPosition(orderByClause.GetToken1());
  36. if (!OrderByClause(orderByClause, sortSpecs)) {
  37. return {};
  38. }
  39. }
  40. TPosition measuresPos = pos;
  41. TVector<TNamedFunction> measures;
  42. if (matchRecognizeClause.HasBlock5()) {
  43. const auto& measuresClause = matchRecognizeClause.GetBlock5().GetRule_row_pattern_measures1();
  44. measuresPos = TokenPosition(measuresClause.GetToken1());
  45. measures = ParseMeasures(measuresClause.GetRule_row_pattern_measure_list2());
  46. }
  47. TPosition rowsPerMatchPos = pos;
  48. ERowsPerMatch rowsPerMatch = ERowsPerMatch::OneRow;
  49. if (matchRecognizeClause.HasBlock6()) {
  50. std::tie(rowsPerMatchPos, rowsPerMatch) = ParseRowsPerMatch(matchRecognizeClause.GetBlock6().GetRule_row_pattern_rows_per_match1());
  51. if (ERowsPerMatch::AllRows == rowsPerMatch) {
  52. //https://st.yandex-team.ru/YQL-16213
  53. Ctx.Error(pos, TIssuesIds::CORE) << "ALL ROWS PER MATCH is not supported yet";
  54. return {};
  55. }
  56. }
  57. const auto& commonSyntax = matchRecognizeClause.GetRule_row_pattern_common_syntax7();
  58. if (commonSyntax.HasBlock2()) {
  59. const auto& initialOrSeek = commonSyntax.GetBlock2().GetRule_row_pattern_initial_or_seek1();
  60. Ctx.Error(TokenPosition(initialOrSeek.GetToken1())) << "InitialOrSeek subclause is not allowed in FROM clause";
  61. return {};
  62. }
  63. auto pattern = ParsePattern(commonSyntax.GetRule_row_pattern5());
  64. const auto& patternPos = TokenPosition(commonSyntax.token3());
  65. //this block is located before pattern block in grammar,
  66. // but depends on it, so it is processed after pattern block
  67. std::pair<TPosition, NYql::NMatchRecognize::TAfterMatchSkipTo> skipTo {
  68. pos,
  69. NYql::NMatchRecognize::TAfterMatchSkipTo{
  70. NYql::NMatchRecognize::EAfterMatchSkipTo::PastLastRow,
  71. TString()
  72. }
  73. };
  74. if (commonSyntax.HasBlock1()){
  75. skipTo = ParseAfterMatchSkipTo(commonSyntax.GetBlock1().GetRule_row_pattern_skip_to3());
  76. const auto varRequired =
  77. NYql::NMatchRecognize::EAfterMatchSkipTo::ToFirst == skipTo.second.To ||
  78. NYql::NMatchRecognize::EAfterMatchSkipTo::ToLast == skipTo.second.To ||
  79. NYql::NMatchRecognize::EAfterMatchSkipTo::To == skipTo.second.To;
  80. if (varRequired) {
  81. const auto& allVars = NYql::NMatchRecognize::GetPatternVars(pattern);
  82. if (allVars.find(skipTo.second.Var) == allVars.cend()) {
  83. Ctx.Error(skipTo.first) << "Unknown pattern variable in AFTER MATCH";
  84. return {};
  85. }
  86. }
  87. }
  88. TNodePtr subset;
  89. TPosition subsetPos = pos;
  90. if (commonSyntax.HasBlock7()) {
  91. const auto& rowPatternSubset = commonSyntax.GetBlock7().GetRule_row_pattern_subset_clause1();
  92. subsetPos = TokenPosition(rowPatternSubset.GetToken1());
  93. Ctx.Error() << "SUBSET is not implemented yet";
  94. //TODO https://st.yandex-team.ru/YQL-16225
  95. return {};
  96. }
  97. const auto& definitions = ParseDefinitions(commonSyntax.GetRule_row_pattern_definition_list9());
  98. const auto& definitionsPos = TokenPosition(commonSyntax.GetToken8());
  99. const auto& rowPatternVariables = GetPatternVars(pattern);
  100. for (const auto& [callable, name]: definitions) {
  101. if (!rowPatternVariables.contains(name)) {
  102. Ctx.Error(callable->GetPos()) << "ROW PATTERN VARIABLE " << name << " is defined, but not mentioned in the PATTERN";
  103. return {};
  104. }
  105. }
  106. return new TMatchRecognizeBuilder{
  107. pos,
  108. std::pair{partitionsPos, std::move(partitioners)},
  109. std::pair{orderByPos, std::move(sortSpecs)},
  110. std::pair{measuresPos, measures},
  111. std::pair{rowsPerMatchPos, rowsPerMatch},
  112. std::move(skipTo),
  113. std::pair{patternPos, std::move(pattern)},
  114. std::pair{subsetPos, std::move(subset)},
  115. std::pair{definitionsPos, std::move(definitions)}
  116. };
  117. }
  118. TVector<TNamedFunction> TSqlMatchRecognizeClause::ParsePartitionBy(const TRule_window_partition_clause& partitionClause) {
  119. TColumnRefScope scope(Ctx, EColumnRefState::Allow);
  120. TVector<TNodePtr> partitionExprs;
  121. if (!NamedExprList(
  122. partitionClause.GetRule_named_expr_list4(),
  123. partitionExprs)) {
  124. return {};
  125. }
  126. TVector<TNamedFunction> partitioners;
  127. for (const auto& p: partitionExprs) {
  128. auto label = p->GetLabel();
  129. if (!label && p->GetColumnName()) {
  130. label = *p->GetColumnName();
  131. }
  132. partitioners.push_back(TNamedFunction{p, label});
  133. }
  134. return partitioners;
  135. }
  136. TNamedFunction TSqlMatchRecognizeClause::ParseOneMeasure(const TRule_row_pattern_measure_definition& node) {
  137. TColumnRefScope scope(Ctx, EColumnRefState::MatchRecognize);
  138. const auto& expr = TSqlExpression(Ctx, Mode).Build(node.GetRule_expr1());
  139. const auto& name = Id(node.GetRule_an_id3(), *this);
  140. //TODO https://st.yandex-team.ru/YQL-16186
  141. //Each measure must be a lambda, that accepts 2 args:
  142. // - List<InputTableColumns + _yql_Classifier, _yql_MatchNumber>
  143. // - Struct that maps row pattern variables to ranges in the queue
  144. return {expr, name};
  145. }
  146. TVector<TNamedFunction> TSqlMatchRecognizeClause::ParseMeasures(const TRule_row_pattern_measure_list& node) {
  147. TVector<TNamedFunction> result{ ParseOneMeasure(node.GetRule_row_pattern_measure_definition1()) };
  148. for (const auto& m: node.GetBlock2()) {
  149. result.push_back(ParseOneMeasure(m.GetRule_row_pattern_measure_definition2()));
  150. }
  151. return result;
  152. }
  153. std::pair<TPosition, ERowsPerMatch> TSqlMatchRecognizeClause::ParseRowsPerMatch(const TRule_row_pattern_rows_per_match& rowsPerMatchClause) {
  154. switch(rowsPerMatchClause.GetAltCase()) {
  155. case TRule_row_pattern_rows_per_match::kAltRowPatternRowsPerMatch1:
  156. return std::pair {
  157. TokenPosition(rowsPerMatchClause.GetAlt_row_pattern_rows_per_match1().GetToken1()),
  158. ERowsPerMatch::OneRow
  159. };
  160. case TRule_row_pattern_rows_per_match::kAltRowPatternRowsPerMatch2:
  161. return std::pair {
  162. TokenPosition(rowsPerMatchClause.GetAlt_row_pattern_rows_per_match2().GetToken1()),
  163. ERowsPerMatch::AllRows
  164. };
  165. case TRule_row_pattern_rows_per_match::ALT_NOT_SET:
  166. Y_ABORT("You should change implementation according to grammar changes");
  167. }
  168. }
  169. std::pair<TPosition, NYql::NMatchRecognize::TAfterMatchSkipTo> TSqlMatchRecognizeClause::ParseAfterMatchSkipTo(const TRule_row_pattern_skip_to& skipToClause) {
  170. switch (skipToClause.GetAltCase()) {
  171. case TRule_row_pattern_skip_to::kAltRowPatternSkipTo1:
  172. return std::pair{
  173. TokenPosition(skipToClause.GetAlt_row_pattern_skip_to1().GetToken1()),
  174. NYql::NMatchRecognize::TAfterMatchSkipTo{NYql::NMatchRecognize::EAfterMatchSkipTo::NextRow, ""}
  175. };
  176. case TRule_row_pattern_skip_to::kAltRowPatternSkipTo2:
  177. return std::pair{
  178. TokenPosition(skipToClause.GetAlt_row_pattern_skip_to2().GetToken1()),
  179. NYql::NMatchRecognize::TAfterMatchSkipTo{NYql::NMatchRecognize::EAfterMatchSkipTo::PastLastRow, ""}
  180. };
  181. case TRule_row_pattern_skip_to::kAltRowPatternSkipTo3:
  182. return std::pair{
  183. TokenPosition(skipToClause.GetAlt_row_pattern_skip_to3().GetToken1()),
  184. NYql::NMatchRecognize::TAfterMatchSkipTo{
  185. NYql::NMatchRecognize::EAfterMatchSkipTo::ToFirst,
  186. skipToClause.GetAlt_row_pattern_skip_to3().GetRule_row_pattern_skip_to_variable_name4().GetRule_row_pattern_variable_name1().GetRule_identifier1().GetToken1().GetValue()
  187. }
  188. };
  189. case TRule_row_pattern_skip_to::kAltRowPatternSkipTo4:
  190. return std::pair{
  191. TokenPosition(skipToClause.GetAlt_row_pattern_skip_to4().GetToken1()),
  192. NYql::NMatchRecognize::TAfterMatchSkipTo{
  193. NYql::NMatchRecognize::EAfterMatchSkipTo::ToLast,
  194. skipToClause.GetAlt_row_pattern_skip_to4().GetRule_row_pattern_skip_to_variable_name4().GetRule_row_pattern_variable_name1().GetRule_identifier1().GetToken1().GetValue()
  195. }
  196. };
  197. case TRule_row_pattern_skip_to::kAltRowPatternSkipTo5:
  198. return std::pair{
  199. TokenPosition(skipToClause.GetAlt_row_pattern_skip_to5().GetToken1()),
  200. NYql::NMatchRecognize::TAfterMatchSkipTo{
  201. NYql::NMatchRecognize::EAfterMatchSkipTo::To,
  202. skipToClause.GetAlt_row_pattern_skip_to5().GetRule_row_pattern_skip_to_variable_name3().GetRule_row_pattern_variable_name1().GetRule_identifier1().GetToken1().GetValue()
  203. }
  204. };
  205. case TRule_row_pattern_skip_to::ALT_NOT_SET:
  206. Y_ABORT("You should change implementation according to grammar changes");
  207. }
  208. }
  209. NYql::NMatchRecognize::TRowPatternTerm TSqlMatchRecognizeClause::ParsePatternTerm(const TRule_row_pattern_term& node){
  210. NYql::NMatchRecognize::TRowPatternTerm term;
  211. TPosition pos;
  212. for (const auto& factor: node.GetBlock1()) {
  213. const auto& primaryVar = factor.GetRule_row_pattern_factor1().GetRule_row_pattern_primary1();
  214. NYql::NMatchRecognize::TRowPatternPrimary primary;
  215. bool output = true;
  216. switch (primaryVar.GetAltCase()) {
  217. case TRule_row_pattern_primary::kAltRowPatternPrimary1:
  218. primary = PatternVar(primaryVar.GetAlt_row_pattern_primary1().GetRule_row_pattern_primary_variable_name1().GetRule_row_pattern_variable_name1(), *this);
  219. break;
  220. case TRule_row_pattern_primary::kAltRowPatternPrimary2:
  221. primary = primaryVar.GetAlt_row_pattern_primary2().GetToken1().GetValue();
  222. Y_ENSURE("$" == std::get<0>(primary));
  223. break;
  224. case TRule_row_pattern_primary::kAltRowPatternPrimary3:
  225. primary = primaryVar.GetAlt_row_pattern_primary3().GetToken1().GetValue();
  226. Y_ENSURE("^" == std::get<0>(primary));
  227. break;
  228. case TRule_row_pattern_primary::kAltRowPatternPrimary4: {
  229. if (++PatternNestingLevel <= NYql::NMatchRecognize::MaxPatternNesting) {
  230. primary = ParsePattern(primaryVar.GetAlt_row_pattern_primary4().GetBlock2().GetRule_row_pattern1());
  231. --PatternNestingLevel;
  232. } else {
  233. Ctx.Error(TokenPosition(primaryVar.GetAlt_row_pattern_primary4().GetToken1()))
  234. << "To big nesting level in the pattern";
  235. return NYql::NMatchRecognize::TRowPatternTerm{};
  236. }
  237. break;
  238. }
  239. case TRule_row_pattern_primary::kAltRowPatternPrimary5:
  240. output = false;
  241. Ctx.Error(TokenPosition(primaryVar.GetAlt_row_pattern_primary4().GetToken1()))
  242. << "ALL ROWS PER MATCH and {- -} are not supported yet"; //https://st.yandex-team.ru/YQL-16227
  243. break;
  244. case TRule_row_pattern_primary::kAltRowPatternPrimary6: {
  245. std::vector<NYql::NMatchRecognize::TRowPatternPrimary> items{ParsePattern(
  246. primaryVar.GetAlt_row_pattern_primary6().GetRule_row_pattern_permute1().GetRule_row_pattern3())
  247. };
  248. for (const auto& p: primaryVar.GetAlt_row_pattern_primary6().GetRule_row_pattern_permute1().GetBlock4()) {
  249. items.push_back(ParsePattern(p.GetRule_row_pattern2()));
  250. }
  251. //Permutations now is a syntactic sugar and converted to all possible alternatives
  252. if (items.size() > NYql::NMatchRecognize::MaxPermutedItems) {
  253. Ctx.Error(TokenPosition(primaryVar.GetAlt_row_pattern_primary4().GetToken1()))
  254. << "Too many items in permute";
  255. return NYql::NMatchRecognize::TRowPatternTerm{};
  256. }
  257. std::vector<size_t> indexes(items.size());
  258. std::generate(begin(indexes), end(indexes), [n = 0] () mutable { return n++; });
  259. NYql::NMatchRecognize::TRowPattern permuted;
  260. do {
  261. NYql::NMatchRecognize::TRowPatternTerm term;
  262. term.reserve(indexes.size());
  263. for (size_t i = 0; i != indexes.size(); ++i) {
  264. term.push_back({items[indexes[i]], 1, 1, true, false, false});
  265. }
  266. permuted.push_back(std::move(term));
  267. } while (std::next_permutation(indexes.begin(), indexes.end()));
  268. primary = permuted;
  269. break;
  270. }
  271. case TRule_row_pattern_primary::ALT_NOT_SET:
  272. Y_ABORT("You should change implementation according to grammar changes");
  273. }
  274. uint64_t quantityMin = 1;
  275. uint64_t quantityMax = 1;
  276. constexpr uint64_t infinity = std::numeric_limits<uint64_t>::max();
  277. bool greedy = true;
  278. if (factor.GetRule_row_pattern_factor1().HasBlock2()) {
  279. const auto& quantifier = factor.GetRule_row_pattern_factor1().GetBlock2().GetRule_row_pattern_quantifier1();
  280. switch(quantifier.GetAltCase()){
  281. case TRule_row_pattern_quantifier::kAltRowPatternQuantifier1: //*
  282. quantityMin = 0;
  283. quantityMax = infinity;
  284. greedy = !quantifier.GetAlt_row_pattern_quantifier1().HasBlock2();
  285. break;
  286. case TRule_row_pattern_quantifier::kAltRowPatternQuantifier2: //+
  287. quantityMax = infinity;
  288. greedy = !quantifier.GetAlt_row_pattern_quantifier2().HasBlock2();
  289. break;
  290. case TRule_row_pattern_quantifier::kAltRowPatternQuantifier3: //?
  291. quantityMin = 0;
  292. greedy = !quantifier.GetAlt_row_pattern_quantifier3().HasBlock2();
  293. break;
  294. case TRule_row_pattern_quantifier::kAltRowPatternQuantifier4: //{ 2?, 4?}
  295. if (quantifier.GetAlt_row_pattern_quantifier4().HasBlock2()) {
  296. quantityMin = FromString(quantifier.GetAlt_row_pattern_quantifier4().GetBlock2().GetRule_integer1().GetToken1().GetValue());
  297. }
  298. else {
  299. quantityMin = 0;;
  300. }
  301. if (quantifier.GetAlt_row_pattern_quantifier4().HasBlock4()) {
  302. quantityMax = FromString(quantifier.GetAlt_row_pattern_quantifier4().GetBlock4().GetRule_integer1().GetToken1().GetValue());
  303. }
  304. else {
  305. quantityMax = infinity;
  306. }
  307. greedy = !quantifier.GetAlt_row_pattern_quantifier4().HasBlock6();
  308. break;
  309. case TRule_row_pattern_quantifier::kAltRowPatternQuantifier5:
  310. quantityMin = quantityMax = FromString(quantifier.GetAlt_row_pattern_quantifier5().GetRule_integer2().GetToken1().GetValue());
  311. break;
  312. case TRule_row_pattern_quantifier::ALT_NOT_SET:
  313. Y_ABORT("You should change implementation according to grammar changes");
  314. }
  315. }
  316. term.push_back(NYql::NMatchRecognize::TRowPatternFactor{std::move(primary), quantityMin, quantityMax, greedy, output, false});
  317. }
  318. return term;
  319. }
  320. NYql::NMatchRecognize::TRowPattern TSqlMatchRecognizeClause::ParsePattern(const TRule_row_pattern& node){
  321. TVector<NYql::NMatchRecognize::TRowPatternTerm> result;
  322. result.push_back(ParsePatternTerm(node.GetRule_row_pattern_term1()));
  323. for (const auto& term: node.GetBlock2())
  324. result.push_back(ParsePatternTerm(term.GetRule_row_pattern_term2()));
  325. return result;
  326. }
  327. TNamedFunction TSqlMatchRecognizeClause::ParseOneDefinition(const TRule_row_pattern_definition& node){
  328. const auto& varName = PatternVar(node.GetRule_row_pattern_definition_variable_name1().GetRule_row_pattern_variable_name1(), *this);
  329. TColumnRefScope scope(Ctx, EColumnRefState::MatchRecognize, true, varName);
  330. const auto& searchCondition = TSqlExpression(Ctx, Mode).Build(node.GetRule_row_pattern_definition_search_condition3().GetRule_search_condition1().GetRule_expr1());
  331. return TNamedFunction{searchCondition, varName};
  332. }
  333. TVector<TNamedFunction> TSqlMatchRecognizeClause::ParseDefinitions(const TRule_row_pattern_definition_list& node) {
  334. TVector<TNamedFunction> result { ParseOneDefinition(node.GetRule_row_pattern_definition1())};
  335. for (const auto& d: node.GetBlock2()) {
  336. //TODO https://st.yandex-team.ru/YQL-16186
  337. //Each define must be a predicate lambda, that accepts 3 args:
  338. // - List<input table rows>
  339. // - A struct that maps row pattern variables to ranges in the queue
  340. // - An index of the current row
  341. result.push_back(ParseOneDefinition(d.GetRule_row_pattern_definition2()));
  342. }
  343. return result;
  344. }
  345. } //namespace NSQLTranslationV1