proto_table_reader.cpp 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. #include "proto_table_reader.h"
  2. #include "node_table_reader.h"
  3. #include "proto_helpers.h"
  4. #include <yt/yt/core/misc/protobuf_helpers.h>
  5. #include <yt/yt_proto/yt/formats/extension.pb.h>
  6. #include <util/string/escape.h>
  7. #include <util/string/printf.h>
  8. namespace NYT {
  9. using ::google::protobuf::Descriptor;
  10. using ::google::protobuf::FieldDescriptor;
  11. using ::google::protobuf::EnumValueDescriptor;
  12. using NYT::FromProto;
  13. TString GetFieldColumnName(const FieldDescriptor* fieldDesc)
  14. {
  15. auto columnName = FromProto<TString>(fieldDesc->options().GetExtension(column_name));
  16. if (!columnName.empty()) {
  17. return columnName;
  18. }
  19. auto keyColumnName = FromProto<TString>(fieldDesc->options().GetExtension(key_column_name));
  20. if (!keyColumnName.empty()) {
  21. return keyColumnName;
  22. }
  23. return FromProto<TString>(fieldDesc->name());
  24. }
  25. void ReadMessageFromNode(const TNode& node, Message* row)
  26. {
  27. auto* descriptor = row->GetDescriptor();
  28. auto* reflection = row->GetReflection();
  29. int count = descriptor->field_count();
  30. for (int i = 0; i < count; ++i) {
  31. auto* fieldDesc = descriptor->field(i);
  32. const auto& columnName = GetFieldColumnName(fieldDesc);
  33. const auto& nodeMap = node.AsMap();
  34. auto it = nodeMap.find(columnName);
  35. if (it == nodeMap.end()) {
  36. continue; // no such column
  37. }
  38. auto actualType = it->second.GetType();
  39. if (actualType == TNode::Null) {
  40. continue; // null field
  41. }
  42. auto checkType = [fieldDesc] (TNode::EType expected, TNode::EType actual) {
  43. if (expected != actual) {
  44. ythrow TNode::TTypeError() << "expected node type " << expected
  45. << ", actual " << actual << " for node " << GetFieldColumnName(fieldDesc);
  46. }
  47. };
  48. switch (fieldDesc->type()) {
  49. case FieldDescriptor::TYPE_STRING:
  50. case FieldDescriptor::TYPE_BYTES:
  51. checkType(TNode::String, actualType);
  52. reflection->SetString(row, fieldDesc, it->second.AsString());
  53. break;
  54. case FieldDescriptor::TYPE_INT64:
  55. case FieldDescriptor::TYPE_SINT64:
  56. case FieldDescriptor::TYPE_SFIXED64:
  57. checkType(TNode::Int64, actualType);
  58. reflection->SetInt64(row, fieldDesc, it->second.AsInt64());
  59. break;
  60. case FieldDescriptor::TYPE_INT32:
  61. case FieldDescriptor::TYPE_SINT32:
  62. case FieldDescriptor::TYPE_SFIXED32:
  63. checkType(TNode::Int64, actualType);
  64. reflection->SetInt32(row, fieldDesc, it->second.AsInt64());
  65. break;
  66. case FieldDescriptor::TYPE_UINT64:
  67. case FieldDescriptor::TYPE_FIXED64:
  68. checkType(TNode::Uint64, actualType);
  69. reflection->SetUInt64(row, fieldDesc, it->second.AsUint64());
  70. break;
  71. case FieldDescriptor::TYPE_UINT32:
  72. case FieldDescriptor::TYPE_FIXED32:
  73. checkType(TNode::Uint64, actualType);
  74. reflection->SetUInt32(row, fieldDesc, it->second.AsUint64());
  75. break;
  76. case FieldDescriptor::TYPE_DOUBLE:
  77. checkType(TNode::Double, actualType);
  78. reflection->SetDouble(row, fieldDesc, it->second.AsDouble());
  79. break;
  80. case FieldDescriptor::TYPE_FLOAT:
  81. checkType(TNode::Double, actualType);
  82. reflection->SetFloat(row, fieldDesc, it->second.AsDouble());
  83. break;
  84. case FieldDescriptor::TYPE_BOOL:
  85. checkType(TNode::Bool, actualType);
  86. reflection->SetBool(row, fieldDesc, it->second.AsBool());
  87. break;
  88. case FieldDescriptor::TYPE_ENUM: {
  89. TNode::EType columnType = TNode::String;
  90. for (const auto& flag : fieldDesc->options().GetRepeatedExtension(flags)) {
  91. if (flag == EWrapperFieldFlag::ENUM_INT) {
  92. columnType = TNode::Int64;
  93. break;
  94. }
  95. }
  96. checkType(columnType, actualType);
  97. const EnumValueDescriptor* valueDesc = nullptr;
  98. TString stringValue;
  99. if (columnType == TNode::String) {
  100. const auto& value = it->second.AsString();
  101. valueDesc = fieldDesc->enum_type()->FindValueByName(value);
  102. stringValue = value;
  103. } else if (columnType == TNode::Int64) {
  104. const auto& value = it->second.AsInt64();
  105. valueDesc = fieldDesc->enum_type()->FindValueByNumber(value);
  106. stringValue = ToString(value);
  107. } else {
  108. Y_ABORT();
  109. }
  110. if (valueDesc == nullptr) {
  111. ythrow yexception() << "Failed to parse value '" << EscapeC(stringValue) << "' as " << fieldDesc->enum_type()->full_name();
  112. }
  113. reflection->SetEnum(row, fieldDesc, valueDesc);
  114. break;
  115. }
  116. case FieldDescriptor::TYPE_MESSAGE: {
  117. checkType(TNode::String, actualType);
  118. Message* message = reflection->MutableMessage(row, fieldDesc);
  119. if (!message->ParseFromArray(it->second.AsString().data(), it->second.AsString().size())) {
  120. ythrow yexception() << "Failed to parse protobuf message";
  121. }
  122. break;
  123. }
  124. default:
  125. ythrow yexception() << "Incorrect protobuf type";
  126. }
  127. }
  128. }
  129. ////////////////////////////////////////////////////////////////////////////////
  130. TProtoTableReader::TProtoTableReader(
  131. ::TIntrusivePtr<TRawTableReader> input,
  132. TVector<const Descriptor*>&& descriptors)
  133. : NodeReader_(new TNodeTableReader(std::move(input)))
  134. , Descriptors_(std::move(descriptors))
  135. { }
  136. void TProtoTableReader::ReadRow(Message* row)
  137. {
  138. const auto& node = NodeReader_->GetRow();
  139. ReadMessageFromNode(node, row);
  140. }
  141. bool TProtoTableReader::IsValid() const
  142. {
  143. return NodeReader_->IsValid();
  144. }
  145. void TProtoTableReader::Next()
  146. {
  147. NodeReader_->Next();
  148. }
  149. ui32 TProtoTableReader::GetTableIndex() const
  150. {
  151. return NodeReader_->GetTableIndex();
  152. }
  153. ui32 TProtoTableReader::GetRangeIndex() const
  154. {
  155. return NodeReader_->GetRangeIndex();
  156. }
  157. ui64 TProtoTableReader::GetRowIndex() const
  158. {
  159. return NodeReader_->GetRowIndex();
  160. }
  161. void TProtoTableReader::NextKey()
  162. {
  163. NodeReader_->NextKey();
  164. }
  165. TMaybe<size_t> TProtoTableReader::GetReadByteCount() const
  166. {
  167. return NodeReader_->GetReadByteCount();
  168. }
  169. bool TProtoTableReader::IsEndOfStream() const
  170. {
  171. return NodeReader_->IsEndOfStream();
  172. }
  173. bool TProtoTableReader::IsRawReaderExhausted() const
  174. {
  175. return NodeReader_->IsRawReaderExhausted();
  176. }
  177. ////////////////////////////////////////////////////////////////////////////////
  178. TLenvalProtoTableReader::TLenvalProtoTableReader(
  179. ::TIntrusivePtr<TRawTableReader> input,
  180. TVector<const Descriptor*>&& descriptors)
  181. : TLenvalTableReader(std::move(input))
  182. , ValidateProtoDescriptor_(true)
  183. , Descriptors_(std::move(descriptors))
  184. { }
  185. TLenvalProtoTableReader::TLenvalProtoTableReader(
  186. ::TIntrusivePtr<TRawTableReader> input)
  187. : TLenvalTableReader(std::move(input))
  188. , ValidateProtoDescriptor_(false)
  189. { }
  190. void TLenvalProtoTableReader::ReadRow(Message* row)
  191. {
  192. if (ValidateProtoDescriptor_) {
  193. ValidateProtoDescriptor(*row, GetTableIndex(), Descriptors_, true);
  194. }
  195. while (true) {
  196. try {
  197. ParseFromArcadiaStream(&Input_, *row, Length_);
  198. RowTaken_ = true;
  199. // We successfully parsed one more row from the stream,
  200. // so reset retry count to their initial value.
  201. Input_.ResetRetries();
  202. break;
  203. } catch (const std::exception& ex) {
  204. if (!TLenvalTableReader::Retry(std::make_exception_ptr(ex))) {
  205. throw;
  206. }
  207. }
  208. }
  209. }
  210. bool TLenvalProtoTableReader::IsValid() const
  211. {
  212. return TLenvalTableReader::IsValid();
  213. }
  214. void TLenvalProtoTableReader::Next()
  215. {
  216. TLenvalTableReader::Next();
  217. }
  218. ui32 TLenvalProtoTableReader::GetTableIndex() const
  219. {
  220. return TLenvalTableReader::GetTableIndex();
  221. }
  222. ui32 TLenvalProtoTableReader::GetRangeIndex() const
  223. {
  224. return TLenvalTableReader::GetRangeIndex();
  225. }
  226. ui64 TLenvalProtoTableReader::GetRowIndex() const
  227. {
  228. return TLenvalTableReader::GetRowIndex();
  229. }
  230. void TLenvalProtoTableReader::NextKey()
  231. {
  232. TLenvalTableReader::NextKey();
  233. }
  234. TMaybe<size_t> TLenvalProtoTableReader::GetReadByteCount() const
  235. {
  236. return TLenvalTableReader::GetReadByteCount();
  237. }
  238. bool TLenvalProtoTableReader::IsEndOfStream() const
  239. {
  240. return TLenvalTableReader::IsEndOfStream();
  241. }
  242. bool TLenvalProtoTableReader::IsRawReaderExhausted() const
  243. {
  244. return TLenvalTableReader::IsRawReaderExhausted();
  245. }
  246. void TLenvalProtoTableReader::SkipRow()
  247. {
  248. while (true) {
  249. try {
  250. size_t skipped = Input_.Skip(Length_);
  251. if (skipped != Length_) {
  252. ythrow yexception() << "Premature end of stream";
  253. }
  254. break;
  255. } catch (const std::exception& ex) {
  256. if (!TLenvalTableReader::Retry(std::make_exception_ptr(ex))) {
  257. throw;
  258. }
  259. }
  260. }
  261. }
  262. ////////////////////////////////////////////////////////////////////////////////
  263. } // namespace NYT