mutator.cc 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812
  1. // Copyright 2016 Google Inc. All rights reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "src/mutator.h"
  15. #include <algorithm>
  16. #include <bitset>
  17. #include <iostream>
  18. #include <map>
  19. #include <memory>
  20. #include <random>
  21. #include <string>
  22. #include <utility>
  23. #include <vector>
  24. #include "src/field_instance.h"
  25. #include "src/utf8_fix.h"
  26. #include "src/weighted_reservoir_sampler.h"
  27. namespace protobuf_mutator {
  28. using google::protobuf::Any;
  29. using protobuf::Descriptor;
  30. using protobuf::FieldDescriptor;
  31. using protobuf::FileDescriptor;
  32. using protobuf::Message;
  33. using protobuf::OneofDescriptor;
  34. using protobuf::Reflection;
  35. using protobuf::util::MessageDifferencer;
  36. using std::placeholders::_1;
  37. namespace {
  38. const int kMaxInitializeDepth = 200;
  39. const uint64_t kDefaultMutateWeight = 1000000;
  40. enum class Mutation : uint8_t {
  41. None,
  42. Add, // Adds new field with default value.
  43. Mutate, // Mutates field contents.
  44. Delete, // Deletes field.
  45. Copy, // Copy values copied from another field.
  46. Clone, // Create new field with value copied from another.
  47. Last = Clone,
  48. };
  49. using MutationBitset = std::bitset<static_cast<size_t>(Mutation::Last) + 1>;
  50. using Messages = std::vector<Message*>;
  51. using ConstMessages = std::vector<const Message*>;
  52. // Return random integer from [0, count)
  53. size_t GetRandomIndex(RandomEngine* random, size_t count) {
  54. assert(count > 0);
  55. if (count == 1) return 0;
  56. return std::uniform_int_distribution<size_t>(0, count - 1)(*random);
  57. }
  58. // Flips random bit in the buffer.
  59. void FlipBit(size_t size, uint8_t* bytes, RandomEngine* random) {
  60. size_t bit = GetRandomIndex(random, size * 8);
  61. bytes[bit / 8] ^= (1u << (bit % 8));
  62. }
  63. // Flips random bit in the value.
  64. template <class T>
  65. T FlipBit(T value, RandomEngine* random) {
  66. FlipBit(sizeof(value), reinterpret_cast<uint8_t*>(&value), random);
  67. return value;
  68. }
  69. // Return true with probability about 1-of-n.
  70. bool GetRandomBool(RandomEngine* random, size_t n = 2) {
  71. return GetRandomIndex(random, n) == 0;
  72. }
  73. bool IsProto3SimpleField(const FieldDescriptor& field) {
  74. assert(field.file()->syntax() == FileDescriptor::SYNTAX_PROTO3 ||
  75. field.file()->syntax() == FileDescriptor::SYNTAX_PROTO2);
  76. return field.file()->syntax() == FileDescriptor::SYNTAX_PROTO3 &&
  77. field.cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE &&
  78. !field.containing_oneof() && !field.is_repeated();
  79. }
  80. struct CreateDefaultField : public FieldFunction<CreateDefaultField> {
  81. template <class T>
  82. void ForType(const FieldInstance& field) const {
  83. T value;
  84. field.GetDefault(&value);
  85. field.Create(value);
  86. }
  87. };
  88. struct DeleteField : public FieldFunction<DeleteField> {
  89. template <class T>
  90. void ForType(const FieldInstance& field) const {
  91. field.Delete();
  92. }
  93. };
  94. struct CopyField : public FieldFunction<CopyField> {
  95. template <class T>
  96. void ForType(const ConstFieldInstance& source,
  97. const FieldInstance& field) const {
  98. T value;
  99. source.Load(&value);
  100. field.Store(value);
  101. }
  102. };
  103. struct AppendField : public FieldFunction<AppendField> {
  104. template <class T>
  105. void ForType(const ConstFieldInstance& source,
  106. const FieldInstance& field) const {
  107. T value;
  108. source.Load(&value);
  109. field.Create(value);
  110. }
  111. };
  112. class CanCopyAndDifferentField
  113. : public FieldFunction<CanCopyAndDifferentField, bool> {
  114. public:
  115. template <class T>
  116. bool ForType(const ConstFieldInstance& src, const ConstFieldInstance& dst,
  117. int size_increase_hint) const {
  118. T s;
  119. src.Load(&s);
  120. if (!dst.CanStore(s)) return false;
  121. T d;
  122. dst.Load(&d);
  123. return SizeDiff(s, d) <= size_increase_hint && !IsEqual(s, d);
  124. }
  125. private:
  126. bool IsEqual(const ConstFieldInstance::Enum& a,
  127. const ConstFieldInstance::Enum& b) const {
  128. assert(a.count == b.count);
  129. return a.index == b.index;
  130. }
  131. bool IsEqual(const std::unique_ptr<Message>& a,
  132. const std::unique_ptr<Message>& b) const {
  133. return MessageDifferencer::Equals(*a, *b);
  134. }
  135. template <class T>
  136. bool IsEqual(const T& a, const T& b) const {
  137. return a == b;
  138. }
  139. int64_t SizeDiff(const std::unique_ptr<Message>& src,
  140. const std::unique_ptr<Message>& dst) const {
  141. return src->ByteSizeLong() - dst->ByteSizeLong();
  142. }
  143. int64_t SizeDiff(const TProtoStringType& src, const TProtoStringType& dst) const {
  144. return src.size() - dst.size();
  145. }
  146. template <class T>
  147. int64_t SizeDiff(const T&, const T&) const {
  148. return 0;
  149. }
  150. };
  151. // Selects random field and mutation from the given proto message.
  152. class MutationSampler {
  153. public:
  154. MutationSampler(bool keep_initialized, MutationBitset allowed_mutations,
  155. RandomEngine* random)
  156. : keep_initialized_(keep_initialized),
  157. allowed_mutations_(allowed_mutations),
  158. random_(random),
  159. sampler_(random) {}
  160. // Returns selected field.
  161. const FieldInstance& field() const { return sampler_.selected().field; }
  162. // Returns selected mutation.
  163. Mutation mutation() const { return sampler_.selected().mutation; }
  164. void Sample(Message* message) {
  165. SampleImpl(message);
  166. assert(mutation() != Mutation::None ||
  167. !allowed_mutations_[static_cast<size_t>(Mutation::Mutate)] ||
  168. message->GetDescriptor()->field_count() == 0);
  169. }
  170. private:
  171. void SampleImpl(Message* message) {
  172. const Descriptor* descriptor = message->GetDescriptor();
  173. const Reflection* reflection = message->GetReflection();
  174. int field_count = descriptor->field_count();
  175. for (int i = 0; i < field_count; ++i) {
  176. const FieldDescriptor* field = descriptor->field(i);
  177. if (const OneofDescriptor* oneof = field->containing_oneof()) {
  178. // Handle entire oneof group on the first field.
  179. if (field->index_in_oneof() == 0) {
  180. assert(oneof->field_count());
  181. const FieldDescriptor* current_field =
  182. reflection->GetOneofFieldDescriptor(*message, oneof);
  183. for (;;) {
  184. const FieldDescriptor* add_field =
  185. oneof->field(GetRandomIndex(random_, oneof->field_count()));
  186. if (add_field != current_field) {
  187. Try({message, add_field}, Mutation::Add);
  188. Try({message, add_field}, Mutation::Clone);
  189. break;
  190. }
  191. if (oneof->field_count() < 2) break;
  192. }
  193. if (current_field) {
  194. if (current_field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE)
  195. Try({message, current_field}, Mutation::Mutate);
  196. Try({message, current_field}, Mutation::Delete);
  197. Try({message, current_field}, Mutation::Copy);
  198. }
  199. }
  200. } else {
  201. if (field->is_repeated()) {
  202. int field_size = reflection->FieldSize(*message, field);
  203. size_t random_index = GetRandomIndex(random_, field_size + 1);
  204. Try({message, field, random_index}, Mutation::Add);
  205. Try({message, field, random_index}, Mutation::Clone);
  206. if (field_size) {
  207. size_t random_index = GetRandomIndex(random_, field_size);
  208. if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE)
  209. Try({message, field, random_index}, Mutation::Mutate);
  210. Try({message, field, random_index}, Mutation::Delete);
  211. Try({message, field, random_index}, Mutation::Copy);
  212. }
  213. } else {
  214. if (reflection->HasField(*message, field) ||
  215. IsProto3SimpleField(*field)) {
  216. if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE)
  217. Try({message, field}, Mutation::Mutate);
  218. if (!IsProto3SimpleField(*field) &&
  219. (!field->is_required() || !keep_initialized_)) {
  220. Try({message, field}, Mutation::Delete);
  221. }
  222. Try({message, field}, Mutation::Copy);
  223. } else {
  224. Try({message, field}, Mutation::Add);
  225. Try({message, field}, Mutation::Clone);
  226. }
  227. }
  228. }
  229. if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
  230. if (field->is_repeated()) {
  231. const int field_size = reflection->FieldSize(*message, field);
  232. for (int j = 0; j < field_size; ++j)
  233. SampleImpl(reflection->MutableRepeatedMessage(message, field, j));
  234. } else if (reflection->HasField(*message, field)) {
  235. SampleImpl(reflection->MutableMessage(message, field));
  236. }
  237. }
  238. }
  239. }
  240. void Try(const FieldInstance& field, Mutation mutation) {
  241. assert(mutation != Mutation::None);
  242. if (!allowed_mutations_[static_cast<size_t>(mutation)]) return;
  243. sampler_.Try(kDefaultMutateWeight, {field, mutation});
  244. }
  245. bool keep_initialized_ = false;
  246. MutationBitset allowed_mutations_;
  247. RandomEngine* random_;
  248. struct Result {
  249. Result() = default;
  250. Result(const FieldInstance& f, Mutation m) : field(f), mutation(m) {}
  251. FieldInstance field;
  252. Mutation mutation = Mutation::None;
  253. };
  254. WeightedReservoirSampler<Result, RandomEngine> sampler_;
  255. };
  256. // Selects random field of compatible type to use for clone mutations.
  257. class DataSourceSampler {
  258. public:
  259. DataSourceSampler(const ConstFieldInstance& match, RandomEngine* random,
  260. int size_increase_hint)
  261. : match_(match),
  262. random_(random),
  263. size_increase_hint_(size_increase_hint),
  264. sampler_(random) {}
  265. void Sample(const Message& message) { SampleImpl(message); }
  266. // Returns selected field.
  267. const ConstFieldInstance& field() const {
  268. assert(!IsEmpty());
  269. return sampler_.selected();
  270. }
  271. bool IsEmpty() const { return sampler_.IsEmpty(); }
  272. private:
  273. void SampleImpl(const Message& message) {
  274. const Descriptor* descriptor = message.GetDescriptor();
  275. const Reflection* reflection = message.GetReflection();
  276. int field_count = descriptor->field_count();
  277. for (int i = 0; i < field_count; ++i) {
  278. const FieldDescriptor* field = descriptor->field(i);
  279. if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
  280. if (field->is_repeated()) {
  281. const int field_size = reflection->FieldSize(message, field);
  282. for (int j = 0; j < field_size; ++j) {
  283. SampleImpl(reflection->GetRepeatedMessage(message, field, j));
  284. }
  285. } else if (reflection->HasField(message, field)) {
  286. SampleImpl(reflection->GetMessage(message, field));
  287. }
  288. }
  289. if (field->cpp_type() != match_.cpp_type()) continue;
  290. if (match_.cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
  291. if (field->enum_type() != match_.enum_type()) continue;
  292. } else if (match_.cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
  293. if (field->message_type() != match_.message_type()) continue;
  294. }
  295. if (field->is_repeated()) {
  296. if (int field_size = reflection->FieldSize(message, field)) {
  297. ConstFieldInstance source(&message, field,
  298. GetRandomIndex(random_, field_size));
  299. if (CanCopyAndDifferentField()(source, match_, size_increase_hint_))
  300. sampler_.Try(field_size, source);
  301. }
  302. } else {
  303. if (reflection->HasField(message, field)) {
  304. ConstFieldInstance source(&message, field);
  305. if (CanCopyAndDifferentField()(source, match_, size_increase_hint_))
  306. sampler_.Try(1, source);
  307. }
  308. }
  309. }
  310. }
  311. ConstFieldInstance match_;
  312. RandomEngine* random_;
  313. int size_increase_hint_;
  314. WeightedReservoirSampler<ConstFieldInstance, RandomEngine> sampler_;
  315. };
  316. using UnpackedAny =
  317. std::unordered_map<const Message*, std::unique_ptr<Message>>;
  318. const Descriptor* GetAnyTypeDescriptor(const Any& any) {
  319. TProtoStringType type_name;
  320. if (!Any::ParseAnyTypeUrl(TProtoStringType(any.type_url()), &type_name))
  321. return nullptr;
  322. return any.descriptor()->file()->pool()->FindMessageTypeByName(type_name);
  323. }
  324. std::unique_ptr<Message> UnpackAny(const Any& any) {
  325. const Descriptor* desc = GetAnyTypeDescriptor(any);
  326. if (!desc) return {};
  327. std::unique_ptr<Message> message(
  328. any.GetReflection()->GetMessageFactory()->GetPrototype(desc)->New());
  329. message->ParsePartialFromString(TProtoStringType(any.value()));
  330. return message;
  331. }
  332. const Any* CastToAny(const Message* message) {
  333. return Any::GetDescriptor() == message->GetDescriptor()
  334. ? static_cast<const Any*>(message)
  335. : nullptr;
  336. }
  337. Any* CastToAny(Message* message) {
  338. return Any::GetDescriptor() == message->GetDescriptor()
  339. ? static_cast<Any*>(message)
  340. : nullptr;
  341. }
  342. std::unique_ptr<Message> UnpackIfAny(const Message& message) {
  343. if (const Any* any = CastToAny(&message)) return UnpackAny(*any);
  344. return {};
  345. }
  346. void UnpackAny(const Message& message, UnpackedAny* result) {
  347. if (std::unique_ptr<Message> any = UnpackIfAny(message)) {
  348. UnpackAny(*any, result);
  349. result->emplace(&message, std::move(any));
  350. return;
  351. }
  352. const Descriptor* descriptor = message.GetDescriptor();
  353. const Reflection* reflection = message.GetReflection();
  354. for (int i = 0; i < descriptor->field_count(); ++i) {
  355. const FieldDescriptor* field = descriptor->field(i);
  356. if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
  357. if (field->is_repeated()) {
  358. const int field_size = reflection->FieldSize(message, field);
  359. for (int j = 0; j < field_size; ++j) {
  360. UnpackAny(reflection->GetRepeatedMessage(message, field, j), result);
  361. }
  362. } else if (reflection->HasField(message, field)) {
  363. UnpackAny(reflection->GetMessage(message, field), result);
  364. }
  365. }
  366. }
  367. }
  368. class PostProcessing {
  369. public:
  370. using PostProcessors =
  371. std::unordered_multimap<const Descriptor*, Mutator::PostProcess>;
  372. PostProcessing(bool keep_initialized, const PostProcessors& post_processors,
  373. const UnpackedAny& any, RandomEngine* random)
  374. : keep_initialized_(keep_initialized),
  375. post_processors_(post_processors),
  376. any_(any),
  377. random_(random) {}
  378. void Run(Message* message, int max_depth) {
  379. --max_depth;
  380. const Descriptor* descriptor = message->GetDescriptor();
  381. // Apply custom mutators in nested messages before packing any.
  382. const Reflection* reflection = message->GetReflection();
  383. for (int i = 0; i < descriptor->field_count(); i++) {
  384. const FieldDescriptor* field = descriptor->field(i);
  385. if (keep_initialized_ &&
  386. (field->is_required() || descriptor->options().map_entry()) &&
  387. !reflection->HasField(*message, field)) {
  388. CreateDefaultField()(FieldInstance(message, field));
  389. }
  390. if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) continue;
  391. if (max_depth < 0 && !field->is_required()) {
  392. // Clear deep optional fields to avoid stack overflow.
  393. reflection->ClearField(message, field);
  394. if (field->is_repeated())
  395. assert(!reflection->FieldSize(*message, field));
  396. else
  397. assert(!reflection->HasField(*message, field));
  398. continue;
  399. }
  400. if (field->is_repeated()) {
  401. const int field_size = reflection->FieldSize(*message, field);
  402. for (int j = 0; j < field_size; ++j) {
  403. Message* nested_message =
  404. reflection->MutableRepeatedMessage(message, field, j);
  405. Run(nested_message, max_depth);
  406. }
  407. } else if (reflection->HasField(*message, field)) {
  408. Message* nested_message = reflection->MutableMessage(message, field);
  409. Run(nested_message, max_depth);
  410. }
  411. }
  412. if (Any* any = CastToAny(message)) {
  413. if (max_depth < 0) {
  414. // Clear deep Any fields to avoid stack overflow.
  415. any->Clear();
  416. } else {
  417. auto It = any_.find(message);
  418. if (It != any_.end()) {
  419. Run(It->second.get(), max_depth);
  420. TProtoStringType value;
  421. It->second->SerializePartialToString(&value);
  422. *any->mutable_value() = value;
  423. }
  424. }
  425. }
  426. // Call user callback after message trimmed, initialized and packed.
  427. auto range = post_processors_.equal_range(descriptor);
  428. for (auto it = range.first; it != range.second; ++it)
  429. it->second(message, (*random_)());
  430. }
  431. private:
  432. bool keep_initialized_;
  433. const PostProcessors& post_processors_;
  434. const UnpackedAny& any_;
  435. RandomEngine* random_;
  436. };
  437. } // namespace
  438. class FieldMutator {
  439. public:
  440. FieldMutator(int size_increase_hint, bool enforce_changes,
  441. bool enforce_utf8_strings, const ConstMessages& sources,
  442. Mutator* mutator)
  443. : size_increase_hint_(size_increase_hint),
  444. enforce_changes_(enforce_changes),
  445. enforce_utf8_strings_(enforce_utf8_strings),
  446. sources_(sources),
  447. mutator_(mutator) {}
  448. void Mutate(int32_t* value) const {
  449. RepeatMutate(value, std::bind(&Mutator::MutateInt32, mutator_, _1));
  450. }
  451. void Mutate(int64_t* value) const {
  452. RepeatMutate(value, std::bind(&Mutator::MutateInt64, mutator_, _1));
  453. }
  454. void Mutate(uint32_t* value) const {
  455. RepeatMutate(value, std::bind(&Mutator::MutateUInt32, mutator_, _1));
  456. }
  457. void Mutate(uint64_t* value) const {
  458. RepeatMutate(value, std::bind(&Mutator::MutateUInt64, mutator_, _1));
  459. }
  460. void Mutate(float* value) const {
  461. RepeatMutate(value, std::bind(&Mutator::MutateFloat, mutator_, _1));
  462. }
  463. void Mutate(double* value) const {
  464. RepeatMutate(value, std::bind(&Mutator::MutateDouble, mutator_, _1));
  465. }
  466. void Mutate(bool* value) const {
  467. RepeatMutate(value, std::bind(&Mutator::MutateBool, mutator_, _1));
  468. }
  469. void Mutate(FieldInstance::Enum* value) const {
  470. RepeatMutate(&value->index,
  471. std::bind(&Mutator::MutateEnum, mutator_, _1, value->count));
  472. assert(value->index < value->count);
  473. }
  474. void Mutate(TProtoStringType* value) const {
  475. if (enforce_utf8_strings_) {
  476. RepeatMutate(value, std::bind(&Mutator::MutateUtf8String, mutator_, _1,
  477. size_increase_hint_));
  478. } else {
  479. RepeatMutate(value, std::bind(&Mutator::MutateString, mutator_, _1,
  480. size_increase_hint_));
  481. }
  482. }
  483. void Mutate(std::unique_ptr<Message>* message) const {
  484. assert(!enforce_changes_);
  485. assert(*message);
  486. if (GetRandomBool(mutator_->random(), mutator_->random_to_default_ratio_))
  487. return;
  488. mutator_->MutateImpl(sources_, {message->get()}, false,
  489. size_increase_hint_);
  490. }
  491. private:
  492. template <class T, class F>
  493. void RepeatMutate(T* value, F mutate) const {
  494. if (!enforce_changes_ &&
  495. GetRandomBool(mutator_->random(), mutator_->random_to_default_ratio_)) {
  496. return;
  497. }
  498. T tmp = *value;
  499. for (int i = 0; i < 10; ++i) {
  500. *value = mutate(*value);
  501. if (!enforce_changes_ || *value != tmp) return;
  502. }
  503. }
  504. int size_increase_hint_;
  505. size_t enforce_changes_;
  506. bool enforce_utf8_strings_;
  507. const ConstMessages& sources_;
  508. Mutator* mutator_;
  509. };
  510. namespace {
  511. struct MutateField : public FieldFunction<MutateField> {
  512. template <class T>
  513. void ForType(const FieldInstance& field, int size_increase_hint,
  514. const ConstMessages& sources, Mutator* mutator) const {
  515. T value;
  516. field.Load(&value);
  517. FieldMutator(size_increase_hint, true, field.EnforceUtf8(), sources,
  518. mutator)
  519. .Mutate(&value);
  520. field.Store(value);
  521. }
  522. };
  523. struct CreateField : public FieldFunction<CreateField> {
  524. public:
  525. template <class T>
  526. void ForType(const FieldInstance& field, int size_increase_hint,
  527. const ConstMessages& sources, Mutator* mutator) const {
  528. T value;
  529. field.GetDefault(&value);
  530. FieldMutator field_mutator(size_increase_hint,
  531. false /* defaults could be useful */,
  532. field.EnforceUtf8(), sources, mutator);
  533. field_mutator.Mutate(&value);
  534. field.Create(value);
  535. }
  536. };
  537. } // namespace
  538. void Mutator::Seed(uint32_t value) { random_.seed(value); }
  539. void Mutator::Fix(Message* message) {
  540. UnpackedAny any;
  541. UnpackAny(*message, &any);
  542. PostProcessing(keep_initialized_, post_processors_, any, &random_)
  543. .Run(message, kMaxInitializeDepth);
  544. assert(IsInitialized(*message));
  545. }
  546. void Mutator::Mutate(Message* message, size_t max_size_hint) {
  547. UnpackedAny any;
  548. UnpackAny(*message, &any);
  549. Messages messages;
  550. messages.reserve(any.size() + 1);
  551. messages.push_back(message);
  552. for (const auto& kv : any) messages.push_back(kv.second.get());
  553. ConstMessages sources(messages.begin(), messages.end());
  554. MutateImpl(sources, messages, false,
  555. static_cast<int>(max_size_hint) -
  556. static_cast<int>(message->ByteSizeLong()));
  557. PostProcessing(keep_initialized_, post_processors_, any, &random_)
  558. .Run(message, kMaxInitializeDepth);
  559. assert(IsInitialized(*message));
  560. }
  561. void Mutator::CrossOver(const Message& message1, Message* message2,
  562. size_t max_size_hint) {
  563. UnpackedAny any;
  564. UnpackAny(*message2, &any);
  565. Messages messages;
  566. messages.reserve(any.size() + 1);
  567. messages.push_back(message2);
  568. for (auto& kv : any) messages.push_back(kv.second.get());
  569. UnpackAny(message1, &any);
  570. ConstMessages sources;
  571. sources.reserve(any.size() + 2);
  572. sources.push_back(&message1);
  573. sources.push_back(message2);
  574. for (const auto& kv : any) sources.push_back(kv.second.get());
  575. MutateImpl(sources, messages, true,
  576. static_cast<int>(max_size_hint) -
  577. static_cast<int>(message2->ByteSizeLong()));
  578. PostProcessing(keep_initialized_, post_processors_, any, &random_)
  579. .Run(message2, kMaxInitializeDepth);
  580. assert(IsInitialized(*message2));
  581. }
  582. void Mutator::RegisterPostProcessor(const Descriptor* desc,
  583. PostProcess callback) {
  584. post_processors_.emplace(desc, callback);
  585. }
  586. bool Mutator::MutateImpl(const ConstMessages& sources, const Messages& messages,
  587. bool copy_clone_only, int size_increase_hint) {
  588. MutationBitset mutations;
  589. if (copy_clone_only) {
  590. mutations[static_cast<size_t>(Mutation::Copy)] = true;
  591. mutations[static_cast<size_t>(Mutation::Clone)] = true;
  592. } else if (size_increase_hint <= 16) {
  593. mutations[static_cast<size_t>(Mutation::Delete)] = true;
  594. } else {
  595. mutations.set();
  596. mutations[static_cast<size_t>(Mutation::Copy)] = false;
  597. mutations[static_cast<size_t>(Mutation::Clone)] = false;
  598. }
  599. while (mutations.any()) {
  600. MutationSampler mutation(keep_initialized_, mutations, &random_);
  601. for (Message* message : messages) mutation.Sample(message);
  602. switch (mutation.mutation()) {
  603. case Mutation::None:
  604. return true;
  605. case Mutation::Add:
  606. CreateField()(mutation.field(), size_increase_hint, sources, this);
  607. return true;
  608. case Mutation::Mutate:
  609. MutateField()(mutation.field(), size_increase_hint, sources, this);
  610. return true;
  611. case Mutation::Delete:
  612. DeleteField()(mutation.field());
  613. return true;
  614. case Mutation::Clone: {
  615. CreateDefaultField()(mutation.field());
  616. DataSourceSampler source_sampler(mutation.field(), &random_,
  617. size_increase_hint);
  618. for (const Message* source : sources) source_sampler.Sample(*source);
  619. if (source_sampler.IsEmpty()) {
  620. if (!IsProto3SimpleField(*mutation.field().descriptor()))
  621. return true; // CreateField is enough for proto2.
  622. break;
  623. }
  624. CopyField()(source_sampler.field(), mutation.field());
  625. return true;
  626. }
  627. case Mutation::Copy: {
  628. DataSourceSampler source_sampler(mutation.field(), &random_,
  629. size_increase_hint);
  630. for (const Message* source : sources) source_sampler.Sample(*source);
  631. if (source_sampler.IsEmpty()) break;
  632. CopyField()(source_sampler.field(), mutation.field());
  633. return true;
  634. }
  635. default:
  636. assert(false && "unexpected mutation");
  637. return false;
  638. }
  639. // Don't try same mutation next time.
  640. mutations[static_cast<size_t>(mutation.mutation())] = false;
  641. }
  642. return false;
  643. }
  644. int32_t Mutator::MutateInt32(int32_t value) { return FlipBit(value, &random_); }
  645. int64_t Mutator::MutateInt64(int64_t value) { return FlipBit(value, &random_); }
  646. uint32_t Mutator::MutateUInt32(uint32_t value) {
  647. return FlipBit(value, &random_);
  648. }
  649. uint64_t Mutator::MutateUInt64(uint64_t value) {
  650. return FlipBit(value, &random_);
  651. }
  652. float Mutator::MutateFloat(float value) { return FlipBit(value, &random_); }
  653. double Mutator::MutateDouble(double value) { return FlipBit(value, &random_); }
  654. bool Mutator::MutateBool(bool value) { return !value; }
  655. size_t Mutator::MutateEnum(size_t index, size_t item_count) {
  656. if (item_count <= 1) return 0;
  657. return (index + 1 + GetRandomIndex(&random_, item_count - 1)) % item_count;
  658. }
  659. TProtoStringType Mutator::MutateString(const TProtoStringType& value,
  660. int size_increase_hint) {
  661. TProtoStringType result = value;
  662. while (!result.empty() && GetRandomBool(&random_)) {
  663. result.erase(GetRandomIndex(&random_, result.size()), 1);
  664. }
  665. while (size_increase_hint > 0 &&
  666. result.size() < static_cast<size_t>(size_increase_hint) &&
  667. GetRandomBool(&random_)) {
  668. size_t index = GetRandomIndex(&random_, result.size() + 1);
  669. result.insert(result.begin() + index, GetRandomIndex(&random_, 1 << 8));
  670. }
  671. if (result != value) return result;
  672. if (result.empty()) {
  673. result.push_back(GetRandomIndex(&random_, 1 << 8));
  674. return result;
  675. }
  676. if (!result.empty())
  677. FlipBit(result.size(), reinterpret_cast<uint8_t*>(&result[0]), &random_);
  678. return result;
  679. }
  680. TProtoStringType Mutator::MutateUtf8String(const TProtoStringType& value,
  681. int size_increase_hint) {
  682. TProtoStringType str = MutateString(value, size_increase_hint);
  683. FixUtf8String(&str, &random_);
  684. return str;
  685. }
  686. bool Mutator::IsInitialized(const Message& message) const {
  687. if (!keep_initialized_ || message.IsInitialized()) return true;
  688. std::cerr << "Uninitialized: " << message.DebugString() << "\n";
  689. return false;
  690. }
  691. } // namespace protobuf_mutator