Browse Source

Preserve original join node (#2372)

Alexey Ozeritskiy 1 year ago
parent
commit
9588de16ce
1 changed files with 11 additions and 10 deletions
  1. 11 10
      ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp

+ 11 - 10
ydb/library/yql/dq/opt/dq_opt_join_cost_based.cpp

@@ -315,20 +315,20 @@ std::shared_ptr<TJoinOptimizerNodeInternal> PickBestJoin(std::shared_ptr<IBaseOp
 /**
 /**
  * Iterate over all join algorithms and pick the best join that is applicable
  * Iterate over all join algorithms and pick the best join that is applicable
 */
 */
-std::shared_ptr<TJoinOptimizerNode> PickBestNonReorderabeJoin(std::shared_ptr<IBaseOptimizerNode> left,
-    std::shared_ptr<IBaseOptimizerNode> right,
-    const std::set<std::pair<TJoinColumn, TJoinColumn>>& leftJoinConditions,
-    const TVector<TString>& leftJoinKeys,
-    const TVector<TString>& rightJoinKeys,
-    EJoinKind joinKind,
+std::shared_ptr<TJoinOptimizerNode> PickBestNonReorderabeJoin(const std::shared_ptr<TJoinOptimizerNode>& node,
     IProviderContext& ctx) {
     IProviderContext& ctx) {
 
 
     EJoinAlgoType bestJoinAlgo;
     EJoinAlgoType bestJoinAlgo;
     bool bestJoinValid = false;
     bool bestJoinValid = false;
     double bestJoinCost;
     double bestJoinCost;
+    const auto& left = node->LeftArg;
+    const auto& right = node->RightArg;
+    const auto& joinConditions = node->JoinConditions;
+    const auto& leftJoinKeys = node->LeftJoinKeys;
+    const auto& rightJoinKeys = node->RightJoinKeys;
 
 
     for ( auto joinAlgo : AllJoinAlgos ) {
     for ( auto joinAlgo : AllJoinAlgos ) {
-        if (ctx.IsJoinApplicable(left, right, leftJoinConditions, leftJoinKeys, rightJoinKeys, joinAlgo)){
+        if (ctx.IsJoinApplicable(left, right, joinConditions, leftJoinKeys, rightJoinKeys, joinAlgo)){
             auto cost = ComputeJoinStats(*right->Stats, *left->Stats,  rightJoinKeys, leftJoinKeys, joinAlgo, ctx).Cost;
             auto cost = ComputeJoinStats(*right->Stats, *left->Stats,  rightJoinKeys, leftJoinKeys, joinAlgo, ctx).Cost;
             if (bestJoinValid) {
             if (bestJoinValid) {
                 if (cost < bestJoinCost) {
                 if (cost < bestJoinCost) {
@@ -344,7 +344,9 @@ std::shared_ptr<TJoinOptimizerNode> PickBestNonReorderabeJoin(std::shared_ptr<IB
     }
     }
 
 
     Y_ENSURE(bestJoinValid,"No join was chosen!");
     Y_ENSURE(bestJoinValid,"No join was chosen!");
-    return MakeJoin(left, right, leftJoinConditions, leftJoinKeys, rightJoinKeys, joinKind, bestJoinAlgo, true, ctx);
+    node->Stats = std::make_shared<TOptimizerStatistics>(ComputeJoinStats(*left->Stats, *right->Stats, leftJoinKeys, rightJoinKeys, bestJoinAlgo, ctx));
+    node->JoinAlgo = bestJoinAlgo;
+    return node;
 }
 }
 
 
 struct pair_hash {
 struct pair_hash {
@@ -1185,8 +1187,7 @@ void ComputeStatistics(const std::shared_ptr<TJoinOptimizerNode>& join, IProvide
 */
 */
 std::shared_ptr<TJoinOptimizerNode> OptimizeSubtree(const std::shared_ptr<TJoinOptimizerNode>& joinTree, ui32 maxDPccpDPTableSize, IProviderContext& ctx) {
 std::shared_ptr<TJoinOptimizerNode> OptimizeSubtree(const std::shared_ptr<TJoinOptimizerNode>& joinTree, ui32 maxDPccpDPTableSize, IProviderContext& ctx) {
     if (!joinTree->IsReorderable) {
     if (!joinTree->IsReorderable) {
-        return PickBestNonReorderabeJoin(joinTree->LeftArg, joinTree->RightArg, joinTree->JoinConditions,
-            joinTree->LeftJoinKeys, joinTree->RightJoinKeys, joinTree->JoinType, ctx);
+        return PickBestNonReorderabeJoin(joinTree, ctx);
     }
     }
 
 
     TVector<std::shared_ptr<IBaseOptimizerNode>> rels;
     TVector<std::shared_ptr<IBaseOptimizerNode>> rels;