123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388 |
- // Copyright 2019 The Go Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package impl
- import (
- "reflect"
- "sort"
- "google.golang.org/protobuf/encoding/protowire"
- "google.golang.org/protobuf/internal/genid"
- "google.golang.org/protobuf/reflect/protoreflect"
- )
- type mapInfo struct {
- goType reflect.Type
- keyWiretag uint64
- valWiretag uint64
- keyFuncs valueCoderFuncs
- valFuncs valueCoderFuncs
- keyZero protoreflect.Value
- keyKind protoreflect.Kind
- conv *mapConverter
- }
- func encoderFuncsForMap(fd protoreflect.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) {
- // TODO: Consider generating specialized map coders.
- keyField := fd.MapKey()
- valField := fd.MapValue()
- keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()])
- valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()])
- keyFuncs := encoderFuncsForValue(keyField)
- valFuncs := encoderFuncsForValue(valField)
- conv := newMapConverter(ft, fd)
- mapi := &mapInfo{
- goType: ft,
- keyWiretag: keyWiretag,
- valWiretag: valWiretag,
- keyFuncs: keyFuncs,
- valFuncs: valFuncs,
- keyZero: keyField.Default(),
- keyKind: keyField.Kind(),
- conv: conv,
- }
- if valField.Kind() == protoreflect.MessageKind {
- valueMessage = getMessageInfo(ft.Elem())
- }
- funcs = pointerCoderFuncs{
- size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
- return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts)
- },
- marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
- return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts)
- },
- unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
- mp := p.AsValueOf(ft)
- if mp.Elem().IsNil() {
- mp.Elem().Set(reflect.MakeMap(mapi.goType))
- }
- if f.mi == nil {
- return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts)
- } else {
- return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts)
- }
- },
- }
- switch valField.Kind() {
- case protoreflect.MessageKind:
- funcs.merge = mergeMapOfMessage
- case protoreflect.BytesKind:
- funcs.merge = mergeMapOfBytes
- default:
- funcs.merge = mergeMap
- }
- if valFuncs.isInit != nil {
- funcs.isInit = func(p pointer, f *coderFieldInfo) error {
- return isInitMap(p.AsValueOf(ft).Elem(), mapi, f)
- }
- }
- return valueMessage, funcs
- }
- const (
- mapKeyTagSize = 1 // field 1, tag size 1.
- mapValTagSize = 1 // field 2, tag size 2.
- )
- func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int {
- if mapv.Len() == 0 {
- return 0
- }
- n := 0
- iter := mapRange(mapv)
- for iter.Next() {
- key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
- keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
- var valSize int
- value := mapi.conv.valConv.PBValueOf(iter.Value())
- if f.mi == nil {
- valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
- } else {
- p := pointerOfValue(iter.Value())
- valSize += mapValTagSize
- valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts))
- }
- n += f.tagsize + protowire.SizeBytes(keySize+valSize)
- }
- return n
- }
- func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
- if wtyp != protowire.BytesType {
- return out, errUnknown
- }
- b, n := protowire.ConsumeBytes(b)
- if n < 0 {
- return out, errDecode
- }
- var (
- key = mapi.keyZero
- val = mapi.conv.valConv.New()
- )
- for len(b) > 0 {
- num, wtyp, n := protowire.ConsumeTag(b)
- if n < 0 {
- return out, errDecode
- }
- if num > protowire.MaxValidNumber {
- return out, errDecode
- }
- b = b[n:]
- err := errUnknown
- switch num {
- case genid.MapEntry_Key_field_number:
- var v protoreflect.Value
- var o unmarshalOutput
- v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
- if err != nil {
- break
- }
- key = v
- n = o.n
- case genid.MapEntry_Value_field_number:
- var v protoreflect.Value
- var o unmarshalOutput
- v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
- if err != nil {
- break
- }
- val = v
- n = o.n
- }
- if err == errUnknown {
- n = protowire.ConsumeFieldValue(num, wtyp, b)
- if n < 0 {
- return out, errDecode
- }
- } else if err != nil {
- return out, err
- }
- b = b[n:]
- }
- mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
- out.n = n
- return out, nil
- }
- func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
- if wtyp != protowire.BytesType {
- return out, errUnknown
- }
- b, n := protowire.ConsumeBytes(b)
- if n < 0 {
- return out, errDecode
- }
- var (
- key = mapi.keyZero
- val = reflect.New(f.mi.GoReflectType.Elem())
- )
- for len(b) > 0 {
- num, wtyp, n := protowire.ConsumeTag(b)
- if n < 0 {
- return out, errDecode
- }
- if num > protowire.MaxValidNumber {
- return out, errDecode
- }
- b = b[n:]
- err := errUnknown
- switch num {
- case 1:
- var v protoreflect.Value
- var o unmarshalOutput
- v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
- if err != nil {
- break
- }
- key = v
- n = o.n
- case 2:
- if wtyp != protowire.BytesType {
- break
- }
- var v []byte
- v, n = protowire.ConsumeBytes(b)
- if n < 0 {
- return out, errDecode
- }
- var o unmarshalOutput
- o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
- if o.initialized {
- // Consider this map item initialized so long as we see
- // an initialized value.
- out.initialized = true
- }
- }
- if err == errUnknown {
- n = protowire.ConsumeFieldValue(num, wtyp, b)
- if n < 0 {
- return out, errDecode
- }
- } else if err != nil {
- return out, err
- }
- b = b[n:]
- }
- mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
- out.n = n
- return out, nil
- }
- func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
- if f.mi == nil {
- key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
- val := mapi.conv.valConv.PBValueOf(valrv)
- size := 0
- size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
- size += mapi.valFuncs.size(val, mapValTagSize, opts)
- b = protowire.AppendVarint(b, uint64(size))
- b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
- if err != nil {
- return nil, err
- }
- return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
- } else {
- key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
- val := pointerOfValue(valrv)
- valSize := f.mi.sizePointer(val, opts)
- size := 0
- size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
- size += mapValTagSize + protowire.SizeBytes(valSize)
- b = protowire.AppendVarint(b, uint64(size))
- b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
- if err != nil {
- return nil, err
- }
- b = protowire.AppendVarint(b, mapi.valWiretag)
- b = protowire.AppendVarint(b, uint64(valSize))
- return f.mi.marshalAppendPointer(b, val, opts)
- }
- }
- func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
- if mapv.Len() == 0 {
- return b, nil
- }
- if opts.Deterministic() {
- return appendMapDeterministic(b, mapv, mapi, f, opts)
- }
- iter := mapRange(mapv)
- for iter.Next() {
- var err error
- b = protowire.AppendVarint(b, f.wiretag)
- b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts)
- if err != nil {
- return b, err
- }
- }
- return b, nil
- }
- func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
- keys := mapv.MapKeys()
- sort.Slice(keys, func(i, j int) bool {
- switch keys[i].Kind() {
- case reflect.Bool:
- return !keys[i].Bool() && keys[j].Bool()
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- return keys[i].Int() < keys[j].Int()
- case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
- return keys[i].Uint() < keys[j].Uint()
- case reflect.Float32, reflect.Float64:
- return keys[i].Float() < keys[j].Float()
- case reflect.String:
- return keys[i].String() < keys[j].String()
- default:
- panic("invalid kind: " + keys[i].Kind().String())
- }
- })
- for _, key := range keys {
- var err error
- b = protowire.AppendVarint(b, f.wiretag)
- b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts)
- if err != nil {
- return b, err
- }
- }
- return b, nil
- }
- func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error {
- if mi := f.mi; mi != nil {
- mi.init()
- if !mi.needsInitCheck {
- return nil
- }
- iter := mapRange(mapv)
- for iter.Next() {
- val := pointerOfValue(iter.Value())
- if err := mi.checkInitializedPointer(val); err != nil {
- return err
- }
- }
- } else {
- iter := mapRange(mapv)
- for iter.Next() {
- val := mapi.conv.valConv.PBValueOf(iter.Value())
- if err := mapi.valFuncs.isInit(val); err != nil {
- return err
- }
- }
- }
- return nil
- }
- func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
- dstm := dst.AsValueOf(f.ft).Elem()
- srcm := src.AsValueOf(f.ft).Elem()
- if srcm.Len() == 0 {
- return
- }
- if dstm.IsNil() {
- dstm.Set(reflect.MakeMap(f.ft))
- }
- iter := mapRange(srcm)
- for iter.Next() {
- dstm.SetMapIndex(iter.Key(), iter.Value())
- }
- }
- func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
- dstm := dst.AsValueOf(f.ft).Elem()
- srcm := src.AsValueOf(f.ft).Elem()
- if srcm.Len() == 0 {
- return
- }
- if dstm.IsNil() {
- dstm.Set(reflect.MakeMap(f.ft))
- }
- iter := mapRange(srcm)
- for iter.Next() {
- dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...)))
- }
- }
- func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
- dstm := dst.AsValueOf(f.ft).Elem()
- srcm := src.AsValueOf(f.ft).Elem()
- if srcm.Len() == 0 {
- return
- }
- if dstm.IsNil() {
- dstm.Set(reflect.MakeMap(f.ft))
- }
- iter := mapRange(srcm)
- for iter.Next() {
- val := reflect.New(f.ft.Elem().Elem())
- if f.mi != nil {
- f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts)
- } else {
- opts.Merge(asMessage(val), asMessage(iter.Value()))
- }
- dstm.SetMapIndex(iter.Key(), val)
- }
- }
|