optimizer.cpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719
  1. #include "utils.h"
  2. #include "optimizer.h"
  3. #include <iostream>
  4. #include <yql/essentials/parser/pg_wrapper/arena_ctx.h>
  5. #include <yql/essentials/utils/yql_panic.h>
  6. #include <yql/essentials/ast/yql_expr.h>
  7. #include <util/string/builder.h>
  8. #include <util/generic/scope.h>
  9. #ifdef _WIN32
  10. #define __restrict
  11. #endif
  12. #define TypeName PG_TypeName
  13. #define SortBy PG_SortBy
  14. #undef SIZEOF_SIZE_T
  15. extern "C" {
  16. Y_PRAGMA_DIAGNOSTIC_PUSH
  17. #ifdef _win_
  18. Y_PRAGMA("GCC diagnostic ignored \"-Wshift-count-overflow\"")
  19. #endif
  20. Y_PRAGMA("GCC diagnostic ignored \"-Wunused-parameter\"")
  21. #include "postgres.h"
  22. #include "miscadmin.h"
  23. #include "optimizer/paths.h"
  24. #include "nodes/print.h"
  25. #include "utils/selfuncs.h"
  26. #include "utils/palloc.h"
  27. Y_PRAGMA_DIAGNOSTIC_POP
  28. }
  29. #undef Min
  30. #undef Max
  31. #undef TypeName
  32. #undef SortBy
  33. namespace NYql {
  34. namespace {
  35. bool RelationStatsHook(
  36. PlannerInfo *root,
  37. RangeTblEntry *rte,
  38. AttrNumber attnum,
  39. VariableStatData *vardata)
  40. {
  41. Y_UNUSED(root);
  42. Y_UNUSED(rte);
  43. Y_UNUSED(attnum);
  44. vardata->statsTuple = nullptr;
  45. return true;
  46. }
  47. } // namespace
  48. Var* MakeVar(int relno, int varno) {
  49. Var* v = makeNode(Var);
  50. v->varno = relno; // table number
  51. v->varattno = varno; // column number in table
  52. // ?
  53. v->vartype = 25; // ?
  54. v->vartypmod = -1; // ?
  55. v->varcollid = 0;
  56. v->varnosyn = v->varno;
  57. v->varattnosyn = v->varattno;
  58. v->location = -1;
  59. return v;
  60. }
  61. RelOptInfo* MakeRelOptInfo(const IOptimizer::TRel& r, int relno) {
  62. RelOptInfo* rel = makeNode(RelOptInfo);
  63. rel->rows = r.Rows;
  64. rel->tuples = r.Rows;
  65. rel->pages = r.Rows;
  66. rel->allvisfrac = 1.0;
  67. rel->relid = relno;
  68. rel->amflags = 1.0;
  69. rel->rel_parallel_workers = -1;
  70. PathTarget* t = makeNode(PathTarget);
  71. int maxattno = 0;
  72. for (int i = 0; i < (int)r.TargetVars.size(); i++) {
  73. t->exprs = lappend(t->exprs, MakeVar(relno, i+1));
  74. maxattno = i+1;
  75. }
  76. t->width = 8;
  77. rel->reltarget = t;
  78. rel->max_attr = maxattno;
  79. Path* p = makeNode(Path);
  80. p->pathtype = T_SeqScan;
  81. p->rows = r.Rows;
  82. p->startup_cost = 0;
  83. p->total_cost = r.TotalCost;
  84. p->pathtarget = t;
  85. p->parent = rel;
  86. rel->pathlist = list_make1(p);
  87. rel->cheapest_total_path = p;
  88. rel->relids = bms_add_member(nullptr, rel->relid);
  89. rel->attr_needed = (Relids*)palloc0((1+maxattno)*sizeof(Relids));
  90. return rel;
  91. }
  92. List* MakeRelOptInfoList(const IOptimizer::TInput& input) {
  93. List* l = nullptr;
  94. int id = 1;
  95. for (auto& rel : input.Rels) {
  96. l = lappend(l, MakeRelOptInfo(rel, id++));
  97. }
  98. return l;
  99. }
  100. TPgOptimizer::TPgOptimizer(
  101. const TInput& input,
  102. const std::function<void(const TString&)>& log)
  103. : Input(input)
  104. , Log(log)
  105. {
  106. get_relation_stats_hook = RelationStatsHook;
  107. }
  108. TPgOptimizer::~TPgOptimizer()
  109. { }
  110. TPgOptimizer::TOutput TPgOptimizer::JoinSearch()
  111. {
  112. TArenaMemoryContext ctx;
  113. auto prev_work_mem = work_mem;
  114. work_mem = 4096;
  115. Y_DEFER {
  116. work_mem = prev_work_mem;
  117. };
  118. auto* rel = JoinSearchInternal();
  119. return MakeOutput(rel->cheapest_total_path);
  120. }
  121. Var* TPgOptimizer::MakeVar(TVarId varId) {
  122. auto*& var = Vars[varId];
  123. return var
  124. ? var
  125. : (var = ::NYql::MakeVar(std::get<0>(varId), std::get<1>(varId)));
  126. }
  127. EquivalenceClass* TPgOptimizer::MakeEqClass(int i) {
  128. EquivalenceClass* eq = makeNode(EquivalenceClass);
  129. for (auto [relno, varno] : Input.EqClasses[i].Vars) {
  130. EquivalenceMember* m = makeNode(EquivalenceMember);
  131. m->em_expr = (Expr*)MakeVar(TVarId{relno, varno});
  132. m->em_relids = bms_add_member(nullptr, relno);
  133. m->em_datatype = 20;
  134. eq->ec_opfamilies = list_make1_oid(1976);
  135. eq->ec_members = lappend(eq->ec_members, m);
  136. eq->ec_relids = bms_union(eq->ec_relids, m->em_relids);
  137. }
  138. return eq;
  139. }
  140. List* TPgOptimizer::MakeEqClasses() {
  141. List* l = nullptr;
  142. for (int i = 0; i < (int)Input.EqClasses.size(); i++) {
  143. l = lappend(l, MakeEqClass(i));
  144. }
  145. return l;
  146. }
  147. void TPgOptimizer::LogNode(const TString& prefix, void* node)
  148. {
  149. if (Log) {
  150. auto* str = nodeToString(node);
  151. auto* fmt = pretty_format_node_dump(str);
  152. pfree(str);
  153. Log(TStringBuilder() << prefix << ": " << fmt);
  154. pfree(fmt);
  155. }
  156. }
  157. IOptimizer::TOutput TPgOptimizer::MakeOutput(Path* path) {
  158. TOutput output = {{}, &Input};
  159. output.Rows = path->rows;
  160. output.TotalCost = path->total_cost;
  161. MakeOutputJoin(output, path);
  162. return output;
  163. }
  164. int TPgOptimizer::MakeOutputJoin(TOutput& output, Path* path) {
  165. if (path->type == T_MaterialPath) {
  166. return MakeOutputJoin(output, ((MaterialPath*)path)->subpath);
  167. }
  168. int id = output.Nodes.size();
  169. TJoinNode node = output.Nodes.emplace_back(TJoinNode{});
  170. int relid = -1;
  171. while ((relid = bms_next_member(path->parent->relids, relid)) >= 0)
  172. {
  173. node.Rels.emplace_back(relid);
  174. }
  175. if (path->type != T_Path) {
  176. node.Strategy = EJoinStrategy::Unknown;
  177. if (path->type == T_HashPath) {
  178. node.Strategy = EJoinStrategy::Hash;
  179. } else if (path->type == T_NestPath) {
  180. node.Strategy = EJoinStrategy::Loop;
  181. } else {
  182. YQL_ENSURE(false, "Uknown pathtype " << (int)path->type);
  183. }
  184. JoinPath* jpath = (JoinPath*)path;
  185. switch (jpath->jointype) {
  186. case JOIN_INNER:
  187. node.Mode = EJoinType::Inner;
  188. break;
  189. case JOIN_LEFT:
  190. node.Mode = EJoinType::Left;
  191. break;
  192. case JOIN_RIGHT:
  193. node.Mode = EJoinType::Right;
  194. break;
  195. default:
  196. YQL_ENSURE(false, "Unsupported join type");
  197. break;
  198. }
  199. YQL_ENSURE(list_length(jpath->joinrestrictinfo) >= 1, "Unsupported joinrestrictinfo len");
  200. for (int i = 0; i < list_length(jpath->joinrestrictinfo); i++) {
  201. RestrictInfo* rinfo = (RestrictInfo*)jpath->joinrestrictinfo->elements[i].ptr_value;
  202. Var* left = nullptr;
  203. Var* right = nullptr;
  204. if (jpath->jointype == JOIN_INNER) {
  205. YQL_ENSURE(rinfo->left_em->em_expr->type == T_Var, "Unsupported left em type");
  206. YQL_ENSURE(rinfo->right_em->em_expr->type == T_Var, "Unsupported right em type");
  207. left = (Var*)rinfo->left_em->em_expr;
  208. right = (Var*)rinfo->right_em->em_expr;
  209. } else if (jpath->jointype == JOIN_LEFT || jpath->jointype == JOIN_RIGHT) {
  210. YQL_ENSURE(rinfo->clause->type == T_OpExpr);
  211. OpExpr* expr = (OpExpr*)rinfo->clause;
  212. YQL_ENSURE(list_length(expr->args) == 2);
  213. Expr* a1 = (Expr*)list_nth(expr->args, 0);
  214. Expr* a2 = (Expr*)list_nth(expr->args, 1);
  215. YQL_ENSURE(a1->type == T_Var, "Unsupported left arg type");
  216. YQL_ENSURE(a2->type == T_Var, "Unsupported right arg type");
  217. left = (Var*)a1;
  218. right = (Var*)a2;
  219. }
  220. node.LeftVars.emplace_back(std::make_tuple(left->varno, left->varattno));
  221. node.RightVars.emplace_back(std::make_tuple(right->varno, right->varattno));
  222. if (!bms_is_member(left->varno, jpath->outerjoinpath->parent->relids)) {
  223. std::swap(node.LeftVars.back(), node.RightVars.back());
  224. }
  225. }
  226. node.Inner = MakeOutputJoin(output, jpath->innerjoinpath);
  227. node.Outer = MakeOutputJoin(output, jpath->outerjoinpath);
  228. }
  229. output.Nodes[id] = node;
  230. return id;
  231. }
  232. void TPgOptimizer::MakeLeftOrRightRestrictions(std::vector<RestrictInfo*>& dst, const std::vector<TEq>& src)
  233. {
  234. for (const auto& eq : src) {
  235. YQL_ENSURE(eq.Vars.size() == 2);
  236. RestrictInfo* ri = makeNode(RestrictInfo);
  237. ri->can_join = 1;
  238. ri->norm_selec = -1;
  239. ri->outer_selec = -1;
  240. OpExpr* oe = makeNode(OpExpr);
  241. oe->opno = 410;
  242. oe->opfuncid = 467;
  243. oe->opresulttype = 16;
  244. ri->clause = (Expr*)oe;
  245. bool left = true;
  246. for (const auto& [relId, varId] : eq.Vars) {
  247. ri->required_relids = bms_add_member(ri->required_relids, relId);
  248. ri->clause_relids = bms_add_member(ri->clause_relids, relId);
  249. if (left) {
  250. ri->outer_relids = bms_add_member(nullptr, relId);
  251. ri->left_relids = bms_add_member(nullptr, relId);
  252. left = false;
  253. } else {
  254. ri->right_relids = bms_add_member(nullptr, relId);
  255. }
  256. oe->args = lappend(oe->args, MakeVar(TVarId{relId, varId}));
  257. RestrictInfos[relId].emplace_back(ri);
  258. }
  259. dst.emplace_back(ri);
  260. }
  261. }
  262. RelOptInfo* TPgOptimizer::JoinSearchInternal() {
  263. RestrictInfos.clear();
  264. RestrictInfos.resize(Input.Rels.size()+1);
  265. LeftRestriction.clear();
  266. LeftRestriction.reserve(Input.Left.size());
  267. MakeLeftOrRightRestrictions(LeftRestriction, Input.Left);
  268. MakeLeftOrRightRestrictions(RightRestriction, Input.Right);
  269. List* rels = MakeRelOptInfoList(Input);
  270. ListCell* l;
  271. int relId = 1;
  272. foreach (l, rels) {
  273. RelOptInfo* rel = (RelOptInfo*)lfirst(l);
  274. for (auto* ri : RestrictInfos[relId++]) {
  275. rel->joininfo = lappend(rel->joininfo, ri);
  276. }
  277. }
  278. if (Log) {
  279. int i = 1;
  280. foreach (l, rels) {
  281. LogNode(TStringBuilder() << "Input: " << i++, lfirst(l));
  282. }
  283. }
  284. PlannerInfo root;
  285. memset(&root, 0, sizeof(root));
  286. root.type = T_PlannerInfo;
  287. root.query_level = 1;
  288. root.simple_rel_array_size = rels->length+1;
  289. root.simple_rel_array = (RelOptInfo**)palloc0(
  290. root.simple_rel_array_size
  291. * sizeof(RelOptInfo*));
  292. root.simple_rte_array = (RangeTblEntry**)palloc0(
  293. root.simple_rel_array_size * sizeof(RangeTblEntry*)
  294. );
  295. for (int i = 0; i <= rels->length; i++) {
  296. root.simple_rte_array[i] = makeNode(RangeTblEntry);
  297. root.simple_rte_array[i]->rtekind = RTE_RELATION;
  298. }
  299. root.all_baserels = bms_add_range(nullptr, 1, rels->length);
  300. root.eq_classes = MakeEqClasses();
  301. for (auto* ri : LeftRestriction) {
  302. root.left_join_clauses = lappend(root.left_join_clauses, ri);
  303. root.hasJoinRTEs = 1;
  304. root.outer_join_rels = bms_add_members(root.outer_join_rels, ri->right_relids);
  305. SpecialJoinInfo* ji = makeNode(SpecialJoinInfo);
  306. ji->min_lefthand = bms_add_member(ji->min_lefthand, bms_next_member(ri->left_relids, -1));
  307. ji->min_righthand = bms_add_member(ji->min_righthand, bms_next_member(ri->right_relids, -1));
  308. ji->syn_lefthand = bms_add_members(ji->min_lefthand, ri->left_relids);
  309. ji->syn_righthand = bms_add_members(ji->min_righthand, ri->right_relids);
  310. ji->jointype = JOIN_LEFT;
  311. ji->lhs_strict = 1;
  312. root.join_info_list = lappend(root.join_info_list, ji);
  313. }
  314. for (auto* ri : RightRestriction) {
  315. root.right_join_clauses = lappend(root.right_join_clauses, ri);
  316. root.hasJoinRTEs = 1;
  317. root.outer_join_rels = bms_add_members(root.outer_join_rels, ri->left_relids);
  318. SpecialJoinInfo* ji = makeNode(SpecialJoinInfo);
  319. ji->min_lefthand = bms_add_member(ji->min_lefthand, bms_next_member(ri->right_relids, -1));
  320. ji->min_righthand = bms_add_member(ji->min_righthand, bms_next_member(ri->left_relids, -1));
  321. ji->syn_lefthand = bms_add_members(ji->min_lefthand, ri->right_relids);
  322. ji->syn_righthand = bms_add_members(ji->min_righthand, ri->left_relids);
  323. ji->jointype = JOIN_LEFT;
  324. ji->lhs_strict = 1;
  325. root.join_info_list = lappend(root.join_info_list, ji);
  326. }
  327. root.planner_cxt = CurrentMemoryContext;
  328. for (int i = 0; i < rels->length; i++) {
  329. auto* r = (RelOptInfo*)rels->elements[i].ptr_value;
  330. root.simple_rel_array[i+1] = r;
  331. }
  332. for (int eqId = 0; eqId < (int)Input.EqClasses.size(); eqId++) {
  333. for (auto& [relno, _] : Input.EqClasses[eqId].Vars) {
  334. root.simple_rel_array[relno]->eclass_indexes = bms_add_member(
  335. root.simple_rel_array[relno]->eclass_indexes,
  336. eqId);
  337. }
  338. }
  339. for (int i = 0; i < rels->length; i++) {
  340. root.simple_rel_array[i+1]->has_eclass_joins = bms_num_members(root.simple_rel_array[i+1]->eclass_indexes) > 1;
  341. }
  342. root.ec_merging_done = 1;
  343. LogNode("Context: ", &root);
  344. auto* result = standard_join_search(&root, rels->length, rels);
  345. LogNode("Result: ", result);
  346. return result;
  347. }
  348. struct TPgOptimizerImpl
  349. {
  350. TPgOptimizerImpl(
  351. const std::shared_ptr<TJoinOptimizerNode>& root,
  352. TExprContext& ctx,
  353. const std::function<void(const TString&)>& log)
  354. : Root(root)
  355. , Ctx(ctx)
  356. , Log(log)
  357. { }
  358. std::shared_ptr<TJoinOptimizerNode> Do() {
  359. CollectRels(Root);
  360. if (!CollectOps(Root)) {
  361. return Root;
  362. }
  363. IOptimizer::TInput input;
  364. input.EqClasses = std::move(EqClasses);
  365. input.Left = std::move(Left);
  366. input.Right = std::move(Right);
  367. input.Rels = std::move(Rels);
  368. input.Normalize();
  369. Log("Input: " + input.ToString());
  370. std::unique_ptr<IOptimizer> opt = std::unique_ptr<IOptimizer>(MakePgOptimizerInternal(input, Log));
  371. Result = opt->JoinSearch();
  372. Log("Result: " + Result.ToString());
  373. std::shared_ptr<IBaseOptimizerNode> res = Convert(0);
  374. YQL_ENSURE(res);
  375. return std::static_pointer_cast<TJoinOptimizerNode>(res);
  376. }
  377. void OnLeaf(const std::shared_ptr<TRelOptimizerNode>& leaf) {
  378. int relId = Rels.size() + 1;
  379. Rels.emplace_back(IOptimizer::TRel{});
  380. Var2TableCol.emplace_back();
  381. // rel -> varIds
  382. VarIds.emplace_back(THashMap<TStringBuf, int>{});
  383. // rel -> tables
  384. RelTables.emplace_back(std::vector<TStringBuf>{});
  385. for (const auto& table : leaf->Labels()) {
  386. RelTables.back().emplace_back(table);
  387. Table2RelIds[table].emplace_back(relId);
  388. }
  389. auto& rel = Rels[relId - 1];
  390. rel.Rows = leaf->Stats.Nrows;
  391. rel.TotalCost = leaf->Stats.Cost;
  392. int leafIndex = relId - 1;
  393. if (leafIndex >= static_cast<int>(Leafs.size())) {
  394. Leafs.resize(leafIndex + 1);
  395. }
  396. Leafs[leafIndex] = leaf;
  397. }
  398. int GetVarId(int relId, TStringBuf column) {
  399. int varId = 0;
  400. auto maybeVarId = VarIds[relId-1].find(column);
  401. if (maybeVarId != VarIds[relId-1].end()) {
  402. varId = maybeVarId->second;
  403. } else {
  404. varId = Rels[relId - 1].TargetVars.size() + 1;
  405. VarIds[relId - 1][column] = varId;
  406. Rels[relId - 1].TargetVars.emplace_back();
  407. Var2TableCol[relId - 1].emplace_back();
  408. }
  409. return varId;
  410. }
  411. void ExtractVars(
  412. std::vector<std::tuple<int,int,TStringBuf,TStringBuf>>& leftVars,
  413. std::vector<std::tuple<int,int,TStringBuf,TStringBuf>>& rightVars,
  414. const std::shared_ptr<TJoinOptimizerNode>& op)
  415. {
  416. for (size_t i=0; i<op->LeftJoinKeys.size(); i++ ) {
  417. auto& ltable = op->LeftJoinKeys[i].RelName;
  418. auto& lcol = op->LeftJoinKeys[i].AttributeName;
  419. auto& rtable = op->RightJoinKeys[i].RelName;
  420. auto& rcol = op->RightJoinKeys[i].AttributeName;
  421. const auto& lrelIds = Table2RelIds[ltable];
  422. YQL_ENSURE(!lrelIds.empty());
  423. const auto& rrelIds = Table2RelIds[rtable];
  424. YQL_ENSURE(!rrelIds.empty());
  425. for (int relId : lrelIds) {
  426. int varId = GetVarId(relId, lcol);
  427. leftVars.emplace_back(std::make_tuple(relId, varId, ltable, lcol));
  428. }
  429. for (int relId : rrelIds) {
  430. int varId = GetVarId(relId, rcol);
  431. rightVars.emplace_back(std::make_tuple(relId, varId, rtable, rcol));
  432. }
  433. }
  434. }
  435. IOptimizer::TEq MakeEqClass(const auto& vars) {
  436. IOptimizer::TEq eqClass;
  437. for (auto& [relId, varId, table, column] : vars) {
  438. eqClass.Vars.emplace_back(std::make_tuple(relId, varId));
  439. Var2TableCol[relId - 1][varId - 1] = std::make_tuple(table, column);
  440. }
  441. return eqClass;
  442. }
  443. void MakeEqClasses(std::vector<IOptimizer::TEq>& res, const auto& leftVars, const auto& rightVars) {
  444. for (int i = 0; i < (int)leftVars.size(); i++) {
  445. auto& [lrelId, lvarId, ltable, lcolumn] = leftVars[i];
  446. auto& [rrelId, rvarId, rtable, rcolumn] = rightVars[i];
  447. IOptimizer::TEq eqClass; eqClass.Vars.reserve(2);
  448. eqClass.Vars.emplace_back(std::make_tuple(lrelId, lvarId));
  449. eqClass.Vars.emplace_back(std::make_tuple(rrelId, rvarId));
  450. Var2TableCol[lrelId - 1][lvarId - 1] = std::make_tuple(ltable, lcolumn);
  451. Var2TableCol[rrelId - 1][rvarId - 1] = std::make_tuple(rtable, rcolumn);
  452. res.emplace_back(std::move(eqClass));
  453. }
  454. }
  455. bool OnOp(const std::shared_ptr<TJoinOptimizerNode>& op) {
  456. #define CHECK(A, B) \
  457. if (Y_UNLIKELY(!(A))) { \
  458. TIssues issues; \
  459. issues.AddIssue(TIssue(B).SetCode(0, NYql::TSeverityIds::S_INFO)); \
  460. Ctx.IssueManager.AddIssues(issues); \
  461. return false; \
  462. }
  463. if (op->JoinType == InnerJoin) {
  464. // relId, varId, table, column
  465. std::vector<std::tuple<int,int,TStringBuf,TStringBuf>> leftVars;
  466. std::vector<std::tuple<int,int,TStringBuf,TStringBuf>> rightVars;
  467. ExtractVars(leftVars, rightVars, op);
  468. CHECK(leftVars.size() == rightVars.size(), "Left and right labels must have the same size");
  469. MakeEqClasses(EqClasses, leftVars, rightVars);
  470. } else if (op->JoinType == LeftJoin || op->JoinType == RightJoin) {
  471. CHECK(op->LeftJoinKeys.size() == 1 && op->RightJoinKeys.size() == 1, "Only 1 var per join supported");
  472. std::vector<std::tuple<int,int,TStringBuf,TStringBuf>> leftVars, rightVars;
  473. ExtractVars(leftVars, rightVars, op);
  474. IOptimizer::TEq leftEqClass = MakeEqClass(leftVars);
  475. IOptimizer::TEq rightEqClass = MakeEqClass(rightVars);
  476. IOptimizer::TEq eqClass = leftEqClass;
  477. eqClass.Vars.insert(eqClass.Vars.end(), rightEqClass.Vars.begin(), rightEqClass.Vars.end());
  478. CHECK(eqClass.Vars.size() == 2, "Only a=b left|right join supported yet");
  479. EqClasses.emplace_back(std::move(leftEqClass));
  480. EqClasses.emplace_back(std::move(rightEqClass));
  481. if (op->JoinType == LeftJoin) {
  482. Left.emplace_back(eqClass);
  483. } else {
  484. Right.emplace_back(eqClass);
  485. }
  486. } else {
  487. CHECK(false, "Unsupported join type");
  488. }
  489. #undef CHECK
  490. return true;
  491. }
  492. bool CollectOps(const std::shared_ptr<IBaseOptimizerNode>& node)
  493. {
  494. if (node->Kind == JoinNodeType) {
  495. auto op = std::static_pointer_cast<TJoinOptimizerNode>(node);
  496. return OnOp(op)
  497. && CollectOps(op->LeftArg)
  498. && CollectOps(op->RightArg);
  499. }
  500. return true;
  501. }
  502. void CollectRels(const std::shared_ptr<IBaseOptimizerNode>& node) {
  503. if (node->Kind == JoinNodeType) {
  504. auto op = std::static_pointer_cast<TJoinOptimizerNode>(node);
  505. CollectRels(op->LeftArg);
  506. CollectRels(op->RightArg);
  507. } else if (node->Kind == RelNodeType) {
  508. OnLeaf(std::static_pointer_cast<TRelOptimizerNode>(node));
  509. } else {
  510. YQL_ENSURE(false, "Unknown node kind");
  511. }
  512. }
  513. std::shared_ptr<IBaseOptimizerNode> Convert(int nodeId) const {
  514. const auto* node = &Result.Nodes[nodeId];
  515. if (node->Outer == -1 && node->Inner == -1) {
  516. YQL_ENSURE(node->Rels.size() == 1);
  517. auto leaf = Leafs[node->Rels[0]-1];
  518. return leaf;
  519. } else if (node->Outer != -1 && node->Inner != -1) {
  520. EJoinKind joinKind;
  521. switch (node->Mode) {
  522. case IOptimizer::EJoinType::Inner:
  523. joinKind = InnerJoin; break;
  524. case IOptimizer::EJoinType::Left:
  525. joinKind = LeftJoin; break;
  526. case IOptimizer::EJoinType::Right:
  527. joinKind = RightJoin; break;
  528. default:
  529. YQL_ENSURE(false, "Unsupported join type");
  530. break;
  531. };
  532. auto left = Convert(node->Outer);
  533. auto right = Convert(node->Inner);
  534. YQL_ENSURE(node->LeftVars.size() == node->RightVars.size());
  535. TVector<NDq::TJoinColumn> leftJoinKeys;
  536. TVector<NDq::TJoinColumn> rightJoinKeys;
  537. for (size_t i = 0; i < node->LeftVars.size(); i++) {
  538. auto [lrelId, lvarId] = node->LeftVars[i];
  539. auto [rrelId, rvarId] = node->RightVars[i];
  540. auto [ltable, lcolumn] = Var2TableCol[lrelId - 1][lvarId - 1];
  541. auto [rtable, rcolumn] = Var2TableCol[rrelId - 1][rvarId - 1];
  542. leftJoinKeys.push_back(NDq::TJoinColumn(TString(ltable), TString(lcolumn)));
  543. rightJoinKeys.push_back(NDq::TJoinColumn(TString(rtable), TString(rcolumn)));
  544. }
  545. return std::make_shared<TJoinOptimizerNode>(
  546. left, right,
  547. leftJoinKeys,
  548. rightJoinKeys,
  549. joinKind,
  550. EJoinAlgoType::MapJoin,
  551. false,
  552. false
  553. );
  554. } else {
  555. YQL_ENSURE(false, "Wrong CBO node");
  556. }
  557. return nullptr;
  558. }
  559. std::shared_ptr<TJoinOptimizerNode> Root;
  560. TExprContext& Ctx;
  561. std::function<void(const TString&)> Log;
  562. THashMap<TStringBuf, std::vector<int>> Table2RelIds;
  563. std::vector<IOptimizer::TRel> Rels;
  564. std::vector<std::vector<TStringBuf>> RelTables;
  565. std::vector<std::shared_ptr<TRelOptimizerNode>> Leafs;
  566. std::vector<std::vector<std::tuple<TStringBuf, TStringBuf>>> Var2TableCol;
  567. std::vector<THashMap<TStringBuf, int>> VarIds;
  568. std::vector<IOptimizer::TEq> EqClasses;
  569. std::vector<IOptimizer::TEq> Left;
  570. std::vector<IOptimizer::TEq> Right;
  571. IOptimizer::TOutput Result;
  572. };
  573. class TPgOptimizerNew: public IOptimizerNew
  574. {
  575. public:
  576. TPgOptimizerNew(IProviderContext& pctx, TExprContext& ctx, const std::function<void(const TString&)>& log)
  577. : IOptimizerNew(pctx)
  578. , Ctx(ctx)
  579. , Log(log)
  580. { }
  581. std::shared_ptr<TJoinOptimizerNode> JoinSearch(
  582. const std::shared_ptr<TJoinOptimizerNode>& joinTree,
  583. const TOptimizerHints& hints = {}) override
  584. {
  585. Y_UNUSED(hints);
  586. return TPgOptimizerImpl(joinTree, Ctx, Log).Do();
  587. }
  588. private:
  589. TExprContext& Ctx;
  590. std::function<void(const TString&)> Log;
  591. };
  592. IOptimizer* MakePgOptimizerInternal(const IOptimizer::TInput& input, const std::function<void(const TString&)>& log)
  593. {
  594. return new TPgOptimizer(input, log);
  595. }
  596. IOptimizerNew* MakePgOptimizerNew(IProviderContext& pctx, TExprContext& ctx, const std::function<void(const TString&)>& log)
  597. {
  598. return new TPgOptimizerNew(pctx, ctx, log);
  599. }
  600. } // namespace NYql {