xform.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // Package protocmp provides protobuf specific options for the
  5. // [github.com/google/go-cmp/cmp] package.
  6. //
  7. // The primary feature is the [Transform] option, which transform [proto.Message]
  8. // types into a [Message] map that is suitable for cmp to introspect upon.
  9. // All other options in this package must be used in conjunction with [Transform].
  10. package protocmp
  11. import (
  12. "reflect"
  13. "strconv"
  14. "github.com/google/go-cmp/cmp"
  15. "google.golang.org/protobuf/encoding/protowire"
  16. "google.golang.org/protobuf/internal/genid"
  17. "google.golang.org/protobuf/internal/msgfmt"
  18. "google.golang.org/protobuf/proto"
  19. "google.golang.org/protobuf/reflect/protoreflect"
  20. "google.golang.org/protobuf/reflect/protoregistry"
  21. "google.golang.org/protobuf/runtime/protoiface"
  22. "google.golang.org/protobuf/runtime/protoimpl"
  23. )
  24. var (
  25. enumV2Type = reflect.TypeOf((*protoreflect.Enum)(nil)).Elem()
  26. messageV1Type = reflect.TypeOf((*protoiface.MessageV1)(nil)).Elem()
  27. messageV2Type = reflect.TypeOf((*proto.Message)(nil)).Elem()
  28. )
  29. // Enum is a dynamic representation of a protocol buffer enum that is
  30. // suitable for [cmp.Equal] and [cmp.Diff] to compare upon.
  31. type Enum struct {
  32. num protoreflect.EnumNumber
  33. ed protoreflect.EnumDescriptor
  34. }
  35. // Descriptor returns the enum descriptor.
  36. // It returns nil for a zero Enum value.
  37. func (e Enum) Descriptor() protoreflect.EnumDescriptor {
  38. return e.ed
  39. }
  40. // Number returns the enum value as an integer.
  41. func (e Enum) Number() protoreflect.EnumNumber {
  42. return e.num
  43. }
  44. // Equal reports whether e1 and e2 represent the same enum value.
  45. func (e1 Enum) Equal(e2 Enum) bool {
  46. if e1.ed.FullName() != e2.ed.FullName() {
  47. return false
  48. }
  49. return e1.num == e2.num
  50. }
  51. // String returns the name of the enum value if known (e.g., "ENUM_VALUE"),
  52. // otherwise it returns the formatted decimal enum number (e.g., "14").
  53. func (e Enum) String() string {
  54. if ev := e.ed.Values().ByNumber(e.num); ev != nil {
  55. return string(ev.Name())
  56. }
  57. return strconv.Itoa(int(e.num))
  58. }
  59. const (
  60. // messageTypeKey indicates the protobuf message type.
  61. // The value type is always messageMeta.
  62. // From the public API, it presents itself as only the type, but the
  63. // underlying data structure holds arbitrary metadata about the message.
  64. messageTypeKey = "@type"
  65. // messageInvalidKey indicates that the message is invalid.
  66. // The value is always the boolean "true".
  67. messageInvalidKey = "@invalid"
  68. )
  69. type messageMeta struct {
  70. m proto.Message
  71. md protoreflect.MessageDescriptor
  72. xds map[string]protoreflect.ExtensionDescriptor
  73. }
  74. func (t messageMeta) String() string {
  75. return string(t.md.FullName())
  76. }
  77. func (t1 messageMeta) Equal(t2 messageMeta) bool {
  78. return t1.md.FullName() == t2.md.FullName()
  79. }
  80. // Message is a dynamic representation of a protocol buffer message that is
  81. // suitable for [cmp.Equal] and [cmp.Diff] to directly operate upon.
  82. //
  83. // Every populated known field (excluding extension fields) is stored in the map
  84. // with the key being the short name of the field (e.g., "field_name") and
  85. // the value determined by the kind and cardinality of the field.
  86. //
  87. // Singular scalars are represented by the same Go type as [protoreflect.Value],
  88. // singular messages are represented by the [Message] type,
  89. // singular enums are represented by the [Enum] type,
  90. // list fields are represented as a Go slice, and
  91. // map fields are represented as a Go map.
  92. //
  93. // Every populated extension field is stored in the map with the key being the
  94. // full name of the field surrounded by brackets (e.g., "[extension.full.name]")
  95. // and the value determined according to the same rules as known fields.
  96. //
  97. // Every unknown field is stored in the map with the key being the field number
  98. // encoded as a decimal string (e.g., "132") and the value being the raw bytes
  99. // of the encoded field (as the [protoreflect.RawFields] type).
  100. //
  101. // Message values must not be created by or mutated by users.
  102. type Message map[string]interface{}
  103. // Unwrap returns the original message value.
  104. // It returns nil if this Message was not constructed from another message.
  105. func (m Message) Unwrap() proto.Message {
  106. mm, _ := m[messageTypeKey].(messageMeta)
  107. return mm.m
  108. }
  109. // Descriptor return the message descriptor.
  110. // It returns nil for a zero Message value.
  111. func (m Message) Descriptor() protoreflect.MessageDescriptor {
  112. mm, _ := m[messageTypeKey].(messageMeta)
  113. return mm.md
  114. }
  115. // ProtoReflect returns a reflective view of m.
  116. // It only implements the read-only operations of [protoreflect.Message].
  117. // Calling any mutating operations on m panics.
  118. func (m Message) ProtoReflect() protoreflect.Message {
  119. return (reflectMessage)(m)
  120. }
  121. // ProtoMessage is a marker method from the legacy message interface.
  122. func (m Message) ProtoMessage() {}
  123. // Reset is the required Reset method from the legacy message interface.
  124. func (m Message) Reset() {
  125. panic("invalid mutation of a read-only message")
  126. }
  127. // String returns a formatted string for the message.
  128. // It is intended for human debugging and has no guarantees about its
  129. // exact format or the stability of its output.
  130. func (m Message) String() string {
  131. switch {
  132. case m == nil:
  133. return "<nil>"
  134. case !m.ProtoReflect().IsValid():
  135. return "<invalid>"
  136. default:
  137. return msgfmt.Format(m)
  138. }
  139. }
  140. type transformer struct {
  141. resolver protoregistry.MessageTypeResolver
  142. }
  143. func newTransformer(opts ...option) *transformer {
  144. xf := &transformer{
  145. resolver: protoregistry.GlobalTypes,
  146. }
  147. for _, opt := range opts {
  148. opt(xf)
  149. }
  150. return xf
  151. }
  152. type option func(*transformer)
  153. // MessageTypeResolver overrides the resolver used for messages packed
  154. // inside Any. The default is protoregistry.GlobalTypes, which is
  155. // sufficient for all compiled-in Protobuf messages. Overriding the
  156. // resolver is useful in tests that dynamically create Protobuf
  157. // descriptors and messages, e.g. in proxies using dynamicpb.
  158. func MessageTypeResolver(r protoregistry.MessageTypeResolver) option {
  159. return func(xf *transformer) {
  160. xf.resolver = r
  161. }
  162. }
  163. // Transform returns a [cmp.Option] that converts each [proto.Message] to a [Message].
  164. // The transformation does not mutate nor alias any converted messages.
  165. //
  166. // The google.protobuf.Any message is automatically unmarshaled such that the
  167. // "value" field is a [Message] representing the underlying message value
  168. // assuming it could be resolved and properly unmarshaled.
  169. //
  170. // This does not directly transform higher-order composite Go types.
  171. // For example, []*foopb.Message is not transformed into []Message,
  172. // but rather the individual message elements of the slice are transformed.
  173. func Transform(opts ...option) cmp.Option {
  174. xf := newTransformer(opts...)
  175. // addrType returns a pointer to t if t isn't a pointer or interface.
  176. addrType := func(t reflect.Type) reflect.Type {
  177. if k := t.Kind(); k == reflect.Interface || k == reflect.Ptr {
  178. return t
  179. }
  180. return reflect.PtrTo(t)
  181. }
  182. // TODO: Should this transform protoreflect.Enum types to Enum as well?
  183. return cmp.FilterPath(func(p cmp.Path) bool {
  184. ps := p.Last()
  185. if isMessageType(addrType(ps.Type())) {
  186. return true
  187. }
  188. // Check whether the concrete values of an interface both satisfy
  189. // the Message interface.
  190. if ps.Type().Kind() == reflect.Interface {
  191. vx, vy := ps.Values()
  192. if !vx.IsValid() || vx.IsNil() || !vy.IsValid() || vy.IsNil() {
  193. return false
  194. }
  195. return isMessageType(addrType(vx.Elem().Type())) && isMessageType(addrType(vy.Elem().Type()))
  196. }
  197. return false
  198. }, cmp.Transformer("protocmp.Transform", func(v interface{}) Message {
  199. // For user convenience, shallow copy the message value if necessary
  200. // in order for it to implement the message interface.
  201. if rv := reflect.ValueOf(v); rv.IsValid() && rv.Kind() != reflect.Ptr && !isMessageType(rv.Type()) {
  202. pv := reflect.New(rv.Type())
  203. pv.Elem().Set(rv)
  204. v = pv.Interface()
  205. }
  206. m := protoimpl.X.MessageOf(v)
  207. switch {
  208. case m == nil:
  209. return nil
  210. case !m.IsValid():
  211. return Message{messageTypeKey: messageMeta{m: m.Interface(), md: m.Descriptor()}, messageInvalidKey: true}
  212. default:
  213. return xf.transformMessage(m)
  214. }
  215. }))
  216. }
  217. func isMessageType(t reflect.Type) bool {
  218. // Avoid transforming the Message itself.
  219. if t == reflect.TypeOf(Message(nil)) || t == reflect.TypeOf((*Message)(nil)) {
  220. return false
  221. }
  222. return t.Implements(messageV1Type) || t.Implements(messageV2Type)
  223. }
  224. func (xf *transformer) transformMessage(m protoreflect.Message) Message {
  225. mx := Message{}
  226. mt := messageMeta{m: m.Interface(), md: m.Descriptor(), xds: make(map[string]protoreflect.FieldDescriptor)}
  227. // Handle known and extension fields.
  228. m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
  229. s := fd.TextName()
  230. if fd.IsExtension() {
  231. mt.xds[s] = fd
  232. }
  233. switch {
  234. case fd.IsList():
  235. mx[s] = xf.transformList(fd, v.List())
  236. case fd.IsMap():
  237. mx[s] = xf.transformMap(fd, v.Map())
  238. default:
  239. mx[s] = xf.transformSingular(fd, v)
  240. }
  241. return true
  242. })
  243. // Handle unknown fields.
  244. for b := m.GetUnknown(); len(b) > 0; {
  245. num, _, n := protowire.ConsumeField(b)
  246. s := strconv.Itoa(int(num))
  247. b2, _ := mx[s].(protoreflect.RawFields)
  248. mx[s] = append(b2, b[:n]...)
  249. b = b[n:]
  250. }
  251. // Expand Any messages.
  252. if mt.md.FullName() == genid.Any_message_fullname {
  253. s, _ := mx[string(genid.Any_TypeUrl_field_name)].(string)
  254. b, _ := mx[string(genid.Any_Value_field_name)].([]byte)
  255. mt, err := xf.resolver.FindMessageByURL(s)
  256. if mt != nil && err == nil {
  257. m2 := mt.New()
  258. err := proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(b, m2.Interface())
  259. if err == nil {
  260. mx[string(genid.Any_Value_field_name)] = xf.transformMessage(m2)
  261. }
  262. }
  263. }
  264. mx[messageTypeKey] = mt
  265. return mx
  266. }
  267. func (xf *transformer) transformList(fd protoreflect.FieldDescriptor, lv protoreflect.List) interface{} {
  268. t := protoKindToGoType(fd.Kind())
  269. rv := reflect.MakeSlice(reflect.SliceOf(t), lv.Len(), lv.Len())
  270. for i := 0; i < lv.Len(); i++ {
  271. v := reflect.ValueOf(xf.transformSingular(fd, lv.Get(i)))
  272. rv.Index(i).Set(v)
  273. }
  274. return rv.Interface()
  275. }
  276. func (xf *transformer) transformMap(fd protoreflect.FieldDescriptor, mv protoreflect.Map) interface{} {
  277. kfd := fd.MapKey()
  278. vfd := fd.MapValue()
  279. kt := protoKindToGoType(kfd.Kind())
  280. vt := protoKindToGoType(vfd.Kind())
  281. rv := reflect.MakeMapWithSize(reflect.MapOf(kt, vt), mv.Len())
  282. mv.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
  283. kv := reflect.ValueOf(xf.transformSingular(kfd, k.Value()))
  284. vv := reflect.ValueOf(xf.transformSingular(vfd, v))
  285. rv.SetMapIndex(kv, vv)
  286. return true
  287. })
  288. return rv.Interface()
  289. }
  290. func (xf *transformer) transformSingular(fd protoreflect.FieldDescriptor, v protoreflect.Value) interface{} {
  291. switch fd.Kind() {
  292. case protoreflect.EnumKind:
  293. return Enum{num: v.Enum(), ed: fd.Enum()}
  294. case protoreflect.MessageKind, protoreflect.GroupKind:
  295. return xf.transformMessage(v.Message())
  296. case protoreflect.BytesKind:
  297. // The protoreflect API does not specify whether an empty bytes is
  298. // guaranteed to be nil or not. Always return non-nil bytes to avoid
  299. // leaking information about the concrete proto.Message implementation.
  300. if len(v.Bytes()) == 0 {
  301. return []byte{}
  302. }
  303. return v.Bytes()
  304. default:
  305. return v.Interface()
  306. }
  307. }
  308. func protoKindToGoType(k protoreflect.Kind) reflect.Type {
  309. switch k {
  310. case protoreflect.BoolKind:
  311. return reflect.TypeOf(bool(false))
  312. case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
  313. return reflect.TypeOf(int32(0))
  314. case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
  315. return reflect.TypeOf(int64(0))
  316. case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
  317. return reflect.TypeOf(uint32(0))
  318. case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
  319. return reflect.TypeOf(uint64(0))
  320. case protoreflect.FloatKind:
  321. return reflect.TypeOf(float32(0))
  322. case protoreflect.DoubleKind:
  323. return reflect.TypeOf(float64(0))
  324. case protoreflect.StringKind:
  325. return reflect.TypeOf(string(""))
  326. case protoreflect.BytesKind:
  327. return reflect.TypeOf([]byte(nil))
  328. case protoreflect.EnumKind:
  329. return reflect.TypeOf(Enum{})
  330. case protoreflect.MessageKind, protoreflect.GroupKind:
  331. return reflect.TypeOf(Message{})
  332. default:
  333. panic("invalid kind")
  334. }
  335. }