codec_messageset.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style.
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "sort"
  7. "google.golang.org/protobuf/internal/encoding/messageset"
  8. "google.golang.org/protobuf/internal/encoding/wire"
  9. "google.golang.org/protobuf/internal/errors"
  10. "google.golang.org/protobuf/internal/flags"
  11. )
  12. func makeMessageSetFieldCoder(mi *MessageInfo) pointerCoderFuncs {
  13. return pointerCoderFuncs{
  14. size: func(p pointer, tagsize int, opts marshalOptions) int {
  15. return sizeMessageSet(mi, p, tagsize, opts)
  16. },
  17. marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
  18. return marshalMessageSet(mi, b, p, wiretag, opts)
  19. },
  20. unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
  21. return unmarshalMessageSet(mi, b, p, wtyp, opts)
  22. },
  23. }
  24. }
  25. func sizeMessageSet(mi *MessageInfo, p pointer, tagsize int, opts marshalOptions) (n int) {
  26. ext := *p.Extensions()
  27. if ext == nil {
  28. return 0
  29. }
  30. for _, x := range ext {
  31. xi := mi.extensionFieldInfo(x.GetType())
  32. if xi.funcs.size == nil {
  33. continue
  34. }
  35. num, _ := wire.DecodeTag(xi.wiretag)
  36. n += messageset.SizeField(num)
  37. n += xi.funcs.size(x.Value(), wire.SizeTag(messageset.FieldMessage), opts)
  38. }
  39. return n
  40. }
  41. func marshalMessageSet(mi *MessageInfo, b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
  42. if !flags.ProtoLegacy {
  43. return b, errors.New("no support for message_set_wire_format")
  44. }
  45. ext := *p.Extensions()
  46. if ext == nil {
  47. return b, nil
  48. }
  49. switch len(ext) {
  50. case 0:
  51. return b, nil
  52. case 1:
  53. // Fast-path for one extension: Don't bother sorting the keys.
  54. for _, x := range ext {
  55. var err error
  56. b, err = marshalMessageSetField(mi, b, x, opts)
  57. if err != nil {
  58. return b, err
  59. }
  60. }
  61. return b, nil
  62. default:
  63. // Sort the keys to provide a deterministic encoding.
  64. // Not sure this is required, but the old code does it.
  65. keys := make([]int, 0, len(ext))
  66. for k := range ext {
  67. keys = append(keys, int(k))
  68. }
  69. sort.Ints(keys)
  70. for _, k := range keys {
  71. var err error
  72. b, err = marshalMessageSetField(mi, b, ext[int32(k)], opts)
  73. if err != nil {
  74. return b, err
  75. }
  76. }
  77. return b, nil
  78. }
  79. }
  80. func marshalMessageSetField(mi *MessageInfo, b []byte, x ExtensionField, opts marshalOptions) ([]byte, error) {
  81. xi := mi.extensionFieldInfo(x.GetType())
  82. num, _ := wire.DecodeTag(xi.wiretag)
  83. b = messageset.AppendFieldStart(b, num)
  84. b, err := xi.funcs.marshal(b, x.Value(), wire.EncodeTag(messageset.FieldMessage, wire.BytesType), opts)
  85. if err != nil {
  86. return b, err
  87. }
  88. b = messageset.AppendFieldEnd(b)
  89. return b, nil
  90. }
  91. func unmarshalMessageSet(mi *MessageInfo, b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
  92. if !flags.ProtoLegacy {
  93. return 0, errors.New("no support for message_set_wire_format")
  94. }
  95. if wtyp != wire.StartGroupType {
  96. return 0, errUnknown
  97. }
  98. ep := p.Extensions()
  99. if *ep == nil {
  100. *ep = make(map[int32]ExtensionField)
  101. }
  102. ext := *ep
  103. num, v, n, err := messageset.ConsumeFieldValue(b, true)
  104. if err != nil {
  105. return 0, err
  106. }
  107. if _, err := mi.unmarshalExtension(v, num, wire.BytesType, ext, opts); err != nil {
  108. return 0, err
  109. }
  110. return n, nil
  111. }