123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- // Copyright 2018 The Go Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package proto_test
- import (
- "bytes"
- "fmt"
- "reflect"
- "testing"
- "google.golang.org/protobuf/encoding/prototext"
- "google.golang.org/protobuf/proto"
- "google.golang.org/protobuf/reflect/protoreflect"
- "google.golang.org/protobuf/testing/protopack"
- "google.golang.org/protobuf/internal/errors"
- testpb "google.golang.org/protobuf/internal/testprotos/test"
- test3pb "google.golang.org/protobuf/internal/testprotos/test3"
- )
- func TestDecode(t *testing.T) {
- for _, test := range testValidMessages {
- if len(test.decodeTo) == 0 {
- t.Errorf("%v: no test message types", test.desc)
- }
- for _, want := range test.decodeTo {
- t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
- opts := test.unmarshalOptions
- opts.AllowPartial = test.partial
- wire := append(([]byte)(nil), test.wire...)
- got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
- if err := opts.Unmarshal(wire, got); err != nil {
- t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, prototext.Format(want))
- return
- }
- // Aliasing check: Unmarshal shouldn't modify the original wire
- // bytes, and modifying the original wire bytes shouldn't affect
- // the unmarshaled message.
- if !bytes.Equal(test.wire, wire) {
- t.Errorf("Unmarshal unexpectedly modified its input")
- }
- for i := range wire {
- wire[i] = 0
- }
- if !proto.Equal(got, want) && got.ProtoReflect().IsValid() && want.ProtoReflect().IsValid() {
- t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", prototext.Format(got), prototext.Format(want))
- }
- })
- }
- }
- }
- func TestDecodeRequiredFieldChecks(t *testing.T) {
- for _, test := range testValidMessages {
- if !test.partial {
- continue
- }
- for _, m := range test.decodeTo {
- t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) {
- opts := test.unmarshalOptions
- opts.AllowPartial = false
- got := reflect.New(reflect.TypeOf(m).Elem()).Interface().(proto.Message)
- if err := proto.Unmarshal(test.wire, got); err == nil {
- t.Fatalf("Unmarshal succeeded (want error)\nMessage:\n%v", prototext.Format(got))
- }
- })
- }
- }
- }
- func TestDecodeInvalidMessages(t *testing.T) {
- for _, test := range testInvalidMessages {
- if len(test.decodeTo) == 0 {
- t.Errorf("%v: no test message types", test.desc)
- }
- for _, want := range test.decodeTo {
- t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
- opts := test.unmarshalOptions
- opts.AllowPartial = test.partial
- got := want.ProtoReflect().New().Interface()
- if err := opts.Unmarshal(test.wire, got); err == nil {
- t.Errorf("Unmarshal unexpectedly succeeded\ninput bytes: [%x]\nMessage:\n%v", test.wire, prototext.Format(got))
- } else if !errors.Is(err, proto.Error) {
- t.Errorf("Unmarshal error is not a proto.Error: %v", err)
- }
- })
- }
- }
- }
- func TestDecodeZeroLengthBytes(t *testing.T) {
- // Verify that proto3 bytes fields don't give the mistaken
- // impression that they preserve presence.
- wire := protopack.Message{
- protopack.Tag{94, protopack.BytesType}, protopack.Bytes(nil),
- }.Marshal()
- m := &test3pb.TestAllTypes{}
- if err := proto.Unmarshal(wire, m); err != nil {
- t.Fatal(err)
- }
- if m.OptionalBytes != nil {
- t.Errorf("unmarshal zero-length proto3 bytes field: got %v, want nil", m.OptionalBytes)
- }
- }
- func TestDecodeOneofNilWrapper(t *testing.T) {
- wire := protopack.Message{
- protopack.Tag{111, protopack.VarintType}, protopack.Varint(1111),
- }.Marshal()
- m := &testpb.TestAllTypes{OneofField: (*testpb.TestAllTypes_OneofUint32)(nil)}
- if err := proto.Unmarshal(wire, m); err != nil {
- t.Fatal(err)
- }
- if got := m.GetOneofUint32(); got != 1111 {
- t.Errorf("GetOneofUint32() = %v, want %v", got, 1111)
- }
- }
- func TestDecodeEmptyBytes(t *testing.T) {
- // There's really nothing wrong with a nil entry in a [][]byte,
- // but we take care to produce non-nil []bytes for zero-length
- // byte strings, so test for it.
- m := &testpb.TestAllTypes{}
- b := protopack.Message{
- protopack.Tag{45, protopack.BytesType}, protopack.Bytes(nil),
- }.Marshal()
- if err := proto.Unmarshal(b, m); err != nil {
- t.Fatal(err)
- }
- if m.RepeatedBytes[0] == nil {
- t.Errorf("unmarshaling repeated bytes field containing zero-length value: Got nil bytes, want non-nil")
- }
- }
- func build(m proto.Message, opts ...buildOpt) proto.Message {
- for _, opt := range opts {
- opt(m)
- }
- return m
- }
- type buildOpt func(proto.Message)
- func unknown(raw protoreflect.RawFields) buildOpt {
- return func(m proto.Message) {
- m.ProtoReflect().SetUnknown(raw)
- }
- }
- func extend(desc protoreflect.ExtensionType, value interface{}) buildOpt {
- return func(m proto.Message) {
- proto.SetExtension(m, desc, value)
- }
- }
|