123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176 |
- // Copyright 2019 The Go Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package impl
- import (
- "reflect"
- "google.golang.org/protobuf/internal/encoding/wire"
- "google.golang.org/protobuf/internal/mapsort"
- pref "google.golang.org/protobuf/reflect/protoreflect"
- )
- type mapInfo struct {
- goType reflect.Type
- keyWiretag uint64
- valWiretag uint64
- keyFuncs valueCoderFuncs
- valFuncs valueCoderFuncs
- keyZero pref.Value
- keyKind pref.Kind
- }
- func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
- // TODO: Consider generating specialized map coders.
- keyField := fd.MapKey()
- valField := fd.MapValue()
- keyWiretag := wire.EncodeTag(1, wireTypes[keyField.Kind()])
- valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
- keyFuncs := encoderFuncsForValue(keyField)
- valFuncs := encoderFuncsForValue(valField)
- conv := NewConverter(ft, fd)
- mapi := &mapInfo{
- goType: ft,
- keyWiretag: keyWiretag,
- valWiretag: valWiretag,
- keyFuncs: keyFuncs,
- valFuncs: valFuncs,
- keyZero: keyField.Default(),
- keyKind: keyField.Kind(),
- }
- funcs = pointerCoderFuncs{
- size: func(p pointer, tagsize int, opts marshalOptions) int {
- mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
- return sizeMap(mapv, tagsize, mapi, opts)
- },
- marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
- mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
- return appendMap(b, mapv, wiretag, mapi, opts)
- },
- unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
- mp := p.AsValueOf(ft)
- if mp.Elem().IsNil() {
- mp.Elem().Set(reflect.MakeMap(mapi.goType))
- }
- mapv := conv.PBValueOf(mp.Elem()).Map()
- return consumeMap(b, mapv, wtyp, mapi, opts)
- },
- }
- if valFuncs.isInit != nil {
- funcs.isInit = func(p pointer) error {
- mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
- return isInitMap(mapv, mapi)
- }
- }
- return funcs
- }
- const (
- mapKeyTagSize = 1 // field 1, tag size 1.
- mapValTagSize = 1 // field 2, tag size 2.
- )
- func sizeMap(mapv pref.Map, tagsize int, mapi *mapInfo, opts marshalOptions) int {
- if mapv.Len() == 0 {
- return 0
- }
- n := 0
- mapv.Range(func(key pref.MapKey, value pref.Value) bool {
- n += tagsize + wire.SizeBytes(
- mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)+
- mapi.valFuncs.size(value, mapValTagSize, opts))
- return true
- })
- return n
- }
- func consumeMap(b []byte, mapv pref.Map, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
- if wtyp != wire.BytesType {
- return 0, errUnknown
- }
- b, n := wire.ConsumeBytes(b)
- if n < 0 {
- return 0, wire.ParseError(n)
- }
- var (
- key = mapi.keyZero
- val = mapv.NewValue()
- )
- for len(b) > 0 {
- num, wtyp, n := wire.ConsumeTag(b)
- if n < 0 {
- return 0, wire.ParseError(n)
- }
- b = b[n:]
- err := errUnknown
- switch num {
- case 1:
- var v pref.Value
- v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
- if err != nil {
- break
- }
- key = v
- case 2:
- var v pref.Value
- v, n, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
- if err != nil {
- break
- }
- val = v
- }
- if err == errUnknown {
- n = wire.ConsumeFieldValue(num, wtyp, b)
- if n < 0 {
- return 0, wire.ParseError(n)
- }
- } else if err != nil {
- return 0, err
- }
- b = b[n:]
- }
- mapv.Set(key.MapKey(), val)
- return n, nil
- }
- func appendMap(b []byte, mapv pref.Map, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
- if mapv.Len() == 0 {
- return b, nil
- }
- var err error
- fn := func(key pref.MapKey, value pref.Value) bool {
- b = wire.AppendVarint(b, wiretag)
- size := 0
- size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
- size += mapi.valFuncs.size(value, mapValTagSize, opts)
- b = wire.AppendVarint(b, uint64(size))
- b, err = mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
- if err != nil {
- return false
- }
- b, err = mapi.valFuncs.marshal(b, value, mapi.valWiretag, opts)
- if err != nil {
- return false
- }
- return true
- }
- if opts.Deterministic() {
- mapsort.Range(mapv, mapi.keyKind, fn)
- } else {
- mapv.Range(fn)
- }
- return b, err
- }
- func isInitMap(mapv pref.Map, mapi *mapInfo) error {
- var err error
- mapv.Range(func(_ pref.MapKey, value pref.Value) bool {
- err = mapi.valFuncs.isInit(value)
- return err == nil
- })
- return err
- }
|