codec_map.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "fmt"
  7. "reflect"
  8. "sort"
  9. "google.golang.org/protobuf/internal/encoding/wire"
  10. "google.golang.org/protobuf/proto"
  11. pref "google.golang.org/protobuf/reflect/protoreflect"
  12. )
  13. var protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
  14. type mapInfo struct {
  15. goType reflect.Type
  16. keyWiretag uint64
  17. valWiretag uint64
  18. keyFuncs ifaceCoderFuncs
  19. valFuncs ifaceCoderFuncs
  20. keyZero interface{}
  21. valZero interface{}
  22. newVal func() interface{}
  23. }
  24. func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
  25. // TODO: Consider generating specialized map coders.
  26. keyField := fd.MapKey()
  27. valField := fd.MapValue()
  28. keyWiretag := wire.EncodeTag(1, wireTypes[keyField.Kind()])
  29. valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
  30. keyFuncs := encoderFuncsForValue(keyField, ft.Key())
  31. valFuncs := encoderFuncsForValue(valField, ft.Elem())
  32. mapi := &mapInfo{
  33. goType: ft,
  34. keyWiretag: keyWiretag,
  35. valWiretag: valWiretag,
  36. keyFuncs: keyFuncs,
  37. valFuncs: valFuncs,
  38. keyZero: reflect.Zero(ft.Key()).Interface(),
  39. valZero: reflect.Zero(ft.Elem()).Interface(),
  40. }
  41. switch valField.Kind() {
  42. case pref.GroupKind, pref.MessageKind:
  43. mapi.newVal = func() interface{} {
  44. return reflect.New(ft.Elem().Elem()).Interface()
  45. }
  46. }
  47. funcs = pointerCoderFuncs{
  48. size: func(p pointer, tagsize int, opts marshalOptions) int {
  49. return sizeMap(p, tagsize, ft, keyFuncs, valFuncs, opts)
  50. },
  51. marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
  52. return appendMap(b, p, wiretag, keyWiretag, valWiretag, ft, keyFuncs, valFuncs, opts)
  53. },
  54. unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
  55. return consumeMap(b, p, wtyp, mapi, opts)
  56. },
  57. }
  58. if valFuncs.isInit != nil {
  59. funcs.isInit = func(p pointer) error {
  60. return isInitMap(p, ft, valFuncs.isInit)
  61. }
  62. }
  63. return funcs
  64. }
  65. const (
  66. mapKeyTagSize = 1 // field 1, tag size 1.
  67. mapValTagSize = 1 // field 2, tag size 2.
  68. )
  69. func consumeMap(b []byte, p pointer, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
  70. mp := p.AsValueOf(mapi.goType)
  71. if mp.Elem().IsNil() {
  72. mp.Elem().Set(reflect.MakeMap(mapi.goType))
  73. }
  74. m := mp.Elem()
  75. if wtyp != wire.BytesType {
  76. return 0, errUnknown
  77. }
  78. b, n := wire.ConsumeBytes(b)
  79. if n < 0 {
  80. return 0, wire.ParseError(n)
  81. }
  82. var (
  83. key = mapi.keyZero
  84. val = mapi.valZero
  85. )
  86. if mapi.newVal != nil {
  87. val = mapi.newVal()
  88. }
  89. for len(b) > 0 {
  90. num, wtyp, n := wire.ConsumeTag(b)
  91. if n < 0 {
  92. return 0, wire.ParseError(n)
  93. }
  94. b = b[n:]
  95. err := errUnknown
  96. switch num {
  97. case 1:
  98. var v interface{}
  99. v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
  100. if err != nil {
  101. break
  102. }
  103. key = v
  104. case 2:
  105. var v interface{}
  106. v, n, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
  107. if err != nil {
  108. break
  109. }
  110. val = v
  111. }
  112. if err == errUnknown {
  113. n = wire.ConsumeFieldValue(num, wtyp, b)
  114. if n < 0 {
  115. return 0, wire.ParseError(n)
  116. }
  117. } else if err != nil {
  118. return 0, err
  119. }
  120. b = b[n:]
  121. }
  122. m.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(val))
  123. return n, nil
  124. }
  125. func sizeMap(p pointer, tagsize int, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) int {
  126. m := p.AsValueOf(goType).Elem()
  127. n := 0
  128. if m.Len() == 0 {
  129. return 0
  130. }
  131. iter := mapRange(m)
  132. for iter.Next() {
  133. ki := iter.Key().Interface()
  134. vi := iter.Value().Interface()
  135. size := keyFuncs.size(ki, mapKeyTagSize, opts) + valFuncs.size(vi, mapValTagSize, opts)
  136. n += wire.SizeBytes(size) + tagsize
  137. }
  138. return n
  139. }
  140. func appendMap(b []byte, p pointer, wiretag, keyWiretag, valWiretag uint64, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) ([]byte, error) {
  141. m := p.AsValueOf(goType).Elem()
  142. var err error
  143. if m.Len() == 0 {
  144. return b, nil
  145. }
  146. if opts.Deterministic() {
  147. keys := m.MapKeys()
  148. sort.Sort(mapKeys(keys))
  149. for _, k := range keys {
  150. b, err = appendMapElement(b, k, m.MapIndex(k), wiretag, keyWiretag, valWiretag, keyFuncs, valFuncs, opts)
  151. if err != nil {
  152. return b, err
  153. }
  154. }
  155. return b, nil
  156. }
  157. iter := mapRange(m)
  158. for iter.Next() {
  159. b, err = appendMapElement(b, iter.Key(), iter.Value(), wiretag, keyWiretag, valWiretag, keyFuncs, valFuncs, opts)
  160. if err != nil {
  161. return b, err
  162. }
  163. }
  164. return b, nil
  165. }
  166. func appendMapElement(b []byte, key, value reflect.Value, wiretag, keyWiretag, valWiretag uint64, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) ([]byte, error) {
  167. ki := key.Interface()
  168. vi := value.Interface()
  169. b = wire.AppendVarint(b, wiretag)
  170. size := keyFuncs.size(ki, mapKeyTagSize, opts) + valFuncs.size(vi, mapValTagSize, opts)
  171. b = wire.AppendVarint(b, uint64(size))
  172. b, err := keyFuncs.marshal(b, ki, keyWiretag, opts)
  173. if err != nil {
  174. return b, err
  175. }
  176. b, err = valFuncs.marshal(b, vi, valWiretag, opts)
  177. if err != nil {
  178. return b, err
  179. }
  180. return b, nil
  181. }
  182. func isInitMap(p pointer, goType reflect.Type, isInit func(interface{}) error) error {
  183. m := p.AsValueOf(goType).Elem()
  184. if m.Len() == 0 {
  185. return nil
  186. }
  187. iter := mapRange(m)
  188. for iter.Next() {
  189. if err := isInit(iter.Value().Interface()); err != nil {
  190. return err
  191. }
  192. }
  193. return nil
  194. }
  195. // mapKeys returns a sort.Interface to be used for sorting the map keys.
  196. // Map fields may have key types of non-float scalars, strings and enums.
  197. func mapKeys(vs []reflect.Value) sort.Interface {
  198. s := mapKeySorter{vs: vs}
  199. // Type specialization per https://developers.google.com/protocol-buffers/docs/proto#maps.
  200. if len(vs) == 0 {
  201. return s
  202. }
  203. switch vs[0].Kind() {
  204. case reflect.Int32, reflect.Int64:
  205. s.less = func(a, b reflect.Value) bool { return a.Int() < b.Int() }
  206. case reflect.Uint32, reflect.Uint64:
  207. s.less = func(a, b reflect.Value) bool { return a.Uint() < b.Uint() }
  208. case reflect.Bool:
  209. s.less = func(a, b reflect.Value) bool { return !a.Bool() && b.Bool() } // false < true
  210. case reflect.String:
  211. s.less = func(a, b reflect.Value) bool { return a.String() < b.String() }
  212. default:
  213. panic(fmt.Sprintf("unsupported map key type: %v", vs[0].Kind()))
  214. }
  215. return s
  216. }
  217. type mapKeySorter struct {
  218. vs []reflect.Value
  219. less func(a, b reflect.Value) bool
  220. }
  221. func (s mapKeySorter) Len() int { return len(s.vs) }
  222. func (s mapKeySorter) Swap(i, j int) { s.vs[i], s.vs[j] = s.vs[j], s.vs[i] }
  223. func (s mapKeySorter) Less(i, j int) bool {
  224. return s.less(s.vs[i], s.vs[j])
  225. }