decode.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "math/bits"
  7. "google.golang.org/protobuf/encoding/protowire"
  8. "google.golang.org/protobuf/internal/errors"
  9. "google.golang.org/protobuf/internal/flags"
  10. "google.golang.org/protobuf/proto"
  11. "google.golang.org/protobuf/reflect/protoreflect"
  12. "google.golang.org/protobuf/reflect/protoregistry"
  13. "google.golang.org/protobuf/runtime/protoiface"
  14. )
  15. var errDecode = errors.New("cannot parse invalid wire-format data")
  16. var errRecursionDepth = errors.New("exceeded maximum recursion depth")
  17. type unmarshalOptions struct {
  18. flags protoiface.UnmarshalInputFlags
  19. resolver interface {
  20. FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
  21. FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
  22. }
  23. depth int
  24. }
  25. func (o unmarshalOptions) Options() proto.UnmarshalOptions {
  26. return proto.UnmarshalOptions{
  27. Merge: true,
  28. AllowPartial: true,
  29. DiscardUnknown: o.DiscardUnknown(),
  30. Resolver: o.resolver,
  31. }
  32. }
  33. func (o unmarshalOptions) DiscardUnknown() bool {
  34. return o.flags&protoiface.UnmarshalDiscardUnknown != 0
  35. }
  36. func (o unmarshalOptions) IsDefault() bool {
  37. return o.flags == 0 && o.resolver == protoregistry.GlobalTypes
  38. }
  39. var lazyUnmarshalOptions = unmarshalOptions{
  40. resolver: protoregistry.GlobalTypes,
  41. depth: protowire.DefaultRecursionLimit,
  42. }
  43. type unmarshalOutput struct {
  44. n int // number of bytes consumed
  45. initialized bool
  46. }
  47. // unmarshal is protoreflect.Methods.Unmarshal.
  48. func (mi *MessageInfo) unmarshal(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
  49. var p pointer
  50. if ms, ok := in.Message.(*messageState); ok {
  51. p = ms.pointer()
  52. } else {
  53. p = in.Message.(*messageReflectWrapper).pointer()
  54. }
  55. out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
  56. flags: in.Flags,
  57. resolver: in.Resolver,
  58. depth: in.Depth,
  59. })
  60. var flags protoiface.UnmarshalOutputFlags
  61. if out.initialized {
  62. flags |= protoiface.UnmarshalInitialized
  63. }
  64. return protoiface.UnmarshalOutput{
  65. Flags: flags,
  66. }, err
  67. }
  68. // errUnknown is returned during unmarshaling to indicate a parse error that
  69. // should result in a field being placed in the unknown fields section (for example,
  70. // when the wire type doesn't match) as opposed to the entire unmarshal operation
  71. // failing (for example, when a field extends past the available input).
  72. //
  73. // This is a sentinel error which should never be visible to the user.
  74. var errUnknown = errors.New("unknown")
  75. func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
  76. mi.init()
  77. opts.depth--
  78. if opts.depth < 0 {
  79. return out, errRecursionDepth
  80. }
  81. if flags.ProtoLegacy && mi.isMessageSet {
  82. return unmarshalMessageSet(mi, b, p, opts)
  83. }
  84. initialized := true
  85. var requiredMask uint64
  86. var exts *map[int32]ExtensionField
  87. start := len(b)
  88. for len(b) > 0 {
  89. // Parse the tag (field number and wire type).
  90. var tag uint64
  91. if b[0] < 0x80 {
  92. tag = uint64(b[0])
  93. b = b[1:]
  94. } else if len(b) >= 2 && b[1] < 128 {
  95. tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
  96. b = b[2:]
  97. } else {
  98. var n int
  99. tag, n = protowire.ConsumeVarint(b)
  100. if n < 0 {
  101. return out, errDecode
  102. }
  103. b = b[n:]
  104. }
  105. var num protowire.Number
  106. if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
  107. return out, errDecode
  108. } else {
  109. num = protowire.Number(n)
  110. }
  111. wtyp := protowire.Type(tag & 7)
  112. if wtyp == protowire.EndGroupType {
  113. if num != groupTag {
  114. return out, errDecode
  115. }
  116. groupTag = 0
  117. break
  118. }
  119. var f *coderFieldInfo
  120. if int(num) < len(mi.denseCoderFields) {
  121. f = mi.denseCoderFields[num]
  122. } else {
  123. f = mi.coderFields[num]
  124. }
  125. var n int
  126. err := errUnknown
  127. switch {
  128. case f != nil:
  129. if f.funcs.unmarshal == nil {
  130. break
  131. }
  132. var o unmarshalOutput
  133. o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
  134. n = o.n
  135. if err != nil {
  136. break
  137. }
  138. requiredMask |= f.validation.requiredBit
  139. if f.funcs.isInit != nil && !o.initialized {
  140. initialized = false
  141. }
  142. default:
  143. // Possible extension.
  144. if exts == nil && mi.extensionOffset.IsValid() {
  145. exts = p.Apply(mi.extensionOffset).Extensions()
  146. if *exts == nil {
  147. *exts = make(map[int32]ExtensionField)
  148. }
  149. }
  150. if exts == nil {
  151. break
  152. }
  153. var o unmarshalOutput
  154. o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
  155. if err != nil {
  156. break
  157. }
  158. n = o.n
  159. if !o.initialized {
  160. initialized = false
  161. }
  162. }
  163. if err != nil {
  164. if err != errUnknown {
  165. return out, err
  166. }
  167. n = protowire.ConsumeFieldValue(num, wtyp, b)
  168. if n < 0 {
  169. return out, errDecode
  170. }
  171. if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
  172. u := mi.mutableUnknownBytes(p)
  173. *u = protowire.AppendTag(*u, num, wtyp)
  174. *u = append(*u, b[:n]...)
  175. }
  176. }
  177. b = b[n:]
  178. }
  179. if groupTag != 0 {
  180. return out, errDecode
  181. }
  182. if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
  183. initialized = false
  184. }
  185. if initialized {
  186. out.initialized = true
  187. }
  188. out.n = start - len(b)
  189. return out, nil
  190. }
  191. func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) {
  192. x := exts[int32(num)]
  193. xt := x.Type()
  194. if xt == nil {
  195. var err error
  196. xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num)
  197. if err != nil {
  198. if err == protoregistry.NotFound {
  199. return out, errUnknown
  200. }
  201. return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err)
  202. }
  203. }
  204. xi := getExtensionFieldInfo(xt)
  205. if xi.funcs.unmarshal == nil {
  206. return out, errUnknown
  207. }
  208. if flags.LazyUnmarshalExtensions {
  209. if opts.IsDefault() && x.canLazy(xt) {
  210. out, valid := skipExtension(b, xi, num, wtyp, opts)
  211. switch valid {
  212. case ValidationValid:
  213. if out.initialized {
  214. x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
  215. exts[int32(num)] = x
  216. return out, nil
  217. }
  218. case ValidationInvalid:
  219. return out, errDecode
  220. case ValidationUnknown:
  221. }
  222. }
  223. }
  224. ival := x.Value()
  225. if !ival.IsValid() && xi.unmarshalNeedsValue {
  226. // Create a new message, list, or map value to fill in.
  227. // For enums, create a prototype value to let the unmarshal func know the
  228. // concrete type.
  229. ival = xt.New()
  230. }
  231. v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
  232. if err != nil {
  233. return out, err
  234. }
  235. if xi.funcs.isInit == nil {
  236. out.initialized = true
  237. }
  238. x.Set(xt, v)
  239. exts[int32(num)] = x
  240. return out, nil
  241. }
  242. func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
  243. if xi.validation.mi == nil {
  244. return out, ValidationUnknown
  245. }
  246. xi.validation.mi.init()
  247. switch xi.validation.typ {
  248. case validationTypeMessage:
  249. if wtyp != protowire.BytesType {
  250. return out, ValidationUnknown
  251. }
  252. v, n := protowire.ConsumeBytes(b)
  253. if n < 0 {
  254. return out, ValidationUnknown
  255. }
  256. out, st := xi.validation.mi.validate(v, 0, opts)
  257. out.n = n
  258. return out, st
  259. case validationTypeGroup:
  260. if wtyp != protowire.StartGroupType {
  261. return out, ValidationUnknown
  262. }
  263. out, st := xi.validation.mi.validate(b, num, opts)
  264. return out, st
  265. default:
  266. return out, ValidationUnknown
  267. }
  268. }