Browse Source

Use left/right join reordering in yt provider

aozeritsky 1 year ago
parent
commit
01ad3c1241

+ 16 - 1
ydb/library/yql/providers/yt/provider/ut/yql_yt_cbo_ut.cpp

@@ -118,7 +118,7 @@ Y_UNIT_TEST(OrderJoins2TablesTableIn2Rels)
     UNIT_ASSERT(optimizedTree != tree);
 }
 
-Y_UNIT_TEST(UnsupportedJoin)
+Y_UNIT_TEST(OrderLeftJoin)
 {
     TExprContext exprCtx;
     auto tree = MakeOp({"c", "c_nationkey"}, {"n", "n_nationkey"}, {"c", "n"}, exprCtx);
@@ -126,6 +126,21 @@ Y_UNIT_TEST(UnsupportedJoin)
     tree->Right = MakeLeaf({"n"}, {"n"}, 10000, 12333, exprCtx);
     tree->JoinKind = exprCtx.NewAtom(exprCtx.AppendPosition({}), "Left");
 
+    TYtState::TPtr state = MakeIntrusive<TYtState>();
+    state->Configuration->CostBasedOptimizer = ECostBasedOptimizer::PG;
+    auto optimizedTree = OrderJoins(tree, state, exprCtx, true);
+    UNIT_ASSERT(optimizedTree != tree);
+    UNIT_ASSERT_STRINGS_EQUAL("Left", optimizedTree->JoinKind->Content());
+}
+
+Y_UNIT_TEST(UnsupportedJoin)
+{
+    TExprContext exprCtx;
+    auto tree = MakeOp({"c", "c_nationkey"}, {"n", "n_nationkey"}, {"c", "n"}, exprCtx);
+    tree->Left = MakeLeaf({"c"}, {"c"}, 1000000, 1233333, exprCtx);
+    tree->Right = MakeLeaf({"n"}, {"n"}, 10000, 12333, exprCtx);
+    tree->JoinKind = exprCtx.NewAtom(exprCtx.AppendPosition({}), "Full");
+
     TYtState::TPtr state = MakeIntrusive<TYtState>();
     state->Configuration->CostBasedOptimizer = ECostBasedOptimizer::PG;
     auto optimizedTree = OrderJoins(tree, state, exprCtx, true);

+ 61 - 13
ydb/library/yql/providers/yt/provider/yql_yt_join_reorder.cpp

@@ -94,6 +94,8 @@ public:
 
         IOptimizer::TInput input;
         input.EqClasses = std::move(EqClasses);
+        input.Left = std::move(Left);
+        input.Right = std::move(Right);
         input.Rels = std::move(Rels);
         input.Normalize();
         if (Debug) {
@@ -215,6 +217,17 @@ private:
         Leafs[leafIndex] = leaf;
     };
 
+    IOptimizer::TEq MakeEqClass(const auto& vars) {
+        IOptimizer::TEq eqClass;
+
+        for (auto& [relId, varId, table, column] : vars) {
+            eqClass.Vars.emplace_back(std::make_tuple(relId, varId));
+            Var2TableCol[relId - 1][varId - 1] = std::make_tuple(table, column);
+        }
+
+        return eqClass;
+    }
+
     bool OnOp(TYtJoinNodeOp* op) {
 #define CHECK(A, B) \
         if (Y_UNLIKELY(!(A))) { \
@@ -224,26 +237,45 @@ private:
             return false; \
         }
 
-        CHECK(op->JoinKind->Content() == "Inner", "Unsupported join type");
         CHECK(!op->Output, "Non empty output");
         CHECK(op->StarOptions.empty(), "Non empty StarOptions");
 
         CHECK(op->LeftLabel->ChildrenSize() == 2, "Only 1 var per join supported");
         CHECK(op->RightLabel->ChildrenSize() == 2, "Only 1 var per join supported");
 
-        // relId, varId, table, column
-        std::vector<std::tuple<int,int,TStringBuf,TStringBuf>> vars;
-        ExtractVars(vars, op->LeftLabel);
-        ExtractVars(vars, op->RightLabel);
+        const auto& joinKind = op->JoinKind->Content();
 
-        IOptimizer::TEq eqClass;
+        if (joinKind == "Inner") {
+            // relId, varId, table, column
+            std::vector<std::tuple<int,int,TStringBuf,TStringBuf>> vars;
+            ExtractVars(vars, op->LeftLabel);
+            ExtractVars(vars, op->RightLabel);
 
-        for (auto& [relId, varId, table, column] : vars) {
-            eqClass.Vars.emplace_back(std::make_tuple(relId, varId));
-            Var2TableCol[relId - 1][varId - 1] = std::make_tuple(table, column);
-        }
+            IOptimizer::TEq eqClass = MakeEqClass(vars);
+
+            EqClasses.emplace_back(std::move(eqClass));
+        } else if (joinKind == "Left" || joinKind == "Right") {
+            std::vector<std::tuple<int,int,TStringBuf,TStringBuf>> leftVars, rightVars;
+            ExtractVars(leftVars, op->LeftLabel);
+            ExtractVars(rightVars, op->RightLabel);
+
+            IOptimizer::TEq leftEqClass = MakeEqClass(leftVars);
+            IOptimizer::TEq rightEqClass = MakeEqClass(rightVars);
+            IOptimizer::TEq eqClass = leftEqClass;
+            eqClass.Vars.insert(eqClass.Vars.end(), rightEqClass.Vars.begin(), rightEqClass.Vars.end());
 
-        EqClasses.emplace_back(std::move(eqClass));
+            CHECK(eqClass.Vars.size() == 2, "Only a=b left|right join supported yet");
+
+            EqClasses.emplace_back(std::move(leftEqClass));
+            EqClasses.emplace_back(std::move(rightEqClass));
+            if (joinKind == "Left") {
+                Left.emplace_back(eqClass);
+            } else {
+                Right.emplace_back(eqClass);
+            }
+        } else {
+            CHECK(false, "Unsupported join type");
+        }
 
 #undef CHECK
         return true;
@@ -296,8 +328,22 @@ private:
             return leaf;
         } else if (node->Outer != -1 && node->Inner != -1) {
             auto ret = MakeIntrusive<TYtJoinNodeOp>();
-            YQL_ENSURE(node->Mode == IOptimizer::EJoinType::Inner, "Unsupported join type");
-            ret->JoinKind = Ctx.NewAtom(Root->JoinKind->Pos(), "Inner");
+            TString joinKind;
+            switch (node->Mode) {
+            case IOptimizer::EJoinType::Inner:
+                joinKind = "Inner";
+                break;
+            case IOptimizer::EJoinType::Left:
+                joinKind = "Left";
+                break;
+            case IOptimizer::EJoinType::Right:
+                joinKind = "Right";
+                break;
+            default:
+                YQL_ENSURE(false, "Unsupported join type");
+                break;
+            }
+            ret->JoinKind = Ctx.NewAtom(Root->JoinKind->Pos(), joinKind);
             ret->LeftLabel = MakeLabel(node->LeftVar);
             ret->RightLabel = MakeLabel(node->RightVar);
             int index = scope.size();
@@ -324,6 +370,8 @@ private:
     std::vector<THashMap<TStringBuf, int>> VarIds;
 
     std::vector<IOptimizer::TEq> EqClasses;
+    std::vector<IOptimizer::TEq> Left;
+    std::vector<IOptimizer::TEq> Right;
 
     IOptimizer::TOutput Result;
 };