decode.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. // Copyright 2018 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style.
  3. // license that can be found in the LICENSE file.
  4. package proto
  5. import (
  6. "google.golang.org/protobuf/internal/encoding/messageset"
  7. "google.golang.org/protobuf/internal/encoding/wire"
  8. "google.golang.org/protobuf/internal/errors"
  9. "google.golang.org/protobuf/internal/flags"
  10. "google.golang.org/protobuf/internal/pragma"
  11. "google.golang.org/protobuf/reflect/protoreflect"
  12. "google.golang.org/protobuf/reflect/protoregistry"
  13. "google.golang.org/protobuf/runtime/protoiface"
  14. )
  15. // UnmarshalOptions configures the unmarshaler.
  16. //
  17. // Example usage:
  18. // err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
  19. type UnmarshalOptions struct {
  20. pragma.NoUnkeyedLiterals
  21. // Merge merges the input into the destination message.
  22. // The default behavior is to always reset the message before unmarshaling,
  23. // unless Merge is specified.
  24. Merge bool
  25. // AllowPartial accepts input for messages that will result in missing
  26. // required fields. If AllowPartial is false (the default), Unmarshal will
  27. // return an error if there are any missing required fields.
  28. AllowPartial bool
  29. // If DiscardUnknown is set, unknown fields are ignored.
  30. DiscardUnknown bool
  31. // Resolver is used for looking up types when unmarshaling extension fields.
  32. // If nil, this defaults to using protoregistry.GlobalTypes.
  33. Resolver interface {
  34. protoregistry.ExtensionTypeResolver
  35. }
  36. }
  37. var _ = protoiface.UnmarshalOptions(UnmarshalOptions{})
  38. // Unmarshal parses the wire-format message in b and places the result in m.
  39. func Unmarshal(b []byte, m Message) error {
  40. return UnmarshalOptions{}.Unmarshal(b, m)
  41. }
  42. // Unmarshal parses the wire-format message in b and places the result in m.
  43. func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
  44. if o.Resolver == nil {
  45. o.Resolver = protoregistry.GlobalTypes
  46. }
  47. if !o.Merge {
  48. Reset(m)
  49. }
  50. err := o.unmarshalMessage(b, m.ProtoReflect())
  51. if err != nil {
  52. return err
  53. }
  54. if o.AllowPartial {
  55. return nil
  56. }
  57. return IsInitialized(m)
  58. }
  59. func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
  60. if methods := protoMethods(m); methods != nil && methods.Unmarshal != nil &&
  61. !(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
  62. return methods.Unmarshal(b, m, protoiface.UnmarshalOptions(o))
  63. }
  64. return o.unmarshalMessageSlow(b, m)
  65. }
  66. func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
  67. md := m.Descriptor()
  68. if messageset.IsMessageSet(md) {
  69. return unmarshalMessageSet(b, m, o)
  70. }
  71. fields := md.Fields()
  72. for len(b) > 0 {
  73. // Parse the tag (field number and wire type).
  74. num, wtyp, tagLen := wire.ConsumeTag(b)
  75. if tagLen < 0 {
  76. return wire.ParseError(tagLen)
  77. }
  78. // Find the field descriptor for this field number.
  79. fd := fields.ByNumber(num)
  80. if fd == nil && md.ExtensionRanges().Has(num) {
  81. extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
  82. if err != nil && err != protoregistry.NotFound {
  83. return err
  84. }
  85. if extType != nil {
  86. fd = extType.TypeDescriptor()
  87. }
  88. }
  89. var err error
  90. if fd == nil {
  91. err = errUnknown
  92. } else if flags.ProtoLegacy {
  93. if fd.IsWeak() && fd.Message().IsPlaceholder() {
  94. err = errUnknown // weak referent is not linked in
  95. }
  96. }
  97. // Parse the field value.
  98. var valLen int
  99. switch {
  100. case err != nil:
  101. case fd.IsList():
  102. valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
  103. case fd.IsMap():
  104. valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd)
  105. default:
  106. valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
  107. }
  108. if err != nil {
  109. if err != errUnknown {
  110. return err
  111. }
  112. valLen = wire.ConsumeFieldValue(num, wtyp, b[tagLen:])
  113. if valLen < 0 {
  114. return wire.ParseError(valLen)
  115. }
  116. m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
  117. }
  118. b = b[tagLen+valLen:]
  119. }
  120. return nil
  121. }
  122. func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp wire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) {
  123. v, n, err := o.unmarshalScalar(b, wtyp, fd)
  124. if err != nil {
  125. return 0, err
  126. }
  127. switch fd.Kind() {
  128. case protoreflect.GroupKind, protoreflect.MessageKind:
  129. m2 := m.Mutable(fd).Message()
  130. if err := o.unmarshalMessage(v.Bytes(), m2); err != nil {
  131. return n, err
  132. }
  133. default:
  134. // Non-message scalars replace the previous value.
  135. m.Set(fd, v)
  136. }
  137. return n, nil
  138. }
  139. func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) {
  140. if wtyp != wire.BytesType {
  141. return 0, errUnknown
  142. }
  143. b, n = wire.ConsumeBytes(b)
  144. if n < 0 {
  145. return 0, wire.ParseError(n)
  146. }
  147. var (
  148. keyField = fd.MapKey()
  149. valField = fd.MapValue()
  150. key protoreflect.Value
  151. val protoreflect.Value
  152. haveKey bool
  153. haveVal bool
  154. )
  155. switch valField.Kind() {
  156. case protoreflect.GroupKind, protoreflect.MessageKind:
  157. val = mapv.NewValue()
  158. }
  159. // Map entries are represented as a two-element message with fields
  160. // containing the key and value.
  161. for len(b) > 0 {
  162. num, wtyp, n := wire.ConsumeTag(b)
  163. if n < 0 {
  164. return 0, wire.ParseError(n)
  165. }
  166. b = b[n:]
  167. err = errUnknown
  168. switch num {
  169. case 1:
  170. key, n, err = o.unmarshalScalar(b, wtyp, keyField)
  171. if err != nil {
  172. break
  173. }
  174. haveKey = true
  175. case 2:
  176. var v protoreflect.Value
  177. v, n, err = o.unmarshalScalar(b, wtyp, valField)
  178. if err != nil {
  179. break
  180. }
  181. switch valField.Kind() {
  182. case protoreflect.GroupKind, protoreflect.MessageKind:
  183. if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
  184. return 0, err
  185. }
  186. default:
  187. val = v
  188. }
  189. haveVal = true
  190. }
  191. if err == errUnknown {
  192. n = wire.ConsumeFieldValue(num, wtyp, b)
  193. if n < 0 {
  194. return 0, wire.ParseError(n)
  195. }
  196. } else if err != nil {
  197. return 0, err
  198. }
  199. b = b[n:]
  200. }
  201. // Every map entry should have entries for key and value, but this is not strictly required.
  202. if !haveKey {
  203. key = keyField.Default()
  204. }
  205. if !haveVal {
  206. switch valField.Kind() {
  207. case protoreflect.GroupKind, protoreflect.MessageKind:
  208. default:
  209. val = valField.Default()
  210. }
  211. }
  212. mapv.Set(key.MapKey(), val)
  213. return n, nil
  214. }
  215. // errUnknown is used internally to indicate fields which should be added
  216. // to the unknown field set of a message. It is never returned from an exported
  217. // function.
  218. var errUnknown = errors.New("BUG: internal error (unknown)")