decode_test.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. // Copyright 2018 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package proto_test
  5. import (
  6. "bytes"
  7. "fmt"
  8. "reflect"
  9. "testing"
  10. "google.golang.org/protobuf/encoding/prototext"
  11. "google.golang.org/protobuf/proto"
  12. "google.golang.org/protobuf/reflect/protoreflect"
  13. "google.golang.org/protobuf/testing/protopack"
  14. "google.golang.org/protobuf/internal/errors"
  15. testpb "google.golang.org/protobuf/internal/testprotos/test"
  16. test3pb "google.golang.org/protobuf/internal/testprotos/test3"
  17. )
  18. func TestDecode(t *testing.T) {
  19. for _, test := range testValidMessages {
  20. if len(test.decodeTo) == 0 {
  21. t.Errorf("%v: no test message types", test.desc)
  22. }
  23. for _, want := range test.decodeTo {
  24. t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
  25. opts := test.unmarshalOptions
  26. opts.AllowPartial = test.partial
  27. wire := append(([]byte)(nil), test.wire...)
  28. got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
  29. if err := opts.Unmarshal(wire, got); err != nil {
  30. t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, prototext.Format(want))
  31. return
  32. }
  33. // Aliasing check: Unmarshal shouldn't modify the original wire
  34. // bytes, and modifying the original wire bytes shouldn't affect
  35. // the unmarshaled message.
  36. if !bytes.Equal(test.wire, wire) {
  37. t.Errorf("Unmarshal unexpectedly modified its input")
  38. }
  39. for i := range wire {
  40. wire[i] = 0
  41. }
  42. if !proto.Equal(got, want) && got.ProtoReflect().IsValid() && want.ProtoReflect().IsValid() {
  43. t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", prototext.Format(got), prototext.Format(want))
  44. }
  45. })
  46. }
  47. }
  48. }
  49. func TestDecodeRequiredFieldChecks(t *testing.T) {
  50. for _, test := range testValidMessages {
  51. if !test.partial {
  52. continue
  53. }
  54. for _, m := range test.decodeTo {
  55. t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) {
  56. opts := test.unmarshalOptions
  57. opts.AllowPartial = false
  58. got := reflect.New(reflect.TypeOf(m).Elem()).Interface().(proto.Message)
  59. if err := proto.Unmarshal(test.wire, got); err == nil {
  60. t.Fatalf("Unmarshal succeeded (want error)\nMessage:\n%v", prototext.Format(got))
  61. }
  62. })
  63. }
  64. }
  65. }
  66. func TestDecodeInvalidMessages(t *testing.T) {
  67. for _, test := range testInvalidMessages {
  68. if len(test.decodeTo) == 0 {
  69. t.Errorf("%v: no test message types", test.desc)
  70. }
  71. for _, want := range test.decodeTo {
  72. t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
  73. opts := test.unmarshalOptions
  74. opts.AllowPartial = test.partial
  75. got := want.ProtoReflect().New().Interface()
  76. if err := opts.Unmarshal(test.wire, got); err == nil {
  77. t.Errorf("Unmarshal unexpectedly succeeded\ninput bytes: [%x]\nMessage:\n%v", test.wire, prototext.Format(got))
  78. } else if !errors.Is(err, proto.Error) {
  79. t.Errorf("Unmarshal error is not a proto.Error: %v", err)
  80. }
  81. })
  82. }
  83. }
  84. }
  85. func TestDecodeZeroLengthBytes(t *testing.T) {
  86. // Verify that proto3 bytes fields don't give the mistaken
  87. // impression that they preserve presence.
  88. wire := protopack.Message{
  89. protopack.Tag{94, protopack.BytesType}, protopack.Bytes(nil),
  90. }.Marshal()
  91. m := &test3pb.TestAllTypes{}
  92. if err := proto.Unmarshal(wire, m); err != nil {
  93. t.Fatal(err)
  94. }
  95. if m.OptionalBytes != nil {
  96. t.Errorf("unmarshal zero-length proto3 bytes field: got %v, want nil", m.OptionalBytes)
  97. }
  98. }
  99. func TestDecodeOneofNilWrapper(t *testing.T) {
  100. wire := protopack.Message{
  101. protopack.Tag{111, protopack.VarintType}, protopack.Varint(1111),
  102. }.Marshal()
  103. m := &testpb.TestAllTypes{OneofField: (*testpb.TestAllTypes_OneofUint32)(nil)}
  104. if err := proto.Unmarshal(wire, m); err != nil {
  105. t.Fatal(err)
  106. }
  107. if got := m.GetOneofUint32(); got != 1111 {
  108. t.Errorf("GetOneofUint32() = %v, want %v", got, 1111)
  109. }
  110. }
  111. func TestDecodeEmptyBytes(t *testing.T) {
  112. // There's really nothing wrong with a nil entry in a [][]byte,
  113. // but we take care to produce non-nil []bytes for zero-length
  114. // byte strings, so test for it.
  115. m := &testpb.TestAllTypes{}
  116. b := protopack.Message{
  117. protopack.Tag{45, protopack.BytesType}, protopack.Bytes(nil),
  118. }.Marshal()
  119. if err := proto.Unmarshal(b, m); err != nil {
  120. t.Fatal(err)
  121. }
  122. if m.RepeatedBytes[0] == nil {
  123. t.Errorf("unmarshaling repeated bytes field containing zero-length value: Got nil bytes, want non-nil")
  124. }
  125. }
  126. func build(m proto.Message, opts ...buildOpt) proto.Message {
  127. for _, opt := range opts {
  128. opt(m)
  129. }
  130. return m
  131. }
  132. type buildOpt func(proto.Message)
  133. func unknown(raw protoreflect.RawFields) buildOpt {
  134. return func(m proto.Message) {
  135. m.ProtoReflect().SetUnknown(raw)
  136. }
  137. }
  138. func extend(desc protoreflect.ExtensionType, value interface{}) buildOpt {
  139. return func(m proto.Message) {
  140. proto.SetExtension(m, desc, value)
  141. }
  142. }