ValidatingCodec.cc 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. /**
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * https://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "ValidatingCodec.hh"
  19. #include <algorithm>
  20. #include <boost/any.hpp>
  21. #include <map>
  22. #include <memory>
  23. #include <utility>
  24. #include "Decoder.hh"
  25. #include "Encoder.hh"
  26. #include "NodeImpl.hh"
  27. #include "ValidSchema.hh"
  28. namespace avro {
  29. using std::make_shared;
  30. namespace parsing {
  31. using std::shared_ptr;
  32. using std::static_pointer_cast;
  33. using std::map;
  34. using std::ostringstream;
  35. using std::pair;
  36. using std::reverse;
  37. using std::string;
  38. using std::vector;
  39. /** Follows the design of Avro Parser in Java. */
  40. ProductionPtr ValidatingGrammarGenerator::generate(const NodePtr &n) {
  41. map<NodePtr, ProductionPtr> m;
  42. ProductionPtr result = doGenerate(n, m);
  43. fixup(result, m);
  44. return result;
  45. }
  46. Symbol ValidatingGrammarGenerator::generate(const ValidSchema &schema) {
  47. ProductionPtr r = generate(schema.root());
  48. return Symbol::rootSymbol(r);
  49. }
  50. ProductionPtr ValidatingGrammarGenerator::doGenerate(const NodePtr &n,
  51. map<NodePtr, ProductionPtr> &m) {
  52. switch (n->type()) {
  53. case AVRO_NULL:
  54. return make_shared<Production>(1, Symbol::nullSymbol());
  55. case AVRO_BOOL:
  56. return make_shared<Production>(1, Symbol::boolSymbol());
  57. case AVRO_INT:
  58. return make_shared<Production>(1, Symbol::intSymbol());
  59. case AVRO_LONG:
  60. return make_shared<Production>(1, Symbol::longSymbol());
  61. case AVRO_FLOAT:
  62. return make_shared<Production>(1, Symbol::floatSymbol());
  63. case AVRO_DOUBLE:
  64. return make_shared<Production>(1, Symbol::doubleSymbol());
  65. case AVRO_STRING:
  66. return make_shared<Production>(1, Symbol::stringSymbol());
  67. case AVRO_BYTES:
  68. return make_shared<Production>(1, Symbol::bytesSymbol());
  69. case AVRO_FIXED: {
  70. ProductionPtr result = make_shared<Production>();
  71. result->push_back(Symbol::sizeCheckSymbol(n->fixedSize()));
  72. result->push_back(Symbol::fixedSymbol());
  73. m[n] = result;
  74. return result;
  75. }
  76. case AVRO_RECORD: {
  77. ProductionPtr result = make_shared<Production>();
  78. m.erase(n);
  79. size_t c = n->leaves();
  80. for (size_t i = 0; i < c; ++i) {
  81. const NodePtr &leaf = n->leafAt(i);
  82. ProductionPtr v = doGenerate(leaf, m);
  83. copy(v->rbegin(), v->rend(), back_inserter(*result));
  84. }
  85. reverse(result->begin(), result->end());
  86. m[n] = result;
  87. return make_shared<Production>(1, Symbol::indirect(result));
  88. }
  89. case AVRO_ENUM: {
  90. ProductionPtr result = make_shared<Production>();
  91. result->push_back(Symbol::sizeCheckSymbol(n->names()));
  92. result->push_back(Symbol::enumSymbol());
  93. m[n] = result;
  94. return result;
  95. }
  96. case AVRO_ARRAY: {
  97. ProductionPtr result = make_shared<Production>();
  98. result->push_back(Symbol::arrayEndSymbol());
  99. result->push_back(Symbol::repeater(doGenerate(n->leafAt(0), m), true));
  100. result->push_back(Symbol::arrayStartSymbol());
  101. return result;
  102. }
  103. case AVRO_MAP: {
  104. ProductionPtr pp = doGenerate(n->leafAt(1), m);
  105. ProductionPtr v(new Production(*pp));
  106. v->push_back(Symbol::stringSymbol());
  107. ProductionPtr result = make_shared<Production>();
  108. result->push_back(Symbol::mapEndSymbol());
  109. result->push_back(Symbol::repeater(v, false));
  110. result->push_back(Symbol::mapStartSymbol());
  111. return result;
  112. }
  113. case AVRO_UNION: {
  114. vector<ProductionPtr> vv;
  115. size_t c = n->leaves();
  116. vv.reserve(c);
  117. for (size_t i = 0; i < c; ++i) {
  118. vv.push_back(doGenerate(n->leafAt(i), m));
  119. }
  120. ProductionPtr result = make_shared<Production>();
  121. result->push_back(Symbol::alternative(vv));
  122. result->push_back(Symbol::unionSymbol());
  123. return result;
  124. }
  125. case AVRO_SYMBOLIC: {
  126. shared_ptr<NodeSymbolic> ns = static_pointer_cast<NodeSymbolic>(n);
  127. NodePtr nn = ns->getNode();
  128. auto it = m.find(nn);
  129. if (it != m.end() && it->second) {
  130. return it->second;
  131. } else {
  132. m[nn] = ProductionPtr();
  133. return make_shared<Production>(1, Symbol::placeholder(nn));
  134. }
  135. }
  136. default:
  137. throw Exception("Unknown node type");
  138. }
  139. }
  140. struct DummyHandler {
  141. static size_t handle(const Symbol &) {
  142. return 0;
  143. }
  144. };
  145. template<typename P>
  146. class ValidatingDecoder : public Decoder {
  147. const shared_ptr<Decoder> base;
  148. DummyHandler handler_;
  149. P parser;
  150. void init(InputStream &is) final;
  151. void decodeNull() final;
  152. bool decodeBool() final;
  153. int32_t decodeInt() final;
  154. int64_t decodeLong() final;
  155. float decodeFloat() final;
  156. double decodeDouble() final;
  157. void decodeString(string &value) final;
  158. void skipString() final;
  159. void decodeBytes(vector<uint8_t> &value) final;
  160. void skipBytes() final;
  161. void decodeFixed(size_t n, vector<uint8_t> &value) final;
  162. void skipFixed(size_t n) final;
  163. size_t decodeEnum() final;
  164. size_t arrayStart() final;
  165. size_t arrayNext() final;
  166. size_t skipArray() final;
  167. size_t mapStart() final;
  168. size_t mapNext() final;
  169. size_t skipMap() final;
  170. size_t decodeUnionIndex() final;
  171. void drain() final {
  172. base->drain();
  173. }
  174. public:
  175. ValidatingDecoder(const ValidSchema &s, const shared_ptr<Decoder> &b) : base(b),
  176. parser(ValidatingGrammarGenerator().generate(s), NULL, handler_) {}
  177. };
  178. template<typename P>
  179. void ValidatingDecoder<P>::init(InputStream &is) {
  180. base->init(is);
  181. }
  182. template<typename P>
  183. void ValidatingDecoder<P>::decodeNull() {
  184. parser.advance(Symbol::Kind::Null);
  185. base->decodeNull();
  186. }
  187. template<typename P>
  188. bool ValidatingDecoder<P>::decodeBool() {
  189. parser.advance(Symbol::Kind::Bool);
  190. return base->decodeBool();
  191. }
  192. template<typename P>
  193. int32_t ValidatingDecoder<P>::decodeInt() {
  194. parser.advance(Symbol::Kind::Int);
  195. return base->decodeInt();
  196. }
  197. template<typename P>
  198. int64_t ValidatingDecoder<P>::decodeLong() {
  199. parser.advance(Symbol::Kind::Long);
  200. return base->decodeLong();
  201. }
  202. template<typename P>
  203. float ValidatingDecoder<P>::decodeFloat() {
  204. parser.advance(Symbol::Kind::Float);
  205. return base->decodeFloat();
  206. }
  207. template<typename P>
  208. double ValidatingDecoder<P>::decodeDouble() {
  209. parser.advance(Symbol::Kind::Double);
  210. return base->decodeDouble();
  211. }
  212. template<typename P>
  213. void ValidatingDecoder<P>::decodeString(string &value) {
  214. parser.advance(Symbol::Kind::String);
  215. base->decodeString(value);
  216. }
  217. template<typename P>
  218. void ValidatingDecoder<P>::skipString() {
  219. parser.advance(Symbol::Kind::String);
  220. base->skipString();
  221. }
  222. template<typename P>
  223. void ValidatingDecoder<P>::decodeBytes(vector<uint8_t> &value) {
  224. parser.advance(Symbol::Kind::Bytes);
  225. base->decodeBytes(value);
  226. }
  227. template<typename P>
  228. void ValidatingDecoder<P>::skipBytes() {
  229. parser.advance(Symbol::Kind::Bytes);
  230. base->skipBytes();
  231. }
  232. template<typename P>
  233. void ValidatingDecoder<P>::decodeFixed(size_t n, vector<uint8_t> &value) {
  234. parser.advance(Symbol::Kind::Fixed);
  235. parser.assertSize(n);
  236. base->decodeFixed(n, value);
  237. }
  238. template<typename P>
  239. void ValidatingDecoder<P>::skipFixed(size_t n) {
  240. parser.advance(Symbol::Kind::Fixed);
  241. parser.assertSize(n);
  242. base->skipFixed(n);
  243. }
  244. template<typename P>
  245. size_t ValidatingDecoder<P>::decodeEnum() {
  246. parser.advance(Symbol::Kind::Enum);
  247. size_t result = base->decodeEnum();
  248. parser.assertLessThanSize(result);
  249. return result;
  250. }
  251. template<typename P>
  252. size_t ValidatingDecoder<P>::arrayStart() {
  253. parser.advance(Symbol::Kind::ArrayStart);
  254. size_t result = base->arrayStart();
  255. parser.pushRepeatCount(result);
  256. if (result == 0) {
  257. parser.popRepeater();
  258. parser.advance(Symbol::Kind::ArrayEnd);
  259. }
  260. return result;
  261. }
  262. template<typename P>
  263. size_t ValidatingDecoder<P>::arrayNext() {
  264. size_t result = base->arrayNext();
  265. parser.nextRepeatCount(result);
  266. if (result == 0) {
  267. parser.popRepeater();
  268. parser.advance(Symbol::Kind::ArrayEnd);
  269. }
  270. return result;
  271. }
  272. template<typename P>
  273. size_t ValidatingDecoder<P>::skipArray() {
  274. parser.advance(Symbol::Kind::ArrayStart);
  275. size_t n = base->skipArray();
  276. if (n == 0) {
  277. parser.pop();
  278. } else {
  279. parser.pushRepeatCount(n);
  280. parser.skip(*base);
  281. }
  282. parser.advance(Symbol::Kind::ArrayEnd);
  283. return 0;
  284. }
  285. template<typename P>
  286. size_t ValidatingDecoder<P>::mapStart() {
  287. parser.advance(Symbol::Kind::MapStart);
  288. size_t result = base->mapStart();
  289. parser.pushRepeatCount(result);
  290. if (result == 0) {
  291. parser.popRepeater();
  292. parser.advance(Symbol::Kind::MapEnd);
  293. }
  294. return result;
  295. }
  296. template<typename P>
  297. size_t ValidatingDecoder<P>::mapNext() {
  298. size_t result = base->mapNext();
  299. parser.nextRepeatCount(result);
  300. if (result == 0) {
  301. parser.popRepeater();
  302. parser.advance(Symbol::Kind::MapEnd);
  303. }
  304. return result;
  305. }
  306. template<typename P>
  307. size_t ValidatingDecoder<P>::skipMap() {
  308. parser.advance(Symbol::Kind::MapStart);
  309. size_t n = base->skipMap();
  310. if (n == 0) {
  311. parser.pop();
  312. } else {
  313. parser.pushRepeatCount(n);
  314. parser.skip(*base);
  315. }
  316. parser.advance(Symbol::Kind::MapEnd);
  317. return 0;
  318. }
  319. template<typename P>
  320. size_t ValidatingDecoder<P>::decodeUnionIndex() {
  321. parser.advance(Symbol::Kind::Union);
  322. size_t result = base->decodeUnionIndex();
  323. parser.selectBranch(result);
  324. return result;
  325. }
  326. template<typename P>
  327. class ValidatingEncoder : public Encoder {
  328. DummyHandler handler_;
  329. P parser_;
  330. EncoderPtr base_;
  331. void init(OutputStream &os) final;
  332. void flush() final;
  333. int64_t byteCount() const final;
  334. void encodeNull() final;
  335. void encodeBool(bool b) final;
  336. void encodeInt(int32_t i) final;
  337. void encodeLong(int64_t l) final;
  338. void encodeFloat(float f) final;
  339. void encodeDouble(double d) final;
  340. void encodeString(const std::string &s) final;
  341. void encodeBytes(const uint8_t *bytes, size_t len) final;
  342. void encodeFixed(const uint8_t *bytes, size_t len) final;
  343. void encodeEnum(size_t e) final;
  344. void arrayStart() final;
  345. void arrayEnd() final;
  346. void mapStart() final;
  347. void mapEnd() final;
  348. void setItemCount(size_t count) final;
  349. void startItem() final;
  350. void encodeUnionIndex(size_t e) final;
  351. public:
  352. ValidatingEncoder(const ValidSchema &schema, EncoderPtr base) : parser_(ValidatingGrammarGenerator().generate(schema), NULL, handler_),
  353. base_(std::move(base)) {}
  354. };
  355. template<typename P>
  356. void ValidatingEncoder<P>::init(OutputStream &os) {
  357. base_->init(os);
  358. }
  359. template<typename P>
  360. void ValidatingEncoder<P>::flush() {
  361. base_->flush();
  362. }
  363. template<typename P>
  364. void ValidatingEncoder<P>::encodeNull() {
  365. parser_.advance(Symbol::Kind::Null);
  366. base_->encodeNull();
  367. }
  368. template<typename P>
  369. void ValidatingEncoder<P>::encodeBool(bool b) {
  370. parser_.advance(Symbol::Kind::Bool);
  371. base_->encodeBool(b);
  372. }
  373. template<typename P>
  374. void ValidatingEncoder<P>::encodeInt(int32_t i) {
  375. parser_.advance(Symbol::Kind::Int);
  376. base_->encodeInt(i);
  377. }
  378. template<typename P>
  379. void ValidatingEncoder<P>::encodeLong(int64_t l) {
  380. parser_.advance(Symbol::Kind::Long);
  381. base_->encodeLong(l);
  382. }
  383. template<typename P>
  384. void ValidatingEncoder<P>::encodeFloat(float f) {
  385. parser_.advance(Symbol::Kind::Float);
  386. base_->encodeFloat(f);
  387. }
  388. template<typename P>
  389. void ValidatingEncoder<P>::encodeDouble(double d) {
  390. parser_.advance(Symbol::Kind::Double);
  391. base_->encodeDouble(d);
  392. }
  393. template<typename P>
  394. void ValidatingEncoder<P>::encodeString(const std::string &s) {
  395. parser_.advance(Symbol::Kind::String);
  396. base_->encodeString(s);
  397. }
  398. template<typename P>
  399. void ValidatingEncoder<P>::encodeBytes(const uint8_t *bytes, size_t len) {
  400. parser_.advance(Symbol::Kind::Bytes);
  401. base_->encodeBytes(bytes, len);
  402. }
  403. template<typename P>
  404. void ValidatingEncoder<P>::encodeFixed(const uint8_t *bytes, size_t len) {
  405. parser_.advance(Symbol::Kind::Fixed);
  406. parser_.assertSize(len);
  407. base_->encodeFixed(bytes, len);
  408. }
  409. template<typename P>
  410. void ValidatingEncoder<P>::encodeEnum(size_t e) {
  411. parser_.advance(Symbol::Kind::Enum);
  412. parser_.assertLessThanSize(e);
  413. base_->encodeEnum(e);
  414. }
  415. template<typename P>
  416. void ValidatingEncoder<P>::arrayStart() {
  417. parser_.advance(Symbol::Kind::ArrayStart);
  418. parser_.pushRepeatCount(0);
  419. base_->arrayStart();
  420. }
  421. template<typename P>
  422. void ValidatingEncoder<P>::arrayEnd() {
  423. parser_.popRepeater();
  424. parser_.advance(Symbol::Kind::ArrayEnd);
  425. base_->arrayEnd();
  426. }
  427. template<typename P>
  428. void ValidatingEncoder<P>::mapStart() {
  429. parser_.advance(Symbol::Kind::MapStart);
  430. parser_.pushRepeatCount(0);
  431. base_->mapStart();
  432. }
  433. template<typename P>
  434. void ValidatingEncoder<P>::mapEnd() {
  435. parser_.popRepeater();
  436. parser_.advance(Symbol::Kind::MapEnd);
  437. base_->mapEnd();
  438. }
  439. template<typename P>
  440. void ValidatingEncoder<P>::setItemCount(size_t count) {
  441. parser_.nextRepeatCount(count);
  442. base_->setItemCount(count);
  443. }
  444. template<typename P>
  445. void ValidatingEncoder<P>::startItem() {
  446. parser_.processImplicitActions();
  447. if (parser_.top() != Symbol::Kind::Repeater) {
  448. throw Exception("startItem at not an item boundary");
  449. }
  450. base_->startItem();
  451. }
  452. template<typename P>
  453. void ValidatingEncoder<P>::encodeUnionIndex(size_t e) {
  454. parser_.advance(Symbol::Kind::Union);
  455. parser_.selectBranch(e);
  456. base_->encodeUnionIndex(e);
  457. }
  458. template<typename P>
  459. int64_t ValidatingEncoder<P>::byteCount() const {
  460. return base_->byteCount();
  461. }
  462. } // namespace parsing
  463. DecoderPtr validatingDecoder(const ValidSchema &s,
  464. const DecoderPtr &base) {
  465. return make_shared<parsing::ValidatingDecoder<parsing::SimpleParser<parsing::DummyHandler>>>(s, base);
  466. }
  467. EncoderPtr validatingEncoder(const ValidSchema &schema, const EncoderPtr &base) {
  468. return make_shared<parsing::ValidatingEncoder<parsing::SimpleParser<parsing::DummyHandler>>>(schema, base);
  469. }
  470. } // namespace avro