Browse Source

YQL-14786: pg syntax: support case/when

ref:ffc2cd34eb7df0c072f4ddf57e2b96fb76c5e323
Sergey Uzhakov 2 years ago
parent
commit
afea45a528
1 changed files with 85 additions and 1 deletions
  1. 85 1
      ydb/library/yql/sql/pg/pg_sql.cpp

+ 85 - 1
ydb/library/yql/sql/pg/pg_sql.cpp

@@ -36,6 +36,10 @@ const T* CastNode(const void* nodeptr, int tag) {
     return static_cast<const T*>(nodeptr);
 }
 
+const Node* Expr2Node(const Expr* e) {
+    return reinterpret_cast<const Node*>(e);
+}
+
 int NodeTag(const Node* node) {
     return nodeTag(node);
 }
@@ -943,7 +947,10 @@ public:
             AddError("NullTest: unsupported argisrow");
             return nullptr;
         }
-        auto arg = ParseExpr((const Node*)value->arg, settings);
+        auto arg = ParseExpr(Expr2Node(value->arg), settings);
+        if (!arg) {
+            return nullptr;
+        }
         auto result = L(A("Exists"), arg);
         if (value->nulltesttype == IS_NULL) {
             result = L(A("Not"), result);
@@ -951,6 +958,80 @@ public:
         return L(A("ToPg"), result);
     }
 
+    struct TCaseBranch {
+        TAstNode* Pred;
+        TAstNode* Value;
+    };
+
+    TCaseBranch ReduceCaseBranches(std::vector<TCaseBranch>::const_iterator begin, std::vector<TCaseBranch>::const_iterator end) {
+        Y_ENSURE(begin < end);
+        const size_t branchCount = end - begin;
+        if (branchCount == 1) {
+            return *begin;
+        }
+
+        auto mid = begin + branchCount / 2;
+        auto left = ReduceCaseBranches(begin, mid);
+        auto right = ReduceCaseBranches(mid, end);
+
+        TVector<TAstNode*> preds;
+        preds.reserve(branchCount + 1);
+        preds.push_back(A("Or"));
+        for (auto it = begin; it != end; ++it) {
+            preds.push_back(it->Pred);
+        }
+
+        TCaseBranch result;
+        result.Pred = VL(&preds[0], preds.size());
+        result.Value = L(A("If"), left.Pred, left.Value, right.Value);
+        return result;
+
+    }
+
+    TAstNode* ParseCaseExpr(const CaseExpr* value, const TExprSettings& settings) {
+        TAstNode* testExpr = nullptr;
+        if (value->arg) {
+            testExpr = ParseExpr(Expr2Node(value->arg), settings);
+            if (!testExpr) {
+                return nullptr;
+            }
+        }
+        std::vector<TCaseBranch> branches;
+        for (int i = 0; i < ListLength(value->args); ++i) {
+            auto node = ListNodeNth(value->args, i);
+            auto whenNode = CAST_NODE(CaseWhen, node);
+            auto whenExpr = ParseExpr(Expr2Node(whenNode->expr), settings);
+            if (!whenExpr) {
+                return nullptr;
+            }
+            if (testExpr) {
+                whenExpr = L(A("PgOp"), QA("="), testExpr, whenExpr);
+            }
+
+            whenExpr = L(A("Coalesce"),
+                L(A("FromPg"), whenExpr),
+                L(A("Bool"), QA("false"))
+            );
+
+            auto whenResult = ParseExpr(Expr2Node(whenNode->result), settings);
+            if (!whenResult) {
+                return nullptr;
+            }
+            branches.emplace_back(TCaseBranch{ .Pred = whenExpr,.Value = whenResult });
+        }
+        TAstNode* defaultResult = nullptr;
+        if (value->defresult) {
+            defaultResult = ParseExpr(Expr2Node(value->defresult), settings);
+            if (!defaultResult) {
+                return nullptr;
+            }
+        } else {
+            defaultResult = L(A("Null"));
+        }
+        auto final = ReduceCaseBranches(branches.begin(), branches.end());
+        return L(A("If"), final.Pred, final.Value, defaultResult);
+    }
+
     TAstNode* ParseExpr(const Node* node, const TExprSettings& settings) {
         switch (NodeTag(node)) {
         case T_A_Const: {
@@ -959,6 +1040,9 @@ public:
         case T_A_Expr: {
             return ParseAExpr(CAST_NODE(A_Expr, node), settings);
         }
+        case T_CaseExpr: {
+            return ParseCaseExpr(CAST_NODE(CaseExpr, node), settings);
+        }
         case T_ColumnRef: {
             if (!settings.AllowColumns) {
                 AddError(TStringBuilder() << "Columns are not allowed in: " << settings.Scope);