proto_table_writer.cpp 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. #include "proto_table_writer.h"
  2. #include "node_table_writer.h"
  3. #include "proto_helpers.h"
  4. #include <yt/yt/core/misc/protobuf_helpers.h>
  5. #include <yt/cpp/mapreduce/common/node_builder.h>
  6. #include <yt/cpp/mapreduce/interface/io.h>
  7. #include <yt/cpp/mapreduce/io/job_writer.h>
  8. #include <yt/yt_proto/yt/formats/extension.pb.h>
  9. #include <google/protobuf/unknown_field_set.h>
  10. namespace NYT {
  11. using ::google::protobuf::Descriptor;
  12. using ::google::protobuf::FieldDescriptor;
  13. ////////////////////////////////////////////////////////////////////////////////
  14. TNode MakeNodeFromMessage(const Message& row)
  15. {
  16. TNode node;
  17. TNodeBuilder builder(&node);
  18. builder.OnBeginMap();
  19. auto* descriptor = row.GetDescriptor();
  20. auto* reflection = row.GetReflection();
  21. int count = descriptor->field_count();
  22. for (int i = 0; i < count; ++i) {
  23. auto* fieldDesc = descriptor->field(i);
  24. if (fieldDesc->is_repeated()) {
  25. Y_ENSURE(reflection->FieldSize(row, fieldDesc) == 0, "Storing repeated protobuf fields is not supported yet");
  26. continue;
  27. } else if (!reflection->HasField(row, fieldDesc)) {
  28. continue;
  29. }
  30. auto columnName = fieldDesc->options().GetExtension(column_name);
  31. if (columnName.empty()) {
  32. const auto& keyColumnName = fieldDesc->options().GetExtension(key_column_name);
  33. columnName = keyColumnName.empty() ? fieldDesc->name() : keyColumnName;
  34. }
  35. builder.OnKeyedItem(columnName);
  36. switch (fieldDesc->type()) {
  37. case FieldDescriptor::TYPE_STRING:
  38. case FieldDescriptor::TYPE_BYTES:
  39. builder.OnStringScalar(reflection->GetString(row, fieldDesc));
  40. break;
  41. case FieldDescriptor::TYPE_INT64:
  42. case FieldDescriptor::TYPE_SINT64:
  43. case FieldDescriptor::TYPE_SFIXED64:
  44. builder.OnInt64Scalar(reflection->GetInt64(row, fieldDesc));
  45. break;
  46. case FieldDescriptor::TYPE_INT32:
  47. case FieldDescriptor::TYPE_SINT32:
  48. case FieldDescriptor::TYPE_SFIXED32:
  49. builder.OnInt64Scalar(reflection->GetInt32(row, fieldDesc));
  50. break;
  51. case FieldDescriptor::TYPE_UINT64:
  52. case FieldDescriptor::TYPE_FIXED64:
  53. builder.OnUint64Scalar(reflection->GetUInt64(row, fieldDesc));
  54. break;
  55. case FieldDescriptor::TYPE_UINT32:
  56. case FieldDescriptor::TYPE_FIXED32:
  57. builder.OnUint64Scalar(reflection->GetUInt32(row, fieldDesc));
  58. break;
  59. case FieldDescriptor::TYPE_DOUBLE:
  60. builder.OnDoubleScalar(reflection->GetDouble(row, fieldDesc));
  61. break;
  62. case FieldDescriptor::TYPE_FLOAT:
  63. builder.OnDoubleScalar(reflection->GetFloat(row, fieldDesc));
  64. break;
  65. case FieldDescriptor::TYPE_BOOL:
  66. builder.OnBooleanScalar(reflection->GetBool(row, fieldDesc));
  67. break;
  68. case FieldDescriptor::TYPE_ENUM:
  69. builder.OnStringScalar(reflection->GetEnum(row, fieldDesc)->name());
  70. break;
  71. case FieldDescriptor::TYPE_MESSAGE:
  72. builder.OnStringScalar(reflection->GetMessage(row, fieldDesc).SerializeAsString());
  73. break;
  74. default:
  75. ythrow yexception() << "Invalid field type for column: " << columnName;
  76. break;
  77. }
  78. }
  79. builder.OnEndMap();
  80. return node;
  81. }
  82. ////////////////////////////////////////////////////////////////////////////////
  83. TProtoTableWriter::TProtoTableWriter(
  84. THolder<IProxyOutput> output,
  85. TVector<const Descriptor*>&& descriptors)
  86. : NodeWriter_(new TNodeTableWriter(std::move(output)))
  87. , Descriptors_(std::move(descriptors))
  88. { }
  89. TProtoTableWriter::~TProtoTableWriter()
  90. { }
  91. size_t TProtoTableWriter::GetBufferMemoryUsage() const
  92. {
  93. return NodeWriter_->GetBufferMemoryUsage();
  94. }
  95. size_t TProtoTableWriter::GetTableCount() const
  96. {
  97. return NodeWriter_->GetTableCount();
  98. }
  99. void TProtoTableWriter::FinishTable(size_t tableIndex)
  100. {
  101. NodeWriter_->FinishTable(tableIndex);
  102. }
  103. void TProtoTableWriter::AddRow(const Message& row, size_t tableIndex)
  104. {
  105. NodeWriter_->AddRow(MakeNodeFromMessage(row), tableIndex);
  106. }
  107. void TProtoTableWriter::AddRow(Message&& row, size_t tableIndex)
  108. {
  109. TProtoTableWriter::AddRow(row, tableIndex);
  110. }
  111. void TProtoTableWriter::Abort()
  112. {
  113. NodeWriter_->Abort();
  114. }
  115. ////////////////////////////////////////////////////////////////////////////////
  116. TLenvalProtoTableWriter::TLenvalProtoTableWriter(
  117. THolder<IProxyOutput> output,
  118. TVector<const Descriptor*>&& descriptors)
  119. : Output_(std::move(output))
  120. , Descriptors_(std::move(descriptors))
  121. { }
  122. TLenvalProtoTableWriter::~TLenvalProtoTableWriter()
  123. { }
  124. size_t TLenvalProtoTableWriter::GetBufferMemoryUsage() const
  125. {
  126. return Output_->GetBufferMemoryUsage();
  127. }
  128. size_t TLenvalProtoTableWriter::GetTableCount() const
  129. {
  130. return Output_->GetStreamCount();
  131. }
  132. void TLenvalProtoTableWriter::FinishTable(size_t tableIndex)
  133. {
  134. Output_->GetStream(tableIndex)->Finish();
  135. }
  136. void TLenvalProtoTableWriter::AddRow(const Message& row, size_t tableIndex)
  137. {
  138. ValidateProtoDescriptor(row, tableIndex, Descriptors_, false);
  139. Y_ABORT_UNLESS(row.GetReflection()->GetUnknownFields(row).empty(),
  140. "Message has unknown fields. This probably means bug in client code.\n"
  141. "Message: %s", row.DebugString().data());
  142. auto* stream = Output_->GetStream(tableIndex);
  143. i32 size = row.ByteSizeLong();
  144. stream->Write(&size, sizeof(size));
  145. // NB: Scope is essential here since output stream adaptor flushes in destructor.
  146. {
  147. TProtobufOutputStreamAdaptor streamAdaptor(stream);
  148. auto result = row.SerializeToZeroCopyStream(&streamAdaptor);
  149. Y_ENSURE(result && !streamAdaptor.HasError(), "Failed to serialize protobuf message");
  150. }
  151. Output_->OnRowFinished(tableIndex);
  152. }
  153. void TLenvalProtoTableWriter::AddRow(Message&& row, size_t tableIndex)
  154. {
  155. TLenvalProtoTableWriter::AddRow(row, tableIndex);
  156. }
  157. void TLenvalProtoTableWriter::Abort()
  158. {
  159. Output_->Abort();
  160. }
  161. ////////////////////////////////////////////////////////////////////////////////
  162. TLenvalProtoSingleTableWriter::TLenvalProtoSingleTableWriter(
  163. THolder<IProxyOutput> output,
  164. const Descriptor* descriptor)
  165. : TLenvalProtoTableWriter(std::move(output), {descriptor})
  166. { }
  167. void TLenvalProtoSingleTableWriter::AddRow(const Message& row, size_t tableIndex)
  168. {
  169. ValidateProtoDescriptor(row, 0, Descriptors_, false);
  170. Y_ABORT_UNLESS(row.GetReflection()->GetUnknownFields(row).empty(),
  171. "Message has unknown fields. This probably means bug in client code.\n"
  172. "Message: %s", row.DebugString().data());
  173. auto* stream = Output_->GetStream(tableIndex);
  174. i32 size = row.ByteSizeLong();
  175. stream->Write(&size, sizeof(size));
  176. // NB: Scope is essential here since output stream adaptor flushes in destructor.
  177. {
  178. TProtobufOutputStreamAdaptor streamAdaptor(stream);
  179. auto result = row.SerializeToZeroCopyStream(&streamAdaptor);
  180. Y_ENSURE(result && !streamAdaptor.HasError(), "Failed to serialize protobuf message");
  181. }
  182. Output_->OnRowFinished(tableIndex);
  183. }
  184. ////////////////////////////////////////////////////////////////////////////////
  185. } // namespace NYT