extensions.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. // Copyright 2010 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package proto
  5. /*
  6. * Types and routines for supporting protocol buffer extensions.
  7. */
  8. import (
  9. "errors"
  10. "fmt"
  11. "io"
  12. "reflect"
  13. "sync"
  14. "github.com/golang/protobuf/internal/wire"
  15. "google.golang.org/protobuf/reflect/protoreflect"
  16. "google.golang.org/protobuf/runtime/protoiface"
  17. "google.golang.org/protobuf/runtime/protoimpl"
  18. )
  19. // ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message.
  20. var ErrMissingExtension = errors.New("proto: missing extension")
  21. func extensionFieldsOf(p interface{}) *extensionMap {
  22. if p, ok := p.(*map[int32]Extension); ok {
  23. return (*extensionMap)(p)
  24. }
  25. panic(fmt.Sprintf("invalid extension fields type: %T", p))
  26. }
  27. type extensionMap map[int32]Extension
  28. func (m extensionMap) Len() int {
  29. return len(m)
  30. }
  31. func (m extensionMap) Has(n protoreflect.FieldNumber) bool {
  32. _, ok := m[int32(n)]
  33. return ok
  34. }
  35. func (m extensionMap) Get(n protoreflect.FieldNumber) Extension {
  36. return m[int32(n)]
  37. }
  38. func (m *extensionMap) Set(n protoreflect.FieldNumber, x Extension) {
  39. if *m == nil {
  40. *m = make(map[int32]Extension)
  41. }
  42. (*m)[int32(n)] = x
  43. }
  44. func (m *extensionMap) Clear(n protoreflect.FieldNumber) {
  45. delete(*m, int32(n))
  46. }
  47. func (m extensionMap) Range(f func(protoreflect.FieldNumber, Extension) bool) {
  48. for n, x := range m {
  49. if !f(protoreflect.FieldNumber(n), x) {
  50. return
  51. }
  52. }
  53. }
  54. func extendable(p interface{}) (*extensionMap, error) {
  55. type extendableProto interface {
  56. Message
  57. ExtensionRangeArray() []ExtensionRange
  58. }
  59. if _, ok := p.(extendableProto); ok {
  60. v := reflect.ValueOf(p)
  61. if v.Kind() == reflect.Ptr && !v.IsNil() {
  62. v = v.Elem()
  63. if vf := extensionFieldsValue(v); vf.IsValid() {
  64. return extensionFieldsOf(vf.Addr().Interface()), nil
  65. }
  66. }
  67. }
  68. // Don't allocate a specific error containing %T:
  69. // this is the hot path for Clone and MarshalText.
  70. return nil, errNotExtendable
  71. }
  72. var errNotExtendable = errors.New("proto: not an extendable proto.Message")
  73. type (
  74. ExtensionRange = protoiface.ExtensionRangeV1
  75. ExtensionDesc = protoimpl.ExtensionInfo
  76. Extension = protoimpl.ExtensionFieldV1
  77. XXX_InternalExtensions = protoimpl.ExtensionFields
  78. )
  79. func isRepeatedExtension(ed *ExtensionDesc) bool {
  80. t := reflect.TypeOf(ed.ExtensionType)
  81. return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8
  82. }
  83. // SetRawExtension is for testing only.
  84. func SetRawExtension(base Message, id int32, b []byte) {
  85. v := reflect.ValueOf(base)
  86. if !v.IsValid() || v.Kind() != reflect.Ptr || v.IsNil() || v.Elem().Kind() != reflect.Struct {
  87. return
  88. }
  89. v = unknownFieldsValue(v.Elem())
  90. if !v.IsValid() {
  91. return
  92. }
  93. // Verify that the raw field is valid.
  94. for b0 := b; len(b0) > 0; {
  95. fieldNum, _, n := wire.ConsumeField(b0)
  96. if int32(fieldNum) != id {
  97. panic(fmt.Sprintf("mismatching field number: got %d, want %d", fieldNum, id))
  98. }
  99. b0 = b0[n:]
  100. }
  101. fnum := protoreflect.FieldNumber(id)
  102. v.SetBytes(append(removeRawFields(v.Bytes(), fnum), b...))
  103. }
  104. func removeRawFields(b []byte, fnum protoreflect.FieldNumber) []byte {
  105. out := b[:0]
  106. for len(b) > 0 {
  107. got, _, n := wire.ConsumeField(b)
  108. if got != fnum {
  109. out = append(out, b[:n]...)
  110. }
  111. b = b[n:]
  112. }
  113. return out
  114. }
  115. // isExtensionField returns true iff the given field number is in an extension range.
  116. func isExtensionField(pb Message, field int32) bool {
  117. m, ok := pb.(interface{ ExtensionRangeArray() []ExtensionRange })
  118. if ok {
  119. for _, er := range m.ExtensionRangeArray() {
  120. if er.Start <= field && field <= er.End {
  121. return true
  122. }
  123. }
  124. }
  125. return false
  126. }
  127. // checkExtensionTypeAndRanges checks that the given extension is valid for pb.
  128. func checkExtensionTypeAndRanges(pb Message, extension *ExtensionDesc) error {
  129. // Check the extended type.
  130. if extension.ExtendedType != nil {
  131. if a, b := reflect.TypeOf(pb), reflect.TypeOf(extension.ExtendedType); a != b {
  132. return fmt.Errorf("proto: bad extended type; %v does not extend %v", b, a)
  133. }
  134. }
  135. // Check the range.
  136. if !isExtensionField(pb, extension.Field) {
  137. return errors.New("proto: bad extension number; not in declared ranges")
  138. }
  139. return nil
  140. }
  141. // extPropKey is sufficient to uniquely identify an extension.
  142. type extPropKey struct {
  143. base reflect.Type
  144. field int32
  145. }
  146. var extProp = struct {
  147. sync.RWMutex
  148. m map[extPropKey]*Properties
  149. }{
  150. m: make(map[extPropKey]*Properties),
  151. }
  152. func extensionProperties(pb Message, ed *ExtensionDesc) *Properties {
  153. key := extPropKey{base: reflect.TypeOf(pb), field: ed.Field}
  154. extProp.RLock()
  155. if prop, ok := extProp.m[key]; ok {
  156. extProp.RUnlock()
  157. return prop
  158. }
  159. extProp.RUnlock()
  160. extProp.Lock()
  161. defer extProp.Unlock()
  162. // Check again.
  163. if prop, ok := extProp.m[key]; ok {
  164. return prop
  165. }
  166. prop := new(Properties)
  167. prop.Init(reflect.TypeOf(ed.ExtensionType), "unknown_name", ed.Tag, nil)
  168. extProp.m[key] = prop
  169. return prop
  170. }
  171. // HasExtension returns whether the given extension is present in pb.
  172. func HasExtension(pb Message, extension *ExtensionDesc) bool {
  173. // TODO: Check types, field numbers, etc.?
  174. epb, err := extendable(pb)
  175. if err != nil || epb == nil {
  176. return false
  177. }
  178. if epb.Has(protoreflect.FieldNumber(extension.Field)) {
  179. return true
  180. }
  181. // Check whether this field exists in raw form.
  182. unrecognized := unknownFieldsValue(reflect.ValueOf(pb).Elem())
  183. fnum := protoreflect.FieldNumber(extension.Field)
  184. for b := unrecognized.Bytes(); len(b) > 0; {
  185. got, _, n := wire.ConsumeField(b)
  186. if got == fnum {
  187. return true
  188. }
  189. b = b[n:]
  190. }
  191. return false
  192. }
  193. // ClearExtension removes the given extension from pb.
  194. func ClearExtension(pb Message, extension *ExtensionDesc) {
  195. epb, err := extendable(pb)
  196. if err != nil {
  197. return
  198. }
  199. // TODO: Check types, field numbers, etc.?
  200. epb.Clear(protoreflect.FieldNumber(extension.Field))
  201. }
  202. // GetExtension retrieves a proto2 extended field from pb.
  203. //
  204. // If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil),
  205. // then GetExtension parses the encoded field and returns a Go value of the specified type.
  206. // If the field is not present, then the default value is returned (if one is specified),
  207. // otherwise ErrMissingExtension is reported.
  208. //
  209. // If the descriptor is not type complete (i.e., ExtensionDesc.ExtensionType is nil),
  210. // then GetExtension returns the raw encoded bytes of the field extension.
  211. func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
  212. epb, err := extendable(pb)
  213. if err != nil {
  214. return nil, err
  215. }
  216. // can only check type if this is a complete descriptor
  217. if err := checkExtensionTypeAndRanges(pb, extension); err != nil {
  218. return nil, err
  219. }
  220. unrecognized := unknownFieldsValue(reflect.ValueOf(pb).Elem())
  221. var out []byte
  222. fnum := protoreflect.FieldNumber(extension.Field)
  223. for b := unrecognized.Bytes(); len(b) > 0; {
  224. got, _, n := wire.ConsumeField(b)
  225. if got == fnum {
  226. out = append(out, b[:n]...)
  227. }
  228. b = b[n:]
  229. }
  230. if !epb.Has(protoreflect.FieldNumber(extension.Field)) && len(out) == 0 {
  231. // defaultExtensionValue returns the default value or
  232. // ErrMissingExtension if there is no default.
  233. return defaultExtensionValue(pb, extension)
  234. }
  235. e := epb.Get(protoreflect.FieldNumber(extension.Field))
  236. if e.HasValue() {
  237. // Already decoded. Check the descriptor, though.
  238. if protoimpl.X.ExtensionDescFromType(e.GetType()) != extension {
  239. // This shouldn't happen. If it does, it means that
  240. // GetExtension was called twice with two different
  241. // descriptors with the same field number.
  242. return nil, errors.New("proto: descriptor conflict")
  243. }
  244. return extensionAsLegacyType(e.GetValue()), nil
  245. }
  246. // Descriptor without type information.
  247. if extension.ExtensionType == nil {
  248. return out, nil
  249. }
  250. // TODO: Remove this logic for automatically unmarshaling the unknown fields.
  251. v, err := decodeExtension(out, extension)
  252. if err != nil {
  253. return nil, err
  254. }
  255. // Remember the decoded version and drop the encoded version.
  256. // That way it is safe to mutate what we return.
  257. e.SetType(extension)
  258. e.SetEagerValue(extensionAsStorageType(v))
  259. unrecognized.SetBytes(removeRawFields(unrecognized.Bytes(), fnum))
  260. epb.Set(protoreflect.FieldNumber(extension.Field), e)
  261. return extensionAsLegacyType(e.GetValue()), nil
  262. }
  263. // defaultExtensionValue returns the default value for extension.
  264. // If no default for an extension is defined ErrMissingExtension is returned.
  265. func defaultExtensionValue(pb Message, extension *ExtensionDesc) (interface{}, error) {
  266. if extension.ExtensionType == nil {
  267. // incomplete descriptor, so no default
  268. return nil, ErrMissingExtension
  269. }
  270. t := reflect.TypeOf(extension.ExtensionType)
  271. props := extensionProperties(pb, extension)
  272. sf, _, err := fieldDefault(t, props)
  273. if err != nil {
  274. return nil, err
  275. }
  276. if sf == nil || sf.value == nil {
  277. // There is no default value.
  278. return nil, ErrMissingExtension
  279. }
  280. if t.Kind() != reflect.Ptr {
  281. // We do not need to return a Ptr, we can directly return sf.value.
  282. return sf.value, nil
  283. }
  284. // We need to return an interface{} that is a pointer to sf.value.
  285. value := reflect.New(t).Elem()
  286. value.Set(reflect.New(value.Type().Elem()))
  287. if sf.kind == reflect.Int32 {
  288. // We may have an int32 or an enum, but the underlying data is int32.
  289. // Since we can't set an int32 into a non int32 reflect.Value directly
  290. // set it as a int32.
  291. value.Elem().SetInt(int64(sf.value.(int32)))
  292. } else {
  293. value.Elem().Set(reflect.ValueOf(sf.value))
  294. }
  295. return value.Interface(), nil
  296. }
  297. // decodeExtension decodes an extension encoded in b.
  298. func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
  299. t := reflect.TypeOf(extension.ExtensionType)
  300. unmarshal := typeUnmarshaler(t, extension.Tag)
  301. // t is a pointer to a struct, pointer to basic type or a slice.
  302. // Allocate space to store the pointer/slice.
  303. value := reflect.New(t).Elem()
  304. var err error
  305. for {
  306. x, n := decodeVarint(b)
  307. if n == 0 {
  308. return nil, io.ErrUnexpectedEOF
  309. }
  310. b = b[n:]
  311. wire := int(x) & 7
  312. b, err = unmarshal(b, valToPointer(value.Addr()), wire)
  313. if err != nil {
  314. return nil, err
  315. }
  316. if len(b) == 0 {
  317. break
  318. }
  319. }
  320. return value.Interface(), nil
  321. }
  322. // GetExtensions returns a slice of the extensions present in pb that are also listed in es.
  323. // The returned slice has the same length as es; missing extensions will appear as nil elements.
  324. func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
  325. _, err = extendable(pb)
  326. if err != nil {
  327. return nil, err
  328. }
  329. extensions = make([]interface{}, len(es))
  330. for i, e := range es {
  331. extensions[i], err = GetExtension(pb, e)
  332. if err == ErrMissingExtension {
  333. err = nil
  334. }
  335. if err != nil {
  336. return
  337. }
  338. }
  339. return
  340. }
  341. // ExtensionDescs returns a new slice containing pb's extension descriptors, in undefined order.
  342. // For non-registered extensions, ExtensionDescs returns an incomplete descriptor containing
  343. // just the Field field, which defines the extension's field number.
  344. func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) {
  345. epb, err := extendable(pb)
  346. if err != nil {
  347. return nil, err
  348. }
  349. registeredExtensions := RegisteredExtensions(pb)
  350. if epb == nil {
  351. return nil, nil
  352. }
  353. extensions := make([]*ExtensionDesc, 0, epb.Len())
  354. epb.Range(func(extid protoreflect.FieldNumber, e Extension) bool {
  355. desc := protoimpl.X.ExtensionDescFromType(e.GetType())
  356. if desc == nil {
  357. desc = registeredExtensions[int32(extid)]
  358. if desc == nil {
  359. desc = &ExtensionDesc{Field: int32(extid)}
  360. }
  361. }
  362. extensions = append(extensions, desc)
  363. return true
  364. })
  365. unrecognized := unknownFieldsValue(reflect.ValueOf(pb).Elem())
  366. if b := unrecognized.Bytes(); len(b) > 0 {
  367. fieldNums := make(map[int32]bool)
  368. for len(b) > 0 {
  369. fnum, _, n := wire.ConsumeField(b)
  370. if isExtensionField(pb, int32(fnum)) {
  371. fieldNums[int32(fnum)] = true
  372. }
  373. b = b[n:]
  374. }
  375. for id := range fieldNums {
  376. desc := registeredExtensions[id]
  377. if desc == nil {
  378. desc = &ExtensionDesc{Field: id}
  379. }
  380. extensions = append(extensions, desc)
  381. }
  382. }
  383. return extensions, nil
  384. }
  385. // SetExtension sets the specified extension of pb to the specified value.
  386. func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error {
  387. epb, err := extendable(pb)
  388. if err != nil {
  389. return err
  390. }
  391. if err := checkExtensionTypeAndRanges(pb, extension); err != nil {
  392. return err
  393. }
  394. typ := reflect.TypeOf(extension.ExtensionType)
  395. if typ != reflect.TypeOf(value) {
  396. return fmt.Errorf("proto: bad extension value type. got: %T, want: %T", value, extension.ExtensionType)
  397. }
  398. // nil extension values need to be caught early, because the
  399. // encoder can't distinguish an ErrNil due to a nil extension
  400. // from an ErrNil due to a missing field. Extensions are
  401. // always optional, so the encoder would just swallow the error
  402. // and drop all the extensions from the encoded message.
  403. if reflect.ValueOf(value).IsNil() {
  404. return fmt.Errorf("proto: SetExtension called with nil value of type %T", value)
  405. }
  406. var x Extension
  407. x.SetType(extension)
  408. x.SetEagerValue(extensionAsStorageType(value))
  409. epb.Set(protoreflect.FieldNumber(extension.Field), x)
  410. return nil
  411. }
  412. // ClearAllExtensions clears all extensions from pb.
  413. func ClearAllExtensions(pb Message) {
  414. epb, err := extendable(pb)
  415. if err != nil {
  416. return
  417. }
  418. epb.Range(func(k protoreflect.FieldNumber, _ Extension) bool {
  419. epb.Clear(k)
  420. return true
  421. })
  422. }
  423. // extensionAsLegacyType converts an value in the storage type as the API type.
  424. // See Extension.Value.
  425. func extensionAsLegacyType(v interface{}) interface{} {
  426. switch rv := reflect.ValueOf(v); rv.Kind() {
  427. case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
  428. // Represent primitive types as a pointer to the value.
  429. rv2 := reflect.New(rv.Type())
  430. rv2.Elem().Set(rv)
  431. v = rv2.Interface()
  432. case reflect.Ptr:
  433. // Represent slice types as the value itself.
  434. switch rv.Type().Elem().Kind() {
  435. case reflect.Slice:
  436. if rv.IsNil() {
  437. v = reflect.Zero(rv.Type().Elem()).Interface()
  438. } else {
  439. v = rv.Elem().Interface()
  440. }
  441. }
  442. }
  443. return v
  444. }
  445. // extensionAsStorageType converts an value in the API type as the storage type.
  446. // See Extension.Value.
  447. func extensionAsStorageType(v interface{}) interface{} {
  448. switch rv := reflect.ValueOf(v); rv.Kind() {
  449. case reflect.Ptr:
  450. // Represent slice types as the value itself.
  451. switch rv.Type().Elem().Kind() {
  452. case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
  453. if rv.IsNil() {
  454. v = reflect.Zero(rv.Type().Elem()).Interface()
  455. } else {
  456. v = rv.Elem().Interface()
  457. }
  458. }
  459. case reflect.Slice:
  460. // Represent slice types as a pointer to the value.
  461. if rv.Type().Elem().Kind() != reflect.Uint8 {
  462. rv2 := reflect.New(rv.Type())
  463. rv2.Elem().Set(rv)
  464. v = rv2.Interface()
  465. }
  466. }
  467. return v
  468. }