Просмотр исходного кода

YTORM-844: Add protobuf interop checks for valid UTF-8 sequences and unify checks

dgolear 1 год назад
Родитель
Сommit
a4f27deb04

+ 73 - 48
yt/yt/core/yson/protobuf_interop.cpp

@@ -1,30 +1,31 @@
 #include "protobuf_interop.h"
 
+#include "consumer.h"
+#include "forwarding_consumer.h"
+#include "null_consumer.h"
 #include "parser.h"
+#include "protobuf_interop_unknown_fields.h"
+#include "writer.h"
 
 #include <yt/yt_proto/yt/core/yson/proto/protobuf_interop.pb.h>
 
-#include <yt/yt/core/yson/consumer.h>
-#include <yt/yt/core/yson/writer.h>
-#include <yt/yt/core/yson/forwarding_consumer.h>
-#include <yt/yt/core/yson/null_consumer.h>
-#include <yt/yt/core/yson/protobuf_interop_unknown_fields.h>
+#include <yt/yt/core/misc/protobuf_helpers.h>
 
 #include <yt/yt/core/ypath/helpers.h>
 #include <yt/yt/core/ypath/stack.h>
 #include <yt/yt/core/ypath/token.h>
 #include <yt/yt/core/ypath/tokenizer.h>
 
-#include <yt/yt/core/misc/protobuf_helpers.h>
-
 #include <yt/yt/core/ytree/convert.h>
 #include <yt/yt/core/ytree/ephemeral_node_factory.h>
 #include <yt/yt/core/ytree/tree_builder.h>
 
-#include <yt/yt/core/concurrency/thread_affinity.h>
-
 #include <yt/yt_proto/yt/core/ytree/proto/attributes.pb.h>
 
+#include <yt/yt/library/syncmap/map.h>
+
+#include <yt/yt/core/concurrency/thread_affinity.h>
+
 #include <library/cpp/yt/misc/cast.h>
 
 #include <library/cpp/yt/threading/fork_aware_spin_lock.h>
@@ -32,9 +33,7 @@
 #include <library/cpp/yt/coding/varint.h>
 #include <library/cpp/yt/coding/zig_zag.h>
 
-#include <library/cpp/yt/threading/fork_aware_spin_lock.h>
-
-#include <yt/yt/library/syncmap/map.h>
+#include <util/charset/utf8.h>
 
 #include <google/protobuf/descriptor.h>
 #include <google/protobuf/dynamic_message.h>
@@ -42,7 +41,6 @@
 
 #include <google/protobuf/io/coded_stream.h>
 #include <google/protobuf/io/zero_copy_stream.h>
-#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
 
 namespace NYT::NYson {
 
@@ -821,10 +819,10 @@ protected:
                 YT_VERIFY(field);
                 YPathStack_.Push(TString{field->GetYsonName()});
                 THROW_ERROR_EXCEPTION("Missing required field %v",
-                    YPathStack_.GetPath())
+                    YPathStack_.GetHumanReadablePath())
                     << TErrorAttribute("ypath", YPathStack_.GetPath())
-                    << TErrorAttribute("protobuf_type", type->GetFullName())
-                    << TErrorAttribute("protobuf_field", field->GetFullName());
+                    << TErrorAttribute("proto_type", type->GetFullName())
+                    << TErrorAttribute("proto_field", field->GetFullName());
             }
         }
 
@@ -838,9 +836,9 @@ protected:
                 const auto* field = type->GetFieldByNumber(numbers[index]);
                 YPathStack_.Push(TString{field->GetYsonName()});
                 THROW_ERROR_EXCEPTION("Duplicate field %v",
-                    YPathStack_.GetPath())
+                    YPathStack_.GetHumanReadablePath())
                     << TErrorAttribute("ypath", YPathStack_.GetPath())
-                    << TErrorAttribute("protobuf_type", type->GetFullName());
+                    << TErrorAttribute("proto_type", type->GetFullName());
             }
         }
     }
