pb_io.cpp 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. #include "pb_io.h"
  2. #include <library/cpp/binsaver/bin_saver.h>
  3. #include <library/cpp/string_utils/base64/base64.h>
  4. #include <google/protobuf/io/tokenizer.h>
  5. #include <google/protobuf/message.h>
  6. #include <google/protobuf/messagext.h>
  7. #include <google/protobuf/text_format.h>
  8. #include <util/generic/string.h>
  9. #include <util/stream/file.h>
  10. #include <util/stream/str.h>
  11. #include <util/string/cast.h>
  12. namespace NProtoBuf {
  13. class TEnumIdValuePrinter : public google::protobuf::TextFormat::FastFieldValuePrinter {
  14. public:
  15. void PrintEnum(int32 val, const TString& /*name*/, google::protobuf::TextFormat::BaseTextGenerator* generator) const override {
  16. generator->PrintString(ToString(val));
  17. }
  18. };
  19. void ParseFromBase64String(const TStringBuf dataBase64, Message& m, bool allowUneven) {
  20. if (!m.ParseFromString(allowUneven ? Base64DecodeUneven(dataBase64) : Base64StrictDecode(dataBase64))) {
  21. ythrow yexception() << "can't parse " << m.GetTypeName() << " from base64-encoded string";
  22. }
  23. }
  24. bool TryParseFromBase64String(const TStringBuf dataBase64, Message& m, bool allowUneven) {
  25. try {
  26. ParseFromBase64String(dataBase64, m, allowUneven);
  27. return true;
  28. } catch (const std::exception&) {
  29. return false;
  30. }
  31. }
  32. void SerializeToBase64String(const Message& m, TString& dataBase64) {
  33. TString rawData;
  34. if (!m.SerializeToString(&rawData)) {
  35. ythrow yexception() << "can't serialize " << m.GetTypeName();
  36. }
  37. Base64EncodeUrl(rawData, dataBase64);
  38. }
  39. TString SerializeToBase64String(const Message& m) {
  40. TString s;
  41. SerializeToBase64String(m, s);
  42. return s;
  43. }
  44. bool TrySerializeToBase64String(const Message& m, TString& dataBase64) {
  45. try {
  46. SerializeToBase64String(m, dataBase64);
  47. return true;
  48. } catch (const std::exception&) {
  49. return false;
  50. }
  51. }
  52. const TString ShortUtf8DebugString(const Message& message) {
  53. TextFormat::Printer printer;
  54. printer.SetSingleLineMode(true);
  55. printer.SetUseUtf8StringEscaping(true);
  56. TString result;
  57. printer.PrintToString(message, &result);
  58. return result;
  59. }
  60. bool MergePartialFromString(NProtoBuf::Message& m, const TStringBuf serializedProtoMessage) {
  61. google::protobuf::io::CodedInputStream input(reinterpret_cast<const ui8*>(serializedProtoMessage.data()), serializedProtoMessage.size());
  62. bool ok = m.MergePartialFromCodedStream(&input);
  63. ok = ok && input.ConsumedEntireMessage();
  64. return ok;
  65. }
  66. bool MergeFromString(NProtoBuf::Message& m, const TStringBuf serializedProtoMessage) {
  67. return MergePartialFromString(m, serializedProtoMessage) && m.IsInitialized();
  68. }
  69. } // end of namespace NProtoBuf
  70. namespace {
  71. class TErrorCollector: public NProtoBuf::io::ErrorCollector {
  72. public:
  73. TErrorCollector(const NProtoBuf::Message& m, IOutputStream* errorOut, IOutputStream* warningOut)
  74. : TypeName_(m.GetTypeName())
  75. {
  76. ErrorOut_ = errorOut ? errorOut : &Cerr;
  77. WarningOut_ = warningOut ? warningOut : &Cerr;
  78. }
  79. void AddError(int line, int column, const TProtoStringType& message) override {
  80. PrintErrorMessage(ErrorOut_, "Error", line, column, message);
  81. }
  82. void AddWarning(int line, int column, const TProtoStringType& message) override {
  83. PrintErrorMessage(WarningOut_, "Warning", line, column, message);
  84. }
  85. private:
  86. void PrintErrorMessage(IOutputStream* out, TStringBuf errorLevel, int line, int column, const TProtoStringType& message) {
  87. (*out) << errorLevel << " parsing text-format ";
  88. if (line >= 0) {
  89. (*out) << TypeName_ << ": " << (line + 1) << ":" << (column + 1) << ": " << message;
  90. } else {
  91. (*out) << TypeName_ << ": " << message;
  92. }
  93. out->Flush();
  94. }
  95. private:
  96. const TProtoStringType TypeName_;
  97. IOutputStream* ErrorOut_;
  98. IOutputStream* WarningOut_;
  99. };
  100. } // end of anonymous namespace
  101. int operator&(NProtoBuf::Message& m, IBinSaver& f) {
  102. TStringStream ss;
  103. if (f.IsReading()) {
  104. f.Add(0, &ss.Str());
  105. m.ParseFromArcadiaStream(&ss);
  106. } else {
  107. m.SerializeToArcadiaStream(&ss);
  108. f.Add(0, &ss.Str());
  109. }
  110. return 0;
  111. }
  112. void SerializeToTextFormat(const NProtoBuf::Message& m, IOutputStream& out) {
  113. NProtoBuf::io::TCopyingOutputStreamAdaptor adaptor(&out);
  114. if (!NProtoBuf::TextFormat::Print(m, &adaptor)) {
  115. ythrow yexception() << "SerializeToTextFormat failed on Print";
  116. }
  117. }
  118. void SerializeToTextFormat(const NProtoBuf::Message& m, const TString& fileName) {
  119. /* TUnbufferedFileOutput is unbuffered, but TCopyingOutputStreamAdaptor adds
  120. * a buffer on top of it. */
  121. TUnbufferedFileOutput stream(fileName);
  122. SerializeToTextFormat(m, stream);
  123. }
  124. void SerializeToTextFormatWithEnumId(const NProtoBuf::Message& m, IOutputStream& out) {
  125. google::protobuf::TextFormat::Printer printer;
  126. printer.SetDefaultFieldValuePrinter(new NProtoBuf::TEnumIdValuePrinter());
  127. NProtoBuf::io::TCopyingOutputStreamAdaptor adaptor(&out);
  128. if (!printer.Print(m, &adaptor)) {
  129. ythrow yexception() << "SerializeToTextFormatWithEnumId failed on Print";
  130. }
  131. }
  132. void SerializeToTextFormatPretty(const NProtoBuf::Message& m, IOutputStream& out) {
  133. google::protobuf::TextFormat::Printer printer;
  134. printer.SetUseUtf8StringEscaping(true);
  135. printer.SetUseShortRepeatedPrimitives(true);
  136. NProtoBuf::io::TCopyingOutputStreamAdaptor adaptor(&out);
  137. if (!printer.Print(m, &adaptor)) {
  138. ythrow yexception() << "SerializeToTextFormatPretty failed on Print";
  139. }
  140. }
  141. static void ConfigureParser(const EParseFromTextFormatOptions options,
  142. NProtoBuf::TextFormat::Parser& p) {
  143. if (options & EParseFromTextFormatOption::AllowUnknownField) {
  144. p.AllowUnknownField(true);
  145. }
  146. }
  147. void ParseFromTextFormat(IInputStream& in, NProtoBuf::Message& m,
  148. const EParseFromTextFormatOptions options, IOutputStream* warningStream) {
  149. NProtoBuf::io::TCopyingInputStreamAdaptor adaptor(&in);
  150. NProtoBuf::TextFormat::Parser p;
  151. ConfigureParser(options, p);
  152. TStringStream errorLog;
  153. THolder<TErrorCollector> errorCollector;
  154. errorCollector = MakeHolder<TErrorCollector>(m, &errorLog, warningStream);
  155. p.RecordErrorsTo(errorCollector.Get());
  156. if (!p.Parse(&adaptor, &m)) {
  157. // remove everything that may have been read
  158. m.Clear();
  159. ythrow yexception() << errorLog.Str();
  160. }
  161. }
  162. void ParseFromTextFormat(const TString& fileName, NProtoBuf::Message& m,
  163. const EParseFromTextFormatOptions options, IOutputStream* warningStream) {
  164. /* TUnbufferedFileInput is unbuffered, but TCopyingInputStreamAdaptor adds
  165. * a buffer on top of it. */
  166. TUnbufferedFileInput stream(fileName);
  167. ParseFromTextFormat(stream, m, options, warningStream);
  168. }
  169. bool TryParseFromTextFormat(const TString& fileName, NProtoBuf::Message& m,
  170. const EParseFromTextFormatOptions options, IOutputStream* warningStream) {
  171. try {
  172. ParseFromTextFormat(fileName, m, options, warningStream);
  173. } catch (std::exception&) {
  174. return false;
  175. }
  176. return true;
  177. }
  178. bool TryParseFromTextFormat(IInputStream& in, NProtoBuf::Message& m,
  179. const EParseFromTextFormatOptions options, IOutputStream* warningStream) {
  180. try {
  181. ParseFromTextFormat(in, m, options, warningStream);
  182. } catch (std::exception&) {
  183. return false;
  184. }
  185. return true;
  186. }
  187. void MergeFromTextFormat(IInputStream& in, NProtoBuf::Message& m,
  188. const EParseFromTextFormatOptions options) {
  189. NProtoBuf::io::TCopyingInputStreamAdaptor adaptor(&in);
  190. NProtoBuf::TextFormat::Parser p;
  191. ConfigureParser(options, p);
  192. if (!p.Merge(&adaptor, &m)) {
  193. ythrow yexception() << "MergeFromTextFormat failed on Merge for " << m.GetTypeName();
  194. }
  195. }
  196. void MergeFromTextFormat(const TString& fileName, NProtoBuf::Message& m,
  197. const EParseFromTextFormatOptions options) {
  198. /* TUnbufferedFileInput is unbuffered, but TCopyingInputStreamAdaptor adds
  199. * a buffer on top of it. */
  200. TUnbufferedFileInput stream(fileName);
  201. MergeFromTextFormat(stream, m, options);
  202. }
  203. bool TryMergeFromTextFormat(const TString& fileName, NProtoBuf::Message& m,
  204. const EParseFromTextFormatOptions options) {
  205. try {
  206. MergeFromTextFormat(fileName, m, options);
  207. } catch (std::exception&) {
  208. return false;
  209. }
  210. return true;
  211. }
  212. bool TryMergeFromTextFormat(IInputStream& in, NProtoBuf::Message& m,
  213. const EParseFromTextFormatOptions options) {
  214. try {
  215. MergeFromTextFormat(in, m, options);
  216. } catch (std::exception&) {
  217. return false;
  218. }
  219. return true;
  220. }