encode_map.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "fmt"
  7. "reflect"
  8. "sort"
  9. "google.golang.org/protobuf/internal/encoding/wire"
  10. "google.golang.org/protobuf/proto"
  11. pref "google.golang.org/protobuf/reflect/protoreflect"
  12. )
  13. var protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
  14. func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
  15. // TODO: Consider generating specialized map coders.
  16. keyField := fd.MapKey()
  17. valField := fd.MapValue()
  18. keyWiretag := wire.EncodeTag(1, wireTypes[keyField.Kind()])
  19. valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
  20. keyFuncs := encoderFuncsForValue(keyField, ft.Key())
  21. valFuncs := encoderFuncsForValue(valField, ft.Elem())
  22. funcs = pointerCoderFuncs{
  23. size: func(p pointer, tagsize int, opts marshalOptions) int {
  24. return sizeMap(p, tagsize, ft, keyFuncs, valFuncs, opts)
  25. },
  26. marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
  27. return appendMap(b, p, wiretag, keyWiretag, valWiretag, ft, keyFuncs, valFuncs, opts)
  28. },
  29. }
  30. if valFuncs.isInit != nil {
  31. funcs.isInit = func(p pointer) error {
  32. return isInitMap(p, ft, valFuncs.isInit)
  33. }
  34. }
  35. return funcs
  36. }
  37. const (
  38. mapKeyTagSize = 1 // field 1, tag size 1.
  39. mapValTagSize = 1 // field 2, tag size 2.
  40. )
  41. func sizeMap(p pointer, tagsize int, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) int {
  42. m := p.AsValueOf(goType).Elem()
  43. n := 0
  44. if m.Len() == 0 {
  45. return 0
  46. }
  47. iter := mapRange(m)
  48. for iter.Next() {
  49. ki := iter.Key().Interface()
  50. vi := iter.Value().Interface()
  51. size := keyFuncs.size(ki, mapKeyTagSize, opts) + valFuncs.size(vi, mapValTagSize, opts)
  52. n += wire.SizeBytes(size) + tagsize
  53. }
  54. return n
  55. }
  56. func appendMap(b []byte, p pointer, wiretag, keyWiretag, valWiretag uint64, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) ([]byte, error) {
  57. m := p.AsValueOf(goType).Elem()
  58. var err error
  59. if m.Len() == 0 {
  60. return b, nil
  61. }
  62. if opts.Deterministic() {
  63. keys := m.MapKeys()
  64. sort.Sort(mapKeys(keys))
  65. for _, k := range keys {
  66. b, err = appendMapElement(b, k, m.MapIndex(k), wiretag, keyWiretag, valWiretag, keyFuncs, valFuncs, opts)
  67. if err != nil {
  68. return b, err
  69. }
  70. }
  71. return b, nil
  72. }
  73. iter := mapRange(m)
  74. for iter.Next() {
  75. b, err = appendMapElement(b, iter.Key(), iter.Value(), wiretag, keyWiretag, valWiretag, keyFuncs, valFuncs, opts)
  76. if err != nil {
  77. return b, err
  78. }
  79. }
  80. return b, nil
  81. }
  82. func appendMapElement(b []byte, key, value reflect.Value, wiretag, keyWiretag, valWiretag uint64, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) ([]byte, error) {
  83. ki := key.Interface()
  84. vi := value.Interface()
  85. b = wire.AppendVarint(b, wiretag)
  86. size := keyFuncs.size(ki, mapKeyTagSize, opts) + valFuncs.size(vi, mapValTagSize, opts)
  87. b = wire.AppendVarint(b, uint64(size))
  88. b, err := keyFuncs.marshal(b, ki, keyWiretag, opts)
  89. if err != nil {
  90. return b, err
  91. }
  92. b, err = valFuncs.marshal(b, vi, valWiretag, opts)
  93. if err != nil {
  94. return b, err
  95. }
  96. return b, nil
  97. }
  98. func isInitMap(p pointer, goType reflect.Type, isInit func(interface{}) error) error {
  99. m := p.AsValueOf(goType).Elem()
  100. if m.Len() == 0 {
  101. return nil
  102. }
  103. iter := mapRange(m)
  104. for iter.Next() {
  105. if err := isInit(iter.Value().Interface()); err != nil {
  106. return err
  107. }
  108. }
  109. return nil
  110. }
  111. // mapKeys returns a sort.Interface to be used for sorting the map keys.
  112. // Map fields may have key types of non-float scalars, strings and enums.
  113. func mapKeys(vs []reflect.Value) sort.Interface {
  114. s := mapKeySorter{vs: vs}
  115. // Type specialization per https://developers.google.com/protocol-buffers/docs/proto#maps.
  116. if len(vs) == 0 {
  117. return s
  118. }
  119. switch vs[0].Kind() {
  120. case reflect.Int32, reflect.Int64:
  121. s.less = func(a, b reflect.Value) bool { return a.Int() < b.Int() }
  122. case reflect.Uint32, reflect.Uint64:
  123. s.less = func(a, b reflect.Value) bool { return a.Uint() < b.Uint() }
  124. case reflect.Bool:
  125. s.less = func(a, b reflect.Value) bool { return !a.Bool() && b.Bool() } // false < true
  126. case reflect.String:
  127. s.less = func(a, b reflect.Value) bool { return a.String() < b.String() }
  128. default:
  129. panic(fmt.Sprintf("unsupported map key type: %v", vs[0].Kind()))
  130. }
  131. return s
  132. }
  133. type mapKeySorter struct {
  134. vs []reflect.Value
  135. less func(a, b reflect.Value) bool
  136. }
  137. func (s mapKeySorter) Len() int { return len(s.vs) }
  138. func (s mapKeySorter) Swap(i, j int) { s.vs[i], s.vs[j] = s.vs[j], s.vs[i] }
  139. func (s mapKeySorter) Less(i, j int) bool {
  140. return s.less(s.vs[i], s.vs[j])
  141. }