Browse Source

YQL-14728 fixed any/all correlated sublinks

ref:0b3f2c0347d117b1632df6ce8a86ff724460341f
Vitaly Stoyan 2 years ago
parent
commit
a2a90a1a4b
1 changed files with 74 additions and 35 deletions
  1. 74 35
      ydb/library/yql/core/common_opt/yql_co_pgselect.cpp

+ 74 - 35
ydb/library/yql/core/common_opt/yql_co_pgselect.cpp

@@ -341,16 +341,13 @@ std::pair<TExprNode::TPtr, TExprNode::TPtr> RewriteSubLinks(TPositionHandle pos,
 
                 TExprNode::TPtr countAllTraits;
                 TExprNode::TPtr someTraits;
-                TExprNode::TPtr countIfTraits;
-                for (ui32 factoryIndex = 0; factoryIndex < 3; ++factoryIndex)
+                TExprNode::TPtr orTraits;
+                TExprNode::TPtr andTraits;
+                for (ui32 factoryIndex = 0; factoryIndex < 4; ++factoryIndex)
                 {
                     TStringBuf name;
                     switch (factoryIndex) {
                     case 0:
-                        if (linkType != "exists" && linkType != "expr") {
-                            continue;
-                        }
-
                         name = "count_all_traits_factory";
                         break;
                     case 1:
@@ -361,11 +358,18 @@ std::pair<TExprNode::TPtr, TExprNode::TPtr> RewriteSubLinks(TPositionHandle pos,
                         name = "some_traits_factory";
                         break;
                     case 2:
-                        if (linkType != "any" && linkType != "all") {
+                        if (linkType != "any") {
+                            continue;
+                        }
+
+                        name = "or_traits_factory";
+                        break;
+                    case 3:
+                        if (linkType != "all") {
                             continue;
                         }
 
-                        name = "count_if_traits_factory";
+                        name = "and_traits_factory";
                         break;
                     }
 
@@ -383,7 +387,9 @@ std::pair<TExprNode::TPtr, TExprNode::TPtr> RewriteSubLinks(TPositionHandle pos,
                         root = ctx.NewCallable(node->Pos(), "SingleMember", {
                             ctx.NewCallable(node->Pos(), "RemoveSystemMembers", { arg }) });
                         break;
-                    case 2: {
+                    case 2:
+                    case 3:
+                    {
                         auto value = ctx.NewCallable(node->Pos(), "SingleMember", {
                             ctx.NewCallable(node->Pos(), "RemoveSystemMembers", { arg }) });
 
@@ -392,14 +398,6 @@ std::pair<TExprNode::TPtr, TExprNode::TPtr> RewriteSubLinks(TPositionHandle pos,
                             {testLambda->Head().Child(1), value},
                             });
 
-                        if (linkType == "all") {
-                            filterExpr = ctx.Builder(node->Pos())
-                                .Callable("PgNot")
-                                    .Add(0, filterExpr)
-                                .Seal()
-                                .Build();
-                        }
-
                         root = ctx.NewCallable(node->Pos(), "FromPg", { filterExpr });
                         break;
                     }
@@ -423,7 +421,10 @@ std::pair<TExprNode::TPtr, TExprNode::TPtr> RewriteSubLinks(TPositionHandle pos,
                         someTraits = traits;
                         break;
                     case 2:
-                        countIfTraits = traits;
+                        orTraits = traits;
+                        break;
+                    case 3:
+                        andTraits = traits;
                         break;
                     }
                 }
@@ -450,11 +451,31 @@ std::pair<TExprNode::TPtr, TExprNode::TPtr> RewriteSubLinks(TPositionHandle pos,
                             .Add(1, someTraits)
                         .Seal()
                         .Build());
-                } else if (linkType == "any" || linkType == "all") {
+                } else if (linkType == "any") {
                     aggregateItems.push_back(ctx.Builder(node->Pos())
                         .List()
-                            .Atom(0, columnName)
-                            .Add(1, countIfTraits)
+                            .Atom(0, columnName + "_count")
+                            .Add(1, countAllTraits)
+                        .Seal()
+                        .Build());
+                    aggregateItems.push_back(ctx.Builder(node->Pos())
+                        .List()
+                            .Atom(0, columnName + "_value")
+                            .Add(1, orTraits)
+                        .Seal()
+                        .Build());
+                } else {
+                    YQL_ENSURE(linkType == "all");
+                    aggregateItems.push_back(ctx.Builder(node->Pos())
+                        .List()
+                            .Atom(0, columnName + "_count")
+                            .Add(1, countAllTraits)
+                        .Seal()
+                        .Build());
+                    aggregateItems.push_back(ctx.Builder(node->Pos())
+                        .List()
+                            .Atom(0, columnName + "_value")
+                            .Add(1, andTraits)
                         .Seal()
                         .Build());
                 }
@@ -527,32 +548,50 @@ std::pair<TExprNode::TPtr, TExprNode::TPtr> RewriteSubLinks(TPositionHandle pos,
                 } else if (linkType == "any") {
                     return ctx.Builder(node->Pos())
                         .Callable("ToPg")
-                            .Callable(0, ">")
-                                .Callable(0, "Member")
-                                    .Add(0, originalRow)
-                                    .Atom(1, columnName)
+                            .Callable(0, "And")
+                                .Callable(0, "!=")
+                                    .Callable(0, "Coalesce")
+                                        .Callable(0, "Member")
+                                            .Add(0, originalRow)
+                                            .Atom(1, columnName + "_count")
+                                        .Seal()
+                                        .Callable(1, "Uint64")
+                                            .Atom(0, "0")
+                                        .Seal()
+                                    .Seal()
+                                    .Callable(1, "Uint64")
+                                        .Atom(0, "0")
+                                    .Seal()
                                 .Seal()
-                                .Callable(1, "Uint64")
-                                    .Atom(0, "0")
+                                .Callable(1, "Member")
+                                    .Add(0, originalRow)
+                                    .Atom(1, columnName + "_value")
                                 .Seal()
                             .Seal()
                         .Seal()
                         .Build();
-                } else if (linkType == "all") {
+                } else {
+                    YQL_ENSURE(linkType == "all");
                     return ctx.Builder(node->Pos())
                         .Callable("ToPg")
-                            .Callable(0, "==")
-                                .Callable(0, "Coalesce")
-                                    .Callable(0, "Member")
-                                        .Add(0, originalRow)
-                                        .Atom(1, columnName)
+                            .Callable(0, "Or")
+                                .Callable(0, "==")
+                                    .Callable(0, "Coalesce")
+                                        .Callable(0, "Member")
+                                            .Add(0, originalRow)
+                                            .Atom(1, columnName + "_count")
+                                        .Seal()
+                                        .Callable(1, "Uint64")
+                                            .Atom(0, "0")
+                                        .Seal()
                                     .Seal()
                                     .Callable(1, "Uint64")
                                         .Atom(0, "0")
                                     .Seal()
                                 .Seal()
-                                .Callable(1, "Uint64")
-                                    .Atom(0, "0")
+                                .Callable(1, "Member")
+                                    .Add(0, originalRow)
+                                    .Atom(1, columnName + "_value")
                                 .Seal()
                             .Seal()
                         .Seal()