scimpl_protobuf.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. #include "scheme.h"
  2. #include <util/generic/vector.h>
  3. #include <util/generic/yexception.h>
  4. #include <google/protobuf/descriptor.h>
  5. #include <google/protobuf/message.h>
  6. #include <google/protobuf/reflection.h>
  7. using namespace google::protobuf;
  8. namespace NSc {
  9. TValue TValue::From(const Message& msg, bool mapAsDict) {
  10. TValue v;
  11. const Reflection* r = msg.GetReflection();
  12. TVector<const FieldDescriptor*> fields;
  13. TVector<const FieldDescriptor*>::iterator it;
  14. int i1;
  15. r->ListFields(msg, &fields);
  16. for (it = fields.begin(), i1 = 0; it != fields.end(); ++it, ++i1) {
  17. const FieldDescriptor* field = *it;
  18. try {
  19. if (field->is_repeated()) {
  20. if (field->is_map() && mapAsDict) {
  21. auto& elem = v[field->name()];
  22. for (int i2 = 0; i2 < r->FieldSize(msg, field); ++i2) {
  23. auto val = FromRepeatedField(msg, field, i2);
  24. if (val.IsDict()) {
  25. elem[TStringBuf(val["key"])] = val["value"];
  26. }
  27. }
  28. } else {
  29. for (int i2 = 0; i2 < r->FieldSize(msg, field); ++i2)
  30. v[field->name()][i2] = FromRepeatedField(msg, field, i2);
  31. }
  32. } else {
  33. v[field->name()] = FromField(msg, field);
  34. }
  35. } catch (...) {
  36. /* conversion failed, skip this field */
  37. }
  38. }
  39. return v;
  40. }
  41. TValue TValue::FromField(const Message& msg, const FieldDescriptor* field) {
  42. TValue v;
  43. const Reflection* r = msg.GetReflection();
  44. switch (field->cpp_type()) {
  45. case FieldDescriptor::CPPTYPE_INT32:
  46. v = r->GetInt32(msg, field);
  47. break;
  48. case FieldDescriptor::CPPTYPE_INT64:
  49. v = r->GetInt64(msg, field);
  50. break;
  51. case FieldDescriptor::CPPTYPE_UINT32:
  52. v = r->GetUInt32(msg, field);
  53. break;
  54. case FieldDescriptor::CPPTYPE_UINT64:
  55. v = r->GetUInt64(msg, field);
  56. break;
  57. case FieldDescriptor::CPPTYPE_DOUBLE:
  58. v = r->GetDouble(msg, field);
  59. break;
  60. case FieldDescriptor::CPPTYPE_FLOAT:
  61. v = r->GetFloat(msg, field);
  62. break;
  63. case FieldDescriptor::CPPTYPE_BOOL:
  64. v.SetBool(r->GetBool(msg, field));
  65. break;
  66. case FieldDescriptor::CPPTYPE_ENUM:
  67. v = r->GetEnum(msg, field)->name();
  68. break;
  69. case FieldDescriptor::CPPTYPE_STRING:
  70. v = r->GetString(msg, field);
  71. break;
  72. case FieldDescriptor::CPPTYPE_MESSAGE:
  73. v = From(r->GetMessage(msg, field));
  74. break;
  75. default:
  76. ythrow TSchemeException() << "field " << field->full_name() << " unexpected type " << (int)field->cpp_type();
  77. }
  78. return v;
  79. }
  80. TValue TValue::FromRepeatedField(const Message& msg, const FieldDescriptor* field, int index) {
  81. TValue v;
  82. const Reflection* r = msg.GetReflection();
  83. switch (field->cpp_type()) {
  84. case FieldDescriptor::CPPTYPE_INT32:
  85. v = r->GetRepeatedInt32(msg, field, index);
  86. break;
  87. case FieldDescriptor::CPPTYPE_INT64:
  88. v = r->GetRepeatedInt64(msg, field, index);
  89. break;
  90. case FieldDescriptor::CPPTYPE_UINT32:
  91. v = r->GetRepeatedUInt32(msg, field, index);
  92. break;
  93. case FieldDescriptor::CPPTYPE_UINT64:
  94. v = r->GetRepeatedUInt64(msg, field, index);
  95. break;
  96. case FieldDescriptor::CPPTYPE_DOUBLE:
  97. v = r->GetRepeatedDouble(msg, field, index);
  98. break;
  99. case FieldDescriptor::CPPTYPE_FLOAT:
  100. v = r->GetRepeatedFloat(msg, field, index);
  101. break;
  102. case FieldDescriptor::CPPTYPE_BOOL:
  103. v.SetBool(r->GetRepeatedBool(msg, field, index));
  104. break;
  105. case FieldDescriptor::CPPTYPE_ENUM:
  106. v = r->GetRepeatedEnum(msg, field, index)->name();
  107. break;
  108. case FieldDescriptor::CPPTYPE_STRING:
  109. v = r->GetRepeatedString(msg, field, index);
  110. break;
  111. case FieldDescriptor::CPPTYPE_MESSAGE:
  112. v = From(r->GetRepeatedMessage(msg, field, index));
  113. break;
  114. default:
  115. ythrow TSchemeException() << "field " << field->full_name() << " unexpected type " << (int)field->cpp_type();
  116. }
  117. return v;
  118. }
  119. void TValue::To(Message& msg, const TProtoOpts& opts) const {
  120. msg.Clear();
  121. if (IsNull()) {
  122. return;
  123. }
  124. if (!IsDict()) {
  125. ythrow TSchemeException() << "expected dictionary";
  126. }
  127. const Descriptor* descriptor = msg.GetDescriptor();
  128. for (int i = 0, count = descriptor->field_count(); i < count; ++i) {
  129. const FieldDescriptor* field = descriptor->field(i);
  130. if (field->is_map()) {
  131. ToMapField(msg, field, opts);
  132. } else if (field->is_repeated()) {
  133. ToRepeatedField(msg, field, opts);
  134. } else {
  135. ToField(msg, field, opts);
  136. }
  137. }
  138. }
  139. void TValue::ValueToField(const TValue& value, Message& msg, const FieldDescriptor* field, const TProtoOpts& opts) const {
  140. const TString& name = field->name();
  141. if (value.IsNull()) {
  142. if (field->is_required() && !field->has_default_value()) {
  143. ythrow TSchemeException() << "has no value for required field " << name;
  144. }
  145. return;
  146. }
  147. const Reflection* reflection = msg.GetReflection();
  148. switch (field->cpp_type()) {
  149. case FieldDescriptor::CPPTYPE_INT32:
  150. reflection->SetInt32(&msg, field, value.ForceIntNumber());
  151. break;
  152. case FieldDescriptor::CPPTYPE_INT64:
  153. reflection->SetInt64(&msg, field, value.ForceIntNumber());
  154. break;
  155. case FieldDescriptor::CPPTYPE_UINT32:
  156. reflection->SetUInt32(&msg, field, value.ForceIntNumber());
  157. break;
  158. case FieldDescriptor::CPPTYPE_UINT64:
  159. reflection->SetUInt64(&msg, field, value.ForceIntNumber());
  160. break;
  161. case FieldDescriptor::CPPTYPE_DOUBLE:
  162. reflection->SetDouble(&msg, field, value.ForceNumber());
  163. break;
  164. case FieldDescriptor::CPPTYPE_FLOAT:
  165. reflection->SetFloat(&msg, field, value.ForceNumber());
  166. break;
  167. case FieldDescriptor::CPPTYPE_BOOL:
  168. reflection->SetBool(&msg, field, value.IsTrue());
  169. break;
  170. case FieldDescriptor::CPPTYPE_STRING:
  171. reflection->SetString(&msg, field, value.ForceString());
  172. break;
  173. case FieldDescriptor::CPPTYPE_ENUM:
  174. value.ToEnumField(msg, field, opts);
  175. break;
  176. case FieldDescriptor::CPPTYPE_MESSAGE:
  177. value.To(*reflection->MutableMessage(&msg, field), opts);
  178. break;
  179. default:
  180. ythrow TSchemeException()
  181. << "field " << field->full_name()
  182. << " unexpected type " << (int)field->cpp_type();
  183. }
  184. }
  185. void TValue::ToField(Message& msg, const FieldDescriptor* field, const TProtoOpts& opts) const {
  186. const TString& name = field->name();
  187. const TValue& value = Get(name);
  188. ValueToField(value, msg, field, opts);
  189. }
  190. void TValue::ToEnumField(Message& msg, const FieldDescriptor* field, const TProtoOpts& opts) const {
  191. const EnumDescriptor* enumField = field->enum_type();
  192. const EnumValueDescriptor* enumFieldValue = IsString()
  193. ? enumField->FindValueByName(ForceString())
  194. : enumField->FindValueByNumber(ForceIntNumber());
  195. if (!enumFieldValue) {
  196. if (opts.UnknownEnumValueIsDefault) {
  197. enumFieldValue = field->default_value_enum();
  198. } else {
  199. ythrow TSchemeException() << "invalid value of enum field " << field->name();
  200. }
  201. }
  202. const Reflection* reflection = msg.GetReflection();
  203. if (field->is_repeated()) {
  204. reflection->AddEnum(&msg, field, enumFieldValue);
  205. } else {
  206. reflection->SetEnum(&msg, field, enumFieldValue);
  207. }
  208. }
  209. void TValue::ToRepeatedField(Message& msg, const FieldDescriptor* field, const TProtoOpts& opts) const {
  210. const TString& name = field->name();
  211. const TValue& fieldValue = Get(name);
  212. if (fieldValue.IsNull()) {
  213. return;
  214. }
  215. if (!fieldValue.IsArray()) {
  216. if (opts.SkipTypeMismatch) {
  217. return; // leave repeated field empty
  218. } else {
  219. ythrow TSchemeException() << "invalid type of repeated field " << name << ": not an array";
  220. }
  221. }
  222. const Reflection* reflection = msg.GetReflection();
  223. for (const TValue& value : fieldValue.GetArray()) {
  224. switch (field->cpp_type()) {
  225. case FieldDescriptor::CPPTYPE_INT32:
  226. reflection->AddInt32(&msg, field, value.ForceIntNumber());
  227. break;
  228. case FieldDescriptor::CPPTYPE_INT64:
  229. reflection->AddInt64(&msg, field, value.ForceIntNumber());
  230. break;
  231. case FieldDescriptor::CPPTYPE_UINT32:
  232. reflection->AddUInt32(&msg, field, value.ForceIntNumber());
  233. break;
  234. case FieldDescriptor::CPPTYPE_UINT64:
  235. reflection->AddUInt64(&msg, field, value.ForceIntNumber());
  236. break;
  237. case FieldDescriptor::CPPTYPE_DOUBLE:
  238. reflection->AddDouble(&msg, field, value.ForceNumber());
  239. break;
  240. case FieldDescriptor::CPPTYPE_FLOAT:
  241. reflection->AddFloat(&msg, field, value.ForceNumber());
  242. break;
  243. case FieldDescriptor::CPPTYPE_BOOL:
  244. reflection->AddBool(&msg, field, value.IsTrue());
  245. break;
  246. case FieldDescriptor::CPPTYPE_STRING:
  247. reflection->AddString(&msg, field, value.ForceString());
  248. break;
  249. case FieldDescriptor::CPPTYPE_ENUM:
  250. value.ToEnumField(msg, field, opts);
  251. break;
  252. case FieldDescriptor::CPPTYPE_MESSAGE:
  253. value.To(*reflection->AddMessage(&msg, field));
  254. break;
  255. default:
  256. ythrow TSchemeException()
  257. << "field " << field->full_name()
  258. << " unexpected type " << (int)field->cpp_type();
  259. }
  260. }
  261. }
  262. void TValue::ToMapField(Message& msg, const FieldDescriptor* field, const TProtoOpts& opts) const {
  263. const TString& name = field->name();
  264. const TValue& fieldValue = Get(name);
  265. if (fieldValue.IsNull()) {
  266. return;
  267. }
  268. if (fieldValue.IsArray()) {
  269. // read dict from key, value array
  270. ToRepeatedField(msg, field, opts);
  271. return;
  272. }
  273. if (!fieldValue.IsDict()) {
  274. if (opts.SkipTypeMismatch) {
  275. return; // leave map field empty
  276. } else {
  277. ythrow TSchemeException() << "invalid type of map field " << name << ": not dict or array";
  278. }
  279. }
  280. const Reflection* reflection = msg.GetReflection();
  281. auto mutableField = reflection->GetMutableRepeatedFieldRef<Message>(&msg, field);
  282. for (const auto& value : fieldValue.GetDict()) {
  283. THolder<Message> entry(mutableField.NewMessage());
  284. auto entryDesc = entry->GetDescriptor();
  285. auto keyField = entryDesc->FindFieldByNumber(1);
  286. auto valueField = entryDesc->FindFieldByNumber(2);
  287. auto entryReflection = entry->GetReflection();
  288. entryReflection->SetString(entry.Get(), keyField, TString(value.first));
  289. ValueToField(value.second, *entry, valueField, opts);
  290. mutableField.Add(*entry);
  291. }
  292. }
  293. }