Compiler.cc 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569
  1. /**
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * https://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include <boost/algorithm/string/replace.hpp>
  19. #include <sstream>
  20. #include <unordered_set>
  21. #include <utility>
  22. #include "Compiler.hh"
  23. #include "CustomAttributes.hh"
  24. #include "NodeConcepts.hh"
  25. #include "Schema.hh"
  26. #include "Stream.hh"
  27. #include "Types.hh"
  28. #include "ValidSchema.hh"
  29. #include "json/JsonDom.hh"
  30. using std::make_pair;
  31. using std::map;
  32. using std::pair;
  33. using std::string;
  34. using std::vector;
  35. namespace avro {
  36. using json::Array;
  37. using json::Entity;
  38. using json::EntityType;
  39. using json::Object;
  40. using SymbolTable = map<Name, NodePtr>;
  41. // #define DEBUG_VERBOSE
  42. static NodePtr makePrimitive(const string &t) {
  43. if (t == "null") {
  44. return NodePtr(new NodePrimitive(AVRO_NULL));
  45. } else if (t == "boolean") {
  46. return NodePtr(new NodePrimitive(AVRO_BOOL));
  47. } else if (t == "int") {
  48. return NodePtr(new NodePrimitive(AVRO_INT));
  49. } else if (t == "long") {
  50. return NodePtr(new NodePrimitive(AVRO_LONG));
  51. } else if (t == "float") {
  52. return NodePtr(new NodePrimitive(AVRO_FLOAT));
  53. } else if (t == "double") {
  54. return NodePtr(new NodePrimitive(AVRO_DOUBLE));
  55. } else if (t == "string") {
  56. return NodePtr(new NodePrimitive(AVRO_STRING));
  57. } else if (t == "bytes") {
  58. return NodePtr(new NodePrimitive(AVRO_BYTES));
  59. } else {
  60. return NodePtr();
  61. }
  62. }
  63. static NodePtr makeNode(const json::Entity &e, SymbolTable &st, const string &ns);
  64. template<typename T>
  65. concepts::SingleAttribute<T> asSingleAttribute(const T &t) {
  66. concepts::SingleAttribute<T> n;
  67. n.add(t);
  68. return n;
  69. }
  70. static bool isFullName(const string &s) {
  71. return s.find('.') != string::npos;
  72. }
  73. static Name getName(const string &name, const string &ns) {
  74. return (isFullName(name)) ? Name(name) : Name(name, ns);
  75. }
  76. static NodePtr makeNode(const string &t, SymbolTable &st, const string &ns) {
  77. NodePtr result = makePrimitive(t);
  78. if (result) {
  79. return result;
  80. }
  81. Name n = getName(t, ns);
  82. auto it = st.find(n);
  83. if (it != st.end()) {
  84. return NodePtr(new NodeSymbolic(asSingleAttribute(n), it->second));
  85. }
  86. throw Exception(boost::format("Unknown type: %1%") % n.fullname());
  87. }
  88. /** Returns "true" if the field is in the container */
  89. // e.g.: can be false for non-mandatory fields
  90. bool containsField(const Object &m, const string &fieldName) {
  91. auto it = m.find(fieldName);
  92. return (it != m.end());
  93. }
  94. json::Object::const_iterator findField(const Entity &e,
  95. const Object &m, const string &fieldName);
  96. template<typename T>
  97. void ensureType(const Entity &e, const string &name) {
  98. if (e.type() != json::type_traits<T>::type()) {
  99. throw Exception(boost::format("Json field \"%1%\" is not a %2%: %3%") % name % json::type_traits<T>::name() % e.toString());
  100. }
  101. }
  102. string getStringField(const Entity &e, const Object &m,
  103. const string &fieldName) {
  104. auto it = findField(e, m, fieldName);
  105. ensureType<string>(it->second, fieldName);
  106. return it->second.stringValue();
  107. }
  108. const Array &getArrayField(const Entity &e, const Object &m,
  109. const string &fieldName);
  110. int64_t getLongField(const Entity &e, const Object &m,
  111. const string &fieldName) {
  112. auto it = findField(e, m, fieldName);
  113. ensureType<int64_t>(it->second, fieldName);
  114. return it->second.longValue();
  115. }
  116. // Unescape double quotes (") for de-serialization. This method complements the
  117. // method NodeImpl::escape() which is used for serialization.
  118. static void unescape(string &s) {
  119. boost::replace_all(s, "\\\"", "\"");
  120. }
  121. string getDocField(const Entity &e, const Object &m) {
  122. string doc = getStringField(e, m, "doc");
  123. unescape(doc);
  124. return doc;
  125. }
  126. struct Field {
  127. const string name;
  128. const NodePtr schema;
  129. const GenericDatum defaultValue;
  130. const CustomAttributes customAttributes;
  131. Field(string n, NodePtr v, GenericDatum dv, const CustomAttributes& ca) : name(std::move(n)), schema(std::move(v)), defaultValue(std::move(dv)), customAttributes(std::move(ca)) {}
  132. };
  133. static void assertType(const Entity &e, EntityType et) {
  134. if (e.type() != et) {
  135. throw Exception(boost::format("Unexpected type for default value: "
  136. "Expected %1%, but found %2% in line %3%")
  137. % json::typeToString(et) % json::typeToString(e.type()) % e.line());
  138. }
  139. }
  140. static vector<uint8_t> toBin(const string &s) {
  141. vector<uint8_t> result(s.size());
  142. if (!s.empty()) {
  143. std::copy(s.c_str(), s.c_str() + s.size(), result.data());
  144. }
  145. return result;
  146. }
  147. static GenericDatum makeGenericDatum(NodePtr n,
  148. const Entity &e, const SymbolTable &st) {
  149. Type t = n->type();
  150. EntityType dt = e.type();
  151. if (t == AVRO_SYMBOLIC) {
  152. n = st.find(n->name())->second;
  153. t = n->type();
  154. }
  155. switch (t) {
  156. case AVRO_STRING:
  157. assertType(e, json::EntityType::String);
  158. return GenericDatum(e.stringValue());
  159. case AVRO_BYTES:
  160. assertType(e, json::EntityType::String);
  161. return GenericDatum(toBin(e.bytesValue()));
  162. case AVRO_INT:
  163. assertType(e, json::EntityType::Long);
  164. return GenericDatum(static_cast<int32_t>(e.longValue()));
  165. case AVRO_LONG:
  166. assertType(e, json::EntityType::Long);
  167. return GenericDatum(e.longValue());
  168. case AVRO_FLOAT:
  169. if (dt == json::EntityType::Long) {
  170. return GenericDatum(static_cast<float>(e.longValue()));
  171. }
  172. assertType(e, json::EntityType::Double);
  173. return GenericDatum(static_cast<float>(e.doubleValue()));
  174. case AVRO_DOUBLE:
  175. if (dt == json::EntityType::Long) {
  176. return GenericDatum(static_cast<double>(e.longValue()));
  177. }
  178. assertType(e, json::EntityType::Double);
  179. return GenericDatum(e.doubleValue());
  180. case AVRO_BOOL:
  181. assertType(e, json::EntityType::Bool);
  182. return GenericDatum(e.boolValue());
  183. case AVRO_NULL:
  184. assertType(e, json::EntityType::Null);
  185. return GenericDatum();
  186. case AVRO_RECORD: {
  187. assertType(e, json::EntityType::Obj);
  188. GenericRecord result(n);
  189. const map<string, Entity> &v = e.objectValue();
  190. for (size_t i = 0; i < n->leaves(); ++i) {
  191. auto it = v.find(n->nameAt(i));
  192. if (it == v.end()) {
  193. throw Exception(boost::format(
  194. "No value found in default for %1%")
  195. % n->nameAt(i));
  196. }
  197. result.setFieldAt(i,
  198. makeGenericDatum(n->leafAt(i), it->second, st));
  199. }
  200. return GenericDatum(n, result);
  201. }
  202. case AVRO_ENUM:
  203. assertType(e, json::EntityType::String);
  204. return GenericDatum(n, GenericEnum(n, e.stringValue()));
  205. case AVRO_ARRAY: {
  206. assertType(e, json::EntityType::Arr);
  207. GenericArray result(n);
  208. const vector<Entity> &elements = e.arrayValue();
  209. for (const auto &element : elements) {
  210. result.value().push_back(makeGenericDatum(n->leafAt(0), element, st));
  211. }
  212. return GenericDatum(n, result);
  213. }
  214. case AVRO_MAP: {
  215. assertType(e, json::EntityType::Obj);
  216. GenericMap result(n);
  217. const map<string, Entity> &v = e.objectValue();
  218. for (const auto &it : v) {
  219. result.value().push_back(make_pair(it.first,
  220. makeGenericDatum(n->leafAt(1), it.second, st)));
  221. }
  222. return GenericDatum(n, result);
  223. }
  224. case AVRO_UNION: {
  225. GenericUnion result(n);
  226. result.selectBranch(0);
  227. result.datum() = makeGenericDatum(n->leafAt(0), e, st);
  228. return GenericDatum(n, result);
  229. }
  230. case AVRO_FIXED:
  231. assertType(e, json::EntityType::String);
  232. return GenericDatum(n, GenericFixed(n, toBin(e.bytesValue())));
  233. default: throw Exception(boost::format("Unknown type: %1%") % t);
  234. }
  235. }
  236. static const std::unordered_set<std::string>& getKnownFields() {
  237. // return known fields
  238. static const std::unordered_set<std::string> kKnownFields =
  239. {"name", "type", "default", "doc", "size", "logicalType",
  240. "values", "precision", "scale", "namespace"};
  241. return kKnownFields;
  242. }
  243. static void getCustomAttributes(const Object& m, CustomAttributes &customAttributes)
  244. {
  245. // Don't add known fields on primitive type and fixed type into custom
  246. // fields.
  247. const std::unordered_set<std::string>& kKnownFields = getKnownFields();
  248. for (const auto &entry : m) {
  249. if (kKnownFields.find(entry.first) == kKnownFields.end()) {
  250. customAttributes.addAttribute(entry.first, entry.second.stringValue());
  251. }
  252. }
  253. }
  254. static Field makeField(const Entity &e, SymbolTable &st, const string &ns) {
  255. const Object &m = e.objectValue();
  256. const string &n = getStringField(e, m, "name");
  257. auto it = findField(e, m, "type");
  258. auto it2 = m.find("default");
  259. NodePtr node = makeNode(it->second, st, ns);
  260. if (containsField(m, "doc")) {
  261. node->setDoc(getDocField(e, m));
  262. }
  263. GenericDatum d = (it2 == m.end()) ? GenericDatum() : makeGenericDatum(node, it2->second, st);
  264. // Get custom attributes
  265. CustomAttributes customAttributes;
  266. getCustomAttributes(m, customAttributes);
  267. return Field(n, node, d, customAttributes);
  268. }
  269. // Extended makeRecordNode (with doc).
  270. static NodePtr makeRecordNode(const Entity &e, const Name &name,
  271. const string *doc, const Object &m,
  272. SymbolTable &st, const string &ns) {
  273. const Array &v = getArrayField(e, m, "fields");
  274. concepts::MultiAttribute<string> fieldNames;
  275. concepts::MultiAttribute<NodePtr> fieldValues;
  276. concepts::MultiAttribute<CustomAttributes> customAttributes;
  277. vector<GenericDatum> defaultValues;
  278. for (const auto &it : v) {
  279. Field f = makeField(it, st, ns);
  280. fieldNames.add(f.name);
  281. fieldValues.add(f.schema);
  282. defaultValues.push_back(f.defaultValue);
  283. customAttributes.add(f.customAttributes);
  284. }
  285. NodeRecord *node;
  286. if (doc == nullptr) {
  287. node = new NodeRecord(asSingleAttribute(name), fieldValues, fieldNames,
  288. defaultValues, customAttributes);
  289. } else {
  290. node = new NodeRecord(asSingleAttribute(name), asSingleAttribute(*doc),
  291. fieldValues, fieldNames, defaultValues, customAttributes);
  292. }
  293. return NodePtr(node);
  294. }
  295. static LogicalType makeLogicalType(const Entity &e, const Object &m) {
  296. if (!containsField(m, "logicalType")) {
  297. return LogicalType(LogicalType::NONE);
  298. }
  299. const std::string &typeField = getStringField(e, m, "logicalType");
  300. if (typeField == "decimal") {
  301. LogicalType decimalType(LogicalType::DECIMAL);
  302. try {
  303. decimalType.setPrecision(getLongField(e, m, "precision"));
  304. if (containsField(m, "scale")) {
  305. decimalType.setScale(getLongField(e, m, "scale"));
  306. }
  307. } catch (Exception &ex) {
  308. // If any part of the logical type is malformed, per the standard we
  309. // must ignore the whole attribute.
  310. return LogicalType(LogicalType::NONE);
  311. }
  312. return decimalType;
  313. }
  314. LogicalType::Type t = LogicalType::NONE;
  315. if (typeField == "date")
  316. t = LogicalType::DATE;
  317. else if (typeField == "time-millis")
  318. t = LogicalType::TIME_MILLIS;
  319. else if (typeField == "time-micros")
  320. t = LogicalType::TIME_MICROS;
  321. else if (typeField == "timestamp-millis")
  322. t = LogicalType::TIMESTAMP_MILLIS;
  323. else if (typeField == "timestamp-micros")
  324. t = LogicalType::TIMESTAMP_MICROS;
  325. else if (typeField == "duration")
  326. t = LogicalType::DURATION;
  327. else if (typeField == "uuid")
  328. t = LogicalType::UUID;
  329. return LogicalType(t);
  330. }
  331. static NodePtr makeEnumNode(const Entity &e,
  332. const Name &name, const Object &m) {
  333. const Array &v = getArrayField(e, m, "symbols");
  334. concepts::MultiAttribute<string> symbols;
  335. for (const auto &it : v) {
  336. if (it.type() != json::EntityType::String) {
  337. throw Exception(boost::format("Enum symbol not a string: %1%") % it.toString());
  338. }
  339. symbols.add(it.stringValue());
  340. }
  341. NodePtr node = NodePtr(new NodeEnum(asSingleAttribute(name), symbols));
  342. if (containsField(m, "doc")) {
  343. node->setDoc(getDocField(e, m));
  344. }
  345. return node;
  346. }
  347. static NodePtr makeFixedNode(const Entity &e,
  348. const Name &name, const Object &m) {
  349. int v = static_cast<int>(getLongField(e, m, "size"));
  350. if (v <= 0) {
  351. throw Exception(boost::format("Size for fixed is not positive: %1%") % e.toString());
  352. }
  353. NodePtr node =
  354. NodePtr(new NodeFixed(asSingleAttribute(name), asSingleAttribute(v)));
  355. if (containsField(m, "doc")) {
  356. node->setDoc(getDocField(e, m));
  357. }
  358. return node;
  359. }
  360. static NodePtr makeArrayNode(const Entity &e, const Object &m,
  361. SymbolTable &st, const string &ns) {
  362. auto it = findField(e, m, "items");
  363. NodePtr node = NodePtr(new NodeArray(
  364. asSingleAttribute(makeNode(it->second, st, ns))));
  365. if (containsField(m, "doc")) {
  366. node->setDoc(getDocField(e, m));
  367. }
  368. return node;
  369. }
  370. static NodePtr makeMapNode(const Entity &e, const Object &m,
  371. SymbolTable &st, const string &ns) {
  372. auto it = findField(e, m, "values");
  373. NodePtr node = NodePtr(new NodeMap(
  374. asSingleAttribute(makeNode(it->second, st, ns))));
  375. if (containsField(m, "doc")) {
  376. node->setDoc(getDocField(e, m));
  377. }
  378. return node;
  379. }
  380. static Name getName(const Entity &e, const Object &m, const string &ns) {
  381. const string &name = getStringField(e, m, "name");
  382. if (isFullName(name)) {
  383. return Name(name);
  384. } else {
  385. auto it = m.find("namespace");
  386. if (it != m.end()) {
  387. if (it->second.type() != json::type_traits<string>::type()) {
  388. throw Exception(boost::format(
  389. "Json field \"%1%\" is not a %2%: %3%")
  390. % "namespace" % json::type_traits<string>::name() % it->second.toString());
  391. }
  392. Name result = Name(name, it->second.stringValue());
  393. return result;
  394. }
  395. return Name(name, ns);
  396. }
  397. }
  398. static NodePtr makeNode(const Entity &e, const Object &m,
  399. SymbolTable &st, const string &ns) {
  400. const string &type = getStringField(e, m, "type");
  401. NodePtr result;
  402. if (type == "record" || type == "error" || type == "enum" || type == "fixed") {
  403. Name nm = getName(e, m, ns);
  404. if (type == "record" || type == "error") {
  405. result = NodePtr(new NodeRecord());
  406. st[nm] = result;
  407. // Get field doc
  408. if (containsField(m, "doc")) {
  409. string doc = getDocField(e, m);
  410. NodePtr r = makeRecordNode(e, nm, &doc, m, st, nm.ns());
  411. (std::dynamic_pointer_cast<NodeRecord>(r))->swap(*std::dynamic_pointer_cast<NodeRecord>(result));
  412. } else { // No doc
  413. NodePtr r =
  414. makeRecordNode(e, nm, nullptr, m, st, nm.ns());
  415. (std::dynamic_pointer_cast<NodeRecord>(r))
  416. ->swap(*std::dynamic_pointer_cast<NodeRecord>(result));
  417. }
  418. } else {
  419. result = (type == "enum") ? makeEnumNode(e, nm, m) : makeFixedNode(e, nm, m);
  420. st[nm] = result;
  421. }
  422. } else if (type == "array") {
  423. result = makeArrayNode(e, m, st, ns);
  424. } else if (type == "map") {
  425. result = makeMapNode(e, m, st, ns);
  426. } else {
  427. result = makePrimitive(type);
  428. }
  429. if (result) {
  430. try {
  431. result->setLogicalType(makeLogicalType(e, m));
  432. } catch (Exception &ex) {
  433. // Per the standard we must ignore the logical type attribute if it
  434. // is malformed.
  435. }
  436. return result;
  437. }
  438. throw Exception(boost::format("Unknown type definition: %1%")
  439. % e.toString());
  440. }
  441. static NodePtr makeNode(const Entity &e, const Array &m,
  442. SymbolTable &st, const string &ns) {
  443. concepts::MultiAttribute<NodePtr> mm;
  444. for (const auto &it : m) {
  445. mm.add(makeNode(it, st, ns));
  446. }
  447. return NodePtr(new NodeUnion(mm));
  448. }
  449. static NodePtr makeNode(const json::Entity &e, SymbolTable &st, const string &ns) {
  450. switch (e.type()) {
  451. case json::EntityType::String: return makeNode(e.stringValue(), st, ns);
  452. case json::EntityType::Obj: return makeNode(e, e.objectValue(), st, ns);
  453. case json::EntityType::Arr: return makeNode(e, e.arrayValue(), st, ns);
  454. default: throw Exception(boost::format("Invalid Avro type: %1%") % e.toString());
  455. }
  456. }
  457. json::Object::const_iterator findField(const Entity &e, const Object &m, const string &fieldName) {
  458. auto it = m.find(fieldName);
  459. if (it == m.end()) {
  460. throw Exception(boost::format("Missing Json field \"%1%\": %2%") % fieldName % e.toString());
  461. } else {
  462. return it;
  463. }
  464. }
  465. const Array &getArrayField(const Entity &e, const Object &m, const string &fieldName) {
  466. auto it = findField(e, m, fieldName);
  467. ensureType<Array>(it->second, fieldName);
  468. return it->second.arrayValue();
  469. }
  470. ValidSchema compileJsonSchemaFromStream(InputStream &is) {
  471. json::Entity e = json::loadEntity(is);
  472. SymbolTable st;
  473. NodePtr n = makeNode(e, st, "");
  474. return ValidSchema(n);
  475. }
  476. AVRO_DECL ValidSchema compileJsonSchemaFromFile(const char *filename) {
  477. std::unique_ptr<InputStream> s = fileInputStream(filename);
  478. return compileJsonSchemaFromStream(*s);
  479. }
  480. AVRO_DECL ValidSchema compileJsonSchemaFromMemory(const uint8_t *input, size_t len) {
  481. return compileJsonSchemaFromStream(*memoryInputStream(input, len));
  482. }
  483. AVRO_DECL ValidSchema compileJsonSchemaFromString(const char *input) {
  484. return compileJsonSchemaFromMemory(reinterpret_cast<const uint8_t *>(input),
  485. ::strlen(input));
  486. }
  487. AVRO_DECL ValidSchema compileJsonSchemaFromString(const string &input) {
  488. return compileJsonSchemaFromMemory(
  489. reinterpret_cast<const uint8_t *>(input.data()), input.size());
  490. }
  491. static ValidSchema compile(std::istream &is) {
  492. std::unique_ptr<InputStream> in = istreamInputStream(is);
  493. return compileJsonSchemaFromStream(*in);
  494. }
  495. void compileJsonSchema(std::istream &is, ValidSchema &schema) {
  496. if (!is.good()) {
  497. throw Exception("Input stream is not good");
  498. }
  499. schema = compile(is);
  500. }
  501. AVRO_DECL bool compileJsonSchema(std::istream &is, ValidSchema &schema, string &error) {
  502. try {
  503. compileJsonSchema(is, schema);
  504. return true;
  505. } catch (const Exception &e) {
  506. error = e.what();
  507. return false;
  508. }
  509. }
  510. } // namespace avro