match_recognize.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. #include "match_recognize.h"
  2. #include "source.h"
  3. #include "context.h"
  4. #include <util/generic/overloaded.h>
  5. namespace NSQLTranslationV1 {
  6. namespace {
  7. constexpr auto VarDataName = "data";
  8. constexpr auto VarMatchedVarsName = "vars";
  9. constexpr auto VarLastRowIndexName = "lri";
  10. class TMatchRecognizeColumnAccessNode final : public TAstListNode {
  11. public:
  12. TMatchRecognizeColumnAccessNode(TPosition pos, TString var, TString column)
  13. : TAstListNode(pos)
  14. , Var(std::move(var))
  15. , Column(std::move(column)) {
  16. }
  17. const TString* GetColumnName() const override {
  18. return std::addressof(Column);
  19. }
  20. bool DoInit(TContext& ctx, ISource* /* src */) override {
  21. switch (ctx.GetColumnReferenceState()) {
  22. case EColumnRefState::MatchRecognizeMeasures:
  23. if (!ctx.SetMatchRecognizeAggrVar(Var)) {
  24. return false;
  25. }
  26. Add(
  27. "Member",
  28. BuildAtom(Pos, "row"),
  29. Q(Column)
  30. );
  31. break;
  32. case EColumnRefState::MatchRecognizeDefine:
  33. if (ctx.GetMatchRecognizeDefineVar() != Var) {
  34. ctx.Error() << "Row pattern navigation function is required";
  35. return false;
  36. }
  37. BuildLookup(VarLastRowIndexName);
  38. break;
  39. case EColumnRefState::MatchRecognizeDefineAggregate:
  40. if (!ctx.SetMatchRecognizeAggrVar(Var)) {
  41. return false;
  42. }
  43. BuildLookup("index");
  44. break;
  45. default:
  46. ctx.Error(Pos) << "Unexpected column reference state";
  47. return false;
  48. }
  49. return true;
  50. }
  51. TNodePtr DoClone() const override {
  52. return MakeIntrusive<TMatchRecognizeColumnAccessNode>(Pos, Var, Column);
  53. }
  54. private:
  55. void BuildLookup(TString varKeyName) {
  56. Add(
  57. "Member",
  58. Y(
  59. "Lookup",
  60. Y(
  61. "ToIndexDict",
  62. BuildAtom(Pos, VarDataName)
  63. ),
  64. BuildAtom(Pos, std::move(varKeyName))
  65. ),
  66. Q(Column)
  67. );
  68. }
  69. private:
  70. TString Var;
  71. TString Column;
  72. };
  73. class TMatchRecognizeDefineAggregate final : public TAstListNode {
  74. public:
  75. TMatchRecognizeDefineAggregate(TPosition pos, TString name, TVector<TNodePtr> args)
  76. : TAstListNode(pos)
  77. , Name(std::move(name))
  78. , Args(std::move(args)) {
  79. }
  80. bool DoInit(TContext& ctx, ISource* src) override {
  81. if (EColumnRefState::MatchRecognizeDefine != ctx.GetColumnReferenceState()) {
  82. ctx.Error(Pos) << "Unexpected column reference state";
  83. return false;
  84. }
  85. TColumnRefScope scope(ctx, EColumnRefState::MatchRecognizeDefineAggregate, false, ctx.GetMatchRecognizeDefineVar());
  86. if (Args.size() != 1) {
  87. ctx.Error() << "Exactly one argument is required in MATCH_RECOGNIZE navigation function";
  88. return false;
  89. }
  90. const auto arg = Args[0];
  91. if (!arg || !arg->Init(ctx, src)) {
  92. return false;
  93. }
  94. const auto body = [&]() -> TNodePtr {
  95. if ("first" == Name) {
  96. return Y("Member", Y("Head", "item"), Q("From"));
  97. } else if ("last" == Name) {
  98. return Y("Member", Y("Last", "item"), Q("To"));
  99. } else {
  100. ctx.Error() << "Unknown row pattern navigation function: " << Name;
  101. return {};
  102. }
  103. }();
  104. if (!body) {
  105. return false;
  106. }
  107. Add("Apply", BuildLambda(Pos, Y("index"), arg), body);
  108. return true;
  109. }
  110. TNodePtr DoClone() const override {
  111. return MakeIntrusive<TMatchRecognizeDefineAggregate>(Pos, Name, Args);
  112. }
  113. private:
  114. TString Name;
  115. TVector<TNodePtr> Args;
  116. };
  117. class TMatchRecognizeVarAccessNode final : public INode {
  118. public:
  119. TMatchRecognizeVarAccessNode(TPosition pos, TNodePtr aggr)
  120. : INode(pos)
  121. , Aggr(std::move(aggr)) {
  122. }
  123. bool DoInit(TContext& ctx, ISource* src) override {
  124. if (!Aggr || !Aggr->Init(ctx, src)) {
  125. return false;
  126. }
  127. auto var = ctx.ExtractMatchRecognizeAggrVar();
  128. Expr = [&]() -> TNodePtr {
  129. switch (ctx.GetColumnReferenceState()) {
  130. case EColumnRefState::MatchRecognizeMeasures: {
  131. ctx.GetMatchRecognizeAggregations().emplace_back(std::move(var), Aggr->GetAggregation());
  132. return Aggr;
  133. }
  134. case EColumnRefState::MatchRecognizeDefine:
  135. return Y(
  136. "Apply",
  137. BuildLambda(Pos, Y("item"), Aggr),
  138. Y(
  139. "Member",
  140. BuildAtom(ctx.Pos(), VarMatchedVarsName),
  141. Q(std::move(var))
  142. )
  143. );
  144. default:
  145. ctx.Error(Pos) << "Unexpected column reference state";
  146. return {};
  147. }
  148. }();
  149. return Expr && Expr->Init(ctx, src);
  150. }
  151. TNodePtr DoClone() const override {
  152. return MakeIntrusive<TMatchRecognizeVarAccessNode>(Pos, Aggr);
  153. }
  154. TAstNode* Translate(TContext& ctx) const override {
  155. return Expr->Translate(ctx);
  156. }
  157. private:
  158. TNodePtr Aggr;
  159. TNodePtr Expr;
  160. };
  161. class TMatchRecognize final : public TAstListNode {
  162. public:
  163. TMatchRecognize(
  164. TPosition pos,
  165. TString label,
  166. TNodePtr partitionKeySelector,
  167. TNodePtr partitionColumns,
  168. TVector<TSortSpecificationPtr> sortSpecs,
  169. TVector<TNamedFunction> measures,
  170. TNodePtr rowsPerMatch,
  171. TNodePtr skipTo,
  172. TNodePtr pattern,
  173. TNodePtr patternVars,
  174. TNodePtr subset,
  175. TVector<TNamedFunction> definitions)
  176. : TAstListNode(pos)
  177. , Label(std::move(label))
  178. , PartitionKeySelector(std::move(partitionKeySelector))
  179. , PartitionColumns(std::move(partitionColumns))
  180. , SortSpecs(std::move(sortSpecs))
  181. , Measures(std::move(measures))
  182. , RowsPerMatch(std::move(rowsPerMatch))
  183. , SkipTo(std::move(skipTo))
  184. , Pattern(std::move(pattern))
  185. , PatternVars(std::move(patternVars))
  186. , Subset(std::move(subset))
  187. , Definitions(std::move(definitions)) {
  188. }
  189. private:
  190. bool DoInit(TContext& ctx, ISource* src) override {
  191. auto inputRowType = Y("ListItemType", Y("TypeOf", Label));
  192. if (!PartitionKeySelector || !PartitionKeySelector->Init(ctx, src)) {
  193. return false;
  194. }
  195. if (!PartitionColumns || !PartitionColumns->Init(ctx, src)) {
  196. return false;
  197. }
  198. const auto sortTraits = SortSpecs.empty() ? Y("Void") : src->BuildSortSpec(SortSpecs, Label, true, false);
  199. if (!sortTraits || !sortTraits->Init(ctx, src)) {
  200. return false;
  201. }
  202. auto measureNames = Y();
  203. auto measuresCallables = Y();
  204. for (auto& m: Measures) {
  205. TColumnRefScope scope(ctx, EColumnRefState::MatchRecognizeMeasures);
  206. if (!m.Callable || !m.Callable->Init(ctx, src)) {
  207. return false;
  208. }
  209. const auto pos = m.Callable->GetPos();
  210. measureNames = L(measureNames, BuildQuotedAtom(m.Callable->GetPos(), std::move(m.Name)));
  211. auto measuresVars = Y();
  212. auto measuresAggregates = Y();
  213. for (auto& [var, aggr] : ctx.GetMatchRecognizeAggregations()) {
  214. if (!aggr) {
  215. return false;
  216. }
  217. auto [traits, result] = aggr->AggregationTraits(Y("TypeOf", Label), false, false, false, ctx);
  218. if (!result) {
  219. return false;
  220. }
  221. measuresVars = L(measuresVars, BuildQuotedAtom(pos, std::move(var)));
  222. measuresAggregates = L(measuresAggregates, std::move(traits));
  223. }
  224. ctx.GetMatchRecognizeAggregations().clear();
  225. measuresCallables = L(
  226. measuresCallables,
  227. Y(
  228. "MatchRecognizeMeasuresCallable",
  229. BuildLambda(pos, Y("row"), std::move(m.Callable)),
  230. Q(measuresVars),
  231. Q(measuresAggregates)
  232. )
  233. );
  234. }
  235. auto measuresNode = Y("MatchRecognizeMeasuresCallables", inputRowType, Q(PatternVars), Q(measureNames), Q(measuresCallables));
  236. if (!RowsPerMatch || !RowsPerMatch->Init(ctx, src)) {
  237. return false;
  238. }
  239. if (!SkipTo || !SkipTo->Init(ctx, src)) {
  240. return false;
  241. }
  242. if (!Pattern || !Pattern->Init(ctx, src)) {
  243. return false;
  244. }
  245. if (!PatternVars || !PatternVars->Init(ctx, src)) {
  246. return false;
  247. }
  248. auto defineNames = Y();
  249. for (auto& d: Definitions) {
  250. defineNames = L(defineNames, BuildQuotedAtom(d.Callable->GetPos(), d.Name));
  251. }
  252. auto defineNode = Y("MatchRecognizeDefines", inputRowType, Q(PatternVars), Q(defineNames));
  253. for (auto& d: Definitions) {
  254. TColumnRefScope scope(ctx, EColumnRefState::MatchRecognizeDefine, true, d.Name);
  255. if (!d.Callable || !d.Callable->Init(ctx, src)) {
  256. return false;
  257. }
  258. const auto pos = d.Callable->GetPos();
  259. defineNode = L(defineNode, BuildLambda(pos, Y(VarDataName, VarMatchedVarsName, VarLastRowIndexName), std::move(d.Callable)));
  260. }
  261. Add(
  262. "block",
  263. Q(Y(
  264. Y("let", "input", Label),
  265. Y("let", "partitionKeySelector", PartitionKeySelector),
  266. Y("let", "partitionColumns", PartitionColumns),
  267. Y("let", "sortTraits", sortTraits),
  268. Y("let", "measures", measuresNode),
  269. Y("let", "rowsPerMatch", RowsPerMatch),
  270. Y("let", "skipTo", SkipTo),
  271. Y("let", "pattern", Pattern),
  272. Y("let", "subset", Subset ? Subset : Q("")),
  273. Y("let", "define", defineNode),
  274. Y("let", "res", Y("MatchRecognize",
  275. "input",
  276. "partitionKeySelector",
  277. "partitionColumns",
  278. "sortTraits",
  279. Y("MatchRecognizeParams",
  280. "measures",
  281. "rowsPerMatch",
  282. "skipTo",
  283. "pattern",
  284. "define"
  285. )
  286. )),
  287. Y("return", "res")
  288. ))
  289. );
  290. return true;
  291. }
  292. TNodePtr DoClone() const override {
  293. return MakeIntrusive<TMatchRecognize>(
  294. Pos,
  295. Label,
  296. PartitionKeySelector,
  297. PartitionColumns,
  298. SortSpecs,
  299. Measures,
  300. RowsPerMatch,
  301. SkipTo,
  302. Pattern,
  303. PatternVars,
  304. Subset,
  305. Definitions
  306. );
  307. }
  308. private:
  309. TString Label;
  310. TNodePtr PartitionKeySelector;
  311. TNodePtr PartitionColumns;
  312. TVector<TSortSpecificationPtr> SortSpecs;
  313. TVector<TNamedFunction> Measures;
  314. TNodePtr RowsPerMatch;
  315. TNodePtr SkipTo;
  316. TNodePtr Pattern;
  317. TNodePtr PatternVars;
  318. TNodePtr Subset;
  319. TVector<TNamedFunction> Definitions;
  320. };
  321. } // anonymous namespace
  322. TNodePtr TMatchRecognizeBuilder::Build(TContext& ctx, TString label, ISource* src) {
  323. const auto node = MakeIntrusive<TMatchRecognize>(
  324. Pos,
  325. std::move(label),
  326. std::move(PartitionKeySelector),
  327. std::move(PartitionColumns),
  328. std::move(SortSpecs),
  329. std::move(Measures),
  330. std::move(RowsPerMatch),
  331. std::move(SkipTo),
  332. std::move(Pattern),
  333. std::move(PatternVars),
  334. std::move(Subset),
  335. std::move(Definitions)
  336. );
  337. if (!node->Init(ctx, src)) {
  338. return {};
  339. }
  340. return node;
  341. }
  342. TNodePtr BuildMatchRecognizeColumnAccess(TPosition pos, TString var, TString column) {
  343. return MakeIntrusive<TMatchRecognizeColumnAccessNode>(pos, std::move(var), std::move(column));
  344. }
  345. TNodePtr BuildMatchRecognizeDefineAggregate(TPosition pos, TString name, TVector<TNodePtr> args) {
  346. const auto result = MakeIntrusive<TMatchRecognizeDefineAggregate>(pos, std::move(name), std::move(args));
  347. return BuildMatchRecognizeVarAccess(pos, std::move(result));
  348. }
  349. TNodePtr BuildMatchRecognizeVarAccess(TPosition pos, TNodePtr aggr) {
  350. return MakeIntrusive<TMatchRecognizeVarAccessNode>(pos, std::move(aggr));
  351. }
  352. } // namespace NSQLTranslationV1