join.cpp 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602
  1. #include "node.h"
  2. #include "context.h"
  3. #include <yql/essentials/utils/yql_panic.h>
  4. #include <library/cpp/charset/ci_string.h>
  5. #include <util/generic/hash_set.h>
  6. #include <util/string/cast.h>
  7. #include <util/string/split.h>
  8. #include <util/string/join.h>
  9. using namespace NYql;
  10. namespace NSQLTranslationV0 {
  11. namespace {
  12. }
  13. TString NormalizeJoinOp(const TString& joinOp) {
  14. TVector<TString> joinOpsParts;
  15. Split(joinOp, " ", joinOpsParts);
  16. for (auto&x : joinOpsParts) {
  17. x.to_title();
  18. }
  19. return JoinSeq("", joinOpsParts);
  20. }
  21. struct TJoinDescr {
  22. TString Op;
  23. struct TFullColumn {
  24. ui32 Source;
  25. TNodePtr Column;
  26. };
  27. TVector<std::pair<TFullColumn, TFullColumn>> Keys;
  28. TJoinDescr(const TString& op)
  29. : Op(op)
  30. {}
  31. };
  32. class TJoinBase: public IJoin {
  33. public:
  34. TJoinBase(TPosition pos, TVector<TSourcePtr>&& sources)
  35. : IJoin(pos)
  36. , Sources(std::move(sources))
  37. {
  38. }
  39. TMaybe<bool> AddColumn(TContext& ctx, TColumnNode& column) override {
  40. ISource* srcByName = nullptr;
  41. if (column.IsArtificial()) {
  42. return true;
  43. }
  44. if (const auto sourceName = *column.GetSourceName()) {
  45. for (auto& source: Sources) {
  46. if (sourceName == source->GetLabel()) {
  47. srcByName = source.Get();
  48. break;
  49. }
  50. }
  51. if (!srcByName) {
  52. if (column.IsAsterisk()) {
  53. ctx.Error(column.GetPos()) << "Unknown correlation name for asterisk: " << sourceName;
  54. return {};
  55. }
  56. // \todo add warning, either mistake in correlation name, either it's a column
  57. column.ResetColumn("", sourceName);
  58. column.SetUseSourceAsColumn();
  59. column.SetAsNotReliable();
  60. }
  61. }
  62. if (column.IsAsterisk()) {
  63. if (!column.GetCountHint()) {
  64. if (srcByName) {
  65. srcByName->AllColumns();
  66. } else {
  67. for (auto& source: Sources) {
  68. source->AllColumns();
  69. }
  70. }
  71. }
  72. return true;
  73. }
  74. if (srcByName) {
  75. column.ResetAsReliable();
  76. if (!srcByName->AddColumn(ctx, column)) {
  77. return {};
  78. }
  79. if (!KeysInitializing && !column.IsAsterisk()) {
  80. column.SetUseSource();
  81. }
  82. return true;
  83. } else {
  84. unsigned acceptedColumns = 0;
  85. TIntrusivePtr<TColumnNode> tryColumn = static_cast<TColumnNode*>(column.Clone().Get());
  86. tryColumn->SetAsNotReliable();
  87. TString lastAcceptedColumnSource;
  88. for (auto& source: Sources) {
  89. if (source->AddColumn(ctx, *tryColumn)) {
  90. ++acceptedColumns;
  91. lastAcceptedColumnSource = source->GetLabel();
  92. }
  93. }
  94. if (!acceptedColumns) {
  95. TStringBuilder sb;
  96. const auto& fullColumnName = FullColumnName(column);
  97. sb << "Column " << fullColumnName << " is not fit to any source";
  98. for (auto& source: Sources) {
  99. if (const auto mistype = source->FindColumnMistype(fullColumnName)) {
  100. sb << ". Did you mean " << mistype.GetRef() << "?";
  101. break;
  102. }
  103. }
  104. ctx.Error(column.GetPos()) << sb;
  105. return {};
  106. } else {
  107. column.SetAsNotReliable();
  108. }
  109. return false;
  110. }
  111. }
  112. const TColumns* GetColumns() const override {
  113. YQL_ENSURE(IsColumnDone, "Unable to GetColumns while it's not finished");
  114. return &JoinedColumns;
  115. }
  116. void GetInputTables(TTableList& tableList) const override {
  117. for (auto& src: Sources) {
  118. src->GetInputTables(tableList);
  119. }
  120. ISource::GetInputTables(tableList);
  121. }
  122. TNodePtr BuildJoinKeys(TContext& ctx, const TVector<TDeferredAtom>& names) override {
  123. const size_t n = JoinOps.size();
  124. TString what(Sources[n]->GetLabel());
  125. static const TSet<TString> noRightSourceJoinOps = {"LeftOnly", "LeftSemi"};
  126. for (size_t nn = n; nn > 0 && noRightSourceJoinOps.contains(JoinOps[nn-1]); --nn) {
  127. what = Sources[nn-1]->GetLabel();
  128. }
  129. const TString with(Sources[n + 1]->GetLabel());
  130. for (auto index = n; index <= n + 1; ++index) {
  131. const auto& label = Sources[index]->GetLabel();
  132. if (label.Contains('.')) {
  133. ctx.Error(Sources[index]->GetPos()) << "Invalid label: " << label << ", unable to use name with dot symbol, you should use AS <simple alias name>";
  134. return nullptr;
  135. }
  136. }
  137. if (what.empty() && with.empty()) {
  138. ctx.Error() << "At least one correlation name is required in join";
  139. return nullptr;
  140. }
  141. if (what == with) {
  142. ctx.Error() << "Self joins are not supporting ON syntax";
  143. return nullptr;
  144. }
  145. TPosition pos(ctx.Pos());
  146. TNodePtr expr;
  147. for (auto& name: names) {
  148. auto lhs = BuildColumn(Pos, name, what);
  149. auto rhs = BuildColumn(Pos, name, with);
  150. if (!lhs || !rhs) {
  151. return nullptr;
  152. }
  153. TNodePtr eq(BuildBinaryOp(pos, "==", lhs, rhs));
  154. if (expr) {
  155. expr = BuildBinaryOp(pos, "And", expr, eq);
  156. } else {
  157. expr = eq;
  158. }
  159. }
  160. if (expr && Sources.size() > 2) {
  161. ctx.Warning(ctx.Pos(), TIssuesIds::YQL_MULTIWAY_JOIN_WITH_USING) << "Multi-way JOINs should be connected with ON clause instead of USING clause";
  162. }
  163. return expr;
  164. }
  165. bool DoInit(TContext& ctx, ISource* src) override;
  166. void SetupJoin(const TString& opName, TNodePtr expr) override {
  167. JoinOps.push_back(opName);
  168. JoinExprs.push_back(expr);
  169. }
  170. protected:
  171. static TString FullColumnName(const TColumnNode& column) {
  172. auto sourceName = *column.GetSourceName();
  173. auto columnName = *column.GetColumnName();
  174. return sourceName ? DotJoin(sourceName, columnName) : columnName;
  175. }
  176. bool InitKeysOrFilters(TContext& ctx, ui32 joinIdx, TNodePtr expr) {
  177. const TString joinOp(JoinOps[joinIdx]);
  178. const TCallNode* op = nullptr;
  179. if (expr) {
  180. const TString opName(expr->GetOpName());
  181. if (opName != "==") {
  182. ctx.Error(expr->GetPos()) << "JOIN ON expression must be a conjunction of equality predicates";
  183. return false;
  184. }
  185. op = dynamic_cast<const TCallNode*>(expr.Get());
  186. YQL_ENSURE(op, "Invalid JOIN equal operation node");
  187. YQL_ENSURE(op->GetArgs().size() == 2, "Invalid JOIN equal operation arguments");
  188. }
  189. ui32 idx = 0;
  190. THashMap<TString, ui32> sources;
  191. for (auto& source: Sources) {
  192. sources.insert({ source->GetLabel(), idx });
  193. ++idx;
  194. }
  195. if (sources.size() != Sources.size()) {
  196. ctx.Error(expr ? expr->GetPos() : Pos) << "JOIN: all correlation names must be different";
  197. return false;
  198. }
  199. ui32 pos = 0;
  200. ui32 leftArg = 0;
  201. ui32 rightArg = 0;
  202. ui32 leftSourceIdx = 0;
  203. ui32 rightSourceIdx = 0;
  204. const TString* leftSource = nullptr;
  205. const TString* rightSource = nullptr;
  206. const TString* sameColumnNamePtr = nullptr;
  207. TSet<TString> joinedSources;
  208. if (op) {
  209. const TString* columnNamePtr = nullptr;
  210. for (auto& arg : op->GetArgs()) {
  211. const auto sourceNamePtr = arg->GetSourceName();
  212. if (!sourceNamePtr) {
  213. ctx.Error(expr->GetPos()) << "JOIN: equality predicate arguments must not be constant";
  214. return false;
  215. }
  216. const auto sourceName = *sourceNamePtr;
  217. if (sourceName.empty()) {
  218. ctx.Error(expr->GetPos()) << "JOIN: column requires correlation name";
  219. return false;
  220. }
  221. auto it = sources.find(sourceName);
  222. if (it != sources.end()) {
  223. joinedSources.insert(sourceName);
  224. if (it->second == joinIdx + 1) {
  225. rightArg = pos;
  226. rightSource = sourceNamePtr;
  227. rightSourceIdx = it->second;
  228. }
  229. else if (it->second > joinIdx + 1) {
  230. ctx.Error(expr->GetPos()) << "JOIN: can not use source: " << sourceName << " in equality predicate, it is out of current join scope";
  231. return false;
  232. }
  233. else {
  234. leftArg = pos;
  235. leftSource = sourceNamePtr;
  236. leftSourceIdx = it->second;
  237. }
  238. }
  239. else {
  240. ctx.Error(expr->GetPos()) << "JOIN: unknown corellation name: " << sourceName;
  241. return false;
  242. }
  243. if (!columnNamePtr) {
  244. columnNamePtr = arg->GetColumnName();
  245. } else {
  246. auto curColumnNamePtr = arg->GetColumnName();
  247. if (curColumnNamePtr && *curColumnNamePtr == *columnNamePtr) {
  248. sameColumnNamePtr = columnNamePtr;
  249. }
  250. }
  251. ++pos;
  252. }
  253. } else {
  254. for (auto& x : sources) {
  255. if (x.second == joinIdx) {
  256. leftArg = pos;
  257. leftSourceIdx = x.second;
  258. joinedSources.insert(x.first);
  259. }
  260. else if (x.second = joinIdx + 1) {
  261. rightArg = pos;
  262. rightSourceIdx = x.second;
  263. joinedSources.insert(x.first);
  264. }
  265. }
  266. }
  267. if (joinedSources.size() == 1) {
  268. ctx.Error(expr ? expr->GetPos() : Pos) << "JOIN: different correlation names are required for joined tables";
  269. return false;
  270. }
  271. if (op) {
  272. if (joinedSources.size() != 2) {
  273. ctx.Error(expr->GetPos()) << "JOIN ON expression must be a conjunction of equality predicates over at most two sources";
  274. return false;
  275. }
  276. if (!rightSource) {
  277. ctx.Error(expr->GetPos()) << "JOIN ON equality predicate must have one of its arguments from the rightmost source";
  278. return false;
  279. }
  280. }
  281. KeysInitializing = true;
  282. if (op) {
  283. ctx.PushBlockShortcuts();
  284. for (auto& arg : op->GetArgs()) {
  285. if (!arg->Init(ctx, this)) {
  286. return false;
  287. }
  288. }
  289. KeysGround = ctx.GroundBlockShortcuts(GetPos(), KeysGround);
  290. Y_DEBUG_ABORT_UNLESS(leftSource);
  291. if (sameColumnNamePtr) {
  292. SameKeyMap[*sameColumnNamePtr].insert(*leftSource);
  293. SameKeyMap[*sameColumnNamePtr].insert(*rightSource);
  294. }
  295. }
  296. if (joinIdx == JoinDescrs.size()) {
  297. JoinDescrs.push_back(TJoinDescr(joinOp));
  298. }
  299. JoinDescrs.back().Keys.push_back({ { leftSourceIdx, op ? op->GetArgs()[leftArg] : nullptr},
  300. { rightSourceIdx, op ? op->GetArgs()[rightArg] : nullptr } });
  301. KeysInitializing = false;
  302. return true;
  303. }
  304. bool IsJoinKeysInitializing() const override {
  305. return KeysInitializing;
  306. }
  307. protected:
  308. TVector<TString> JoinOps;
  309. TVector<TNodePtr> JoinExprs;
  310. TVector<TJoinDescr> JoinDescrs;
  311. TNodePtr KeysGround;
  312. THashMap<TString, THashSet<TString>> SameKeyMap;
  313. TVector<TSourcePtr> Sources;
  314. TColumns JoinedColumns;
  315. bool KeysInitializing = false;
  316. bool IsColumnDone = false;
  317. void FinishColumns() override {
  318. if (IsColumnDone) {
  319. return;
  320. }
  321. YQL_ENSURE(JoinOps.size()+1 == Sources.size());
  322. bool excludeNextSource = false;
  323. decltype(JoinOps)::const_iterator opIter = JoinOps.begin();
  324. for (auto& src: Sources) {
  325. if (excludeNextSource) {
  326. excludeNextSource = false;
  327. if (opIter != JoinOps.end()) {
  328. ++opIter;
  329. }
  330. continue;
  331. }
  332. if (opIter != JoinOps.end()) {
  333. auto joinOper = *opIter;
  334. ++opIter;
  335. if (joinOper == "LeftSemi" || joinOper == "LeftOnly") {
  336. excludeNextSource = true;
  337. }
  338. if (joinOper == "RightSemi" || joinOper == "RightOnly") {
  339. continue;
  340. }
  341. }
  342. auto columnsPtr = src->GetColumns();
  343. if (!columnsPtr) {
  344. continue;
  345. }
  346. TColumns upColumns;
  347. upColumns.Merge(*columnsPtr);
  348. upColumns.SetPrefix(src->GetLabel());
  349. JoinedColumns.Merge(upColumns);
  350. }
  351. IsColumnDone = true;
  352. }
  353. };
  354. bool TJoinBase::DoInit(TContext& ctx, ISource* src) {
  355. for (auto& source: Sources) {
  356. if (!source->Init(ctx, src)) {
  357. return false;
  358. }
  359. }
  360. YQL_ENSURE(JoinOps.size() == JoinExprs.size(), "Invalid join exprs number");
  361. const TSet<TString> allowedJoinOps = {"Inner", "Left", "Right", "Full", "LeftOnly", "RightOnly", "Exclusion", "LeftSemi", "RightSemi", "Cross"};
  362. for (auto& opName: JoinOps) {
  363. if (!allowedJoinOps.contains(opName)) {
  364. ctx.Error(Pos) << "Invalid join op: " << opName;
  365. return false;
  366. }
  367. }
  368. ui32 idx = 0;
  369. for (auto expr: JoinExprs) {
  370. if (expr) {
  371. TDeque<TNodePtr> conjQueue;
  372. conjQueue.push_back(expr);
  373. while (!conjQueue.empty()) {
  374. TNodePtr cur = conjQueue.front();
  375. conjQueue.pop_front();
  376. if (cur->GetOpName() == "And") {
  377. auto conj = dynamic_cast<const TCallNode*>(cur.Get());
  378. YQL_ENSURE(conj, "Invalid And operation node");
  379. conjQueue.insert(conjQueue.begin(), conj->GetArgs().begin(), conj->GetArgs().end());
  380. } else if (!InitKeysOrFilters(ctx, idx, cur)) {
  381. return false;
  382. }
  383. }
  384. } else {
  385. if (!InitKeysOrFilters(ctx, idx, nullptr)) {
  386. return false;
  387. }
  388. }
  389. ++idx;
  390. }
  391. TSet<ui32> joinedSources;
  392. for (auto& descr: JoinDescrs) {
  393. for (auto& key : descr.Keys) {
  394. joinedSources.insert(key.first.Source);
  395. joinedSources.insert(key.second.Source);
  396. }
  397. }
  398. for (idx = 0; idx < Sources.size(); ++idx) {
  399. if (!joinedSources.contains(idx)) {
  400. ctx.Error(Sources[idx]->GetPos()) << "Source: " << Sources[idx]->GetLabel() << " was not used in join expressions";
  401. return false;
  402. }
  403. }
  404. return ISource::DoInit(ctx, src);
  405. }
  406. class TEquiJoin: public TJoinBase {
  407. public:
  408. TEquiJoin(TPosition pos, TVector<TSourcePtr>&& sources)
  409. : TJoinBase(pos, std::move(sources))
  410. {
  411. }
  412. TNodePtr Build(TContext& ctx) override {
  413. TMap<std::pair<TString, TString>, TNodePtr> extraColumns;
  414. TNodePtr joinTree;
  415. for (auto& descr: JoinDescrs) {
  416. auto leftBranch = joinTree;
  417. if (!leftBranch) {
  418. leftBranch = BuildQuotedAtom(Pos, Sources[descr.Keys[0].first.Source]->GetLabel());
  419. }
  420. auto leftKeys = GetColumnNames(ctx, extraColumns, descr.Keys, true);
  421. auto rightKeys = GetColumnNames(ctx, extraColumns, descr.Keys, false);
  422. if (!leftKeys || !rightKeys) {
  423. return nullptr;
  424. }
  425. joinTree = Q(Y(
  426. Q(descr.Op),
  427. leftBranch,
  428. BuildQuotedAtom(Pos, Sources[descr.Keys[0].second.Source]->GetLabel()),
  429. leftKeys,
  430. rightKeys,
  431. Q(Y())
  432. ));
  433. }
  434. TNodePtr equiJoin(Y("EquiJoin"));
  435. bool ordered = false;
  436. for (auto& source: Sources) {
  437. auto sourceNode = source->Build(ctx);
  438. if (!sourceNode) {
  439. return nullptr;
  440. }
  441. const bool useOrderedForSource = ctx.UseUnordered(*source);
  442. ordered = ordered || useOrderedForSource;
  443. if (source->IsFlattenByColumns() || source->IsFlattenColumns()) {
  444. auto flatten = source->IsFlattenByColumns() ?
  445. source->BuildFlattenByColumns("row") :
  446. source->BuildFlattenColumns("row");
  447. if (!flatten) {
  448. return nullptr;
  449. }
  450. auto block = Y(Y("let", "flatten", sourceNode));
  451. block = L(block, Y("let", "flatten", Y(useOrderedForSource ? "OrderedFlatMap" : "FlatMap", "flatten", BuildLambda(Pos, Y("row"), flatten, "res"))));
  452. sourceNode = Y("block", Q(L(block, Y("return", "flatten"))));
  453. }
  454. TNodePtr extraMembers;
  455. for (auto it = extraColumns.lower_bound({ source->GetLabel(), "" }); it != extraColumns.end(); ++it) {
  456. if (it->first.first != source->GetLabel()) {
  457. break;
  458. }
  459. if (!extraMembers) {
  460. extraMembers = KeysGround ? KeysGround : Y();
  461. }
  462. extraMembers = L(
  463. extraMembers,
  464. Y("let", "row", Y("AddMember", "row", BuildQuotedAtom(it->second->GetPos(), it->first.second), it->second))
  465. );
  466. }
  467. if (extraMembers) {
  468. sourceNode = Y(useOrderedForSource ? "OrderedMap" : "Map", sourceNode, BuildLambda(Pos, Y("row"), extraMembers, "row"));
  469. }
  470. if (ctx.EnableSystemColumns && source->IsTableSource()) {
  471. sourceNode = Y("RemoveSystemMembers", sourceNode);
  472. }
  473. equiJoin = L(equiJoin, Q(Y(sourceNode, BuildQuotedAtom(source->GetPos(), source->GetLabel()))));
  474. }
  475. TNodePtr removeMembers;
  476. for(auto it: extraColumns) {
  477. if (!removeMembers) {
  478. removeMembers = Y();
  479. }
  480. removeMembers = L(
  481. removeMembers,
  482. Y("let", "row", Y("ForceRemoveMember", "row", BuildQuotedAtom(Pos, DotJoin(it.first.first, it.first.second))))
  483. );
  484. }
  485. auto options = Y();
  486. equiJoin = L(equiJoin, joinTree, Q(options));
  487. if (removeMembers) {
  488. equiJoin = Y(ordered ? "OrderedMap" : "Map", equiJoin, BuildLambda(Pos, Y("row"), removeMembers, "row"));
  489. }
  490. return equiJoin;
  491. }
  492. const THashMap<TString, THashSet<TString>>& GetSameKeysMap() const override {
  493. return SameKeyMap;
  494. }
  495. const TSet<TString> GetJoinLabels() const override {
  496. TSet<TString> labels;
  497. for (auto& source: Sources) {
  498. const auto label = source->GetLabel();
  499. YQL_ENSURE(label);
  500. labels.emplace(label);
  501. }
  502. return labels;
  503. }
  504. TPtr DoClone() const final {
  505. TVector<TSourcePtr> clonedSources;
  506. for (auto& cur: Sources) {
  507. clonedSources.push_back(cur->CloneSource());
  508. }
  509. auto newSource = MakeIntrusive<TEquiJoin>(Pos, std::move(clonedSources));
  510. newSource->JoinOps = JoinOps;
  511. newSource->JoinExprs = CloneContainer(JoinExprs);
  512. return newSource;
  513. }
  514. private:
  515. TNodePtr GetColumnNames(
  516. TContext& ctx,
  517. TMap<std::pair<TString, TString>, TNodePtr>& extraColumns,
  518. const TVector<std::pair<TJoinDescr::TFullColumn, TJoinDescr::TFullColumn>>& keys,
  519. bool left
  520. ) {
  521. Y_UNUSED(ctx);
  522. auto res = Y();
  523. for (auto& it: keys) {
  524. auto tableName = Sources[left ? it.first.Source : it.second.Source]->GetLabel();
  525. TString columnName;
  526. auto column = left ? it.first.Column : it.second.Column;
  527. if (!column) {
  528. continue;
  529. }
  530. if (column->GetColumnName()) {
  531. columnName = *column->GetColumnName();
  532. } else {
  533. TStringStream str;
  534. str << "_equijoin_column_" << extraColumns.size();
  535. columnName = str.Str();
  536. extraColumns.insert({ std::make_pair(tableName, columnName), column });
  537. }
  538. res = L(res, BuildQuotedAtom(Pos, tableName));
  539. res = L(res, BuildQuotedAtom(Pos, columnName));
  540. }
  541. return Q(res);
  542. }
  543. };
  544. TSourcePtr BuildEquiJoin(TPosition pos, TVector<TSourcePtr>&& sources) {
  545. return new TEquiJoin(pos, std::move(sources));
  546. }
  547. } // namespace NSQLTranslationV0