is_equal.cpp 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. #include "is_equal.h"
  2. #include "traits.h"
  3. #include <google/protobuf/descriptor.h>
  4. #include <util/generic/yexception.h>
  5. #include <util/string/cast.h>
  6. #include <util/string/vector.h>
  7. namespace NProtoBuf {
  8. template <bool useDefault>
  9. static bool IsEqualImpl(const Message& m1, const Message& m2, TVector<TString>* differentPath);
  10. namespace {
  11. template <FieldDescriptor::CppType CppType, bool useDefault>
  12. struct TCompareValue {
  13. typedef typename TCppTypeTraits<CppType>::T T;
  14. static inline bool IsEqual(T value1, T value2, TVector<TString>*) {
  15. return value1 == value2;
  16. }
  17. };
  18. template <bool useDefault>
  19. struct TCompareValue<FieldDescriptor::CPPTYPE_MESSAGE, useDefault> {
  20. static inline bool IsEqual(const Message* value1, const Message* value2, TVector<TString>* differentPath) {
  21. return NProtoBuf::IsEqualImpl<useDefault>(*value1, *value2, differentPath);
  22. }
  23. };
  24. template <FieldDescriptor::CppType CppType, bool useDefault>
  25. class TCompareField {
  26. typedef TCppTypeTraits<CppType> TTraits;
  27. typedef TCompareValue<CppType, useDefault> TCompare;
  28. public:
  29. static inline bool IsEqual(const Message& m1, const Message& m2, const FieldDescriptor& field, TVector<TString>* differentPath) {
  30. if (field.is_repeated())
  31. return IsEqualRepeated(m1, m2, &field, differentPath);
  32. else
  33. return IsEqualSingle(m1, m2, &field, differentPath);
  34. }
  35. private:
  36. static bool IsEqualSingle(const Message& m1, const Message& m2, const FieldDescriptor* field, TVector<TString>* differentPath) {
  37. bool has1 = m1.GetReflection()->HasField(m1, field);
  38. bool has2 = m2.GetReflection()->HasField(m2, field);
  39. if (has1 != has2) {
  40. if (!useDefault || field->is_required()) {
  41. return false;
  42. }
  43. } else if (!has1)
  44. return true;
  45. return TCompare::IsEqual(TTraits::Get(m1, field),
  46. TTraits::Get(m2, field),
  47. differentPath);
  48. }
  49. static bool IsEqualRepeated(const Message& m1, const Message& m2, const FieldDescriptor* field, TVector<TString>* differentPath) {
  50. int fieldSize = m1.GetReflection()->FieldSize(m1, field);
  51. if (fieldSize != m2.GetReflection()->FieldSize(m2, field))
  52. return false;
  53. for (int i = 0; i < fieldSize; ++i)
  54. if (!IsEqualRepeatedValue(m1, m2, field, i, differentPath)) {
  55. if (!!differentPath) {
  56. differentPath->push_back(ToString(i));
  57. }
  58. return false;
  59. }
  60. return true;
  61. }
  62. static inline bool IsEqualRepeatedValue(const Message& m1, const Message& m2, const FieldDescriptor* field, int index, TVector<TString>* differentPath) {
  63. return TCompare::IsEqual(TTraits::GetRepeated(m1, field, index),
  64. TTraits::GetRepeated(m2, field, index),
  65. differentPath);
  66. }
  67. };
  68. template <bool useDefault>
  69. bool IsEqualField(const Message& m1, const Message& m2, const FieldDescriptor& field, TVector<TString>* differentPath) {
  70. #define CASE_CPPTYPE(cpptype) \
  71. case FieldDescriptor::CPPTYPE_##cpptype: { \
  72. bool r = TCompareField<FieldDescriptor::CPPTYPE_##cpptype, useDefault>::IsEqual(m1, m2, field, differentPath); \
  73. if (!r && !!differentPath) { \
  74. differentPath->push_back(field.name()); \
  75. } \
  76. return r; \
  77. }
  78. switch (field.cpp_type()) {
  79. CASE_CPPTYPE(INT32)
  80. CASE_CPPTYPE(INT64)
  81. CASE_CPPTYPE(UINT32)
  82. CASE_CPPTYPE(UINT64)
  83. CASE_CPPTYPE(DOUBLE)
  84. CASE_CPPTYPE(FLOAT)
  85. CASE_CPPTYPE(BOOL)
  86. CASE_CPPTYPE(ENUM)
  87. CASE_CPPTYPE(STRING)
  88. CASE_CPPTYPE(MESSAGE)
  89. default:
  90. ythrow yexception() << "Unsupported cpp-type field comparison";
  91. }
  92. #undef CASE_CPPTYPE
  93. }
  94. }
  95. template <bool useDefault>
  96. bool IsEqualImpl(const Message& m1, const Message& m2, TVector<TString>* differentPath) {
  97. const Descriptor* descr = m1.GetDescriptor();
  98. if (descr != m2.GetDescriptor()) {
  99. return false;
  100. }
  101. for (int i = 0; i < descr->field_count(); ++i)
  102. if (!IsEqualField<useDefault>(m1, m2, *descr->field(i), differentPath)) {
  103. return false;
  104. }
  105. return true;
  106. }
  107. bool IsEqual(const Message& m1, const Message& m2) {
  108. return IsEqualImpl<false>(m1, m2, nullptr);
  109. }
  110. bool IsEqual(const Message& m1, const Message& m2, TString* differentPath) {
  111. TVector<TString> differentPathVector;
  112. TVector<TString>* differentPathVectorPtr = !!differentPath ? &differentPathVector : nullptr;
  113. bool r = IsEqualImpl<false>(m1, m2, differentPathVectorPtr);
  114. if (!r && differentPath) {
  115. *differentPath = JoinStrings(differentPathVector.rbegin(), differentPathVector.rend(), "/");
  116. }
  117. return r;
  118. }
  119. bool IsEqualDefault(const Message& m1, const Message& m2) {
  120. return IsEqualImpl<true>(m1, m2, nullptr);
  121. }
  122. template <bool useDefault>
  123. static bool IsEqualFieldImpl(
  124. const Message& m1,
  125. const Message& m2,
  126. const FieldDescriptor& field,
  127. TVector<TString>* differentPath) {
  128. const Descriptor* descr = m1.GetDescriptor();
  129. if (descr != m2.GetDescriptor()) {
  130. return false;
  131. }
  132. return IsEqualField<useDefault>(m1, m2, field, differentPath);
  133. }
  134. bool IsEqualField(const Message& m1, const Message& m2, const FieldDescriptor& field) {
  135. return IsEqualFieldImpl<false>(m1, m2, field, nullptr);
  136. }
  137. bool IsEqualFieldDefault(const Message& m1, const Message& m2, const FieldDescriptor& field) {
  138. return IsEqualFieldImpl<true>(m1, m2, field, nullptr);
  139. }
  140. }