Browse Source

YQL-13710 runtime support for PgCast

ref:d8d2e1117419626450d0f8932e59d76e203b79de
vvvv 3 years ago
parent
commit
8de79fac61

+ 36 - 0
ydb/library/yql/core/type_ann/type_ann_core.cpp

@@ -9427,6 +9427,41 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
         return IGraphTransformer::TStatus::Ok;
     }
 
+    IGraphTransformer::TStatus PgCastWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
+        Y_UNUSED(output);
+        if (!EnsureArgsCount(*input, 2, ctx.Expr)) {
+            return IGraphTransformer::TStatus::Error;
+        }
+
+        if (!EnsureAtom(*input->Child(0), ctx.Expr)) {
+            return IGraphTransformer::TStatus::Error;
+        }
+
+        auto targetTypeId = NPg::LookupType(TString(input->Child(0)->Content())).TypeId;
+
+        auto type = input->Tail().GetTypeAnn();
+        ui32 inputTypeId = 0;
+        if (type->GetKind() != ETypeAnnotationKind::Null) {
+            if (type->GetKind() != ETypeAnnotationKind::Pg) {
+                ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
+                    TStringBuilder() << "Expected PG type for cast argument, but got: " << type->GetKind()));
+                return IGraphTransformer::TStatus::Error;
+            }
+
+            inputTypeId = type->Cast<TPgExprType>()->GetId();
+        }
+
+        if (inputTypeId != 0 && inputTypeId != targetTypeId) {
+            if (NPg::LookupType(inputTypeId).Category != 'S' &&
+                NPg::LookupType(targetTypeId).Category != 'S') {
+                Y_UNUSED(NPg::LookupCast(inputTypeId, targetTypeId));
+            }
+        }
+
+        input->SetTypeAnn(ctx.Expr.MakeType<TPgExprType>(targetTypeId));
+        return IGraphTransformer::TStatus::Ok;
+    }
+
     IGraphTransformer::TStatus PgTypeWrapper(const TExprNode::TPtr& input, TExprNode::TPtr& output, TContext& ctx) {
         Y_UNUSED(output);
         if (!EnsureArgsCount(*input, 1, ctx.Expr)) {
@@ -13156,6 +13191,7 @@ template <NKikimr::NUdf::EDataSlot DataSlot>
         Functions["PgAnonWindow"] = &PgAnonWindowWrapper;
         Functions["PgConst"] = &PgConstWrapper;
         Functions["PgType"] = &PgTypeWrapper;
+        Functions["PgCast"] = &PgCastWrapper;
         Functions["AutoDemuxList"] = &AutoDemuxListWrapper;
         Functions["AggrCountInit"] = &AggrCountInitWrapper;
         Functions["AggrCountUpdate"] = &AggrCountUpdateWrapper;

+ 10 - 0
ydb/library/yql/minikql/mkql_program_builder.cpp

@@ -5042,6 +5042,16 @@ TRuntimeNode TProgramBuilder::PgResolvedCall(const std::string_view& name, ui32
     return TRuntimeNode(callableBuilder.Build(), false);
 }
 
+TRuntimeNode TProgramBuilder::PgCast(TRuntimeNode input, TType* returnType) {
+    if constexpr (RuntimeVersion < 30U) {
+        THROW yexception() << "Runtime version (" << RuntimeVersion << ") too old for " << __func__;
+    }
+
+    TCallableBuilder callableBuilder(Env, __func__, returnType);
+    callableBuilder.Add(input);
+    return TRuntimeNode(callableBuilder.Build(), false);
+}
+
 bool CanExportType(TType* type, const TTypeEnvironment& env) {
     if (type->GetKind() == TType::EKind::Type) {
         return false; // Type of Type

+ 1 - 0
ydb/library/yql/minikql/mkql_program_builder.h

@@ -625,6 +625,7 @@ public:
 
     TRuntimeNode PgConst(TPgType* pgType, const std::string_view& value);
     TRuntimeNode PgResolvedCall(const std::string_view& name, ui32 id, const TArrayRef<const TRuntimeNode>& args, TType* returnType);
+    TRuntimeNode PgCast(TRuntimeNode input, TType* returnType);
 
 protected:
     TRuntimeNode Invoke(const std::string_view& funcName, TType* resultType, const TArrayRef<const TRuntimeNode>& args);

+ 56 - 15
ydb/library/yql/parser/pg_catalog/catalog.cpp

@@ -209,11 +209,17 @@ private:
     bool IsSupported = true;
 };
 
+struct TLazyTypeInfo {
+    TString ElementType;
+    TString InFunc;
+    TString OutFunc;
+};
+
 class TTypesParser : public TParser {
 public:
-    TTypesParser(TTypes& types, THashMap<ui32, TString>& elementTypes)
+    TTypesParser(TTypes& types, THashMap<ui32, TLazyTypeInfo>& lazyInfos)
         : Types(types)
-        , ElementTypes(elementTypes)
+        , LazyInfos(lazyInfos)
     {}
 
     void OnKey(const TString& key, const TString& value) override {
@@ -223,8 +229,23 @@ public:
             LastType.ArrayTypeId = FromString<ui32>(value);
         } else if (key == "typname") {
             LastType.Name = value;
+        } else if (key == "typcategory") {
+            Y_ENSURE(value.size() == 1);
+            LastType.Category = value[0];
+        } else if (key == "typlen") {
+            if (value == "NAMEDATALEN") {
+                LastType.TypeLen = 64;
+            } else if (value == "SIZEOF_POINTER") {
+                LastType.TypeLen = 8;
+            } else {
+                LastType.TypeLen = FromString<i32>(value);
+            }
         } else if (key == "typelem") {
-            LastElementType = value; // resolve later
+            LastLazyTypeInfo.ElementType = value; // resolve later
+        } else if (key == "typinput") {
+            LastLazyTypeInfo.InFunc = value; // resolve later
+        } else if (key == "typoutput") {
+            LastLazyTypeInfo.OutFunc = value; // resolve later
         } else if (key == "typbyval") {
             if (value == "f") {
                 LastType.PassByValue = false;
@@ -243,19 +264,17 @@ public:
             Types[LastType.ArrayTypeId] = LastType;
         }
 
-        if (LastElementType) {
-            ElementTypes[LastType.TypeId] = LastElementType;
-        }
+        LazyInfos[LastType.TypeId] = LastLazyTypeInfo;
 
         LastType = TTypeDesc();
-        LastElementType = TString();
+        LastLazyTypeInfo = TLazyTypeInfo();
     }
 
 private:
     TTypes& Types;
-    THashMap<ui32, TString>& ElementTypes;
+    THashMap<ui32, TLazyTypeInfo>& LazyInfos;
     TTypeDesc LastType;
-    TString LastElementType;
+    TLazyTypeInfo LastLazyTypeInfo;
 };
 
 class TCastsParser : public TParser {
@@ -365,9 +384,9 @@ TProcs ParseProcs(const TString& dat, const THashMap<TString, ui32>& typeByName)
     return ret;
 }
 
-TTypes ParseTypes(const TString& dat, THashMap<ui32, TString>& elementTypes) {
+TTypes ParseTypes(const TString& dat, THashMap<ui32, TLazyTypeInfo>& lazyInfos) {
     TTypes ret;
-    TTypesParser parser(ret, elementTypes);
+    TTypesParser parser(ret, lazyInfos);
     parser.Do(dat);
     return ret;
 }
@@ -390,8 +409,8 @@ struct TCatalog {
         Y_ENSURE(NResource::FindExact("pg_proc.dat", &procData));
         TString castData;
         Y_ENSURE(NResource::FindExact("pg_cast.dat", &castData));
-        THashMap<ui32, TString> elementTypes;
-        Types = ParseTypes(typeData, elementTypes);
+        THashMap<ui32, TLazyTypeInfo> lazyTypeInfos;
+        Types = ParseTypes(typeData, lazyTypeInfos);
         for (const auto& [k, v] : Types) {
             if (k == v.TypeId) {
                 Y_ENSURE(TypeByName.insert(std::make_pair(v.Name, k)).second);
@@ -402,8 +421,12 @@ struct TCatalog {
             }
         }
 
-        for (const auto& [k, v]: elementTypes) {
-            auto elemTypePtr = TypeByName.FindPtr(v);
+        for (const auto& [k, v]: lazyTypeInfos) {
+            if (!v.ElementType) {
+                continue;
+            }
+
+            auto elemTypePtr = TypeByName.FindPtr(v.ElementType);
             Y_ENSURE(elemTypePtr);
             auto typePtr = Types.FindPtr(k);
             Y_ENSURE(typePtr);
@@ -417,6 +440,19 @@ struct TCatalog {
             ProcByName[v.Name].push_back(k);
         }
 
+        for (const auto&[k, v] : lazyTypeInfos) {
+            auto inFuncIdPtr = ProcByName.FindPtr(v.InFunc);
+            Y_ENSURE(inFuncIdPtr);
+            Y_ENSURE(inFuncIdPtr->size() == 1);
+            auto outFuncIdPtr = ProcByName.FindPtr(v.OutFunc);
+            Y_ENSURE(outFuncIdPtr);
+            Y_ENSURE(outFuncIdPtr->size() == 1);
+            auto typePtr = Types.FindPtr(k);
+            Y_ENSURE(typePtr);
+            typePtr->InFuncId = inFuncIdPtr->at(0);
+            typePtr->OutFuncId = outFuncIdPtr->at(0);
+        }
+
         Casts = ParseCasts(castData, TypeByName, ProcByName, Procs);
         for (const auto&[k, v] : Casts) {
             Y_ENSURE(CastsByDir.insert(std::make_pair(std::make_pair(v.SourceId, v.TargetId), k)).second);
@@ -522,6 +558,11 @@ const TTypeDesc& LookupType(ui32 typeId) {
     return *typePtr;
 }
 
+bool HasCast(ui32 sourceId, ui32 targetId) {
+    const auto& catalog = TCatalog::Instance();
+    return catalog.CastsByDir.contains(std::make_pair(sourceId, targetId));
+}
+
 const TCastDesc& LookupCast(ui32 sourceId, ui32 targetId) {
     const auto& catalog = TCatalog::Instance();
     auto castByDirPtr = catalog.CastsByDir.FindPtr(std::make_pair(sourceId, targetId));

+ 5 - 0
ydb/library/yql/parser/pg_catalog/catalog.h

@@ -36,6 +36,10 @@ struct TTypeDesc {
     TString Name;
     ui32 ElementTypeId = 0;
     bool PassByValue = false;
+    char Category = '\0';
+    ui32 InFuncId = 0;
+    ui32 OutFuncId = 0;
+    i32 TypeLen = 0;
 };
 
 enum class ECastMethod {
@@ -59,6 +63,7 @@ const TProcDesc& LookupProc(ui32 procId);
 const TTypeDesc& LookupType(const TString& name);
 const TTypeDesc& LookupType(ui32 typeId);
 
+bool HasCast(ui32 sourceId, ui32 targetId);
 const TCastDesc& LookupCast(ui32 sourceId, ui32 targetId);
 const TCastDesc& LookupCast(ui32 castId);
 

+ 6 - 0
ydb/library/yql/providers/common/mkql/yql_provider_mkql.cpp

@@ -2254,6 +2254,12 @@ TMkqlCommonCallableCompiler::TShared::TShared() {
         return ctx.ProgramBuilder.PgResolvedCall(name, id, args, returnType);
     });
 
+    AddCallable("PgCast", [](const TExprNode& node, TMkqlBuildContext& ctx) {
+        auto input = MkqlBuildExpr(*node.Child(1), ctx);
+        auto returnType = BuildType(node, *node.GetTypeAnn(), ctx.ProgramBuilder);
+        return ctx.ProgramBuilder.PgCast(input, returnType);
+    });
+
     AddCallable("QueueCreate", [](const TExprNode& node, TMkqlBuildContext& ctx) {
         const auto initCapacity = MkqlBuildExpr(*node.Child(1), ctx);
         const auto initSize = MkqlBuildExpr(*node.Child(2), ctx);

+ 22 - 10
ydb/library/yql/sql/pg/pg_sql.cpp

@@ -802,7 +802,7 @@ public:
             return ParseColumnRef(CAST_NODE(ColumnRef, node));
         }
         case T_TypeCast: {
-            return ParseTypeCast(CAST_NODE(TypeCast, node));
+            return ParseTypeCast(CAST_NODE(TypeCast, node), settings);
         }
         case T_BoolExpr: {
             return ParseBoolExpr(CAST_NODE(BoolExpr, node), settings);
@@ -976,7 +976,7 @@ public:
         return VL(args.data(), args.size());
     }
 
-    TAstNode* ParseTypeCast(const TypeCast* value) {
+    TAstNode* ParseTypeCast(const TypeCast* value, const TExprSettings& settings) {
         if (!value->arg) {
             AddError("Expected arg");
             return nullptr;
@@ -989,19 +989,21 @@ public:
 
         auto arg = value->arg;
         auto typeName = value->typeName;
-        if (NodeTag(arg) == T_A_Const &&
-            (NodeTag(CAST_NODE(A_Const, arg)->val) == T_String ||
-            NodeTag(CAST_NODE(A_Const, arg)->val) == T_Null) &&
-            typeName->typeOid == 0 &&
+        auto supportedTypeName = typeName->typeOid == 0 &&
             !typeName->setof &&
             !typeName->pct_type &&
             ListLength(typeName->typmods) == 0 &&
             ListLength(typeName->arrayBounds) == 0 &&
             (ListLength(typeName->names) == 2 &&
-            NodeTag(ListNodeNth(typeName->names, 0)) == T_String &&
-            !StrCompare(StrVal(ListNodeNth(typeName->names, 0)), "pg_catalog") || ListLength(typeName->names) == 1) &&
+                NodeTag(ListNodeNth(typeName->names, 0)) == T_String &&
+                !StrCompare(StrVal(ListNodeNth(typeName->names, 0)), "pg_catalog") || ListLength(typeName->names) == 1) &&
             NodeTag(ListNodeNth(typeName->names, ListLength(typeName->names) - 1)) == T_String &&
-            typeName->typemod == -1) {
+            typeName->typemod == -1;
+
+        if (NodeTag(arg) == T_A_Const &&
+            (NodeTag(CAST_NODE(A_Const, arg)->val) == T_String ||
+            NodeTag(CAST_NODE(A_Const, arg)->val) == T_Null) &&
+            supportedTypeName) {
             TStringBuf targetType = StrVal(ListNodeNth(typeName->names, ListLength(typeName->names) - 1));
             if (NodeTag(CAST_NODE(A_Const, arg)->val) == T_String && targetType == "bool") {
                 auto str = StrVal(CAST_NODE(A_Const, arg)->val);
@@ -1019,7 +1021,7 @@ public:
                 }
             }
 
-            if (NodeTag(CAST_NODE(A_Const, arg)->val) == T_Null) {
+            if (!Settings.PgTypes && NodeTag(CAST_NODE(A_Const, arg)->val) == T_Null) {
                 TString yqlType;
                 if (targetType == "bool") {
                     yqlType = "Bool";
@@ -1037,6 +1039,16 @@ public:
             }
         }
 
+        if (Settings.PgTypes && supportedTypeName) {
+            TStringBuf targetType = StrVal(ListNodeNth(typeName->names, ListLength(typeName->names) - 1));
+            auto input = ParseExpr(arg, settings);
+            if (!input) {
+                return nullptr;
+            }
+
+            return L(A("PgCast"), QA(TString(targetType)), input);
+        }
+
         AddError("Unsupported form of type cast");
         return nullptr;
     }