@@ -948,6 +946,12 @@ private:
             const auto* field = FieldStack_.back().Field;
             switch (field->GetType()) {
                 case FieldDescriptor::TYPE_STRING:
+                    if (Options_.CheckUtf8 && !IsUtf(value)) {
+                        THROW_ERROR_EXCEPTION("Field %v accepts only valid UTF-8 sequence",
+                            YPathStack_.GetHumanReadablePath())
+                            << TErrorAttribute("ypath", YPathStack_.GetPath())
+                            << TErrorAttribute("proto_field", field->GetFullName());
+                    }
                 case FieldDescriptor::TYPE_BYTES:
                     BodyCodedStream_.WriteVarint64(value.length());
                     BodyCodedStream_.WriteRaw(value.begin(), static_cast<int>(value.length()));
@@ -958,7 +962,7 @@ private:
                     auto optionalValue = enumType->FindValueByLiteral(value);
                     if (!optionalValue) {
                         THROW_ERROR_EXCEPTION("Field %v cannot have value %Qv",
-                            YPathStack_.GetPath(),
+                            YPathStack_.GetHumanReadablePath(),
                             value)
                             << TErrorAttribute("ypath", YPathStack_.GetPath())
                             << TErrorAttribute("proto_type", enumType->GetFullName());
@@ -972,7 +976,7 @@ private:
 
                 default:
                     THROW_ERROR_EXCEPTION("Field %v cannot be parsed from \"string\" values",
-                        YPathStack_.GetPath())
+                        YPathStack_.GetHumanReadablePath())
                         << TErrorAttribute("ypath", YPathStack_.GetPath())
                         << TErrorAttribute("proto_field", field->GetFullName());
             }
@@ -1014,7 +1018,7 @@ private:
 
                 default:
                     THROW_ERROR_EXCEPTION("Field %v cannot be parsed from \"double\" values",
-                        YPathStack_.GetPath())
+                        YPathStack_.GetHumanReadablePath())
                         << TErrorAttribute("ypath", YPathStack_.GetPath())
                         << TErrorAttribute("proto_field", field->GetFullName());
             }
@@ -1028,7 +1032,7 @@ private:
             auto type = field->GetType();
             if (type != FieldDescriptor::TYPE_BOOL) {
                 THROW_ERROR_EXCEPTION("Field %v cannot be parsed from \"boolean\" values",
-                    YPathStack_.GetPath())
+                    YPathStack_.GetHumanReadablePath())
                     << TErrorAttribute("ypath", YPathStack_.GetPath())
                     << TErrorAttribute("proto_field", field->GetFullName());
             }
@@ -1093,7 +1097,7 @@ private:
         if (!field->IsYsonMap() || FieldStack_.back().ParsingYsonMapFromList) {
             if (field->GetType() != FieldDescriptor::TYPE_MESSAGE) {
                 THROW_ERROR_EXCEPTION("Field %v cannot be parsed from \"map\" values",
-                    YPathStack_.GetPath())
+                    YPathStack_.GetHumanReadablePath())
                     << TErrorAttribute("ypath", YPathStack_.GetPath())
                     << TErrorAttribute("proto_field", field->GetFullName());
             }
@@ -1147,7 +1151,7 @@ private:
                 i64 keyValue; // the widest singed integral type
                 if (!TryFromString(key, keyValue)) {
                     THROW_ERROR_EXCEPTION("Cannot parse a signed integral key of map %v from %Qv",
-                        YPathStack_.GetPath(),
+                        YPathStack_.GetHumanReadablePath(),
                         key)
                         << TErrorAttribute("ypath", YPathStack_.GetPath())
                         << TErrorAttribute("proto_field", field->GetFullName());
@@ -1163,7 +1167,7 @@ private:
                 ui64 keyValue; // the widest unsigned integral type
                 if (!TryFromString(key, keyValue)) {
                     THROW_ERROR_EXCEPTION("Cannot parse an unsigned integral key of map %v from %Qv",
-                        YPathStack_.GetPath(),
+                        YPathStack_.GetHumanReadablePath(),
                         key)
                         << TErrorAttribute("ypath", YPathStack_.GetPath())
                         << TErrorAttribute("proto_field", field->GetFullName());
@@ -1439,7 +1443,7 @@ private:
         if (FieldStack_.empty()) {
             THROW_ERROR_EXCEPTION("Protobuf message can only be parsed from \"map\" values")
                 << TErrorAttribute("ypath", YPathStack_.GetPath())
-                << TErrorAttribute("protobuf_type", RootType_->GetFullName());
+                << TErrorAttribute("proto_type", RootType_->GetFullName());
         }
     }
 
@@ -1453,13 +1457,13 @@ private:
             THROW_ERROR_EXCEPTION("Map %v cannot be parsed from scalar values",
                 YPathStack_.GetHumanReadablePath())
                 << TErrorAttribute("ypath", YPathStack_.GetPath())
-                << TErrorAttribute("protobuf_field", field->GetFullName());
+                << TErrorAttribute("proto_field", field->GetFullName());
         }
         if (field->IsRepeated()) {
             THROW_ERROR_EXCEPTION("Field %v is repeated and cannot be parsed from scalar values",
                 YPathStack_.GetHumanReadablePath())
                 << TErrorAttribute("ypath", YPathStack_.GetPath())
-                << TErrorAttribute("protobuf_field", field->GetFullName());
+                << TErrorAttribute("proto_field", field->GetFullName());
         }
     }
 
