codec_map.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "fmt"
  7. "reflect"
  8. "sort"
  9. "google.golang.org/protobuf/internal/encoding/wire"
  10. pref "google.golang.org/protobuf/reflect/protoreflect"
  11. )
  12. type mapInfo struct {
  13. goType reflect.Type
  14. keyWiretag uint64
  15. valWiretag uint64
  16. keyFuncs ifaceCoderFuncs
  17. valFuncs ifaceCoderFuncs
  18. keyZero interface{}
  19. valZero interface{}
  20. newVal func() interface{}
  21. }
  22. func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
  23. // TODO: Consider generating specialized map coders.
  24. keyField := fd.MapKey()
  25. valField := fd.MapValue()
  26. keyWiretag := wire.EncodeTag(1, wireTypes[keyField.Kind()])
  27. valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
  28. keyFuncs := encoderFuncsForValue(keyField, ft.Key())
  29. valFuncs := encoderFuncsForValue(valField, ft.Elem())
  30. mapi := &mapInfo{
  31. goType: ft,
  32. keyWiretag: keyWiretag,
  33. valWiretag: valWiretag,
  34. keyFuncs: keyFuncs,
  35. valFuncs: valFuncs,
  36. keyZero: reflect.Zero(ft.Key()).Interface(),
  37. valZero: reflect.Zero(ft.Elem()).Interface(),
  38. }
  39. switch valField.Kind() {
  40. case pref.GroupKind, pref.MessageKind:
  41. mapi.newVal = func() interface{} {
  42. return reflect.New(ft.Elem().Elem()).Interface()
  43. }
  44. }
  45. funcs = pointerCoderFuncs{
  46. size: func(p pointer, tagsize int, opts marshalOptions) int {
  47. return sizeMap(p, tagsize, ft, keyFuncs, valFuncs, opts)
  48. },
  49. marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
  50. return appendMap(b, p, wiretag, keyWiretag, valWiretag, ft, keyFuncs, valFuncs, opts)
  51. },
  52. unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
  53. return consumeMap(b, p, wtyp, mapi, opts)
  54. },
  55. }
  56. if valFuncs.isInit != nil {
  57. funcs.isInit = func(p pointer) error {
  58. return isInitMap(p, ft, valFuncs.isInit)
  59. }
  60. }
  61. return funcs
  62. }
  63. const (
  64. mapKeyTagSize = 1 // field 1, tag size 1.
  65. mapValTagSize = 1 // field 2, tag size 2.
  66. )
  67. func consumeMap(b []byte, p pointer, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
  68. mp := p.AsValueOf(mapi.goType)
  69. if mp.Elem().IsNil() {
  70. mp.Elem().Set(reflect.MakeMap(mapi.goType))
  71. }
  72. m := mp.Elem()
  73. if wtyp != wire.BytesType {
  74. return 0, errUnknown
  75. }
  76. b, n := wire.ConsumeBytes(b)
  77. if n < 0 {
  78. return 0, wire.ParseError(n)
  79. }
  80. var (
  81. key = mapi.keyZero
  82. val = mapi.valZero
  83. )
  84. if mapi.newVal != nil {
  85. val = mapi.newVal()
  86. }
  87. for len(b) > 0 {
  88. num, wtyp, n := wire.ConsumeTag(b)
  89. if n < 0 {
  90. return 0, wire.ParseError(n)
  91. }
  92. b = b[n:]
  93. err := errUnknown
  94. switch num {
  95. case 1:
  96. var v interface{}
  97. v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
  98. if err != nil {
  99. break
  100. }
  101. key = v
  102. case 2:
  103. var v interface{}
  104. v, n, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
  105. if err != nil {
  106. break
  107. }
  108. val = v
  109. }
  110. if err == errUnknown {
  111. n = wire.ConsumeFieldValue(num, wtyp, b)
  112. if n < 0 {
  113. return 0, wire.ParseError(n)
  114. }
  115. } else if err != nil {
  116. return 0, err
  117. }
  118. b = b[n:]
  119. }
  120. m.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(val))
  121. return n, nil
  122. }
  123. func sizeMap(p pointer, tagsize int, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) int {
  124. m := p.AsValueOf(goType).Elem()
  125. n := 0
  126. if m.Len() == 0 {
  127. return 0
  128. }
  129. iter := mapRange(m)
  130. for iter.Next() {
  131. ki := iter.Key().Interface()
  132. vi := iter.Value().Interface()
  133. size := keyFuncs.size(ki, mapKeyTagSize, opts) + valFuncs.size(vi, mapValTagSize, opts)
  134. n += wire.SizeBytes(size) + tagsize
  135. }
  136. return n
  137. }
  138. func appendMap(b []byte, p pointer, wiretag, keyWiretag, valWiretag uint64, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) ([]byte, error) {
  139. m := p.AsValueOf(goType).Elem()
  140. var err error
  141. if m.Len() == 0 {
  142. return b, nil
  143. }
  144. if opts.Deterministic() {
  145. keys := m.MapKeys()
  146. sort.Sort(mapKeys(keys))
  147. for _, k := range keys {
  148. b, err = appendMapElement(b, k, m.MapIndex(k), wiretag, keyWiretag, valWiretag, keyFuncs, valFuncs, opts)
  149. if err != nil {
  150. return b, err
  151. }
  152. }
  153. return b, nil
  154. }
  155. iter := mapRange(m)
  156. for iter.Next() {
  157. b, err = appendMapElement(b, iter.Key(), iter.Value(), wiretag, keyWiretag, valWiretag, keyFuncs, valFuncs, opts)
  158. if err != nil {
  159. return b, err
  160. }
  161. }
  162. return b, nil
  163. }
  164. func appendMapElement(b []byte, key, value reflect.Value, wiretag, keyWiretag, valWiretag uint64, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) ([]byte, error) {
  165. ki := key.Interface()
  166. vi := value.Interface()
  167. b = wire.AppendVarint(b, wiretag)
  168. size := keyFuncs.size(ki, mapKeyTagSize, opts) + valFuncs.size(vi, mapValTagSize, opts)
  169. b = wire.AppendVarint(b, uint64(size))
  170. b, err := keyFuncs.marshal(b, ki, keyWiretag, opts)
  171. if err != nil {
  172. return b, err
  173. }
  174. b, err = valFuncs.marshal(b, vi, valWiretag, opts)
  175. if err != nil {
  176. return b, err
  177. }
  178. return b, nil
  179. }
  180. func isInitMap(p pointer, goType reflect.Type, isInit func(interface{}) error) error {
  181. m := p.AsValueOf(goType).Elem()
  182. if m.Len() == 0 {
  183. return nil
  184. }
  185. iter := mapRange(m)
  186. for iter.Next() {
  187. if err := isInit(iter.Value().Interface()); err != nil {
  188. return err
  189. }
  190. }
  191. return nil
  192. }
  193. // mapKeys returns a sort.Interface to be used for sorting the map keys.
  194. // Map fields may have key types of non-float scalars, strings and enums.
  195. func mapKeys(vs []reflect.Value) sort.Interface {
  196. s := mapKeySorter{vs: vs}
  197. // Type specialization per https://developers.google.com/protocol-buffers/docs/proto#maps.
  198. if len(vs) == 0 {
  199. return s
  200. }
  201. switch vs[0].Kind() {
  202. case reflect.Int32, reflect.Int64:
  203. s.less = func(a, b reflect.Value) bool { return a.Int() < b.Int() }
  204. case reflect.Uint32, reflect.Uint64:
  205. s.less = func(a, b reflect.Value) bool { return a.Uint() < b.Uint() }
  206. case reflect.Bool:
  207. s.less = func(a, b reflect.Value) bool { return !a.Bool() && b.Bool() } // false < true
  208. case reflect.String:
  209. s.less = func(a, b reflect.Value) bool { return a.String() < b.String() }
  210. default:
  211. panic(fmt.Sprintf("unsupported map key type: %v", vs[0].Kind()))
  212. }
  213. return s
  214. }
  215. type mapKeySorter struct {
  216. vs []reflect.Value
  217. less func(a, b reflect.Value) bool
  218. }
  219. func (s mapKeySorter) Len() int { return len(s.vs) }
  220. func (s mapKeySorter) Swap(i, j int) { s.vs[i], s.vs[j] = s.vs[j], s.vs[i] }
  221. func (s mapKeySorter) Less(i, j int) bool {
  222. return s.less(s.vs[i], s.vs[j])
  223. }