extension.go 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package proto
  5. import (
  6. "google.golang.org/protobuf/reflect/protoreflect"
  7. )
  8. // HasExtension reports whether an extension field is populated.
  9. // It returns false if m is invalid or if xt does not extend m.
  10. func HasExtension(m Message, xt protoreflect.ExtensionType) bool {
  11. // Treat nil message interface or descriptor as an empty message; no populated
  12. // fields.
  13. if m == nil || xt == nil {
  14. return false
  15. }
  16. // As a special-case, we reports invalid or mismatching descriptors
  17. // as always not being populated (since they aren't).
  18. mr := m.ProtoReflect()
  19. xd := xt.TypeDescriptor()
  20. if mr.Descriptor() != xd.ContainingMessage() {
  21. return false
  22. }
  23. return mr.Has(xd)
  24. }
  25. // ClearExtension clears an extension field such that subsequent
  26. // [HasExtension] calls return false.
  27. // It panics if m is invalid or if xt does not extend m.
  28. func ClearExtension(m Message, xt protoreflect.ExtensionType) {
  29. m.ProtoReflect().Clear(xt.TypeDescriptor())
  30. }
  31. // GetExtension retrieves the value for an extension field.
  32. // If the field is unpopulated, it returns the default value for
  33. // scalars and an immutable, empty value for lists or messages.
  34. // It panics if xt does not extend m.
  35. //
  36. // The type of the value is dependent on the field type of the extension.
  37. // For extensions generated by protoc-gen-go, the Go type is as follows:
  38. //
  39. // ╔═══════════════════╤═════════════════════════╗
  40. // ║ Go type │ Protobuf kind ║
  41. // ╠═══════════════════╪═════════════════════════╣
  42. // ║ bool │ bool ║
  43. // ║ int32 │ int32, sint32, sfixed32 ║
  44. // ║ int64 │ int64, sint64, sfixed64 ║
  45. // ║ uint32 │ uint32, fixed32 ║
  46. // ║ uint64 │ uint64, fixed64 ║
  47. // ║ float32 │ float ║
  48. // ║ float64 │ double ║
  49. // ║ string │ string ║
  50. // ║ []byte │ bytes ║
  51. // ║ protoreflect.Enum │ enum ║
  52. // ║ proto.Message │ message, group ║
  53. // ╚═══════════════════╧═════════════════════════╝
  54. //
  55. // The protoreflect.Enum and proto.Message types are the concrete Go type
  56. // associated with the named enum or message. Repeated fields are represented
  57. // using a Go slice of the base element type.
  58. //
  59. // If a generated extension descriptor variable is directly passed to
  60. // GetExtension, then the call should be followed immediately by a
  61. // type assertion to the expected output value. For example:
  62. //
  63. // mm := proto.GetExtension(m, foopb.E_MyExtension).(*foopb.MyMessage)
  64. //
  65. // This pattern enables static analysis tools to verify that the asserted type
  66. // matches the Go type associated with the extension field and
  67. // also enables a possible future migration to a type-safe extension API.
  68. //
  69. // Since singular messages are the most common extension type, the pattern of
  70. // calling HasExtension followed by GetExtension may be simplified to:
  71. //
  72. // if mm := proto.GetExtension(m, foopb.E_MyExtension).(*foopb.MyMessage); mm != nil {
  73. // ... // make use of mm
  74. // }
  75. //
  76. // The mm variable is non-nil if and only if HasExtension reports true.
  77. func GetExtension(m Message, xt protoreflect.ExtensionType) any {
  78. // Treat nil message interface as an empty message; return the default.
  79. if m == nil {
  80. return xt.InterfaceOf(xt.Zero())
  81. }
  82. return xt.InterfaceOf(m.ProtoReflect().Get(xt.TypeDescriptor()))
  83. }
  84. // SetExtension stores the value of an extension field.
  85. // It panics if m is invalid, xt does not extend m, or if type of v
  86. // is invalid for the specified extension field.
  87. //
  88. // The type of the value is dependent on the field type of the extension.
  89. // For extensions generated by protoc-gen-go, the Go type is as follows:
  90. //
  91. // ╔═══════════════════╤═════════════════════════╗
  92. // ║ Go type │ Protobuf kind ║
  93. // ╠═══════════════════╪═════════════════════════╣
  94. // ║ bool │ bool ║
  95. // ║ int32 │ int32, sint32, sfixed32 ║
  96. // ║ int64 │ int64, sint64, sfixed64 ║
  97. // ║ uint32 │ uint32, fixed32 ║
  98. // ║ uint64 │ uint64, fixed64 ║
  99. // ║ float32 │ float ║
  100. // ║ float64 │ double ║
  101. // ║ string │ string ║
  102. // ║ []byte │ bytes ║
  103. // ║ protoreflect.Enum │ enum ║
  104. // ║ proto.Message │ message, group ║
  105. // ╚═══════════════════╧═════════════════════════╝
  106. //
  107. // The protoreflect.Enum and proto.Message types are the concrete Go type
  108. // associated with the named enum or message. Repeated fields are represented
  109. // using a Go slice of the base element type.
  110. //
  111. // If a generated extension descriptor variable is directly passed to
  112. // SetExtension (e.g., foopb.E_MyExtension), then the value should be a
  113. // concrete type that matches the expected Go type for the extension descriptor
  114. // so that static analysis tools can verify type correctness.
  115. // This also enables a possible future migration to a type-safe extension API.
  116. func SetExtension(m Message, xt protoreflect.ExtensionType, v any) {
  117. xd := xt.TypeDescriptor()
  118. pv := xt.ValueOf(v)
  119. // Specially treat an invalid list, map, or message as clear.
  120. isValid := true
  121. switch {
  122. case xd.IsList():
  123. isValid = pv.List().IsValid()
  124. case xd.IsMap():
  125. isValid = pv.Map().IsValid()
  126. case xd.Message() != nil:
  127. isValid = pv.Message().IsValid()
  128. }
  129. if !isValid {
  130. m.ProtoReflect().Clear(xd)
  131. return
  132. }
  133. m.ProtoReflect().Set(xd, pv)
  134. }
  135. // RangeExtensions iterates over every populated extension field in m in an
  136. // undefined order, calling f for each extension type and value encountered.
  137. // It returns immediately if f returns false.
  138. // While iterating, mutating operations may only be performed
  139. // on the current extension field.
  140. func RangeExtensions(m Message, f func(protoreflect.ExtensionType, any) bool) {
  141. // Treat nil message interface as an empty message; nothing to range over.
  142. if m == nil {
  143. return
  144. }
  145. m.ProtoReflect().Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
  146. if fd.IsExtension() {
  147. xt := fd.(protoreflect.ExtensionTypeDescriptor).Type()
  148. vi := xt.InterfaceOf(v)
  149. return f(xt, vi)
  150. }
  151. return true
  152. })
  153. }