@@ -1476,7 +1480,7 @@ private:
             THROW_ERROR_EXCEPTION("Field %v is not repeated and cannot be parsed from \"list\" values",
                 YPathStack_.GetHumanReadablePath())
                 << TErrorAttribute("ypath", YPathStack_.GetPath())
-                << TErrorAttribute("protobuf_field", field->GetFullName());
+                << TErrorAttribute("proto_field", field->GetFullName());
         }
     }
 
@@ -1677,7 +1681,7 @@ private:
                 YPathStack_.GetHumanReadablePath(),
                 toTypeName)
                 << TErrorAttribute("ypath", YPathStack_.GetPath())
-                << TErrorAttribute("protobuf_field", field->GetFullName());
+                << TErrorAttribute("proto_field", field->GetFullName());
         }
         return result;
     }
@@ -1956,7 +1960,8 @@ private:
                         ui64 keyValue;
                         if (!CodedStream_.ReadVarint64(&keyValue)) {
                             THROW_ERROR_EXCEPTION("Error reading \"varint\" value for protobuf map key")
-                                << TErrorAttribute("ypath", YPathStack_.GetPath());
+                                << TErrorAttribute("ypath", YPathStack_.GetPath())
+                                << TErrorAttribute("proto_field", field->GetFullName());
                         }
 
                         switch (field->GetType()) {
@@ -1990,7 +1995,8 @@ private:
                         ui32 keyValue;
                         if (!CodedStream_.ReadRaw(&keyValue, sizeof(keyValue))) {
                             THROW_ERROR_EXCEPTION("Error reading \"fixed32\" value for protobuf map key")
-                                << TErrorAttribute("ypath", YPathStack_.GetPath());
+                                << TErrorAttribute("ypath", YPathStack_.GetPath())
+                                << TErrorAttribute("proto_field", field->GetFullName());
                         }
 
                         if (IsSignedIntegralType(field->GetType())) {
@@ -2005,7 +2011,8 @@ private:
                         ui64 keyValue;
                         if (!CodedStream_.ReadRaw(&keyValue, sizeof(keyValue))) {
                             THROW_ERROR_EXCEPTION("Error reading \"fixed64\" value for protobuf map key")
-                                << TErrorAttribute("ypath", YPathStack_.GetPath());
+                                << TErrorAttribute("ypath", YPathStack_.GetPath())
+                                << TErrorAttribute("proto_field", field->GetFullName());
                         }
 
                         if (IsSignedIntegralType(field->GetType())) {
@@ -2020,13 +2027,15 @@ private:
                         ui64 keyLength;
                         if (!CodedStream_.ReadVarint64(&keyLength)) {
                             THROW_ERROR_EXCEPTION("Error reading \"varint\" value for protobuf map key length")
-                                << TErrorAttribute("ypath", YPathStack_.GetPath());
+                                << TErrorAttribute("ypath", YPathStack_.GetPath())
+                                << TErrorAttribute("proto_field", field->GetFullName());
                         }
 
                         PooledString_.resize(keyLength);
                         if (!CodedStream_.ReadRaw(PooledString_.data(), keyLength)) {
                             THROW_ERROR_EXCEPTION("Error reading \"string\" value for protobuf map key")
-                                << TErrorAttribute("ypath", YPathStack_.GetPath());
+                                << TErrorAttribute("ypath", YPathStack_.GetPath())
+                                << TErrorAttribute("proto_field", field->GetFullName());
                         }
                         break;
                     }
@@ -2034,7 +2043,8 @@ private:
                     default:
                         THROW_ERROR_EXCEPTION("Unexpected wire type tag %x for protobuf map key",
                             tag)
-                            << TErrorAttribute("ypath", YPathStack_.GetPath());
+                            << TErrorAttribute("ypath", YPathStack_.GetPath())
+                            << TErrorAttribute("proto_field", field->GetFullName());
                 }
 
                 OnKeyedItem(TString(PooledString_.data(), PooledString_.size()));
@@ -2200,7 +2210,7 @@ private:
     }
 
     template <class T>
