codec_map.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "reflect"
  7. "google.golang.org/protobuf/internal/encoding/wire"
  8. "google.golang.org/protobuf/internal/mapsort"
  9. pref "google.golang.org/protobuf/reflect/protoreflect"
  10. )
  11. type mapInfo struct {
  12. goType reflect.Type
  13. keyWiretag uint64
  14. valWiretag uint64
  15. keyFuncs valueCoderFuncs
  16. valFuncs valueCoderFuncs
  17. keyZero pref.Value
  18. keyKind pref.Kind
  19. }
  20. func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
  21. // TODO: Consider generating specialized map coders.
  22. keyField := fd.MapKey()
  23. valField := fd.MapValue()
  24. keyWiretag := wire.EncodeTag(1, wireTypes[keyField.Kind()])
  25. valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
  26. keyFuncs := encoderFuncsForValue(keyField)
  27. valFuncs := encoderFuncsForValue(valField)
  28. conv := NewConverter(ft, fd)
  29. mapi := &mapInfo{
  30. goType: ft,
  31. keyWiretag: keyWiretag,
  32. valWiretag: valWiretag,
  33. keyFuncs: keyFuncs,
  34. valFuncs: valFuncs,
  35. keyZero: keyField.Default(),
  36. keyKind: keyField.Kind(),
  37. }
  38. funcs = pointerCoderFuncs{
  39. size: func(p pointer, tagsize int, opts marshalOptions) int {
  40. mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
  41. return sizeMap(mapv, tagsize, mapi, opts)
  42. },
  43. marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
  44. mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
  45. return appendMap(b, mapv, wiretag, mapi, opts)
  46. },
  47. unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
  48. mp := p.AsValueOf(ft)
  49. if mp.Elem().IsNil() {
  50. mp.Elem().Set(reflect.MakeMap(mapi.goType))
  51. }
  52. mapv := conv.PBValueOf(mp.Elem()).Map()
  53. return consumeMap(b, mapv, wtyp, mapi, opts)
  54. },
  55. }
  56. if valFuncs.isInit != nil {
  57. funcs.isInit = func(p pointer) error {
  58. mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
  59. return isInitMap(mapv, mapi)
  60. }
  61. }
  62. return funcs
  63. }
  64. const (
  65. mapKeyTagSize = 1 // field 1, tag size 1.
  66. mapValTagSize = 1 // field 2, tag size 2.
  67. )
  68. func sizeMap(mapv pref.Map, tagsize int, mapi *mapInfo, opts marshalOptions) int {
  69. if mapv.Len() == 0 {
  70. return 0
  71. }
  72. n := 0
  73. mapv.Range(func(key pref.MapKey, value pref.Value) bool {
  74. n += tagsize + wire.SizeBytes(
  75. mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)+
  76. mapi.valFuncs.size(value, mapValTagSize, opts))
  77. return true
  78. })
  79. return n
  80. }
  81. func consumeMap(b []byte, mapv pref.Map, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
  82. if wtyp != wire.BytesType {
  83. return 0, errUnknown
  84. }
  85. b, n := wire.ConsumeBytes(b)
  86. if n < 0 {
  87. return 0, wire.ParseError(n)
  88. }
  89. var (
  90. key = mapi.keyZero
  91. val = mapv.NewValue()
  92. )
  93. for len(b) > 0 {
  94. num, wtyp, n := wire.ConsumeTag(b)
  95. if n < 0 {
  96. return 0, wire.ParseError(n)
  97. }
  98. b = b[n:]
  99. err := errUnknown
  100. switch num {
  101. case 1:
  102. var v pref.Value
  103. v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
  104. if err != nil {
  105. break
  106. }
  107. key = v
  108. case 2:
  109. var v pref.Value
  110. v, n, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
  111. if err != nil {
  112. break
  113. }
  114. val = v
  115. }
  116. if err == errUnknown {
  117. n = wire.ConsumeFieldValue(num, wtyp, b)
  118. if n < 0 {
  119. return 0, wire.ParseError(n)
  120. }
  121. } else if err != nil {
  122. return 0, err
  123. }
  124. b = b[n:]
  125. }
  126. mapv.Set(key.MapKey(), val)
  127. return n, nil
  128. }
  129. func appendMap(b []byte, mapv pref.Map, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
  130. if mapv.Len() == 0 {
  131. return b, nil
  132. }
  133. var err error
  134. fn := func(key pref.MapKey, value pref.Value) bool {
  135. b = wire.AppendVarint(b, wiretag)
  136. size := 0
  137. size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
  138. size += mapi.valFuncs.size(value, mapValTagSize, opts)
  139. b = wire.AppendVarint(b, uint64(size))
  140. b, err = mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
  141. if err != nil {
  142. return false
  143. }
  144. b, err = mapi.valFuncs.marshal(b, value, mapi.valWiretag, opts)
  145. if err != nil {
  146. return false
  147. }
  148. return true
  149. }
  150. if opts.Deterministic() {
  151. mapsort.Range(mapv, mapi.keyKind, fn)
  152. } else {
  153. mapv.Range(fn)
  154. }
  155. return b, err
  156. }
  157. func isInitMap(mapv pref.Map, mapi *mapInfo) error {
  158. var err error
  159. mapv.Range(func(_ pref.MapKey, value pref.Value) bool {
  160. err = mapi.valFuncs.isInit(value)
  161. return err == nil
  162. })
  163. return err
  164. }