extension.go 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package proto
  5. import (
  6. "google.golang.org/protobuf/reflect/protoreflect"
  7. )
  8. // HasExtension reports whether an extension field is populated.
  9. // It returns false if m is invalid or if xt does not extend m.
  10. func HasExtension(m Message, xt protoreflect.ExtensionType) bool {
  11. // Treat nil message interface or descriptor as an empty message; no populated
  12. // fields.
  13. if m == nil || xt == nil {
  14. return false
  15. }
  16. // As a special-case, we reports invalid or mismatching descriptors
  17. // as always not being populated (since they aren't).
  18. mr := m.ProtoReflect()
  19. xd := xt.TypeDescriptor()
  20. if mr.Descriptor() != xd.ContainingMessage() {
  21. return false
  22. }
  23. return mr.Has(xd)
  24. }
  25. // ClearExtension clears an extension field such that subsequent
  26. // [HasExtension] calls return false.
  27. // It panics if m is invalid or if xt does not extend m.
  28. func ClearExtension(m Message, xt protoreflect.ExtensionType) {
  29. m.ProtoReflect().Clear(xt.TypeDescriptor())
  30. }
  31. // GetExtension retrieves the value for an extension field.
  32. // If the field is unpopulated, it returns the default value for
  33. // scalars and an immutable, empty value for lists or messages.
  34. // It panics if xt does not extend m.
  35. func GetExtension(m Message, xt protoreflect.ExtensionType) any {
  36. // Treat nil message interface as an empty message; return the default.
  37. if m == nil {
  38. return xt.InterfaceOf(xt.Zero())
  39. }
  40. return xt.InterfaceOf(m.ProtoReflect().Get(xt.TypeDescriptor()))
  41. }
  42. // SetExtension stores the value of an extension field.
  43. // It panics if m is invalid, xt does not extend m, or if type of v
  44. // is invalid for the specified extension field.
  45. func SetExtension(m Message, xt protoreflect.ExtensionType, v any) {
  46. xd := xt.TypeDescriptor()
  47. pv := xt.ValueOf(v)
  48. // Specially treat an invalid list, map, or message as clear.
  49. isValid := true
  50. switch {
  51. case xd.IsList():
  52. isValid = pv.List().IsValid()
  53. case xd.IsMap():
  54. isValid = pv.Map().IsValid()
  55. case xd.Message() != nil:
  56. isValid = pv.Message().IsValid()
  57. }
  58. if !isValid {
  59. m.ProtoReflect().Clear(xd)
  60. return
  61. }
  62. m.ProtoReflect().Set(xd, pv)
  63. }
  64. // RangeExtensions iterates over every populated extension field in m in an
  65. // undefined order, calling f for each extension type and value encountered.
  66. // It returns immediately if f returns false.
  67. // While iterating, mutating operations may only be performed
  68. // on the current extension field.
  69. func RangeExtensions(m Message, f func(protoreflect.ExtensionType, any) bool) {
  70. // Treat nil message interface as an empty message; nothing to range over.
  71. if m == nil {
  72. return
  73. }
  74. m.ProtoReflect().Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
  75. if fd.IsExtension() {
  76. xt := fd.(protoreflect.ExtensionTypeDescriptor).Type()
  77. vi := xt.InterfaceOf(v)
  78. return f(xt, vi)
  79. }
  80. return true
  81. })
  82. }