extensions.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586
  1. // Go support for Protocol Buffers - Google's data interchange format
  2. //
  3. // Copyright 2010 The Go Authors. All rights reserved.
  4. // https://github.com/golang/protobuf
  5. //
  6. // Redistribution and use in source and binary forms, with or without
  7. // modification, are permitted provided that the following conditions are
  8. // met:
  9. //
  10. // * Redistributions of source code must retain the above copyright
  11. // notice, this list of conditions and the following disclaimer.
  12. // * Redistributions in binary form must reproduce the above
  13. // copyright notice, this list of conditions and the following disclaimer
  14. // in the documentation and/or other materials provided with the
  15. // distribution.
  16. // * Neither the name of Google Inc. nor the names of its
  17. // contributors may be used to endorse or promote products derived from
  18. // this software without specific prior written permission.
  19. //
  20. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  21. // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  22. // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  23. // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  24. // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  25. // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  26. // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  27. // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  28. // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  29. // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  30. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  31. package proto
  32. /*
  33. * Types and routines for supporting protocol buffer extensions.
  34. */
  35. import (
  36. "errors"
  37. "fmt"
  38. "reflect"
  39. "strconv"
  40. "sync"
  41. )
  42. // ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message.
  43. var ErrMissingExtension = errors.New("proto: missing extension")
  44. // ExtensionRange represents a range of message extensions for a protocol buffer.
  45. // Used in code generated by the protocol compiler.
  46. type ExtensionRange struct {
  47. Start, End int32 // both inclusive
  48. }
  49. // extendableProto is an interface implemented by any protocol buffer generated by the current
  50. // proto compiler that may be extended.
  51. type extendableProto interface {
  52. Message
  53. ExtensionRangeArray() []ExtensionRange
  54. extensionsWrite() map[int32]Extension
  55. extensionsRead() (map[int32]Extension, sync.Locker)
  56. }
  57. // extendableProtoV1 is an interface implemented by a protocol buffer generated by the previous
  58. // version of the proto compiler that may be extended.
  59. type extendableProtoV1 interface {
  60. Message
  61. ExtensionRangeArray() []ExtensionRange
  62. ExtensionMap() map[int32]Extension
  63. }
  64. // extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto.
  65. type extensionAdapter struct {
  66. extendableProtoV1
  67. }
  68. func (e extensionAdapter) extensionsWrite() map[int32]Extension {
  69. return e.ExtensionMap()
  70. }
  71. func (e extensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) {
  72. return e.ExtensionMap(), notLocker{}
  73. }
  74. // notLocker is a sync.Locker whose Lock and Unlock methods are nops.
  75. type notLocker struct{}
  76. func (n notLocker) Lock() {}
  77. func (n notLocker) Unlock() {}
  78. // extendable returns the extendableProto interface for the given generated proto message.
  79. // If the proto message has the old extension format, it returns a wrapper that implements
  80. // the extendableProto interface.
  81. func extendable(p interface{}) (extendableProto, bool) {
  82. if ep, ok := p.(extendableProto); ok {
  83. return ep, ok
  84. }
  85. if ep, ok := p.(extendableProtoV1); ok {
  86. return extensionAdapter{ep}, ok
  87. }
  88. return nil, false
  89. }
  90. // XXX_InternalExtensions is an internal representation of proto extensions.
  91. //
  92. // Each generated message struct type embeds an anonymous XXX_InternalExtensions field,
  93. // thus gaining the unexported 'extensions' method, which can be called only from the proto package.
  94. //
  95. // The methods of XXX_InternalExtensions are not concurrency safe in general,
  96. // but calls to logically read-only methods such as has and get may be executed concurrently.
  97. type XXX_InternalExtensions struct {
  98. // The struct must be indirect so that if a user inadvertently copies a
  99. // generated message and its embedded XXX_InternalExtensions, they
  100. // avoid the mayhem of a copied mutex.
  101. //
  102. // The mutex serializes all logically read-only operations to p.extensionMap.
  103. // It is up to the client to ensure that write operations to p.extensionMap are
  104. // mutually exclusive with other accesses.
  105. p *struct {
  106. mu sync.Mutex
  107. extensionMap map[int32]Extension
  108. }
  109. }
  110. // extensionsWrite returns the extension map, creating it on first use.
  111. func (e *XXX_InternalExtensions) extensionsWrite() map[int32]Extension {
  112. if e.p == nil {
  113. e.p = new(struct {
  114. mu sync.Mutex
  115. extensionMap map[int32]Extension
  116. })
  117. e.p.extensionMap = make(map[int32]Extension)
  118. }
  119. return e.p.extensionMap
  120. }
  121. // extensionsRead returns the extensions map for read-only use. It may be nil.
  122. // The caller must hold the returned mutex's lock when accessing Elements within the map.
  123. func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Locker) {
  124. if e.p == nil {
  125. return nil, nil
  126. }
  127. return e.p.extensionMap, &e.p.mu
  128. }
  129. var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem()
  130. var extendableProtoV1Type = reflect.TypeOf((*extendableProtoV1)(nil)).Elem()
  131. // ExtensionDesc represents an extension specification.
  132. // Used in generated code from the protocol compiler.
  133. type ExtensionDesc struct {
  134. ExtendedType Message // nil pointer to the type that is being extended
  135. ExtensionType interface{} // nil pointer to the extension type
  136. Field int32 // field number
  137. Name string // fully-qualified name of extension, for text formatting
  138. Tag string // protobuf tag style
  139. }
  140. func (ed *ExtensionDesc) repeated() bool {
  141. t := reflect.TypeOf(ed.ExtensionType)
  142. return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8
  143. }
  144. // Extension represents an extension in a message.
  145. type Extension struct {
  146. // When an extension is stored in a message using SetExtension
  147. // only desc and value are set. When the message is marshaled
  148. // enc will be set to the encoded form of the message.
  149. //
  150. // When a message is unmarshaled and contains extensions, each
  151. // extension will have only enc set. When such an extension is
  152. // accessed using GetExtension (or GetExtensions) desc and value
  153. // will be set.
  154. desc *ExtensionDesc
  155. value interface{}
  156. enc []byte
  157. }
  158. // SetRawExtension is for testing only.
  159. func SetRawExtension(base Message, id int32, b []byte) {
  160. epb, ok := extendable(base)
  161. if !ok {
  162. return
  163. }
  164. extmap := epb.extensionsWrite()
  165. extmap[id] = Extension{enc: b}
  166. }
  167. // isExtensionField returns true iff the given field number is in an extension range.
  168. func isExtensionField(pb extendableProto, field int32) bool {
  169. for _, er := range pb.ExtensionRangeArray() {
  170. if er.Start <= field && field <= er.End {
  171. return true
  172. }
  173. }
  174. return false
  175. }
  176. // checkExtensionTypes checks that the given extension is valid for pb.
  177. func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error {
  178. var pbi interface{} = pb
  179. // Check the extended type.
  180. if ea, ok := pbi.(extensionAdapter); ok {
  181. pbi = ea.extendableProtoV1
  182. }
  183. if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b {
  184. return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String())
  185. }
  186. // Check the range.
  187. if !isExtensionField(pb, extension.Field) {
  188. return errors.New("proto: bad extension number; not in declared ranges")
  189. }
  190. return nil
  191. }
  192. // extPropKey is sufficient to uniquely identify an extension.
  193. type extPropKey struct {
  194. base reflect.Type
  195. field int32
  196. }
  197. var extProp = struct {
  198. sync.RWMutex
  199. m map[extPropKey]*Properties
  200. }{
  201. m: make(map[extPropKey]*Properties),
  202. }
  203. func extensionProperties(ed *ExtensionDesc) *Properties {
  204. key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field}
  205. extProp.RLock()
  206. if prop, ok := extProp.m[key]; ok {
  207. extProp.RUnlock()
  208. return prop
  209. }
  210. extProp.RUnlock()
  211. extProp.Lock()
  212. defer extProp.Unlock()
  213. // Check again.
  214. if prop, ok := extProp.m[key]; ok {
  215. return prop
  216. }
  217. prop := new(Properties)
  218. prop.Init(reflect.TypeOf(ed.ExtensionType), "unknown_name", ed.Tag, nil)
  219. extProp.m[key] = prop
  220. return prop
  221. }
  222. // encode encodes any unmarshaled (unencoded) extensions in e.
  223. func encodeExtensions(e *XXX_InternalExtensions) error {
  224. m, mu := e.extensionsRead()
  225. if m == nil {
  226. return nil // fast path
  227. }
  228. mu.Lock()
  229. defer mu.Unlock()
  230. return encodeExtensionsMap(m)
  231. }
  232. // encode encodes any unmarshaled (unencoded) extensions in e.
  233. func encodeExtensionsMap(m map[int32]Extension) error {
  234. for k, e := range m {
  235. if e.value == nil || e.desc == nil {
  236. // Extension is only in its encoded form.
  237. continue
  238. }
  239. // We don't skip extensions that have an encoded form set,
  240. // because the extension value may have been mutated after
  241. // the last time this function was called.
  242. et := reflect.TypeOf(e.desc.ExtensionType)
  243. props := extensionProperties(e.desc)
  244. p := NewBuffer(nil)
  245. // If e.value has type T, the encoder expects a *struct{ X T }.
  246. // Pass a *T with a zero field and hope it all works out.
  247. x := reflect.New(et)
  248. x.Elem().Set(reflect.ValueOf(e.value))
  249. if err := props.enc(p, props, toStructPointer(x)); err != nil {
  250. return err
  251. }
  252. e.enc = p.buf
  253. m[k] = e
  254. }
  255. return nil
  256. }
  257. func extensionsSize(e *XXX_InternalExtensions) (n int) {
  258. m, mu := e.extensionsRead()
  259. if m == nil {
  260. return 0
  261. }
  262. mu.Lock()
  263. defer mu.Unlock()
  264. return extensionsMapSize(m)
  265. }
  266. func extensionsMapSize(m map[int32]Extension) (n int) {
  267. for _, e := range m {
  268. if e.value == nil || e.desc == nil {
  269. // Extension is only in its encoded form.
  270. n += len(e.enc)
  271. continue
  272. }
  273. // We don't skip extensions that have an encoded form set,
  274. // because the extension value may have been mutated after
  275. // the last time this function was called.
  276. et := reflect.TypeOf(e.desc.ExtensionType)
  277. props := extensionProperties(e.desc)
  278. // If e.value has type T, the encoder expects a *struct{ X T }.
  279. // Pass a *T with a zero field and hope it all works out.
  280. x := reflect.New(et)
  281. x.Elem().Set(reflect.ValueOf(e.value))
  282. n += props.size(props, toStructPointer(x))
  283. }
  284. return
  285. }
  286. // HasExtension returns whether the given extension is present in pb.
  287. func HasExtension(pb Message, extension *ExtensionDesc) bool {
  288. // TODO: Check types, field numbers, etc.?
  289. epb, ok := extendable(pb)
  290. if !ok {
  291. return false
  292. }
  293. extmap, mu := epb.extensionsRead()
  294. if extmap == nil {
  295. return false
  296. }
  297. mu.Lock()
  298. _, ok = extmap[extension.Field]
  299. mu.Unlock()
  300. return ok
  301. }
  302. // ClearExtension removes the given extension from pb.
  303. func ClearExtension(pb Message, extension *ExtensionDesc) {
  304. epb, ok := extendable(pb)
  305. if !ok {
  306. return
  307. }
  308. // TODO: Check types, field numbers, etc.?
  309. extmap := epb.extensionsWrite()
  310. delete(extmap, extension.Field)
  311. }
  312. // GetExtension parses and returns the given extension of pb.
  313. // If the extension is not present and has no default value it returns ErrMissingExtension.
  314. func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
  315. epb, ok := extendable(pb)
  316. if !ok {
  317. return nil, errors.New("proto: not an extendable proto")
  318. }
  319. if err := checkExtensionTypes(epb, extension); err != nil {
  320. return nil, err
  321. }
  322. emap, mu := epb.extensionsRead()
  323. if emap == nil {
  324. return defaultExtensionValue(extension)
  325. }
  326. mu.Lock()
  327. defer mu.Unlock()
  328. e, ok := emap[extension.Field]
  329. if !ok {
  330. // defaultExtensionValue returns the default value or
  331. // ErrMissingExtension if there is no default.
  332. return defaultExtensionValue(extension)
  333. }
  334. if e.value != nil {
  335. // Already decoded. Check the descriptor, though.
  336. if e.desc != extension {
  337. // This shouldn't happen. If it does, it means that
  338. // GetExtension was called twice with two different
  339. // descriptors with the same field number.
  340. return nil, errors.New("proto: descriptor conflict")
  341. }
  342. return e.value, nil
  343. }
  344. v, err := decodeExtension(e.enc, extension)
  345. if err != nil {
  346. return nil, err
  347. }
  348. // Remember the decoded version and drop the encoded version.
  349. // That way it is safe to mutate what we return.
  350. e.value = v
  351. e.desc = extension
  352. e.enc = nil
  353. emap[extension.Field] = e
  354. return e.value, nil
  355. }
  356. // defaultExtensionValue returns the default value for extension.
  357. // If no default for an extension is defined ErrMissingExtension is returned.
  358. func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) {
  359. t := reflect.TypeOf(extension.ExtensionType)
  360. props := extensionProperties(extension)
  361. sf, _, err := fieldDefault(t, props)
  362. if err != nil {
  363. return nil, err
  364. }
  365. if sf == nil || sf.value == nil {
  366. // There is no default value.
  367. return nil, ErrMissingExtension
  368. }
  369. if t.Kind() != reflect.Ptr {
  370. // We do not need to return a Ptr, we can directly return sf.value.
  371. return sf.value, nil
  372. }
  373. // We need to return an interface{} that is a pointer to sf.value.
  374. value := reflect.New(t).Elem()
  375. value.Set(reflect.New(value.Type().Elem()))
  376. if sf.kind == reflect.Int32 {
  377. // We may have an int32 or an enum, but the underlying data is int32.
  378. // Since we can't set an int32 into a non int32 reflect.value directly
  379. // set it as a int32.
  380. value.Elem().SetInt(int64(sf.value.(int32)))
  381. } else {
  382. value.Elem().Set(reflect.ValueOf(sf.value))
  383. }
  384. return value.Interface(), nil
  385. }
  386. // decodeExtension decodes an extension encoded in b.
  387. func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
  388. o := NewBuffer(b)
  389. t := reflect.TypeOf(extension.ExtensionType)
  390. props := extensionProperties(extension)
  391. // t is a pointer to a struct, pointer to basic type or a slice.
  392. // Allocate a "field" to store the pointer/slice itself; the
  393. // pointer/slice will be stored here. We pass
  394. // the address of this field to props.dec.
  395. // This passes a zero field and a *t and lets props.dec
  396. // interpret it as a *struct{ x t }.
  397. value := reflect.New(t).Elem()
  398. for {
  399. // Discard wire type and field number varint. It isn't needed.
  400. if _, err := o.DecodeVarint(); err != nil {
  401. return nil, err
  402. }
  403. if err := props.dec(o, props, toStructPointer(value.Addr())); err != nil {
  404. return nil, err
  405. }
  406. if o.index >= len(o.buf) {
  407. break
  408. }
  409. }
  410. return value.Interface(), nil
  411. }
  412. // GetExtensions returns a slice of the extensions present in pb that are also listed in es.
  413. // The returned slice has the same length as es; missing extensions will appear as nil elements.
  414. func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
  415. epb, ok := extendable(pb)
  416. if !ok {
  417. return nil, errors.New("proto: not an extendable proto")
  418. }
  419. extensions = make([]interface{}, len(es))
  420. for i, e := range es {
  421. extensions[i], err = GetExtension(epb, e)
  422. if err == ErrMissingExtension {
  423. err = nil
  424. }
  425. if err != nil {
  426. return
  427. }
  428. }
  429. return
  430. }
  431. // ExtensionDescs returns a new slice containing pb's extension descriptors, in undefined order.
  432. // For non-registered extensions, ExtensionDescs returns an incomplete descriptor containing
  433. // just the Field field, which defines the extension's field number.
  434. func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) {
  435. epb, ok := extendable(pb)
  436. if !ok {
  437. return nil, fmt.Errorf("proto: %T is not an extendable proto.Message", pb)
  438. }
  439. registeredExtensions := RegisteredExtensions(pb)
  440. emap, mu := epb.extensionsRead()
  441. if emap == nil {
  442. return nil, nil
  443. }
  444. mu.Lock()
  445. defer mu.Unlock()
  446. extensions := make([]*ExtensionDesc, 0, len(emap))
  447. for extid, e := range emap {
  448. desc := e.desc
  449. if desc == nil {
  450. desc = registeredExtensions[extid]
  451. if desc == nil {
  452. desc = &ExtensionDesc{Field: extid}
  453. }
  454. }
  455. extensions = append(extensions, desc)
  456. }
  457. return extensions, nil
  458. }
  459. // SetExtension sets the specified extension of pb to the specified value.
  460. func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error {
  461. epb, ok := extendable(pb)
  462. if !ok {
  463. return errors.New("proto: not an extendable proto")
  464. }
  465. if err := checkExtensionTypes(epb, extension); err != nil {
  466. return err
  467. }
  468. typ := reflect.TypeOf(extension.ExtensionType)
  469. if typ != reflect.TypeOf(value) {
  470. return errors.New("proto: bad extension value type")
  471. }
  472. // nil extension values need to be caught early, because the
  473. // encoder can't distinguish an ErrNil due to a nil extension
  474. // from an ErrNil due to a missing field. Extensions are
  475. // always optional, so the encoder would just swallow the error
  476. // and drop all the extensions from the encoded message.
  477. if reflect.ValueOf(value).IsNil() {
  478. return fmt.Errorf("proto: SetExtension called with nil value of type %T", value)
  479. }
  480. extmap := epb.extensionsWrite()
  481. extmap[extension.Field] = Extension{desc: extension, value: value}
  482. return nil
  483. }
  484. // ClearAllExtensions clears all extensions from pb.
  485. func ClearAllExtensions(pb Message) {
  486. epb, ok := extendable(pb)
  487. if !ok {
  488. return
  489. }
  490. m := epb.extensionsWrite()
  491. for k := range m {
  492. delete(m, k)
  493. }
  494. }
  495. // A global registry of extensions.
  496. // The generated code will register the generated descriptors by calling RegisterExtension.
  497. var extensionMaps = make(map[reflect.Type]map[int32]*ExtensionDesc)
  498. // RegisterExtension is called from the generated code.
  499. func RegisterExtension(desc *ExtensionDesc) {
  500. st := reflect.TypeOf(desc.ExtendedType).Elem()
  501. m := extensionMaps[st]
  502. if m == nil {
  503. m = make(map[int32]*ExtensionDesc)
  504. extensionMaps[st] = m
  505. }
  506. if _, ok := m[desc.Field]; ok {
  507. panic("proto: duplicate extension registered: " + st.String() + " " + strconv.Itoa(int(desc.Field)))
  508. }
  509. m[desc.Field] = desc
  510. }
  511. // RegisteredExtensions returns a map of the registered extensions of a
  512. // protocol buffer struct, indexed by the extension number.
  513. // The argument pb should be a nil pointer to the struct type.
  514. func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc {
  515. return extensionMaps[reflect.TypeOf(pb).Elem()]
  516. }