build.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. // Copyright 2020 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // Package protobuild constructs messages.
  5. //
  6. // This package is used to construct multiple types of message with a similar shape
  7. // from a common template.
  8. package protobuild
  9. import (
  10. "fmt"
  11. "math"
  12. "reflect"
  13. "google.golang.org/protobuf/reflect/protoreflect"
  14. "google.golang.org/protobuf/reflect/protoregistry"
  15. )
  16. // A Value is a value assignable to a field.
  17. // A Value may be a value accepted by protoreflect.ValueOf. In addition:
  18. //
  19. // • An int may be assigned to any numeric field.
  20. //
  21. // • A float64 may be assigned to a double field.
  22. //
  23. // • Either a string or []byte may be assigned to a string or bytes field.
  24. //
  25. // • A string containing the value name may be assigned to an enum field.
  26. //
  27. // • A slice may be assigned to a list, and a map may be assigned to a map.
  28. type Value interface{}
  29. // A Message is a template to apply to a message. Keys are field names, including
  30. // extension names.
  31. type Message map[protoreflect.Name]Value
  32. // Unknown is a key associated with the unknown fields of a message.
  33. // The value should be a []byte.
  34. const Unknown = "@unknown"
  35. // Build applies the template to a message.
  36. func (template Message) Build(m protoreflect.Message) {
  37. md := m.Descriptor()
  38. fields := md.Fields()
  39. exts := make(map[protoreflect.Name]protoreflect.FieldDescriptor)
  40. protoregistry.GlobalTypes.RangeExtensionsByMessage(md.FullName(), func(xt protoreflect.ExtensionType) bool {
  41. xd := xt.TypeDescriptor()
  42. exts[xd.Name()] = xd
  43. return true
  44. })
  45. for k, v := range template {
  46. if k == Unknown {
  47. m.SetUnknown(protoreflect.RawFields(v.([]byte)))
  48. continue
  49. }
  50. fd := fields.ByName(k)
  51. if fd == nil {
  52. fd = exts[k]
  53. }
  54. if fd == nil {
  55. panic(fmt.Sprintf("%v.%v: not found", md.FullName(), k))
  56. }
  57. switch {
  58. case fd.IsList():
  59. list := m.Mutable(fd).List()
  60. s := reflect.ValueOf(v)
  61. for i := 0; i < s.Len(); i++ {
  62. if fd.Message() == nil {
  63. list.Append(fieldValue(fd, s.Index(i).Interface()))
  64. } else {
  65. e := list.NewElement()
  66. s.Index(i).Interface().(Message).Build(e.Message())
  67. list.Append(e)
  68. }
  69. }
  70. case fd.IsMap():
  71. mapv := m.Mutable(fd).Map()
  72. rm := reflect.ValueOf(v)
  73. for _, k := range rm.MapKeys() {
  74. mk := fieldValue(fd.MapKey(), k.Interface()).MapKey()
  75. if fd.MapValue().Message() == nil {
  76. mv := fieldValue(fd.MapValue(), rm.MapIndex(k).Interface())
  77. mapv.Set(mk, mv)
  78. } else if mapv.Has(mk) {
  79. mv := mapv.Get(mk).Message()
  80. rm.MapIndex(k).Interface().(Message).Build(mv)
  81. } else {
  82. mv := mapv.NewValue()
  83. rm.MapIndex(k).Interface().(Message).Build(mv.Message())
  84. mapv.Set(mk, mv)
  85. }
  86. }
  87. default:
  88. if fd.Message() == nil {
  89. m.Set(fd, fieldValue(fd, v))
  90. } else {
  91. v.(Message).Build(m.Mutable(fd).Message())
  92. }
  93. }
  94. }
  95. }
  96. func fieldValue(fd protoreflect.FieldDescriptor, v interface{}) protoreflect.Value {
  97. switch o := v.(type) {
  98. case int:
  99. switch fd.Kind() {
  100. case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
  101. if o < math.MinInt32 || math.MaxInt32 < o {
  102. panic(fmt.Sprintf("%v: value %v out of range [%v, %v]", fd.FullName(), o, int32(math.MinInt32), int32(math.MaxInt32)))
  103. }
  104. v = int32(o)
  105. case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
  106. if o < 0 || math.MaxUint32 < 0 {
  107. panic(fmt.Sprintf("%v: value %v out of range [%v, %v]", fd.FullName(), o, uint32(0), uint32(math.MaxUint32)))
  108. }
  109. v = uint32(o)
  110. case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
  111. v = int64(o)
  112. case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
  113. if o < 0 {
  114. panic(fmt.Sprintf("%v: value %v out of range [%v, %v]", fd.FullName(), o, uint64(0), uint64(math.MaxUint64)))
  115. }
  116. v = uint64(o)
  117. case protoreflect.FloatKind:
  118. v = float32(o)
  119. case protoreflect.DoubleKind:
  120. v = float64(o)
  121. case protoreflect.EnumKind:
  122. v = protoreflect.EnumNumber(o)
  123. default:
  124. panic(fmt.Sprintf("%v: invalid value type int", fd.FullName()))
  125. }
  126. case float64:
  127. switch fd.Kind() {
  128. case protoreflect.FloatKind:
  129. v = float32(o)
  130. }
  131. case string:
  132. switch fd.Kind() {
  133. case protoreflect.BytesKind:
  134. v = []byte(o)
  135. case protoreflect.EnumKind:
  136. v = fd.Enum().Values().ByName(protoreflect.Name(o)).Number()
  137. }
  138. case []byte:
  139. return protoreflect.ValueOf(append([]byte{}, o...))
  140. }
  141. return protoreflect.ValueOf(v)
  142. }