decode.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package protojson
  5. import (
  6. "encoding/base64"
  7. "fmt"
  8. "math"
  9. "strconv"
  10. "strings"
  11. "google.golang.org/protobuf/internal/encoding/json"
  12. "google.golang.org/protobuf/internal/errors"
  13. "google.golang.org/protobuf/internal/pragma"
  14. "google.golang.org/protobuf/internal/set"
  15. "google.golang.org/protobuf/proto"
  16. pref "google.golang.org/protobuf/reflect/protoreflect"
  17. "google.golang.org/protobuf/reflect/protoregistry"
  18. )
  19. // Unmarshal reads the given []byte into the given proto.Message.
  20. func Unmarshal(b []byte, m proto.Message) error {
  21. return UnmarshalOptions{}.Unmarshal(b, m)
  22. }
  23. // UnmarshalOptions is a configurable JSON format parser.
  24. type UnmarshalOptions struct {
  25. pragma.NoUnkeyedLiterals
  26. // If AllowPartial is set, input for messages that will result in missing
  27. // required fields will not return an error.
  28. AllowPartial bool
  29. // If DiscardUnknown is set, unknown fields are ignored.
  30. DiscardUnknown bool
  31. // Resolver is used for looking up types when unmarshaling
  32. // google.protobuf.Any messages or extension fields.
  33. // If nil, this defaults to using protoregistry.GlobalTypes.
  34. Resolver interface {
  35. protoregistry.MessageTypeResolver
  36. protoregistry.ExtensionTypeResolver
  37. }
  38. decoder *json.Decoder
  39. }
  40. // Unmarshal reads the given []byte and populates the given proto.Message using
  41. // options in UnmarshalOptions object. It will clear the message first before
  42. // setting the fields. If it returns an error, the given message may be
  43. // partially set.
  44. func (o UnmarshalOptions) Unmarshal(b []byte, m proto.Message) error {
  45. // TODO: Determine if we would like to have an option for merging or only
  46. // have merging behavior. We should at least be consistent with textproto
  47. // marshaling.
  48. proto.Reset(m)
  49. if o.Resolver == nil {
  50. o.Resolver = protoregistry.GlobalTypes
  51. }
  52. o.decoder = json.NewDecoder(b)
  53. if err := o.unmarshalMessage(m.ProtoReflect(), false); err != nil {
  54. return err
  55. }
  56. // Check for EOF.
  57. val, err := o.decoder.Read()
  58. if err != nil {
  59. return err
  60. }
  61. if val.Type() != json.EOF {
  62. return unexpectedJSONError{val}
  63. }
  64. if o.AllowPartial {
  65. return nil
  66. }
  67. return proto.IsInitialized(m)
  68. }
  69. // unexpectedJSONError is an error that contains the unexpected json.Value. This
  70. // is returned by methods to provide callers the read json.Value that it did not
  71. // expect.
  72. // TODO: Consider moving this to internal/encoding/json for consistency with
  73. // errors that package returns.
  74. type unexpectedJSONError struct {
  75. value json.Value
  76. }
  77. func (e unexpectedJSONError) Error() string {
  78. return newError("unexpected value %s", e.value).Error()
  79. }
  80. // newError returns an error object. If one of the values passed in is of
  81. // json.Value type, it produces an error with position info.
  82. func newError(f string, x ...interface{}) error {
  83. var hasValue bool
  84. var line, column int
  85. for i := 0; i < len(x); i++ {
  86. if val, ok := x[i].(json.Value); ok {
  87. line, column = val.Position()
  88. hasValue = true
  89. break
  90. }
  91. }
  92. e := errors.New(f, x...)
  93. if hasValue {
  94. return errors.New("(line %d:%d): %v", line, column, e)
  95. }
  96. return e
  97. }
  98. // unmarshalMessage unmarshals a message into the given protoreflect.Message.
  99. func (o UnmarshalOptions) unmarshalMessage(m pref.Message, skipTypeURL bool) error {
  100. if isCustomType(m.Descriptor().FullName()) {
  101. return o.unmarshalCustomType(m)
  102. }
  103. jval, err := o.decoder.Read()
  104. if err != nil {
  105. return err
  106. }
  107. if jval.Type() != json.StartObject {
  108. return unexpectedJSONError{jval}
  109. }
  110. if err := o.unmarshalFields(m, skipTypeURL); err != nil {
  111. return err
  112. }
  113. return nil
  114. }
  115. // unmarshalFields unmarshals the fields into the given protoreflect.Message.
  116. func (o UnmarshalOptions) unmarshalFields(m pref.Message, skipTypeURL bool) error {
  117. var seenNums set.Ints
  118. var seenOneofs set.Ints
  119. messageDesc := m.Descriptor()
  120. fieldDescs := messageDesc.Fields()
  121. Loop:
  122. for {
  123. // Read field name.
  124. jval, err := o.decoder.Read()
  125. if err != nil {
  126. return err
  127. }
  128. switch jval.Type() {
  129. default:
  130. return unexpectedJSONError{jval}
  131. case json.EndObject:
  132. break Loop
  133. case json.Name:
  134. // Continue below.
  135. }
  136. name, err := jval.Name()
  137. if err != nil {
  138. return err
  139. }
  140. // Unmarshaling a non-custom embedded message in Any will contain the
  141. // JSON field "@type" which should be skipped because it is not a field
  142. // of the embedded message, but simply an artifact of the Any format.
  143. if skipTypeURL && name == "@type" {
  144. o.decoder.Read()
  145. continue
  146. }
  147. // Get the FieldDescriptor.
  148. var fd pref.FieldDescriptor
  149. if strings.HasPrefix(name, "[") && strings.HasSuffix(name, "]") {
  150. // Only extension names are in [name] format.
  151. extName := pref.FullName(name[1 : len(name)-1])
  152. extType, err := o.findExtension(extName)
  153. if err != nil && err != protoregistry.NotFound {
  154. return errors.New("unable to resolve [%v]: %v", extName, err)
  155. }
  156. fd = extType
  157. } else {
  158. // The name can either be the JSON name or the proto field name.
  159. fd = fieldDescs.ByJSONName(name)
  160. if fd == nil {
  161. fd = fieldDescs.ByName(pref.Name(name))
  162. }
  163. }
  164. if fd == nil {
  165. // Field is unknown.
  166. if o.DiscardUnknown {
  167. if err := skipJSONValue(o.decoder); err != nil {
  168. return err
  169. }
  170. continue
  171. }
  172. return newError("%v contains unknown field %s", messageDesc.FullName(), jval)
  173. }
  174. // Do not allow duplicate fields.
  175. num := uint64(fd.Number())
  176. if seenNums.Has(num) {
  177. return newError("%v contains repeated field %s", messageDesc.FullName(), jval)
  178. }
  179. seenNums.Set(num)
  180. // No need to set values for JSON null unless the field type is
  181. // google.protobuf.Value or google.protobuf.NullValue.
  182. if o.decoder.Peek() == json.Null && !isKnownValue(fd) && !isNullValue(fd) {
  183. o.decoder.Read()
  184. continue
  185. }
  186. switch {
  187. case fd.IsList():
  188. list := m.Mutable(fd).List()
  189. if err := o.unmarshalList(list, fd); err != nil {
  190. return errors.New("%v|%q: %v", fd.FullName(), name, err)
  191. }
  192. case fd.IsMap():
  193. mmap := m.Mutable(fd).Map()
  194. if err := o.unmarshalMap(mmap, fd); err != nil {
  195. return errors.New("%v|%q: %v", fd.FullName(), name, err)
  196. }
  197. default:
  198. // If field is a oneof, check if it has already been set.
  199. if od := fd.ContainingOneof(); od != nil {
  200. idx := uint64(od.Index())
  201. if seenOneofs.Has(idx) {
  202. return errors.New("%v: oneof is already set", od.FullName())
  203. }
  204. seenOneofs.Set(idx)
  205. }
  206. // Required or optional fields.
  207. if err := o.unmarshalSingular(m, fd); err != nil {
  208. return errors.New("%v|%q: %v", fd.FullName(), name, err)
  209. }
  210. }
  211. }
  212. return nil
  213. }
  214. // findExtension returns protoreflect.ExtensionType from the resolver if found.
  215. func (o UnmarshalOptions) findExtension(xtName pref.FullName) (pref.ExtensionType, error) {
  216. xt, err := o.Resolver.FindExtensionByName(xtName)
  217. if err == nil {
  218. return xt, nil
  219. }
  220. // Check if this is a MessageSet extension field.
  221. xt, err = o.Resolver.FindExtensionByName(xtName + ".message_set_extension")
  222. if err == nil && isMessageSetExtension(xt) {
  223. return xt, nil
  224. }
  225. return nil, protoregistry.NotFound
  226. }
  227. func isKnownValue(fd pref.FieldDescriptor) bool {
  228. md := fd.Message()
  229. return md != nil && md.FullName() == "google.protobuf.Value"
  230. }
  231. func isNullValue(fd pref.FieldDescriptor) bool {
  232. ed := fd.Enum()
  233. return ed != nil && ed.FullName() == "google.protobuf.NullValue"
  234. }
  235. // unmarshalSingular unmarshals to the non-repeated field specified by the given
  236. // FieldDescriptor.
  237. func (o UnmarshalOptions) unmarshalSingular(m pref.Message, fd pref.FieldDescriptor) error {
  238. var val pref.Value
  239. var err error
  240. switch fd.Kind() {
  241. case pref.MessageKind, pref.GroupKind:
  242. m2 := m.NewMessage(fd)
  243. err = o.unmarshalMessage(m2, false)
  244. val = pref.ValueOf(m2)
  245. default:
  246. val, err = o.unmarshalScalar(fd)
  247. }
  248. if err != nil {
  249. return err
  250. }
  251. m.Set(fd, val)
  252. return nil
  253. }
  254. // unmarshalScalar unmarshals to a scalar/enum protoreflect.Value specified by
  255. // the given FieldDescriptor.
  256. func (o UnmarshalOptions) unmarshalScalar(fd pref.FieldDescriptor) (pref.Value, error) {
  257. const b32 int = 32
  258. const b64 int = 64
  259. jval, err := o.decoder.Read()
  260. if err != nil {
  261. return pref.Value{}, err
  262. }
  263. kind := fd.Kind()
  264. switch kind {
  265. case pref.BoolKind:
  266. return unmarshalBool(jval)
  267. case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
  268. return unmarshalInt(jval, b32)
  269. case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
  270. return unmarshalInt(jval, b64)
  271. case pref.Uint32Kind, pref.Fixed32Kind:
  272. return unmarshalUint(jval, b32)
  273. case pref.Uint64Kind, pref.Fixed64Kind:
  274. return unmarshalUint(jval, b64)
  275. case pref.FloatKind:
  276. return unmarshalFloat(jval, b32)
  277. case pref.DoubleKind:
  278. return unmarshalFloat(jval, b64)
  279. case pref.StringKind:
  280. pval, err := unmarshalString(jval)
  281. if err != nil {
  282. return pval, err
  283. }
  284. return pval, nil
  285. case pref.BytesKind:
  286. return unmarshalBytes(jval)
  287. case pref.EnumKind:
  288. return unmarshalEnum(jval, fd)
  289. }
  290. panic(fmt.Sprintf("invalid scalar kind %v", kind))
  291. }
  292. func unmarshalBool(jval json.Value) (pref.Value, error) {
  293. if jval.Type() != json.Bool {
  294. return pref.Value{}, unexpectedJSONError{jval}
  295. }
  296. b, err := jval.Bool()
  297. return pref.ValueOf(b), err
  298. }
  299. func unmarshalInt(jval json.Value, bitSize int) (pref.Value, error) {
  300. switch jval.Type() {
  301. case json.Number:
  302. return getInt(jval, bitSize)
  303. case json.String:
  304. // Decode number from string.
  305. s := strings.TrimSpace(jval.String())
  306. if len(s) != len(jval.String()) {
  307. return pref.Value{}, errors.New("invalid number %v", jval.Raw())
  308. }
  309. dec := json.NewDecoder([]byte(s))
  310. jval, err := dec.Read()
  311. if err != nil {
  312. return pref.Value{}, err
  313. }
  314. return getInt(jval, bitSize)
  315. }
  316. return pref.Value{}, unexpectedJSONError{jval}
  317. }
  318. func getInt(jval json.Value, bitSize int) (pref.Value, error) {
  319. n, err := jval.Int(bitSize)
  320. if err != nil {
  321. return pref.Value{}, err
  322. }
  323. if bitSize == 32 {
  324. return pref.ValueOf(int32(n)), nil
  325. }
  326. return pref.ValueOf(n), nil
  327. }
  328. func unmarshalUint(jval json.Value, bitSize int) (pref.Value, error) {
  329. switch jval.Type() {
  330. case json.Number:
  331. return getUint(jval, bitSize)
  332. case json.String:
  333. // Decode number from string.
  334. s := strings.TrimSpace(jval.String())
  335. if len(s) != len(jval.String()) {
  336. return pref.Value{}, errors.New("invalid number %v", jval.Raw())
  337. }
  338. dec := json.NewDecoder([]byte(s))
  339. jval, err := dec.Read()
  340. if err != nil {
  341. return pref.Value{}, err
  342. }
  343. return getUint(jval, bitSize)
  344. }
  345. return pref.Value{}, unexpectedJSONError{jval}
  346. }
  347. func getUint(jval json.Value, bitSize int) (pref.Value, error) {
  348. n, err := jval.Uint(bitSize)
  349. if err != nil {
  350. return pref.Value{}, err
  351. }
  352. if bitSize == 32 {
  353. return pref.ValueOf(uint32(n)), nil
  354. }
  355. return pref.ValueOf(n), nil
  356. }
  357. func unmarshalFloat(jval json.Value, bitSize int) (pref.Value, error) {
  358. switch jval.Type() {
  359. case json.Number:
  360. return getFloat(jval, bitSize)
  361. case json.String:
  362. s := jval.String()
  363. switch s {
  364. case "NaN":
  365. if bitSize == 32 {
  366. return pref.ValueOf(float32(math.NaN())), nil
  367. }
  368. return pref.ValueOf(math.NaN()), nil
  369. case "Infinity":
  370. if bitSize == 32 {
  371. return pref.ValueOf(float32(math.Inf(+1))), nil
  372. }
  373. return pref.ValueOf(math.Inf(+1)), nil
  374. case "-Infinity":
  375. if bitSize == 32 {
  376. return pref.ValueOf(float32(math.Inf(-1))), nil
  377. }
  378. return pref.ValueOf(math.Inf(-1)), nil
  379. }
  380. // Decode number from string.
  381. if len(s) != len(strings.TrimSpace(s)) {
  382. return pref.Value{}, errors.New("invalid number %v", jval.Raw())
  383. }
  384. dec := json.NewDecoder([]byte(s))
  385. jval, err := dec.Read()
  386. if err != nil {
  387. return pref.Value{}, err
  388. }
  389. return getFloat(jval, bitSize)
  390. }
  391. return pref.Value{}, unexpectedJSONError{jval}
  392. }
  393. func getFloat(jval json.Value, bitSize int) (pref.Value, error) {
  394. n, err := jval.Float(bitSize)
  395. if err != nil {
  396. return pref.Value{}, err
  397. }
  398. if bitSize == 32 {
  399. return pref.ValueOf(float32(n)), nil
  400. }
  401. return pref.ValueOf(n), nil
  402. }
  403. func unmarshalString(jval json.Value) (pref.Value, error) {
  404. if jval.Type() != json.String {
  405. return pref.Value{}, unexpectedJSONError{jval}
  406. }
  407. return pref.ValueOf(jval.String()), nil
  408. }
  409. func unmarshalBytes(jval json.Value) (pref.Value, error) {
  410. if jval.Type() != json.String {
  411. return pref.Value{}, unexpectedJSONError{jval}
  412. }
  413. s := jval.String()
  414. enc := base64.StdEncoding
  415. if strings.ContainsAny(s, "-_") {
  416. enc = base64.URLEncoding
  417. }
  418. if len(s)%4 != 0 {
  419. enc = enc.WithPadding(base64.NoPadding)
  420. }
  421. b, err := enc.DecodeString(s)
  422. if err != nil {
  423. return pref.Value{}, err
  424. }
  425. return pref.ValueOf(b), nil
  426. }
  427. func unmarshalEnum(jval json.Value, fd pref.FieldDescriptor) (pref.Value, error) {
  428. switch jval.Type() {
  429. case json.String:
  430. // Lookup EnumNumber based on name.
  431. s := jval.String()
  432. if enumVal := fd.Enum().Values().ByName(pref.Name(s)); enumVal != nil {
  433. return pref.ValueOf(enumVal.Number()), nil
  434. }
  435. return pref.Value{}, newError("invalid enum value %q", jval)
  436. case json.Number:
  437. n, err := jval.Int(32)
  438. if err != nil {
  439. return pref.Value{}, err
  440. }
  441. return pref.ValueOf(pref.EnumNumber(n)), nil
  442. case json.Null:
  443. // This is only valid for google.protobuf.NullValue.
  444. if isNullValue(fd) {
  445. return pref.ValueOf(pref.EnumNumber(0)), nil
  446. }
  447. }
  448. return pref.Value{}, unexpectedJSONError{jval}
  449. }
  450. func (o UnmarshalOptions) unmarshalList(list pref.List, fd pref.FieldDescriptor) error {
  451. jval, err := o.decoder.Read()
  452. if err != nil {
  453. return err
  454. }
  455. if jval.Type() != json.StartArray {
  456. return unexpectedJSONError{jval}
  457. }
  458. switch fd.Kind() {
  459. case pref.MessageKind, pref.GroupKind:
  460. for {
  461. m := list.NewMessage()
  462. err := o.unmarshalMessage(m, false)
  463. if err != nil {
  464. if e, ok := err.(unexpectedJSONError); ok {
  465. if e.value.Type() == json.EndArray {
  466. // Done with list.
  467. return nil
  468. }
  469. }
  470. return err
  471. }
  472. list.Append(pref.ValueOf(m))
  473. }
  474. default:
  475. for {
  476. val, err := o.unmarshalScalar(fd)
  477. if err != nil {
  478. if e, ok := err.(unexpectedJSONError); ok {
  479. if e.value.Type() == json.EndArray {
  480. // Done with list.
  481. return nil
  482. }
  483. }
  484. return err
  485. }
  486. list.Append(val)
  487. }
  488. }
  489. return nil
  490. }
  491. func (o UnmarshalOptions) unmarshalMap(mmap pref.Map, fd pref.FieldDescriptor) error {
  492. jval, err := o.decoder.Read()
  493. if err != nil {
  494. return err
  495. }
  496. if jval.Type() != json.StartObject {
  497. return unexpectedJSONError{jval}
  498. }
  499. // Determine ahead whether map entry is a scalar type or a message type in
  500. // order to call the appropriate unmarshalMapValue func inside the for loop
  501. // below.
  502. var unmarshalMapValue func() (pref.Value, error)
  503. switch fd.MapValue().Kind() {
  504. case pref.MessageKind, pref.GroupKind:
  505. unmarshalMapValue = func() (pref.Value, error) {
  506. m := mmap.NewMessage()
  507. if err := o.unmarshalMessage(m, false); err != nil {
  508. return pref.Value{}, err
  509. }
  510. return pref.ValueOf(m), nil
  511. }
  512. default:
  513. unmarshalMapValue = func() (pref.Value, error) {
  514. return o.unmarshalScalar(fd.MapValue())
  515. }
  516. }
  517. Loop:
  518. for {
  519. // Read field name.
  520. jval, err := o.decoder.Read()
  521. if err != nil {
  522. return err
  523. }
  524. switch jval.Type() {
  525. default:
  526. return unexpectedJSONError{jval}
  527. case json.EndObject:
  528. break Loop
  529. case json.Name:
  530. // Continue.
  531. }
  532. name, err := jval.Name()
  533. if err != nil {
  534. return err
  535. }
  536. // Unmarshal field name.
  537. pkey, err := unmarshalMapKey(name, fd.MapKey())
  538. if err != nil {
  539. return err
  540. }
  541. // Check for duplicate field name.
  542. if mmap.Has(pkey) {
  543. return newError("duplicate map key %q", jval)
  544. }
  545. // Read and unmarshal field value.
  546. pval, err := unmarshalMapValue()
  547. if err != nil {
  548. return err
  549. }
  550. mmap.Set(pkey, pval)
  551. }
  552. return nil
  553. }
  554. // unmarshalMapKey converts given string into a protoreflect.MapKey. A map key type is any
  555. // integral or string type.
  556. func unmarshalMapKey(name string, fd pref.FieldDescriptor) (pref.MapKey, error) {
  557. const b32 = 32
  558. const b64 = 64
  559. const base10 = 10
  560. kind := fd.Kind()
  561. switch kind {
  562. case pref.StringKind:
  563. return pref.ValueOf(name).MapKey(), nil
  564. case pref.BoolKind:
  565. switch name {
  566. case "true":
  567. return pref.ValueOf(true).MapKey(), nil
  568. case "false":
  569. return pref.ValueOf(false).MapKey(), nil
  570. }
  571. return pref.MapKey{}, errors.New("invalid value for boolean key %q", name)
  572. case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
  573. n, err := strconv.ParseInt(name, base10, b32)
  574. if err != nil {
  575. return pref.MapKey{}, err
  576. }
  577. return pref.ValueOf(int32(n)).MapKey(), nil
  578. case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
  579. n, err := strconv.ParseInt(name, base10, b64)
  580. if err != nil {
  581. return pref.MapKey{}, err
  582. }
  583. return pref.ValueOf(int64(n)).MapKey(), nil
  584. case pref.Uint32Kind, pref.Fixed32Kind:
  585. n, err := strconv.ParseUint(name, base10, b32)
  586. if err != nil {
  587. return pref.MapKey{}, err
  588. }
  589. return pref.ValueOf(uint32(n)).MapKey(), nil
  590. case pref.Uint64Kind, pref.Fixed64Kind:
  591. n, err := strconv.ParseUint(name, base10, b64)
  592. if err != nil {
  593. return pref.MapKey{}, err
  594. }
  595. return pref.ValueOf(uint64(n)).MapKey(), nil
  596. }
  597. panic(fmt.Sprintf("%s: invalid kind %s for map key", fd.FullName(), kind))
  598. }