value_equal.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. // Copyright 2022 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package protoreflect
  5. import (
  6. "bytes"
  7. "fmt"
  8. "math"
  9. "reflect"
  10. "google.golang.org/protobuf/encoding/protowire"
  11. )
  12. // Equal reports whether v1 and v2 are recursively equal.
  13. //
  14. // - Values of different types are always unequal.
  15. //
  16. // - Bytes values are equal if they contain identical bytes.
  17. // Empty bytes (regardless of nil-ness) are considered equal.
  18. //
  19. // - Floating point values are equal if they contain the same value.
  20. // Unlike the == operator, a NaN is equal to another NaN.
  21. //
  22. // - Enums are equal if they contain the same number.
  23. // Since [Value] does not contain an enum descriptor,
  24. // enum values do not consider the type of the enum.
  25. //
  26. // - Other scalar values are equal if they contain the same value.
  27. //
  28. // - [Message] values are equal if they belong to the same message descriptor,
  29. // have the same set of populated known and extension field values,
  30. // and the same set of unknown fields values.
  31. //
  32. // - [List] values are equal if they are the same length and
  33. // each corresponding element is equal.
  34. //
  35. // - [Map] values are equal if they have the same set of keys and
  36. // the corresponding value for each key is equal.
  37. func (v1 Value) Equal(v2 Value) bool {
  38. return equalValue(v1, v2)
  39. }
  40. func equalValue(x, y Value) bool {
  41. eqType := x.typ == y.typ
  42. switch x.typ {
  43. case nilType:
  44. return eqType
  45. case boolType:
  46. return eqType && x.Bool() == y.Bool()
  47. case int32Type, int64Type:
  48. return eqType && x.Int() == y.Int()
  49. case uint32Type, uint64Type:
  50. return eqType && x.Uint() == y.Uint()
  51. case float32Type, float64Type:
  52. return eqType && equalFloat(x.Float(), y.Float())
  53. case stringType:
  54. return eqType && x.String() == y.String()
  55. case bytesType:
  56. return eqType && bytes.Equal(x.Bytes(), y.Bytes())
  57. case enumType:
  58. return eqType && x.Enum() == y.Enum()
  59. default:
  60. switch x := x.Interface().(type) {
  61. case Message:
  62. y, ok := y.Interface().(Message)
  63. return ok && equalMessage(x, y)
  64. case List:
  65. y, ok := y.Interface().(List)
  66. return ok && equalList(x, y)
  67. case Map:
  68. y, ok := y.Interface().(Map)
  69. return ok && equalMap(x, y)
  70. default:
  71. panic(fmt.Sprintf("unknown type: %T", x))
  72. }
  73. }
  74. }
  75. // equalFloat compares two floats, where NaNs are treated as equal.
  76. func equalFloat(x, y float64) bool {
  77. if math.IsNaN(x) || math.IsNaN(y) {
  78. return math.IsNaN(x) && math.IsNaN(y)
  79. }
  80. return x == y
  81. }
  82. // equalMessage compares two messages.
  83. func equalMessage(mx, my Message) bool {
  84. if mx.Descriptor() != my.Descriptor() {
  85. return false
  86. }
  87. nx := 0
  88. equal := true
  89. mx.Range(func(fd FieldDescriptor, vx Value) bool {
  90. nx++
  91. vy := my.Get(fd)
  92. equal = my.Has(fd) && equalValue(vx, vy)
  93. return equal
  94. })
  95. if !equal {
  96. return false
  97. }
  98. ny := 0
  99. my.Range(func(fd FieldDescriptor, vx Value) bool {
  100. ny++
  101. return true
  102. })
  103. if nx != ny {
  104. return false
  105. }
  106. return equalUnknown(mx.GetUnknown(), my.GetUnknown())
  107. }
  108. // equalList compares two lists.
  109. func equalList(x, y List) bool {
  110. if x.Len() != y.Len() {
  111. return false
  112. }
  113. for i := x.Len() - 1; i >= 0; i-- {
  114. if !equalValue(x.Get(i), y.Get(i)) {
  115. return false
  116. }
  117. }
  118. return true
  119. }
  120. // equalMap compares two maps.
  121. func equalMap(x, y Map) bool {
  122. if x.Len() != y.Len() {
  123. return false
  124. }
  125. equal := true
  126. x.Range(func(k MapKey, vx Value) bool {
  127. vy := y.Get(k)
  128. equal = y.Has(k) && equalValue(vx, vy)
  129. return equal
  130. })
  131. return equal
  132. }
  133. // equalUnknown compares unknown fields by direct comparison on the raw bytes
  134. // of each individual field number.
  135. func equalUnknown(x, y RawFields) bool {
  136. if len(x) != len(y) {
  137. return false
  138. }
  139. if bytes.Equal([]byte(x), []byte(y)) {
  140. return true
  141. }
  142. mx := make(map[FieldNumber]RawFields)
  143. my := make(map[FieldNumber]RawFields)
  144. for len(x) > 0 {
  145. fnum, _, n := protowire.ConsumeField(x)
  146. mx[fnum] = append(mx[fnum], x[:n]...)
  147. x = x[n:]
  148. }
  149. for len(y) > 0 {
  150. fnum, _, n := protowire.ConsumeField(y)
  151. my[fnum] = append(my[fnum], y[:n]...)
  152. y = y[n:]
  153. }
  154. return reflect.DeepEqual(mx, my)
  155. }