-    void ParseFixedPacked(ui64 length, auto field, auto&& func)
+    void ParseFixedPacked(ui64 length, const TProtobufField* field, auto&& func)
     {
         YT_ASSERT(length % sizeof(T) == 0);
         for (auto index = 0u; index < length / sizeof(T); ++index) {
@@ -2215,7 +2225,8 @@ private:
                 THROW_ERROR_EXCEPTION("Error reading %Qv value from field %v",
                     field->GetTypeName(),
                     YPathStack_.GetHumanReadablePath())
-                    << TErrorAttribute("ypath", YPathStack_.GetPath());
+                    << TErrorAttribute("ypath", YPathStack_.GetPath())
+                    << TErrorAttribute("proto_field", field->GetFullName());
             }
             if (index > 0) {
                 YT_ASSERT(field->IsRepeated());
@@ -2226,7 +2237,7 @@ private:
     }
 
     template <class T>
-    void ParseVarintPacked(ui64 length, auto field, auto&& func)
+    void ParseVarintPacked(ui64 length, const TProtobufField* field, auto&& func)
     {
         const void* data = nullptr;
         int size = 0;
@@ -2247,7 +2258,8 @@ private:
                 THROW_ERROR_EXCEPTION("Error reading \"%v\" value for field %v",
                     field->GetTypeName(),
                     YPathStack_.GetHumanReadablePath())
-                    << TErrorAttribute("ypath", YPathStack_.GetPath());
+                    << TErrorAttribute("ypath", YPathStack_.GetPath())
+                    << TErrorAttribute("proto_field", field->GetFullName());
             }
             if (index > 0) {
                 YT_ASSERT(field->IsRepeated());
@@ -2270,7 +2282,8 @@ private:
                 if (!CodedStream_.ReadVarint64(&unsignedValue)) {
                     THROW_ERROR_EXCEPTION("Error reading \"varint\" value for field %v",
                         YPathStack_.GetHumanReadablePath())
-                        << TErrorAttribute("ypath", YPathStack_.GetPath());
+                        << TErrorAttribute("ypath", YPathStack_.GetPath())
+                        << TErrorAttribute("proto_field", field->GetFullName());
                 }
 
                 switch (field->GetType()) {
@@ -2334,7 +2347,8 @@ private:
                 if (!CodedStream_.ReadLittleEndian32(&unsignedValue)) {
                     THROW_ERROR_EXCEPTION("Error reading \"fixed32\" value for field %v",
                         YPathStack_.GetHumanReadablePath())
-                        << TErrorAttribute("ypath", YPathStack_.GetPath());
+                        << TErrorAttribute("ypath", YPathStack_.GetPath())
+                        << TErrorAttribute("proto_field", field->GetFullName());
                 }
 
                 switch (field->GetType()) {
@@ -2374,7 +2388,8 @@ private:
                 if (!CodedStream_.ReadLittleEndian64(&unsignedValue)) {
                     THROW_ERROR_EXCEPTION("Error reading \"fixed64\" value for field %v",
                         YPathStack_.GetHumanReadablePath())
-                        << TErrorAttribute("ypath", YPathStack_.GetPath());
+                        << TErrorAttribute("ypath", YPathStack_.GetPath())
+                        << TErrorAttribute("proto_field", field->GetFullName());
                 }
 
                 switch (field->GetType()) {
@@ -2414,7 +2429,8 @@ private:
                 if (!CodedStream_.ReadVarint64(&length)) {
                     THROW_ERROR_EXCEPTION("Error reading \"varint\" value for field %v",
                         YPathStack_.GetHumanReadablePath())
-                        << TErrorAttribute("ypath", YPathStack_.GetPath());
+                        << TErrorAttribute("ypath", YPathStack_.GetPath())
+                        << TErrorAttribute("proto_field", field->GetFullName());
                 }
 
                 switch (field->GetType()) {
@@ -2424,9 +2440,16 @@ private:
                         if (!CodedStream_.ReadRaw(PooledString_.data(), length)) {
                             THROW_ERROR_EXCEPTION("Error reading \"string\" value for field %v",
                                 YPathStack_.GetHumanReadablePath())
-                                << TErrorAttribute("ypath", YPathStack_.GetPath());
+                                << TErrorAttribute("ypath", YPathStack_.GetPath())
+                                << TErrorAttribute("proto_field", field->GetFullName());
                         }
                         TStringBuf data(PooledString_.data(), length);
+                        if (Options_.CheckUtf8 && field->GetType() == FieldDescriptor::TYPE_STRING && !IsUtf(data)) {
+                            THROW_ERROR_EXCEPTION("Field %v expected to contain valid UTF-8 sequence",
+                                YPathStack_.GetHumanReadablePath())
+                                << TErrorAttribute("ypath", YPathStack_.GetPath())
+                                << TErrorAttribute("proto_field", field->GetFullName());
+                        }
                         ParseScalar([&] {
                             if (field->GetBytesFieldConverter()) {
                                 const auto& converter = *field->GetBytesFieldConverter();
@@ -2535,7 +2558,8 @@ private:
             default:
                 THROW_ERROR_EXCEPTION("Unexpected wire type tag %x",
                     tag)
-                    << TErrorAttribute("ypath", YPathStack_.GetPath());
+                    << TErrorAttribute("ypath", YPathStack_.GetPath())
+                    << TErrorAttribute("proto_field", field->GetFullName());
         }
     }
 
@@ -2602,7 +2626,7 @@ private:
                 << TErrorAttribute("ypath", YPathStack_.GetPath());
         };
 
-        auto readVarint64 = [&] () {
+        auto readVarint64 = [&] {
             ui64 value;
             if (!CodedStream_.ReadVarint64(&value)) {
                 THROW_ERROR_EXCEPTION("Error reading \"varint\" value while parsing key-value pair at %v",
@@ -2729,7 +2753,8 @@ TProtobufElementResolveResult GetProtobufElementFromField(
     if (std::holds_alternative<std::unique_ptr<TProtobufScalarElement>>(element) && !tokenizer.GetSuffix().empty()) {
         THROW_ERROR_EXCEPTION("Field %v is scalar and does not support nested access",
             FormatYPath(tokenizer.GetPrefixPlusToken()))
-            << TErrorAttribute("ypath", tokenizer.GetPrefixPlusToken());
+            << TErrorAttribute("ypath", tokenizer.GetPrefixPlusToken())
+            << TErrorAttribute("proto_field", field->GetFullName());
     }
     return TProtobufElementResolveResult{
         std::move(element),

+ 3 - 0
yt/yt/core/yson/protobuf_interop.h

@@ -148,6 +148,9 @@ struct TProtobufParserOptions
     //! If |true| then required fields not found in protobuf metadata are
     //! silently skipped; otherwise an exception is thrown.
     bool SkipRequiredFields = false;
+
+    // Check if |string| fields contain actual UTF-8 strings.
+    bool CheckUtf8 = false;
 };
 
 //! Parses a byte sequence and translates it into IYsonConsumer calls.

+ 3 - 0
yt/yt/core/yson/protobuf_interop_options.h

@@ -31,6 +31,9 @@ struct TProtobufWriterOptions
     //! If |true| then required fields not found in protobuf metadata are
     //! silently skipped; otherwise an exception is thrown.
     bool SkipRequiredFields = false;
+
+    // Check if |string| fields contain actual UTF-8 strings.
+    bool CheckUtf8 = false;
 };
 
 ////////////////////////////////////////////////////////////////////////////////

+ 28 - 0
yt/yt/core/yson/unittests/protobuf_yson_ut.cpp

@@ -1013,6 +1013,34 @@ TEST(TYsonToProtobufTest, Entities)
     EXPECT_FALSE(message.has_nested_message1());
 }
 
+TEST(TYsonToProtobufTest, ValidUtf8StringCheck)
+{
+    TProtobufWriterOptions options{
+        .CheckUtf8 = true,
+    };
+
+    TString invalidUtf8 = "\xc3\x28";
+
+    auto check = [&] {
+        TEST_PROLOGUE_WITH_OPTIONS(TMessage, options)
+            .BeginMap()
+                .Item("string_field").Value(invalidUtf8)
+            .EndMap();
+    };
+
+    EXPECT_THROW_WITH_SUBSTRING(check(), "valid UTF-8");
+
+    NProto::TMessage message;
+    message.set_string_field(invalidUtf8);
+    TString newYsonString;
+    TStringOutput newYsonOutputStream(newYsonString);
+    TYsonWriter ysonWriter(&newYsonOutputStream, EYsonFormat::Pretty);
+
+    EXPECT_THROW_WITH_SUBSTRING(
+        WriteProtobufMessage(&ysonWriter, message, TProtobufParserOptions{.CheckUtf8 = true}),
+        "valid UTF-8");
+}
+
 TEST(TYsonToProtobufTest, CustomUnknownFieldsModeResolver)
 {
     {

+ 3 - 3
yt/yt/library/syncmap/map.h

@@ -1,14 +1,14 @@
 #pragma once
 
-#include <util/generic/hash.h>
-#include <util/generic/noncopyable.h>
-
 #include <yt/yt/core/misc/finally.h>
 #include <yt/yt/core/misc/hazard_ptr.h>
 #include <yt/yt/core/misc/ref_counted.h>
 
 #include <library/cpp/yt/threading/spin_lock.h>
 
+#include <util/generic/hash.h>
+#include <util/generic/noncopyable.h>
+
 #include <atomic>
 
 namespace NYT::NConcurrency {