xform.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // Package protocmp provides protobuf specific options for the
  5. // [github.com/google/go-cmp/cmp] package.
  6. //
  7. // The primary feature is the [Transform] option, which transform [proto.Message]
  8. // types into a [Message] map that is suitable for cmp to introspect upon.
  9. // All other options in this package must be used in conjunction with [Transform].
  10. package protocmp
  11. import (
  12. "reflect"
  13. "strconv"
  14. "github.com/google/go-cmp/cmp"
  15. "google.golang.org/protobuf/encoding/protowire"
  16. "google.golang.org/protobuf/internal/genid"
  17. "google.golang.org/protobuf/internal/msgfmt"
  18. "google.golang.org/protobuf/proto"
  19. "google.golang.org/protobuf/reflect/protoreflect"
  20. "google.golang.org/protobuf/reflect/protoregistry"
  21. "google.golang.org/protobuf/runtime/protoiface"
  22. "google.golang.org/protobuf/runtime/protoimpl"
  23. )
  24. var (
  25. enumV2Type = reflect.TypeOf((*protoreflect.Enum)(nil)).Elem()
  26. messageV1Type = reflect.TypeOf((*protoiface.MessageV1)(nil)).Elem()
  27. messageV2Type = reflect.TypeOf((*proto.Message)(nil)).Elem()
  28. )
  29. // Enum is a dynamic representation of a protocol buffer enum that is
  30. // suitable for [cmp.Equal] and [cmp.Diff] to compare upon.
  31. type Enum struct {
  32. num protoreflect.EnumNumber
  33. ed protoreflect.EnumDescriptor
  34. }
  35. // Descriptor returns the enum descriptor.
  36. // It returns nil for a zero Enum value.
  37. func (e Enum) Descriptor() protoreflect.EnumDescriptor {
  38. return e.ed
  39. }
  40. // Number returns the enum value as an integer.
  41. func (e Enum) Number() protoreflect.EnumNumber {
  42. return e.num
  43. }
  44. // Equal reports whether e1 and e2 represent the same enum value.
  45. func (e1 Enum) Equal(e2 Enum) bool {
  46. if e1.ed.FullName() != e2.ed.FullName() {
  47. return false
  48. }
  49. return e1.num == e2.num
  50. }
  51. // String returns the name of the enum value if known (e.g., "ENUM_VALUE"),
  52. // otherwise it returns the formatted decimal enum number (e.g., "14").
  53. func (e Enum) String() string {
  54. if ev := e.ed.Values().ByNumber(e.num); ev != nil {
  55. return string(ev.Name())
  56. }
  57. return strconv.Itoa(int(e.num))
  58. }
  59. const (
  60. // messageTypeKey indicates the protobuf message type.
  61. // The value type is always messageMeta.
  62. // From the public API, it presents itself as only the type, but the
  63. // underlying data structure holds arbitrary metadata about the message.
  64. messageTypeKey = "@type"
  65. // messageInvalidKey indicates that the message is invalid.
  66. // The value is always the boolean "true".
  67. messageInvalidKey = "@invalid"
  68. )
  69. type messageMeta struct {
  70. m proto.Message
  71. md protoreflect.MessageDescriptor
  72. xds map[string]protoreflect.ExtensionDescriptor
  73. }
  74. func (t messageMeta) String() string {
  75. return string(t.md.FullName())
  76. }
  77. func (t1 messageMeta) Equal(t2 messageMeta) bool {
  78. return t1.md.FullName() == t2.md.FullName()
  79. }
  80. // Message is a dynamic representation of a protocol buffer message that is
  81. // suitable for [cmp.Equal] and [cmp.Diff] to directly operate upon.
  82. //
  83. // Every populated known field (excluding extension fields) is stored in the map
  84. // with the key being the short name of the field (e.g., "field_name") and
  85. // the value determined by the kind and cardinality of the field.
  86. //
  87. // Singular scalars are represented by the same Go type as [protoreflect.Value],
  88. // singular messages are represented by the [Message] type,
  89. // singular enums are represented by the [Enum] type,
  90. // list fields are represented as a Go slice, and
  91. // map fields are represented as a Go map.
  92. //
  93. // Every populated extension field is stored in the map with the key being the
  94. // full name of the field surrounded by brackets (e.g., "[extension.full.name]")
  95. // and the value determined according to the same rules as known fields.
  96. //
  97. // Every unknown field is stored in the map with the key being the field number
  98. // encoded as a decimal string (e.g., "132") and the value being the raw bytes
  99. // of the encoded field (as the [protoreflect.RawFields] type).
  100. //
  101. // Message values must not be created by or mutated by users.
  102. type Message map[string]interface{}
  103. // Unwrap returns the original message value.
  104. // It returns nil if this Message was not constructed from another message.
  105. func (m Message) Unwrap() proto.Message {
  106. mm, _ := m[messageTypeKey].(messageMeta)
  107. return mm.m
  108. }
  109. // Descriptor return the message descriptor.
  110. // It returns nil for a zero Message value.
  111. func (m Message) Descriptor() protoreflect.MessageDescriptor {
  112. mm, _ := m[messageTypeKey].(messageMeta)
  113. return mm.md
  114. }
  115. // ProtoReflect returns a reflective view of m.
  116. // It only implements the read-only operations of [protoreflect.Message].
  117. // Calling any mutating operations on m panics.
  118. func (m Message) ProtoReflect() protoreflect.Message {
  119. return (reflectMessage)(m)
  120. }
  121. // ProtoMessage is a marker method from the legacy message interface.
  122. func (m Message) ProtoMessage() {}
  123. // Reset is the required Reset method from the legacy message interface.
  124. func (m Message) Reset() {
  125. panic("invalid mutation of a read-only message")
  126. }
  127. // String returns a formatted string for the message.
  128. // It is intended for human debugging and has no guarantees about its
  129. // exact format or the stability of its output.
  130. func (m Message) String() string {
  131. switch {
  132. case m == nil:
  133. return "<nil>"
  134. case !m.ProtoReflect().IsValid():
  135. return "<invalid>"
  136. default:
  137. return msgfmt.Format(m)
  138. }
  139. }
  140. type option struct{}
  141. // Transform returns a [cmp.Option] that converts each [proto.Message] to a [Message].
  142. // The transformation does not mutate nor alias any converted messages.
  143. //
  144. // The google.protobuf.Any message is automatically unmarshaled such that the
  145. // "value" field is a [Message] representing the underlying message value
  146. // assuming it could be resolved and properly unmarshaled.
  147. //
  148. // This does not directly transform higher-order composite Go types.
  149. // For example, []*foopb.Message is not transformed into []Message,
  150. // but rather the individual message elements of the slice are transformed.
  151. //
  152. // Note that there are currently no custom options for Transform,
  153. // but the use of an unexported type keeps the future open.
  154. func Transform(...option) cmp.Option {
  155. // addrType returns a pointer to t if t isn't a pointer or interface.
  156. addrType := func(t reflect.Type) reflect.Type {
  157. if k := t.Kind(); k == reflect.Interface || k == reflect.Ptr {
  158. return t
  159. }
  160. return reflect.PtrTo(t)
  161. }
  162. // TODO: Should this transform protoreflect.Enum types to Enum as well?
  163. return cmp.FilterPath(func(p cmp.Path) bool {
  164. ps := p.Last()
  165. if isMessageType(addrType(ps.Type())) {
  166. return true
  167. }
  168. // Check whether the concrete values of an interface both satisfy
  169. // the Message interface.
  170. if ps.Type().Kind() == reflect.Interface {
  171. vx, vy := ps.Values()
  172. if !vx.IsValid() || vx.IsNil() || !vy.IsValid() || vy.IsNil() {
  173. return false
  174. }
  175. return isMessageType(addrType(vx.Elem().Type())) && isMessageType(addrType(vy.Elem().Type()))
  176. }
  177. return false
  178. }, cmp.Transformer("protocmp.Transform", func(v interface{}) Message {
  179. // For user convenience, shallow copy the message value if necessary
  180. // in order for it to implement the message interface.
  181. if rv := reflect.ValueOf(v); rv.IsValid() && rv.Kind() != reflect.Ptr && !isMessageType(rv.Type()) {
  182. pv := reflect.New(rv.Type())
  183. pv.Elem().Set(rv)
  184. v = pv.Interface()
  185. }
  186. m := protoimpl.X.MessageOf(v)
  187. switch {
  188. case m == nil:
  189. return nil
  190. case !m.IsValid():
  191. return Message{messageTypeKey: messageMeta{m: m.Interface(), md: m.Descriptor()}, messageInvalidKey: true}
  192. default:
  193. return transformMessage(m)
  194. }
  195. }))
  196. }
  197. func isMessageType(t reflect.Type) bool {
  198. // Avoid transforming the Message itself.
  199. if t == reflect.TypeOf(Message(nil)) || t == reflect.TypeOf((*Message)(nil)) {
  200. return false
  201. }
  202. return t.Implements(messageV1Type) || t.Implements(messageV2Type)
  203. }
  204. func transformMessage(m protoreflect.Message) Message {
  205. mx := Message{}
  206. mt := messageMeta{m: m.Interface(), md: m.Descriptor(), xds: make(map[string]protoreflect.FieldDescriptor)}
  207. // Handle known and extension fields.
  208. m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
  209. s := fd.TextName()
  210. if fd.IsExtension() {
  211. mt.xds[s] = fd
  212. }
  213. switch {
  214. case fd.IsList():
  215. mx[s] = transformList(fd, v.List())
  216. case fd.IsMap():
  217. mx[s] = transformMap(fd, v.Map())
  218. default:
  219. mx[s] = transformSingular(fd, v)
  220. }
  221. return true
  222. })
  223. // Handle unknown fields.
  224. for b := m.GetUnknown(); len(b) > 0; {
  225. num, _, n := protowire.ConsumeField(b)
  226. s := strconv.Itoa(int(num))
  227. b2, _ := mx[s].(protoreflect.RawFields)
  228. mx[s] = append(b2, b[:n]...)
  229. b = b[n:]
  230. }
  231. // Expand Any messages.
  232. if mt.md.FullName() == genid.Any_message_fullname {
  233. // TODO: Expose Transform option to specify a custom resolver?
  234. s, _ := mx[string(genid.Any_TypeUrl_field_name)].(string)
  235. b, _ := mx[string(genid.Any_Value_field_name)].([]byte)
  236. mt, err := protoregistry.GlobalTypes.FindMessageByURL(s)
  237. if mt != nil && err == nil {
  238. m2 := mt.New()
  239. err := proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(b, m2.Interface())
  240. if err == nil {
  241. mx[string(genid.Any_Value_field_name)] = transformMessage(m2)
  242. }
  243. }
  244. }
  245. mx[messageTypeKey] = mt
  246. return mx
  247. }
  248. func transformList(fd protoreflect.FieldDescriptor, lv protoreflect.List) interface{} {
  249. t := protoKindToGoType(fd.Kind())
  250. rv := reflect.MakeSlice(reflect.SliceOf(t), lv.Len(), lv.Len())
  251. for i := 0; i < lv.Len(); i++ {
  252. v := reflect.ValueOf(transformSingular(fd, lv.Get(i)))
  253. rv.Index(i).Set(v)
  254. }
  255. return rv.Interface()
  256. }
  257. func transformMap(fd protoreflect.FieldDescriptor, mv protoreflect.Map) interface{} {
  258. kfd := fd.MapKey()
  259. vfd := fd.MapValue()
  260. kt := protoKindToGoType(kfd.Kind())
  261. vt := protoKindToGoType(vfd.Kind())
  262. rv := reflect.MakeMapWithSize(reflect.MapOf(kt, vt), mv.Len())
  263. mv.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
  264. kv := reflect.ValueOf(transformSingular(kfd, k.Value()))
  265. vv := reflect.ValueOf(transformSingular(vfd, v))
  266. rv.SetMapIndex(kv, vv)
  267. return true
  268. })
  269. return rv.Interface()
  270. }
  271. func transformSingular(fd protoreflect.FieldDescriptor, v protoreflect.Value) interface{} {
  272. switch fd.Kind() {
  273. case protoreflect.EnumKind:
  274. return Enum{num: v.Enum(), ed: fd.Enum()}
  275. case protoreflect.MessageKind, protoreflect.GroupKind:
  276. return transformMessage(v.Message())
  277. case protoreflect.BytesKind:
  278. // The protoreflect API does not specify whether an empty bytes is
  279. // guaranteed to be nil or not. Always return non-nil bytes to avoid
  280. // leaking information about the concrete proto.Message implementation.
  281. if len(v.Bytes()) == 0 {
  282. return []byte{}
  283. }
  284. return v.Bytes()
  285. default:
  286. return v.Interface()
  287. }
  288. }
  289. func protoKindToGoType(k protoreflect.Kind) reflect.Type {
  290. switch k {
  291. case protoreflect.BoolKind:
  292. return reflect.TypeOf(bool(false))
  293. case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
  294. return reflect.TypeOf(int32(0))
  295. case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
  296. return reflect.TypeOf(int64(0))
  297. case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
  298. return reflect.TypeOf(uint32(0))
  299. case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
  300. return reflect.TypeOf(uint64(0))
  301. case protoreflect.FloatKind:
  302. return reflect.TypeOf(float32(0))
  303. case protoreflect.DoubleKind:
  304. return reflect.TypeOf(float64(0))
  305. case protoreflect.StringKind:
  306. return reflect.TypeOf(string(""))
  307. case protoreflect.BytesKind:
  308. return reflect.TypeOf([]byte(nil))
  309. case protoreflect.EnumKind:
  310. return reflect.TypeOf(Enum{})
  311. case protoreflect.MessageKind, protoreflect.GroupKind:
  312. return reflect.TypeOf(Message{})
  313. default:
  314. panic("invalid kind")
  315. }
  316. }