decode.go 18 KB


  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package protojson
  5. import (
  6. "encoding/base64"
  7. "fmt"
  8. "math"
  9. "strconv"
  10. "strings"
  11. "google.golang.org/protobuf/internal/encoding/json"
  12. "google.golang.org/protobuf/internal/encoding/messageset"
  13. "google.golang.org/protobuf/internal/errors"
  14. "google.golang.org/protobuf/internal/flags"
  15. "google.golang.org/protobuf/internal/pragma"
  16. "google.golang.org/protobuf/internal/set"
  17. "google.golang.org/protobuf/proto"
  18. pref "google.golang.org/protobuf/reflect/protoreflect"
  19. "google.golang.org/protobuf/reflect/protoregistry"
  20. )
  21. // Unmarshal reads the given []byte into the given proto.Message.
  22. func Unmarshal(b []byte, m proto.Message) error {
  23. return UnmarshalOptions{}.Unmarshal(b, m)
  24. }
  25. // UnmarshalOptions is a configurable JSON format parser.
  26. type UnmarshalOptions struct {
  27. pragma.NoUnkeyedLiterals
  28. // If AllowPartial is set, input for messages that will result in missing
  29. // required fields will not return an error.
  30. AllowPartial bool
  31. // If DiscardUnknown is set, unknown fields are ignored.
  32. DiscardUnknown bool
  33. // Resolver is used for looking up types when unmarshaling
  34. // google.protobuf.Any messages or extension fields.
  35. // If nil, this defaults to using protoregistry.GlobalTypes.
  36. Resolver interface {
  37. protoregistry.MessageTypeResolver
  38. protoregistry.ExtensionTypeResolver
  39. }
  40. decoder *json.Decoder
  41. }
  42. // Unmarshal reads the given []byte and populates the given proto.Message using
  43. // options in UnmarshalOptions object. It will clear the message first before
  44. // setting the fields. If it returns an error, the given message may be
  45. // partially set.
  46. func (o UnmarshalOptions) Unmarshal(b []byte, m proto.Message) error {
  47. proto.Reset(m)
  48. if o.Resolver == nil {
  49. o.Resolver = protoregistry.GlobalTypes
  50. }
  51. o.decoder = json.NewDecoder(b)
  52. if err := o.unmarshalMessage(m.ProtoReflect(), false); err != nil {
  53. return err
  54. }
  55. // Check for EOF.
  56. val, err := o.decoder.Read()
  57. if err != nil {
  58. return err
  59. }
  60. if val.Type() != json.EOF {
  61. return unexpectedJSONError{val}
  62. }
  63. if o.AllowPartial {
  64. return nil
  65. }
  66. return proto.IsInitialized(m)
  67. }
  68. // unexpectedJSONError is an error that contains the unexpected json.Value. This
  69. // is returned by methods to provide callers the read json.Value that it did not
  70. // expect.
  71. // TODO: Consider moving this to internal/encoding/json for consistency with
  72. // errors that package returns.
  73. type unexpectedJSONError struct {
  74. value json.Value
  75. }
  76. func (e unexpectedJSONError) Error() string {
  77. return newError("unexpected value %s", e.value).Error()
  78. }
  79. // newError returns an error object. If one of the values passed in is of
  80. // json.Value type, it produces an error with position info.
  81. func newError(f string, x ...interface{}) error {
  82. var hasValue bool
  83. var line, column int
  84. for i := 0; i < len(x); i++ {
  85. if val, ok := x[i].(json.Value); ok {
  86. line, column = val.Position()
  87. hasValue = true
  88. break
  89. }
  90. }
  91. e := errors.New(f, x...)
  92. if hasValue {
  93. return errors.New("(line %d:%d): %v", line, column, e)
  94. }
  95. return e
  96. }
  97. // unmarshalMessage unmarshals a message into the given protoreflect.Message.
  98. func (o UnmarshalOptions) unmarshalMessage(m pref.Message, skipTypeURL bool) error {
  99. if isCustomType(m.Descriptor().FullName()) {
  100. return o.unmarshalCustomType(m)
  101. }
  102. jval, err := o.decoder.Read()
  103. if err != nil {
  104. return err
  105. }
  106. if jval.Type() != json.StartObject {
  107. return unexpectedJSONError{jval}
  108. }
  109. if err := o.unmarshalFields(m, skipTypeURL); err != nil {
  110. return err
  111. }
  112. return nil
  113. }
  114. // unmarshalFields unmarshals the fields into the given protoreflect.Message.
  115. func (o UnmarshalOptions) unmarshalFields(m pref.Message, skipTypeURL bool) error {
  116. messageDesc := m.Descriptor()
  117. if !flags.ProtoLegacy && messageset.IsMessageSet(messageDesc) {
  118. return errors.New("no support for proto1 MessageSets")
  119. }
  120. var seenNums set.Ints
  121. var seenOneofs set.Ints
  122. fieldDescs := messageDesc.Fields()
  123. for {
  124. // Read field name.
  125. jval, err := o.decoder.Read()
  126. if err != nil {
  127. return err
  128. }
  129. switch jval.Type() {
  130. default:
  131. return unexpectedJSONError{jval}
  132. case json.EndObject:
  133. return nil
  134. case json.Name:
  135. // Continue below.
  136. }
  137. name, err := jval.Name()
  138. if err != nil {
  139. return err
  140. }
  141. // Unmarshaling a non-custom embedded message in Any will contain the
  142. // JSON field "@type" which should be skipped because it is not a field
  143. // of the embedded message, but simply an artifact of the Any format.
  144. if skipTypeURL && name == "@type" {
  145. o.decoder.Read()
  146. continue
  147. }
  148. // Get the FieldDescriptor.
  149. var fd pref.FieldDescriptor
  150. if strings.HasPrefix(name, "[") && strings.HasSuffix(name, "]") {
  151. // Only extension names are in [name] format.
  152. extName := pref.FullName(name[1 : len(name)-1])
  153. extType, err := o.findExtension(extName)
  154. if err != nil && err != protoregistry.NotFound {
  155. return errors.New("unable to resolve [%v]: %v", extName, err)
  156. }
  157. if extType != nil {
  158. fd = extType.TypeDescriptor()
  159. if !messageDesc.ExtensionRanges().Has(fd.Number()) || fd.ContainingMessage().FullName() != messageDesc.FullName() {
  160. return errors.New("message %v cannot be extended by %v", messageDesc.FullName(), fd.FullName())
  161. }
  162. }
  163. } else {
  164. // The name can either be the JSON name or the proto field name.
  165. fd = fieldDescs.ByJSONName(name)
  166. if fd == nil {
  167. fd = fieldDescs.ByName(pref.Name(name))
  168. if fd == nil {
  169. // The proto name of a group field is in all lowercase,
  170. // while the textual field name is the group message name.
  171. gd := fieldDescs.ByName(pref.Name(strings.ToLower(name)))
  172. if gd != nil && gd.Kind() == pref.GroupKind && gd.Message().Name() == pref.Name(name) {
  173. fd = gd
  174. }
  175. } else if fd.Kind() == pref.GroupKind && fd.Message().Name() != pref.Name(name) {
  176. fd = nil // reset since field name is actually the message name
  177. }
  178. }
  179. if fd != nil && fd.IsWeak() && fd.Message().IsPlaceholder() {
  180. fd = nil // reset since the weak reference is not linked in
  181. }
  182. }
  183. if fd == nil {
  184. // Field is unknown.
  185. if o.DiscardUnknown {
  186. if err := skipJSONValue(o.decoder); err != nil {
  187. return err
  188. }
  189. continue
  190. }
  191. return newError("%v contains unknown field %s", messageDesc.FullName(), jval)
  192. }
  193. // Do not allow duplicate fields.
  194. num := uint64(fd.Number())
  195. if seenNums.Has(num) {
  196. return newError("%v contains repeated field %s", messageDesc.FullName(), jval)
  197. }
  198. seenNums.Set(num)
  199. // No need to set values for JSON null unless the field type is
  200. // google.protobuf.Value or google.protobuf.NullValue.
  201. if o.decoder.Peek() == json.Null && !isKnownValue(fd) && !isNullValue(fd) {
  202. o.decoder.Read()
  203. continue
  204. }
  205. switch {
  206. case fd.IsList():
  207. list := m.Mutable(fd).List()
  208. if err := o.unmarshalList(list, fd); err != nil {
  209. return errors.New("%v|%q: %v", fd.FullName(), name, err)
  210. }
  211. case fd.IsMap():
  212. mmap := m.Mutable(fd).Map()
  213. if err := o.unmarshalMap(mmap, fd); err != nil {
  214. return errors.New("%v|%q: %v", fd.FullName(), name, err)
  215. }
  216. default:
  217. // If field is a oneof, check if it has already been set.
  218. if od := fd.ContainingOneof(); od != nil {
  219. idx := uint64(od.Index())
  220. if seenOneofs.Has(idx) {
  221. return errors.New("%v: oneof is already set", od.FullName())
  222. }
  223. seenOneofs.Set(idx)
  224. }
  225. // Required or optional fields.
  226. if err := o.unmarshalSingular(m, fd); err != nil {
  227. return errors.New("%v|%q: %v", fd.FullName(), name, err)
  228. }
  229. }
  230. }
  231. }
  232. // findExtension returns protoreflect.ExtensionType from the resolver if found.
  233. func (o UnmarshalOptions) findExtension(xtName pref.FullName) (pref.ExtensionType, error) {
  234. xt, err := o.Resolver.FindExtensionByName(xtName)
  235. if err == nil {
  236. return xt, nil
  237. }
  238. return messageset.FindMessageSetExtension(o.Resolver, xtName)
  239. }
  240. func isKnownValue(fd pref.FieldDescriptor) bool {
  241. md := fd.Message()
  242. return md != nil && md.FullName() == "google.protobuf.Value"
  243. }
  244. func isNullValue(fd pref.FieldDescriptor) bool {
  245. ed := fd.Enum()
  246. return ed != nil && ed.FullName() == "google.protobuf.NullValue"
  247. }
  248. // unmarshalSingular unmarshals to the non-repeated field specified by the given
  249. // FieldDescriptor.
  250. func (o UnmarshalOptions) unmarshalSingular(m pref.Message, fd pref.FieldDescriptor) error {
  251. var val pref.Value
  252. var err error
  253. switch fd.Kind() {
  254. case pref.MessageKind, pref.GroupKind:
  255. val = m.NewField(fd)
  256. err = o.unmarshalMessage(val.Message(), false)
  257. default:
  258. val, err = o.unmarshalScalar(fd)
  259. }
  260. if err != nil {
  261. return err
  262. }
  263. m.Set(fd, val)
  264. return nil
  265. }
  266. // unmarshalScalar unmarshals to a scalar/enum protoreflect.Value specified by
  267. // the given FieldDescriptor.
  268. func (o UnmarshalOptions) unmarshalScalar(fd pref.FieldDescriptor) (pref.Value, error) {
  269. const b32 int = 32
  270. const b64 int = 64
  271. jval, err := o.decoder.Read()
  272. if err != nil {
  273. return pref.Value{}, err
  274. }
  275. kind := fd.Kind()
  276. switch kind {
  277. case pref.BoolKind:
  278. return unmarshalBool(jval)
  279. case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
  280. return unmarshalInt(jval, b32)
  281. case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
  282. return unmarshalInt(jval, b64)
  283. case pref.Uint32Kind, pref.Fixed32Kind:
  284. return unmarshalUint(jval, b32)
  285. case pref.Uint64Kind, pref.Fixed64Kind:
  286. return unmarshalUint(jval, b64)
  287. case pref.FloatKind:
  288. return unmarshalFloat(jval, b32)
  289. case pref.DoubleKind:
  290. return unmarshalFloat(jval, b64)
  291. case pref.StringKind:
  292. pval, err := unmarshalString(jval)
  293. if err != nil {
  294. return pval, err
  295. }
  296. return pval, nil
  297. case pref.BytesKind:
  298. return unmarshalBytes(jval)
  299. case pref.EnumKind:
  300. return unmarshalEnum(jval, fd)
  301. }
  302. panic(fmt.Sprintf("invalid scalar kind %v", kind))
  303. }
  304. func unmarshalBool(jval json.Value) (pref.Value, error) {
  305. if jval.Type() != json.Bool {
  306. return pref.Value{}, unexpectedJSONError{jval}
  307. }
  308. b, err := jval.Bool()
  309. return pref.ValueOfBool(b), err
  310. }
  311. func unmarshalInt(jval json.Value, bitSize int) (pref.Value, error) {
  312. switch jval.Type() {
  313. case json.Number:
  314. return getInt(jval, bitSize)
  315. case json.String:
  316. // Decode number from string.
  317. s := strings.TrimSpace(jval.String())
  318. if len(s) != len(jval.String()) {
  319. return pref.Value{}, errors.New("invalid number %v", jval.Raw())
  320. }
  321. dec := json.NewDecoder([]byte(s))
  322. jval, err := dec.Read()
  323. if err != nil {
  324. return pref.Value{}, err
  325. }
  326. return getInt(jval, bitSize)
  327. }
  328. return pref.Value{}, unexpectedJSONError{jval}
  329. }
  330. func getInt(jval json.Value, bitSize int) (pref.Value, error) {
  331. n, err := jval.Int(bitSize)
  332. if err != nil {
  333. return pref.Value{}, err
  334. }
  335. if bitSize == 32 {
  336. return pref.ValueOfInt32(int32(n)), nil
  337. }
  338. return pref.ValueOfInt64(n), nil
  339. }
  340. func unmarshalUint(jval json.Value, bitSize int) (pref.Value, error) {
  341. switch jval.Type() {
  342. case json.Number:
  343. return getUint(jval, bitSize)
  344. case json.String:
  345. // Decode number from string.
  346. s := strings.TrimSpace(jval.String())
  347. if len(s) != len(jval.String()) {
  348. return pref.Value{}, errors.New("invalid number %v", jval.Raw())
  349. }
  350. dec := json.NewDecoder([]byte(s))
  351. jval, err := dec.Read()
  352. if err != nil {
  353. return pref.Value{}, err
  354. }
  355. return getUint(jval, bitSize)
  356. }
  357. return pref.Value{}, unexpectedJSONError{jval}
  358. }
  359. func getUint(jval json.Value, bitSize int) (pref.Value, error) {
  360. n, err := jval.Uint(bitSize)
  361. if err != nil {
  362. return pref.Value{}, err
  363. }
  364. if bitSize == 32 {
  365. return pref.ValueOfUint32(uint32(n)), nil
  366. }
  367. return pref.ValueOfUint64(n), nil
  368. }
  369. func unmarshalFloat(jval json.Value, bitSize int) (pref.Value, error) {
  370. switch jval.Type() {
  371. case json.Number:
  372. return getFloat(jval, bitSize)
  373. case json.String:
  374. s := jval.String()
  375. switch s {
  376. case "NaN":
  377. if bitSize == 32 {
  378. return pref.ValueOfFloat32(float32(math.NaN())), nil
  379. }
  380. return pref.ValueOfFloat64(math.NaN()), nil
  381. case "Infinity":
  382. if bitSize == 32 {
  383. return pref.ValueOfFloat32(float32(math.Inf(+1))), nil
  384. }
  385. return pref.ValueOfFloat64(math.Inf(+1)), nil
  386. case "-Infinity":
  387. if bitSize == 32 {
  388. return pref.ValueOfFloat32(float32(math.Inf(-1))), nil
  389. }
  390. return pref.ValueOfFloat64(math.Inf(-1)), nil
  391. }
  392. // Decode number from string.
  393. if len(s) != len(strings.TrimSpace(s)) {
  394. return pref.Value{}, errors.New("invalid number %v", jval.Raw())
  395. }
  396. dec := json.NewDecoder([]byte(s))
  397. jval, err := dec.Read()
  398. if err != nil {
  399. return pref.Value{}, err
  400. }
  401. return getFloat(jval, bitSize)
  402. }
  403. return pref.Value{}, unexpectedJSONError{jval}
  404. }
  405. func getFloat(jval json.Value, bitSize int) (pref.Value, error) {
  406. n, err := jval.Float(bitSize)
  407. if err != nil {
  408. return pref.Value{}, err
  409. }
  410. if bitSize == 32 {
  411. return pref.ValueOfFloat32(float32(n)), nil
  412. }
  413. return pref.ValueOfFloat64(n), nil
  414. }
  415. func unmarshalString(jval json.Value) (pref.Value, error) {
  416. if jval.Type() != json.String {
  417. return pref.Value{}, unexpectedJSONError{jval}
  418. }
  419. return pref.ValueOfString(jval.String()), nil
  420. }
  421. func unmarshalBytes(jval json.Value) (pref.Value, error) {
  422. if jval.Type() != json.String {
  423. return pref.Value{}, unexpectedJSONError{jval}
  424. }
  425. s := jval.String()
  426. enc := base64.StdEncoding
  427. if strings.ContainsAny(s, "-_") {
  428. enc = base64.URLEncoding
  429. }
  430. if len(s)%4 != 0 {
  431. enc = enc.WithPadding(base64.NoPadding)
  432. }
  433. b, err := enc.DecodeString(s)
  434. if err != nil {
  435. return pref.Value{}, err
  436. }
  437. return pref.ValueOfBytes(b), nil
  438. }
  439. func unmarshalEnum(jval json.Value, fd pref.FieldDescriptor) (pref.Value, error) {
  440. switch jval.Type() {
  441. case json.String:
  442. // Lookup EnumNumber based on name.
  443. s := jval.String()
  444. if enumVal := fd.Enum().Values().ByName(pref.Name(s)); enumVal != nil {
  445. return pref.ValueOfEnum(enumVal.Number()), nil
  446. }
  447. return pref.Value{}, newError("invalid enum value %q", jval)
  448. case json.Number:
  449. n, err := jval.Int(32)
  450. if err != nil {
  451. return pref.Value{}, err
  452. }
  453. return pref.ValueOfEnum(pref.EnumNumber(n)), nil
  454. case json.Null:
  455. // This is only valid for google.protobuf.NullValue.
  456. if isNullValue(fd) {
  457. return pref.ValueOfEnum(0), nil
  458. }
  459. }
  460. return pref.Value{}, unexpectedJSONError{jval}
  461. }
  462. func (o UnmarshalOptions) unmarshalList(list pref.List, fd pref.FieldDescriptor) error {
  463. jval, err := o.decoder.Read()
  464. if err != nil {
  465. return err
  466. }
  467. if jval.Type() != json.StartArray {
  468. return unexpectedJSONError{jval}
  469. }
  470. switch fd.Kind() {
  471. case pref.MessageKind, pref.GroupKind:
  472. for {
  473. val := list.NewElement()
  474. err := o.unmarshalMessage(val.Message(), false)
  475. if err != nil {
  476. if e, ok := err.(unexpectedJSONError); ok {
  477. if e.value.Type() == json.EndArray {
  478. // Done with list.
  479. return nil
  480. }
  481. }
  482. return err
  483. }
  484. list.Append(val)
  485. }
  486. default:
  487. for {
  488. val, err := o.unmarshalScalar(fd)
  489. if err != nil {
  490. if e, ok := err.(unexpectedJSONError); ok {
  491. if e.value.Type() == json.EndArray {
  492. // Done with list.
  493. return nil
  494. }
  495. }
  496. return err
  497. }
  498. list.Append(val)
  499. }
  500. }
  501. return nil
  502. }
  503. func (o UnmarshalOptions) unmarshalMap(mmap pref.Map, fd pref.FieldDescriptor) error {
  504. jval, err := o.decoder.Read()
  505. if err != nil {
  506. return err
  507. }
  508. if jval.Type() != json.StartObject {
  509. return unexpectedJSONError{jval}
  510. }
  511. // Determine ahead whether map entry is a scalar type or a message type in
  512. // order to call the appropriate unmarshalMapValue func inside the for loop
  513. // below.
  514. var unmarshalMapValue func() (pref.Value, error)
  515. switch fd.MapValue().Kind() {
  516. case pref.MessageKind, pref.GroupKind:
  517. unmarshalMapValue = func() (pref.Value, error) {
  518. val := mmap.NewValue()
  519. if err := o.unmarshalMessage(val.Message(), false); err != nil {
  520. return pref.Value{}, err
  521. }
  522. return val, nil
  523. }
  524. default:
  525. unmarshalMapValue = func() (pref.Value, error) {
  526. return o.unmarshalScalar(fd.MapValue())
  527. }
  528. }
  529. Loop:
  530. for {
  531. // Read field name.
  532. jval, err := o.decoder.Read()
  533. if err != nil {
  534. return err
  535. }
  536. switch jval.Type() {
  537. default:
  538. return unexpectedJSONError{jval}
  539. case json.EndObject:
  540. break Loop
  541. case json.Name:
  542. // Continue.
  543. }
  544. name, err := jval.Name()
  545. if err != nil {
  546. return err
  547. }
  548. // Unmarshal field name.
  549. pkey, err := unmarshalMapKey(name, fd.MapKey())
  550. if err != nil {
  551. return err
  552. }
  553. // Check for duplicate field name.
  554. if mmap.Has(pkey) {
  555. return newError("duplicate map key %q", jval)
  556. }
  557. // Read and unmarshal field value.
  558. pval, err := unmarshalMapValue()
  559. if err != nil {
  560. return err
  561. }
  562. mmap.Set(pkey, pval)
  563. }
  564. return nil
  565. }
  566. // unmarshalMapKey converts given string into a protoreflect.MapKey. A map key type is any
  567. // integral or string type.
  568. func unmarshalMapKey(name string, fd pref.FieldDescriptor) (pref.MapKey, error) {
  569. const b32 = 32
  570. const b64 = 64
  571. const base10 = 10
  572. kind := fd.Kind()
  573. switch kind {
  574. case pref.StringKind:
  575. return pref.ValueOfString(name).MapKey(), nil
  576. case pref.BoolKind:
  577. switch name {
  578. case "true":
  579. return pref.ValueOfBool(true).MapKey(), nil
  580. case "false":
  581. return pref.ValueOfBool(false).MapKey(), nil
  582. }
  583. return pref.MapKey{}, errors.New("invalid value for boolean key %q", name)
  584. case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
  585. n, err := strconv.ParseInt(name, base10, b32)
  586. if err != nil {
  587. return pref.MapKey{}, err
  588. }
  589. return pref.ValueOfInt32(int32(n)).MapKey(), nil
  590. case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
  591. n, err := strconv.ParseInt(name, base10, b64)
  592. if err != nil {
  593. return pref.MapKey{}, err
  594. }
  595. return pref.ValueOfInt64(int64(n)).MapKey(), nil
  596. case pref.Uint32Kind, pref.Fixed32Kind:
  597. n, err := strconv.ParseUint(name, base10, b32)
  598. if err != nil {
  599. return pref.MapKey{}, err
  600. }
  601. return pref.ValueOfUint32(uint32(n)).MapKey(), nil
  602. case pref.Uint64Kind, pref.Fixed64Kind:
  603. n, err := strconv.ParseUint(name, base10, b64)
  604. if err != nil {
  605. return pref.MapKey{}, err
  606. }
  607. return pref.ValueOfUint64(uint64(n)).MapKey(), nil
  608. }
  609. panic(fmt.Sprintf("%s: invalid kind %s for map key", fd.FullName(), kind))
  610. }