match_recognize.cpp 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. #include "match_recognize.h"
  2. #include "source.h"
  3. #include "context.h"
  4. namespace NSQLTranslationV1 {
  5. namespace {
  6. const auto VarDataName = "data";
  7. const auto VarMatchedVarsName = "vars";
  8. const auto VarLastRowIndexName = "lri";
  9. } //namespace {
  10. class TMatchRecognize: public TAstListNode {
  11. public:
  12. TMatchRecognize(
  13. TPosition pos,
  14. ISource* source,
  15. const TString& inputTable,
  16. std::pair<TPosition, TVector<TNamedFunction>>&& partitioners,
  17. std::pair<TPosition, TVector<TSortSpecificationPtr>>&& sortSpecs,
  18. std::pair<TPosition, TVector<TNamedFunction>>&& measures,
  19. std::pair<TPosition, ERowsPerMatch>&& rowsPerMatch,
  20. std::pair<TPosition, NYql::NMatchRecognize::TAfterMatchSkipTo>&& skipTo,
  21. std::pair<TPosition, NYql::NMatchRecognize::TRowPattern>&& pattern,
  22. std::pair<TPosition, TNodePtr>&& subset,
  23. std::pair<TPosition, TVector<TNamedFunction>>&& definitions
  24. ): TAstListNode(pos, {BuildAtom(pos, "block")})
  25. {
  26. Add(BuildBlockStatements(
  27. pos,
  28. source,
  29. inputTable,
  30. std::move(partitioners),
  31. std::move(sortSpecs),
  32. std::move(measures),
  33. std::move(rowsPerMatch),
  34. std::move(skipTo),
  35. std::move(pattern),
  36. std::move(subset),
  37. std::move(definitions)
  38. ));
  39. }
  40. private:
  41. TMatchRecognize(const TMatchRecognize& other)
  42. : TAstListNode(other.Pos)
  43. {
  44. Nodes = CloneContainer(other.Nodes);
  45. }
  46. TNodePtr BuildBlockStatements(
  47. TPosition pos,
  48. ISource* source,
  49. const TString& inputTable,
  50. std::pair<TPosition, TVector<TNamedFunction>>&& partitioners,
  51. std::pair<TPosition, TVector<TSortSpecificationPtr>>&& sortSpecs,
  52. std::pair<TPosition, TVector<TNamedFunction>>&& measures,
  53. std::pair<TPosition, ERowsPerMatch>&& rowsPerMatch,
  54. std::pair<TPosition, NYql::NMatchRecognize::TAfterMatchSkipTo>&& skipTo,
  55. std::pair<TPosition, NYql::NMatchRecognize::TRowPattern>&& pattern,
  56. std::pair<TPosition, TNodePtr>&& subset,
  57. std::pair<TPosition, TVector<TNamedFunction>>&& definitions
  58. ) {
  59. Y_UNUSED(pos);
  60. auto inputRowType = Y("ListItemType",Y("TypeOf", inputTable));
  61. auto patternNode = Pattern(pattern.first, pattern.second);
  62. auto partitionColumns = Y();
  63. for (const auto& p: partitioners.second){
  64. partitionColumns->Add(BuildQuotedAtom(p.callable->GetPos(), p.name));
  65. }
  66. partitionColumns = Q(partitionColumns);
  67. auto partitionKeySelector = Y();
  68. for (const auto& p: partitioners.second){
  69. partitionKeySelector->Add(p.callable);
  70. }
  71. partitionKeySelector = BuildLambda(partitioners.first, Y("row"), Q(partitionKeySelector));
  72. auto measureNames = Y();
  73. for (const auto& m: measures.second){
  74. measureNames->Add(BuildQuotedAtom(m.callable->GetPos(), m.name));
  75. }
  76. TNodePtr measuresNode = Y("MatchRecognizeMeasures", inputRowType, patternNode, Q(measureNames));
  77. for (const auto& m: measures.second){
  78. measuresNode->Add(BuildLambda(m.callable->GetPos(), Y(VarDataName, VarMatchedVarsName), m.callable));
  79. }
  80. auto defineNames = Y();
  81. for (const auto& d: definitions.second) {
  82. defineNames->Add(BuildQuotedAtom(d.callable->GetPos(), d.name));
  83. }
  84. TNodePtr defineNode = Y("MatchRecognizeDefines", inputRowType, patternNode, Q(defineNames));
  85. for (const auto& d: definitions.second) {
  86. defineNode->Add(BuildLambda(d.callable->GetPos(), Y(VarDataName, VarMatchedVarsName, VarLastRowIndexName), d.callable));
  87. }
  88. return Q(Y(
  89. Y("let", "input", inputTable),
  90. Y("let", "partitionKeySelector", partitionKeySelector),
  91. Y("let", "partitionColumns", partitionColumns),
  92. Y("let", "sortTraits", sortSpecs.second.empty()? Y("Void") : source->BuildSortSpec(sortSpecs.second, inputTable, true, false)),
  93. Y("let", "measures", measuresNode),
  94. Y("let", "rowsPerMatch", BuildQuotedAtom(rowsPerMatch.first, "RowsPerMatch_" + ToString(rowsPerMatch.second))),
  95. Y("let", "skipTo", BuildTuple(skipTo.first, {Q("AfterMatchSkip_" + ToString(skipTo.second.To)), Q(ToString(skipTo.second.Var))})),
  96. Y("let", "pattern", patternNode),
  97. Y("let", "subset", subset.second ? subset.second : Q("")),
  98. Y("let", "define", defineNode),
  99. Y("let", "res", Y("MatchRecognize",
  100. "input",
  101. "partitionKeySelector",
  102. "partitionColumns",
  103. "sortTraits",
  104. Y("MatchRecognizeParams",
  105. "measures",
  106. "rowsPerMatch",
  107. "skipTo",
  108. "pattern",
  109. "define"
  110. )
  111. )),
  112. Y("return", "res")
  113. ));
  114. }
  115. TPtr PatternFactor(const TPosition& pos, const NYql::NMatchRecognize::TRowPatternFactor& factor) {
  116. return BuildTuple(pos, {
  117. factor.Primary.index() == 0 ?
  118. BuildQuotedAtom(pos, std::get<0>(factor.Primary)) :
  119. Pattern(pos, std::get<1>(factor.Primary)),
  120. BuildQuotedAtom(pos, ToString(factor.QuantityMin)),
  121. BuildQuotedAtom(pos, ToString(factor.QuantityMax)),
  122. BuildQuotedAtom(pos, ToString(factor.Greedy)),
  123. BuildQuotedAtom(pos, ToString(factor.Output)),
  124. BuildQuotedAtom(pos, ToString(factor.Unused))
  125. });
  126. }
  127. TPtr PatternTerm(const TPosition& pos, const NYql::NMatchRecognize::TRowPatternTerm& term) {
  128. auto factors = Y();
  129. for (const auto& f: term)
  130. factors->Add(PatternFactor(pos, f));
  131. return Q(std::move(factors));
  132. }
  133. TPtr Pattern(const TPosition& pos, const NYql::NMatchRecognize::TRowPattern& pattern) {
  134. TNodePtr patternNode = Y("MatchRecognizePattern");
  135. for (const auto& t: pattern) {
  136. patternNode->Add(PatternTerm(pos, t));
  137. }
  138. return patternNode;
  139. }
  140. TPtr DoClone() const final{
  141. return new TMatchRecognize(*this);
  142. }
  143. };
  144. TNodePtr TMatchRecognizeBuilder::Build(TContext& ctx, TString&& inputTable, ISource* source){
  145. TNodePtr node = new TMatchRecognize(
  146. Pos,
  147. source,
  148. std::move(inputTable),
  149. std::move(Partitioners),
  150. std::move(SortSpecs),
  151. std::move(Measures),
  152. std::move(RowsPerMatch),
  153. std::move(SkipTo),
  154. std::move(Pattern),
  155. std::move(Subset),
  156. std::move(Definitions)
  157. );
  158. if (!node->Init(ctx, source))
  159. return nullptr;
  160. return node;
  161. }
  162. namespace {
  163. const auto DefaultNavigatingFunction = "MatchRecognizeDefaultNavigating";
  164. }
  165. bool TMatchRecognizeVarAccessNode::DoInit(TContext& ctx, ISource* src) {
  166. //If referenced var is the var that is currently being defined
  167. //then it's a reference to the last row in a partition
  168. Node = new TMatchRecognizeNavigate(ctx.Pos(), DefaultNavigatingFunction, TVector<TNodePtr>{this->Clone()});
  169. return Node->Init(ctx, src);
  170. }
  171. bool TMatchRecognizeNavigate::DoInit(TContext& ctx, ISource* src) {
  172. Y_UNUSED(src);
  173. if (Args.size() != 1) {
  174. ctx.Error(Pos) << "Exactly one argument is required in MATCH_RECOGNIZE navigation function";
  175. return false;
  176. }
  177. const auto varColumn = dynamic_cast<TMatchRecognizeVarAccessNode *>(Args[0].Get());
  178. if (not varColumn) {
  179. ctx.Error(Pos) << "Row pattern navigation operations are applicable to row pattern variable only";
  180. return false;
  181. }
  182. const auto varData = BuildAtom(ctx.Pos(), VarDataName);
  183. const auto varMatchedVars = BuildAtom(ctx.Pos(), VarMatchedVarsName);
  184. const auto varLastRowIndex = BuildAtom(ctx.Pos(), VarLastRowIndexName);
  185. const auto matchedRanges = Y("Member", varMatchedVars, Q(varColumn->GetVar()));
  186. TNodePtr navigatedRowIndex;
  187. if (DefaultNavigatingFunction == Name) {
  188. if (not varColumn->IsTheSameVar()) {
  189. ctx.Error(Pos) << "Row pattern navigation function is required";
  190. return false;
  191. }
  192. navigatedRowIndex = varLastRowIndex;
  193. }
  194. else if ("PREV" == Name) {
  195. if (not varColumn->IsTheSameVar()) {
  196. ctx.Error(Pos) << "PREV relative to matched vars is not implemented yet";
  197. return false;
  198. }
  199. navigatedRowIndex = Y(
  200. "-",
  201. varLastRowIndex,
  202. Y("Uint64", Q("1"))
  203. );
  204. } else if ("FIRST" == Name) {
  205. navigatedRowIndex = Y(
  206. "Member",
  207. Y("Head", matchedRanges),
  208. Q("From")
  209. );
  210. } else if ("LAST" == Name) {
  211. navigatedRowIndex = Y(
  212. "Member",
  213. Y("Last", matchedRanges),
  214. Q("To")
  215. );
  216. } else {
  217. ctx.Error(Pos) << "Internal logic error";
  218. return false;
  219. }
  220. Add("Member");
  221. Add(
  222. Y(
  223. "Lookup",
  224. Y("ToIndexDict", varData),
  225. navigatedRowIndex
  226. )
  227. ),
  228. Add(Q(varColumn->GetColumn()));
  229. return true;
  230. }
  231. } // namespace NSQLTranslationV1