proto_table_writer.cpp 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. #include "proto_table_writer.h"
  2. #include "node_table_writer.h"
  3. #include "proto_helpers.h"
  4. #include <yt/cpp/mapreduce/common/node_builder.h>
  5. #include <yt/cpp/mapreduce/interface/io.h>
  6. #include <yt/cpp/mapreduce/io/job_writer.h>
  7. #include <yt/yt_proto/yt/formats/extension.pb.h>
  8. #include <google/protobuf/unknown_field_set.h>
  9. namespace NYT {
  10. using ::google::protobuf::Descriptor;
  11. using ::google::protobuf::FieldDescriptor;
  12. ////////////////////////////////////////////////////////////////////////////////
  13. TNode MakeNodeFromMessage(const Message& row)
  14. {
  15. TNode node;
  16. TNodeBuilder builder(&node);
  17. builder.OnBeginMap();
  18. auto* descriptor = row.GetDescriptor();
  19. auto* reflection = row.GetReflection();
  20. int count = descriptor->field_count();
  21. for (int i = 0; i < count; ++i) {
  22. auto* fieldDesc = descriptor->field(i);
  23. if (fieldDesc->is_repeated()) {
  24. Y_ENSURE(reflection->FieldSize(row, fieldDesc) == 0, "Storing repeated protobuf fields is not supported yet");
  25. continue;
  26. } else if (!reflection->HasField(row, fieldDesc)) {
  27. continue;
  28. }
  29. TString columnName = fieldDesc->options().GetExtension(column_name);
  30. if (columnName.empty()) {
  31. const auto& keyColumnName = fieldDesc->options().GetExtension(key_column_name);
  32. columnName = keyColumnName.empty() ? fieldDesc->name() : keyColumnName;
  33. }
  34. builder.OnKeyedItem(columnName);
  35. switch (fieldDesc->type()) {
  36. case FieldDescriptor::TYPE_STRING:
  37. case FieldDescriptor::TYPE_BYTES:
  38. builder.OnStringScalar(reflection->GetString(row, fieldDesc));
  39. break;
  40. case FieldDescriptor::TYPE_INT64:
  41. case FieldDescriptor::TYPE_SINT64:
  42. case FieldDescriptor::TYPE_SFIXED64:
  43. builder.OnInt64Scalar(reflection->GetInt64(row, fieldDesc));
  44. break;
  45. case FieldDescriptor::TYPE_INT32:
  46. case FieldDescriptor::TYPE_SINT32:
  47. case FieldDescriptor::TYPE_SFIXED32:
  48. builder.OnInt64Scalar(reflection->GetInt32(row, fieldDesc));
  49. break;
  50. case FieldDescriptor::TYPE_UINT64:
  51. case FieldDescriptor::TYPE_FIXED64:
  52. builder.OnUint64Scalar(reflection->GetUInt64(row, fieldDesc));
  53. break;
  54. case FieldDescriptor::TYPE_UINT32:
  55. case FieldDescriptor::TYPE_FIXED32:
  56. builder.OnUint64Scalar(reflection->GetUInt32(row, fieldDesc));
  57. break;
  58. case FieldDescriptor::TYPE_DOUBLE:
  59. builder.OnDoubleScalar(reflection->GetDouble(row, fieldDesc));
  60. break;
  61. case FieldDescriptor::TYPE_FLOAT:
  62. builder.OnDoubleScalar(reflection->GetFloat(row, fieldDesc));
  63. break;
  64. case FieldDescriptor::TYPE_BOOL:
  65. builder.OnBooleanScalar(reflection->GetBool(row, fieldDesc));
  66. break;
  67. case FieldDescriptor::TYPE_ENUM:
  68. builder.OnStringScalar(reflection->GetEnum(row, fieldDesc)->name());
  69. break;
  70. case FieldDescriptor::TYPE_MESSAGE:
  71. builder.OnStringScalar(reflection->GetMessage(row, fieldDesc).SerializeAsString());
  72. break;
  73. default:
  74. ythrow yexception() << "Invalid field type for column: " << columnName;
  75. break;
  76. }
  77. }
  78. builder.OnEndMap();
  79. return node;
  80. }
  81. ////////////////////////////////////////////////////////////////////////////////
  82. TProtoTableWriter::TProtoTableWriter(
  83. THolder<IProxyOutput> output,
  84. TVector<const Descriptor*>&& descriptors)
  85. : NodeWriter_(new TNodeTableWriter(std::move(output)))
  86. , Descriptors_(std::move(descriptors))
  87. { }
  88. TProtoTableWriter::~TProtoTableWriter()
  89. { }
  90. size_t TProtoTableWriter::GetBufferMemoryUsage() const
  91. {
  92. return NodeWriter_->GetBufferMemoryUsage();
  93. }
  94. size_t TProtoTableWriter::GetTableCount() const
  95. {
  96. return NodeWriter_->GetTableCount();
  97. }
  98. void TProtoTableWriter::FinishTable(size_t tableIndex)
  99. {
  100. NodeWriter_->FinishTable(tableIndex);
  101. }
  102. void TProtoTableWriter::AddRow(const Message& row, size_t tableIndex)
  103. {
  104. NodeWriter_->AddRow(MakeNodeFromMessage(row), tableIndex);
  105. }
  106. void TProtoTableWriter::AddRow(Message&& row, size_t tableIndex)
  107. {
  108. TProtoTableWriter::AddRow(row, tableIndex);
  109. }
  110. void TProtoTableWriter::Abort()
  111. {
  112. NodeWriter_->Abort();
  113. }
  114. ////////////////////////////////////////////////////////////////////////////////
  115. TLenvalProtoTableWriter::TLenvalProtoTableWriter(
  116. THolder<IProxyOutput> output,
  117. TVector<const Descriptor*>&& descriptors)
  118. : Output_(std::move(output))
  119. , Descriptors_(std::move(descriptors))
  120. { }
  121. TLenvalProtoTableWriter::~TLenvalProtoTableWriter()
  122. { }
  123. size_t TLenvalProtoTableWriter::GetBufferMemoryUsage() const
  124. {
  125. return Output_->GetBufferMemoryUsage();
  126. }
  127. size_t TLenvalProtoTableWriter::GetTableCount() const
  128. {
  129. return Output_->GetStreamCount();
  130. }
  131. void TLenvalProtoTableWriter::FinishTable(size_t tableIndex)
  132. {
  133. Output_->GetStream(tableIndex)->Finish();
  134. }
  135. void TLenvalProtoTableWriter::AddRow(const Message& row, size_t tableIndex)
  136. {
  137. ValidateProtoDescriptor(row, tableIndex, Descriptors_, false);
  138. Y_ABORT_UNLESS(row.GetReflection()->GetUnknownFields(row).empty(),
  139. "Message has unknown fields. This probably means bug in client code.\n"
  140. "Message: %s", row.DebugString().data());
  141. auto* stream = Output_->GetStream(tableIndex);
  142. i32 size = row.ByteSize();
  143. stream->Write(&size, sizeof(size));
  144. bool serializedOk = row.SerializeToArcadiaStream(stream);
  145. Y_ENSURE(serializedOk, "Failed to serialize protobuf message");
  146. Output_->OnRowFinished(tableIndex);
  147. }
  148. void TLenvalProtoTableWriter::AddRow(Message&& row, size_t tableIndex)
  149. {
  150. TLenvalProtoTableWriter::AddRow(row, tableIndex);
  151. }
  152. void TLenvalProtoTableWriter::Abort()
  153. {
  154. Output_->Abort();
  155. }
  156. ////////////////////////////////////////////////////////////////////////////////
  157. TLenvalProtoSingleTableWriter::TLenvalProtoSingleTableWriter(
  158. THolder<IProxyOutput> output,
  159. const Descriptor* descriptor)
  160. : TLenvalProtoTableWriter(std::move(output), {descriptor})
  161. { }
  162. void TLenvalProtoSingleTableWriter::AddRow(const Message& row, size_t tableIndex)
  163. {
  164. ValidateProtoDescriptor(row, 0, Descriptors_, false);
  165. Y_ABORT_UNLESS(row.GetReflection()->GetUnknownFields(row).empty(),
  166. "Message has unknown fields. This probably means bug in client code.\n"
  167. "Message: %s", row.DebugString().data());
  168. auto* stream = Output_->GetStream(tableIndex);
  169. i32 size = row.ByteSize();
  170. stream->Write(&size, sizeof(size));
  171. bool serializedOk = row.SerializeToArcadiaStream(stream);
  172. Y_ENSURE(serializedOk, "Failed to serialize protobuf message");
  173. Output_->OnRowFinished(tableIndex);
  174. }
  175. ////////////////////////////////////////////////////////////////////////////////
  176. } // namespace NYT