proto_table_writer.cpp 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. #include "proto_table_writer.h"
  2. #include "node_table_writer.h"
  3. #include "proto_helpers.h"
  4. #include <yt/cpp/mapreduce/common/node_builder.h>
  5. #include <yt/cpp/mapreduce/interface/io.h>
  6. #include <yt/yt_proto/yt/formats/extension.pb.h>
  7. #include <google/protobuf/unknown_field_set.h>
  8. namespace NYT {
  9. using ::google::protobuf::Descriptor;
  10. using ::google::protobuf::FieldDescriptor;
  11. ////////////////////////////////////////////////////////////////////////////////
  12. TNode MakeNodeFromMessage(const Message& row)
  13. {
  14. TNode node;
  15. TNodeBuilder builder(&node);
  16. builder.OnBeginMap();
  17. auto* descriptor = row.GetDescriptor();
  18. auto* reflection = row.GetReflection();
  19. int count = descriptor->field_count();
  20. for (int i = 0; i < count; ++i) {
  21. auto* fieldDesc = descriptor->field(i);
  22. if (fieldDesc->is_repeated()) {
  23. Y_ENSURE(reflection->FieldSize(row, fieldDesc) == 0, "Storing repeated protobuf fields is not supported yet");
  24. continue;
  25. } else if (!reflection->HasField(row, fieldDesc)) {
  26. continue;
  27. }
  28. TString columnName = fieldDesc->options().GetExtension(column_name);
  29. if (columnName.empty()) {
  30. const auto& keyColumnName = fieldDesc->options().GetExtension(key_column_name);
  31. columnName = keyColumnName.empty() ? fieldDesc->name() : keyColumnName;
  32. }
  33. builder.OnKeyedItem(columnName);
  34. switch (fieldDesc->type()) {
  35. case FieldDescriptor::TYPE_STRING:
  36. case FieldDescriptor::TYPE_BYTES:
  37. builder.OnStringScalar(reflection->GetString(row, fieldDesc));
  38. break;
  39. case FieldDescriptor::TYPE_INT64:
  40. case FieldDescriptor::TYPE_SINT64:
  41. case FieldDescriptor::TYPE_SFIXED64:
  42. builder.OnInt64Scalar(reflection->GetInt64(row, fieldDesc));
  43. break;
  44. case FieldDescriptor::TYPE_INT32:
  45. case FieldDescriptor::TYPE_SINT32:
  46. case FieldDescriptor::TYPE_SFIXED32:
  47. builder.OnInt64Scalar(reflection->GetInt32(row, fieldDesc));
  48. break;
  49. case FieldDescriptor::TYPE_UINT64:
  50. case FieldDescriptor::TYPE_FIXED64:
  51. builder.OnUint64Scalar(reflection->GetUInt64(row, fieldDesc));
  52. break;
  53. case FieldDescriptor::TYPE_UINT32:
  54. case FieldDescriptor::TYPE_FIXED32:
  55. builder.OnUint64Scalar(reflection->GetUInt32(row, fieldDesc));
  56. break;
  57. case FieldDescriptor::TYPE_DOUBLE:
  58. builder.OnDoubleScalar(reflection->GetDouble(row, fieldDesc));
  59. break;
  60. case FieldDescriptor::TYPE_FLOAT:
  61. builder.OnDoubleScalar(reflection->GetFloat(row, fieldDesc));
  62. break;
  63. case FieldDescriptor::TYPE_BOOL:
  64. builder.OnBooleanScalar(reflection->GetBool(row, fieldDesc));
  65. break;
  66. case FieldDescriptor::TYPE_ENUM:
  67. builder.OnStringScalar(reflection->GetEnum(row, fieldDesc)->name());
  68. break;
  69. case FieldDescriptor::TYPE_MESSAGE:
  70. builder.OnStringScalar(reflection->GetMessage(row, fieldDesc).SerializeAsString());
  71. break;
  72. default:
  73. ythrow yexception() << "Invalid field type for column: " << columnName;
  74. break;
  75. }
  76. }
  77. builder.OnEndMap();
  78. return node;
  79. }
  80. ////////////////////////////////////////////////////////////////////////////////
  81. TProtoTableWriter::TProtoTableWriter(
  82. THolder<IProxyOutput> output,
  83. TVector<const Descriptor*>&& descriptors)
  84. : NodeWriter_(new TNodeTableWriter(std::move(output)))
  85. , Descriptors_(std::move(descriptors))
  86. { }
  87. TProtoTableWriter::~TProtoTableWriter()
  88. { }
  89. size_t TProtoTableWriter::GetBufferMemoryUsage() const
  90. {
  91. return NodeWriter_->GetBufferMemoryUsage();
  92. }
  93. size_t TProtoTableWriter::GetTableCount() const
  94. {
  95. return NodeWriter_->GetTableCount();
  96. }
  97. void TProtoTableWriter::FinishTable(size_t tableIndex)
  98. {
  99. NodeWriter_->FinishTable(tableIndex);
  100. }
  101. void TProtoTableWriter::AddRow(const Message& row, size_t tableIndex)
  102. {
  103. NodeWriter_->AddRow(MakeNodeFromMessage(row), tableIndex);
  104. }
  105. void TProtoTableWriter::AddRow(Message&& row, size_t tableIndex)
  106. {
  107. TProtoTableWriter::AddRow(row, tableIndex);
  108. }
  109. void TProtoTableWriter::Abort()
  110. {
  111. NodeWriter_->Abort();
  112. }
  113. ////////////////////////////////////////////////////////////////////////////////
  114. TLenvalProtoTableWriter::TLenvalProtoTableWriter(
  115. THolder<IProxyOutput> output,
  116. TVector<const Descriptor*>&& descriptors)
  117. : Output_(std::move(output))
  118. , Descriptors_(std::move(descriptors))
  119. { }
  120. TLenvalProtoTableWriter::~TLenvalProtoTableWriter()
  121. { }
  122. size_t TLenvalProtoTableWriter::GetBufferMemoryUsage() const
  123. {
  124. return Output_->GetBufferMemoryUsage();
  125. }
  126. size_t TLenvalProtoTableWriter::GetTableCount() const
  127. {
  128. return Output_->GetStreamCount();
  129. }
  130. void TLenvalProtoTableWriter::FinishTable(size_t tableIndex)
  131. {
  132. Output_->GetStream(tableIndex)->Finish();
  133. }
  134. void TLenvalProtoTableWriter::AddRow(const Message& row, size_t tableIndex)
  135. {
  136. ValidateProtoDescriptor(row, tableIndex, Descriptors_, false);
  137. Y_ABORT_UNLESS(row.GetReflection()->GetUnknownFields(row).empty(),
  138. "Message has unknown fields. This probably means bug in client code.\n"
  139. "Message: %s", row.DebugString().data());
  140. auto* stream = Output_->GetStream(tableIndex);
  141. i32 size = row.ByteSize();
  142. stream->Write(&size, sizeof(size));
  143. bool serializedOk = row.SerializeToArcadiaStream(stream);
  144. Y_ENSURE(serializedOk, "Failed to serialize protobuf message");
  145. Output_->OnRowFinished(tableIndex);
  146. }
  147. void TLenvalProtoTableWriter::AddRow(Message&& row, size_t tableIndex)
  148. {
  149. TLenvalProtoTableWriter::AddRow(row, tableIndex);
  150. }
  151. void TLenvalProtoTableWriter::Abort()
  152. {
  153. Output_->Abort();
  154. }
  155. ////////////////////////////////////////////////////////////////////////////////
  156. } // namespace NYT