Browse Source

Inference projections support (#8744)

Ivan Sukhov 6 months ago
parent
commit
18a91edf6d

+ 112 - 42
ydb/core/external_sources/object_storage.cpp

@@ -18,6 +18,7 @@
 #include <ydb/library/yql/providers/s3/object_listers/yql_s3_path.h>
 #include <ydb/library/yql/providers/s3/path_generator/yql_s3_path_generator.h>
 #include <ydb/library/yql/providers/s3/proto/credentials.pb.h>
+#include <ydb/library/yql/utils/yql_panic.h>
 #include <ydb/public/api/protos/ydb_status_codes.pb.h>
 #include <ydb/public/sdk/cpp/client/ydb_value/value.h>
 
@@ -332,44 +333,87 @@ struct TObjectStorageExternalSource : public IExternalSource {
 
         const TString path = meta->TableLocation;
         const TString filePattern = meta->Attributes.Value("filepattern", TString{});
+        const TString projection = meta->Attributes.Value("projection", TString{});
         const TVector<TString> partitionedBy = GetPartitionedByConfig(meta);
+
+        NYql::NPathGenerator::TPathGeneratorPtr pathGenerator;
+        
+        bool shouldInferPartitions = !partitionedBy.empty() && !projection;
+        bool ignoreEmptyListings = !projection.empty();
+
         NYql::NS3Lister::TListingRequest request {
             .Url = meta->DataSourceLocation,
             .Credentials = credentials
         };
+        TVector<NYql::NS3Lister::TListingRequest> requests;
+
+        if (!projection) {
+            auto error = NYql::NS3::BuildS3FilePattern(path, filePattern, partitionedBy, request);
+            if (error) {
+                throw yexception() << *error;
+            }
+            requests.push_back(request);
+        } else {
+            if (NYql::NS3::HasWildcards(path)) {
+                throw yexception() << "Path prefix: '" << path << "' contains wildcards";
+            }
 
-        auto error = NYql::NS3::BuildS3FilePattern(path, filePattern, partitionedBy, request);
-        if (error) {
-            throw yexception() << *error;
+            pathGenerator = NYql::NPathGenerator::CreatePathGenerator(projection, partitionedBy);
+            for (const auto& rule : pathGenerator->GetRules()) {
+                YQL_ENSURE(rule.ColumnValues.size() == partitionedBy.size());
+                
+                request.Pattern = NYql::NS3::NormalizePath(TStringBuilder() << path << "/" << rule.Path << "/*");
+                request.PatternType = NYql::NS3Lister::ES3PatternType::Wildcard;
+                request.Prefix = request.Pattern.substr(0, NYql::NS3::GetFirstWildcardPos(request.Pattern));
+
+                requests.push_back(request);
+            }
         }
 
         auto partByData = std::make_shared<TStringBuilder>();
+        if (shouldInferPartitions) {
+            *partByData << JoinSeq(",", partitionedBy);
+        }
 
+        TVector<NThreading::TFuture<NYql::NS3Lister::TListResult>> futures;
         auto httpGateway = NYql::IHTTPGateway::Make();
         auto httpRetryPolicy = NYql::GetHTTPDefaultRetryPolicy(NYql::THttpRetryPolicyOptions{.RetriedCurlCodes = NYql::FqRetriedCurlCodes()});
-        auto s3Lister = NYql::NS3Lister::MakeS3Lister(httpGateway, httpRetryPolicy, request, Nothing(), AllowLocalFiles, ActorSystem);
-        auto afterListing = s3Lister->Next().Apply([partByData, partitionedBy, path = request.Pattern](const NThreading::TFuture<NYql::NS3Lister::TListResult>& listResFut) {
-            auto& listRes = listResFut.GetValue();
-            auto& partByRef = *partByData;
-            if (std::holds_alternative<NYql::NS3Lister::TListError>(listRes)) {
-                auto& error = std::get<NYql::NS3Lister::TListError>(listRes);
-                throw yexception() << error.Issues.ToString();
-            }
-            auto& entries = std::get<NYql::NS3Lister::TListEntries>(listRes);
-            if (entries.Objects.empty()) {
-                throw yexception() << "couldn't find files at " << path;
-            }
+        for (const auto& req : requests) {
+            auto s3Lister = NYql::NS3Lister::MakeS3Lister(httpGateway, httpRetryPolicy, req, Nothing(), AllowLocalFiles, ActorSystem);
+            futures.push_back(s3Lister->Next());
+        }
 
-            partByRef << JoinSeq(",", partitionedBy);
-            for (const auto& entry : entries.Objects) {
-                Y_ENSURE(entry.MatchedGlobs.size() == partitionedBy.size());
-                partByRef << Endl << JoinSeq(",", entry.MatchedGlobs);
-            }
-            for (const auto& entry : entries.Objects) {
-                if (entry.Size > 0) {
-                    return entry;
+        auto allFuture = NThreading::WaitExceptionOrAll(futures);
+        auto afterListing = allFuture.Apply([partByData, shouldInferPartitions, ignoreEmptyListings, futures = std::move(futures), requests = std::move(requests)](const NThreading::TFuture<void>& result) {
+            result.GetValue();
+            for (size_t i = 0; i < futures.size(); ++i) {
+                auto& listRes = futures[i].GetValue();
+                if (std::holds_alternative<NYql::NS3Lister::TListError>(listRes)) {
+                    auto& error = std::get<NYql::NS3Lister::TListError>(listRes);
+                    throw yexception() << error.Issues.ToString();
+                }
+                auto& entries = std::get<NYql::NS3Lister::TListEntries>(listRes);
+                if (entries.Objects.empty() && !ignoreEmptyListings) {
+                    throw yexception() << "couldn't find files at " << requests[i].Pattern;
+                }
+
+                if (shouldInferPartitions) {
+                    for (const auto& entry : entries.Objects) {
+                        *partByData << Endl << JoinSeq(",", entry.MatchedGlobs);
+                    }
+                }
+
+                for (const auto& entry : entries.Objects) {
+                    if (entry.Size > 0) {
+                        return entry;
+                    }
+                }
+
+                if (!ignoreEmptyListings) {
+                    throw yexception() << "couldn't find any files for type inference, please check that the right path is provided";
                 }
             }
+
             throw yexception() << "couldn't find any files for type inference, please check that the right path is provided";
         });
 
@@ -412,13 +456,45 @@ struct TObjectStorageExternalSource : public IExternalSource {
             ));
 
             return promise.GetFuture();
-        }).Apply([arrowInferencinatorId, meta, partByData, partitionedBy, this](const NThreading::TFuture<TMetadataResult>& result) {
+        }).Apply([arrowInferencinatorId, meta, partByData, partitionedBy, pathGenerator, this](const NThreading::TFuture<TMetadataResult>& result) {
             auto& value = result.GetValue();
             if (!value.Success()) {
                 return result;
             }
 
-            return InferPartitionedColumnsTypes(arrowInferencinatorId, partByData, partitionedBy, result);
+            auto meta = value.Metadata;
+            if (pathGenerator) {
+                for (const auto& rule : pathGenerator->GetConfig().Rules) {
+                    auto& destColumn = *meta->Schema.add_column();
+                    destColumn.mutable_name()->assign(rule.Name);
+                    switch (rule.Type) {
+                    case NYql::NPathGenerator::IPathGenerator::EType::INTEGER:
+                        destColumn.mutable_type()->set_type_id(Ydb::Type::INT64);
+                        break;
+                    
+                    case NYql::NPathGenerator::IPathGenerator::EType::DATE:
+                        destColumn.mutable_type()->set_type_id(Ydb::Type::DATE);
+                        break;
+
+                    case NYql::NPathGenerator::IPathGenerator::EType::ENUM:
+                    default:
+                        destColumn.mutable_type()->set_type_id(Ydb::Type::STRING);
+                        break;
+                    }
+                }
+            } else {
+                for (const auto& partitionName : partitionedBy) {
+                    auto& destColumn = *meta->Schema.add_column();
+                    destColumn.mutable_name()->assign(partitionName);
+                    destColumn.mutable_type()->set_type_id(Ydb::Type::UTF8);
+                }
+            }
+
+            if (!partitionedBy.empty() && !pathGenerator) {
+                return InferPartitionedColumnsTypes(arrowInferencinatorId, partByData, result);
+            }
+
+            return result;
         }).Apply([](const NThreading::TFuture<TMetadataResult>& result) {
             auto& value = result.GetValue();
             if (value.Success()) {
@@ -436,20 +512,10 @@ private:
     NThreading::TFuture<TMetadataResult> InferPartitionedColumnsTypes(
         NActors::TActorId arrowInferencinatorId,
         std::shared_ptr<TStringBuilder> partByData,
-        const TVector<TString>& partitionedBy,
         const NThreading::TFuture<TMetadataResult>& result) const {
 
         auto& value = result.GetValue();
-        if (partitionedBy.empty()) {
-            return result;
-        }
-
         auto meta = value.Metadata;
-        for (const auto& partitionName : partitionedBy) {
-            auto& destColumn = *meta->Schema.add_column();
-            destColumn.mutable_name()->assign(partitionName);
-            destColumn.mutable_type()->set_type_id(Ydb::Type::UTF8);
-        }
 
         arrow::BufferBuilder builder;
         auto partitionBuffer = std::make_shared<arrow::Buffer>(nullptr, 0);
@@ -500,15 +566,19 @@ private:
         THashSet<TString> columns;
         if (auto partitioned = meta->Attributes.FindPtr("partitionedby"); partitioned) {
             NJson::TJsonValue values;
-            Y_ENSURE(NJson::ReadJsonTree(*partitioned, &values));
-            Y_ENSURE(values.GetType() == NJson::JSON_ARRAY);
+            auto successful = NJson::ReadJsonTree(*partitioned, &values);
+            if (!successful) {
+                columns.insert(*partitioned);
+            } else {
+                Y_ENSURE(values.GetType() == NJson::JSON_ARRAY);
 
-            for (const auto& value : values.GetArray()) {
-                Y_ENSURE(value.GetType() == NJson::JSON_STRING);
-                if (columns.contains(value.GetString())) {
-                    throw yexception() << "invalid partitioned_by parameter, column " << value.GetString() << "mentioned twice";
+                for (const auto& value : values.GetArray()) {
+                    Y_ENSURE(value.GetType() == NJson::JSON_STRING);
+                    if (columns.contains(value.GetString())) {
+                        throw yexception() << "invalid partitioned_by parameter, column " << value.GetString() << "mentioned twice";
+                    }
+                    columns.insert(value.GetString());
                 }
-                columns.insert(value.GetString());
             }
         }
 

+ 3 - 6
ydb/core/kqp/provider/read_attributes_utils.cpp

@@ -21,11 +21,6 @@ class TGatheringAttributesVisitor : public IAstAttributesVisitor {
 
     void VisitAttribute(TString key, TString value) override {
         Y_ABORT_UNLESS(CurrentSource, "cannot write %s: %s", key.c_str(), value.c_str());
-        if (key == "partitionedby") {
-            NJson::TJsonArray values({ value });
-            CurrentSource->second.try_emplace(key, NJson::WriteJson({ values }));
-            return;
-        }
         CurrentSource->second.try_emplace(key, value);
     };
 
@@ -126,9 +121,11 @@ public:
         auto nodeChildren = node->Children();
         if (!nodeChildren.empty() && nodeChildren[0]->IsAtom()) {
             TCoAtom attrName{nodeChildren[0]};
-            if (attrName.StringValue().equal("userschema")) {
+            if (attrName.StringValue() == "userschema") {
                 node = BuildSchemaFromMetadata(Read->Pos(), Ctx, Metadata->Columns);
                 ReplacedUserchema = true;
+            } else if (attrName.StringValue() == "partitionedby") {
+                NewAttributes.erase("partitionedby");
             }
         }
         Children.push_back(std::move(node));

+ 74 - 0
ydb/tests/fq/s3/test_s3_0.py

@@ -488,6 +488,80 @@ Pear|15|33|2024-05-06'''
         assert result_set.columns[2].name == "c"
         assert result_set.columns[2].type.type_id == ydb.Type.UTF8
 
+    @yq_v2
+    @pytest.mark.parametrize("client", [{"folder_id": "my_folder"}], indirect=True)
+    def test_inference_projection(self, kikimr, s3, client, unique_prefix):
+        resource = boto3.resource(
+            "s3", endpoint_url=s3.s3_url, aws_access_key_id="key", aws_secret_access_key="secret_key"
+        )
+
+        bucket = resource.Bucket("fbucket")
+        bucket.create(ACL='public-read')
+        bucket.objects.all().delete()
+
+        s3_client = boto3.client(
+            "s3", endpoint_url=s3.s3_url, aws_access_key_id="key", aws_secret_access_key="secret_key"
+        )
+
+        fruits = '''Fruit,Price,Weight
+Banana,3,100
+Apple,2,22
+Pear,15,33'''
+        s3_client.put_object(Body=fruits, Bucket='fbucket', Key='year=2023/fruits.csv', ContentType='text/plain')
+
+        kikimr.control_plane.wait_bootstrap(1)
+        storage_connection_name = unique_prefix + "fruitbucket"
+        client.create_storage_connection(storage_connection_name, "fbucket")
+
+        sql = '''$projection = @@ {
+                "projection.enabled" : "true",
+                "storage.location.template" : "/${date}",
+                "projection.date.type" : "date",
+                "projection.date.min" : "2022-11-02",
+                "projection.date.max" : "2024-12-02",
+                "projection.date.interval" : "1",
+                "projection.date.format" : "/year=%Y",
+                "projection.date.unit" : "YEARS"
+            } @@;''' + f'''
+
+            SELECT *
+            FROM `{storage_connection_name}`.`/`
+            WITH (format=csv_with_names,
+                with_infer='true',
+                partitioned_by=(`date`),
+                projection=$projection);
+            '''
+
+        query_id = client.create_query("simple", sql, type=fq.QueryContent.QueryType.ANALYTICS).result.query_id
+        client.wait_query_status(query_id, fq.QueryMeta.COMPLETED)
+
+        data = client.get_result_data(query_id)
+        result_set = data.result.result_set
+        logging.debug(str(result_set))
+        assert len(result_set.columns) == 4
+        assert result_set.columns[0].name == "Fruit"
+        assert result_set.columns[0].type.type_id == ydb.Type.UTF8
+        assert result_set.columns[1].name == "Price"
+        assert result_set.columns[1].type.optional_type.item.type_id == ydb.Type.INT64
+        assert result_set.columns[2].name == "Weight"
+        assert result_set.columns[2].type.optional_type.item.type_id == ydb.Type.INT64
+        assert result_set.columns[3].name == "date"
+        assert result_set.columns[3].type.type_id == ydb.Type.DATE
+        assert len(result_set.rows) == 3
+        assert result_set.rows[0].items[0].text_value == "Banana"
+        assert result_set.rows[0].items[1].int64_value == 3
+        assert result_set.rows[0].items[2].int64_value == 100
+        assert result_set.rows[0].items[3].uint32_value == 19663
+        assert result_set.rows[1].items[0].text_value == "Apple"
+        assert result_set.rows[1].items[1].int64_value == 2
+        assert result_set.rows[1].items[2].int64_value == 22
+        assert result_set.rows[1].items[3].uint32_value == 19663
+        assert result_set.rows[2].items[0].text_value == "Pear"
+        assert result_set.rows[2].items[1].int64_value == 15
+        assert result_set.rows[2].items[2].int64_value == 33
+        assert result_set.rows[2].items[3].uint32_value == 19663
+        assert sum(kikimr.control_plane.get_metering(1)) == 10
+
     @yq_all
     @pytest.mark.parametrize("client", [{"folder_id": "my_folder"}], indirect=True)
     def test_csv_with_hopping(self, kikimr, s3, client, unique_prefix):