proto_table_reader.cpp 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. #include "proto_table_reader.h"
  2. #include "node_table_reader.h"
  3. #include "proto_helpers.h"
  4. #include <yt/yt_proto/yt/formats/extension.pb.h>
  5. #include <util/string/escape.h>
  6. #include <util/string/printf.h>
  7. namespace NYT {
  8. using ::google::protobuf::Descriptor;
  9. using ::google::protobuf::FieldDescriptor;
  10. using ::google::protobuf::EnumValueDescriptor;
  11. const TString& GetFieldColumnName(const FieldDescriptor* fieldDesc) {
  12. const auto& columnName = fieldDesc->options().GetExtension(column_name);
  13. if (!columnName.empty()) {
  14. return columnName;
  15. }
  16. const auto& keyColumnName = fieldDesc->options().GetExtension(key_column_name);
  17. if (!keyColumnName.empty()) {
  18. return keyColumnName;
  19. }
  20. return fieldDesc->name();
  21. }
  22. void ReadMessageFromNode(const TNode& node, Message* row)
  23. {
  24. auto* descriptor = row->GetDescriptor();
  25. auto* reflection = row->GetReflection();
  26. int count = descriptor->field_count();
  27. for (int i = 0; i < count; ++i) {
  28. auto* fieldDesc = descriptor->field(i);
  29. const auto& columnName = GetFieldColumnName(fieldDesc);
  30. const auto& nodeMap = node.AsMap();
  31. auto it = nodeMap.find(columnName);
  32. if (it == nodeMap.end()) {
  33. continue; // no such column
  34. }
  35. auto actualType = it->second.GetType();
  36. if (actualType == TNode::Null) {
  37. continue; // null field
  38. }
  39. auto checkType = [&columnName] (TNode::EType expected, TNode::EType actual) {
  40. if (expected != actual) {
  41. ythrow TNode::TTypeError() << "expected node type " << expected
  42. << ", actual " << actual << " for node " << columnName.data();
  43. }
  44. };
  45. switch (fieldDesc->type()) {
  46. case FieldDescriptor::TYPE_STRING:
  47. case FieldDescriptor::TYPE_BYTES:
  48. checkType(TNode::String, actualType);
  49. reflection->SetString(row, fieldDesc, it->second.AsString());
  50. break;
  51. case FieldDescriptor::TYPE_INT64:
  52. case FieldDescriptor::TYPE_SINT64:
  53. case FieldDescriptor::TYPE_SFIXED64:
  54. checkType(TNode::Int64, actualType);
  55. reflection->SetInt64(row, fieldDesc, it->second.AsInt64());
  56. break;
  57. case FieldDescriptor::TYPE_INT32:
  58. case FieldDescriptor::TYPE_SINT32:
  59. case FieldDescriptor::TYPE_SFIXED32:
  60. checkType(TNode::Int64, actualType);
  61. reflection->SetInt32(row, fieldDesc, it->second.AsInt64());
  62. break;
  63. case FieldDescriptor::TYPE_UINT64:
  64. case FieldDescriptor::TYPE_FIXED64:
  65. checkType(TNode::Uint64, actualType);
  66. reflection->SetUInt64(row, fieldDesc, it->second.AsUint64());
  67. break;
  68. case FieldDescriptor::TYPE_UINT32:
  69. case FieldDescriptor::TYPE_FIXED32:
  70. checkType(TNode::Uint64, actualType);
  71. reflection->SetUInt32(row, fieldDesc, it->second.AsUint64());
  72. break;
  73. case FieldDescriptor::TYPE_DOUBLE:
  74. checkType(TNode::Double, actualType);
  75. reflection->SetDouble(row, fieldDesc, it->second.AsDouble());
  76. break;
  77. case FieldDescriptor::TYPE_FLOAT:
  78. checkType(TNode::Double, actualType);
  79. reflection->SetFloat(row, fieldDesc, it->second.AsDouble());
  80. break;
  81. case FieldDescriptor::TYPE_BOOL:
  82. checkType(TNode::Bool, actualType);
  83. reflection->SetBool(row, fieldDesc, it->second.AsBool());
  84. break;
  85. case FieldDescriptor::TYPE_ENUM: {
  86. TNode::EType columnType = TNode::String;
  87. for (const auto& flag : fieldDesc->options().GetRepeatedExtension(flags)) {
  88. if (flag == EWrapperFieldFlag::ENUM_INT) {
  89. columnType = TNode::Int64;
  90. break;
  91. }
  92. }
  93. checkType(columnType, actualType);
  94. const EnumValueDescriptor* valueDesc = nullptr;
  95. TString stringValue;
  96. if (columnType == TNode::String) {
  97. const auto& value = it->second.AsString();
  98. valueDesc = fieldDesc->enum_type()->FindValueByName(value);
  99. stringValue = value;
  100. } else if (columnType == TNode::Int64) {
  101. const auto& value = it->second.AsInt64();
  102. valueDesc = fieldDesc->enum_type()->FindValueByNumber(value);
  103. stringValue = ToString(value);
  104. } else {
  105. Y_ABORT();
  106. }
  107. if (valueDesc == nullptr) {
  108. ythrow yexception() << "Failed to parse value '" << EscapeC(stringValue) << "' as " << fieldDesc->enum_type()->full_name();
  109. }
  110. reflection->SetEnum(row, fieldDesc, valueDesc);
  111. break;
  112. }
  113. case FieldDescriptor::TYPE_MESSAGE: {
  114. checkType(TNode::String, actualType);
  115. Message* message = reflection->MutableMessage(row, fieldDesc);
  116. if (!message->ParseFromArray(it->second.AsString().data(), it->second.AsString().size())) {
  117. ythrow yexception() << "Failed to parse protobuf message";
  118. }
  119. break;
  120. }
  121. default:
  122. ythrow yexception() << "Incorrect protobuf type";
  123. }
  124. }
  125. }
  126. ////////////////////////////////////////////////////////////////////////////////
  127. TProtoTableReader::TProtoTableReader(
  128. ::TIntrusivePtr<TRawTableReader> input,
  129. TVector<const Descriptor*>&& descriptors)
  130. : NodeReader_(new TNodeTableReader(std::move(input)))
  131. , Descriptors_(std::move(descriptors))
  132. { }
  133. void TProtoTableReader::ReadRow(Message* row)
  134. {
  135. const auto& node = NodeReader_->GetRow();
  136. ReadMessageFromNode(node, row);
  137. }
  138. bool TProtoTableReader::IsValid() const
  139. {
  140. return NodeReader_->IsValid();
  141. }
  142. void TProtoTableReader::Next()
  143. {
  144. NodeReader_->Next();
  145. }
  146. ui32 TProtoTableReader::GetTableIndex() const
  147. {
  148. return NodeReader_->GetTableIndex();
  149. }
  150. ui32 TProtoTableReader::GetRangeIndex() const
  151. {
  152. return NodeReader_->GetRangeIndex();
  153. }
  154. ui64 TProtoTableReader::GetRowIndex() const
  155. {
  156. return NodeReader_->GetRowIndex();
  157. }
  158. void TProtoTableReader::NextKey()
  159. {
  160. NodeReader_->NextKey();
  161. }
  162. TMaybe<size_t> TProtoTableReader::GetReadByteCount() const
  163. {
  164. return NodeReader_->GetReadByteCount();
  165. }
  166. bool TProtoTableReader::IsEndOfStream() const
  167. {
  168. return NodeReader_->IsEndOfStream();
  169. }
  170. bool TProtoTableReader::IsRawReaderExhausted() const
  171. {
  172. return NodeReader_->IsRawReaderExhausted();
  173. }
  174. ////////////////////////////////////////////////////////////////////////////////
  175. TLenvalProtoTableReader::TLenvalProtoTableReader(
  176. ::TIntrusivePtr<TRawTableReader> input,
  177. TVector<const Descriptor*>&& descriptors)
  178. : TLenvalTableReader(std::move(input))
  179. , ValidateProtoDescriptor_(true)
  180. , Descriptors_(std::move(descriptors))
  181. { }
  182. TLenvalProtoTableReader::TLenvalProtoTableReader(
  183. ::TIntrusivePtr<TRawTableReader> input)
  184. : TLenvalTableReader(std::move(input))
  185. , ValidateProtoDescriptor_(false)
  186. { }
  187. void TLenvalProtoTableReader::ReadRow(Message* row)
  188. {
  189. if (ValidateProtoDescriptor_) {
  190. ValidateProtoDescriptor(*row, GetTableIndex(), Descriptors_, true);
  191. }
  192. while (true) {
  193. try {
  194. ParseFromArcadiaStream(&Input_, *row, Length_);
  195. RowTaken_ = true;
  196. // We successfully parsed one more row from the stream,
  197. // so reset retry count to their initial value.
  198. Input_.ResetRetries();
  199. break;
  200. } catch (const std::exception& ) {
  201. if (!TLenvalTableReader::Retry()) {
  202. throw;
  203. }
  204. }
  205. }
  206. }
  207. bool TLenvalProtoTableReader::IsValid() const
  208. {
  209. return TLenvalTableReader::IsValid();
  210. }
  211. void TLenvalProtoTableReader::Next()
  212. {
  213. TLenvalTableReader::Next();
  214. }
  215. ui32 TLenvalProtoTableReader::GetTableIndex() const
  216. {
  217. return TLenvalTableReader::GetTableIndex();
  218. }
  219. ui32 TLenvalProtoTableReader::GetRangeIndex() const
  220. {
  221. return TLenvalTableReader::GetRangeIndex();
  222. }
  223. ui64 TLenvalProtoTableReader::GetRowIndex() const
  224. {
  225. return TLenvalTableReader::GetRowIndex();
  226. }
  227. void TLenvalProtoTableReader::NextKey()
  228. {
  229. TLenvalTableReader::NextKey();
  230. }
  231. TMaybe<size_t> TLenvalProtoTableReader::GetReadByteCount() const
  232. {
  233. return TLenvalTableReader::GetReadByteCount();
  234. }
  235. bool TLenvalProtoTableReader::IsEndOfStream() const
  236. {
  237. return TLenvalTableReader::IsEndOfStream();
  238. }
  239. bool TLenvalProtoTableReader::IsRawReaderExhausted() const
  240. {
  241. return TLenvalTableReader::IsRawReaderExhausted();
  242. }
  243. void TLenvalProtoTableReader::SkipRow()
  244. {
  245. while (true) {
  246. try {
  247. size_t skipped = Input_.Skip(Length_);
  248. if (skipped != Length_) {
  249. ythrow yexception() << "Premature end of stream";
  250. }
  251. break;
  252. } catch (const std::exception& ) {
  253. if (!TLenvalTableReader::Retry()) {
  254. throw;
  255. }
  256. }
  257. }
  258. }
  259. ////////////////////////////////////////////////////////////////////////////////
  260. } // namespace NYT