|
|
@@ -15,68 +15,30 @@ import (
|
|
|
"reflect"
|
|
|
"strconv"
|
|
|
"sync"
|
|
|
+
|
|
|
+ "github.com/golang/protobuf/protoapi"
|
|
|
+ "github.com/golang/protobuf/v2/reflect/protoreflect"
|
|
|
)
|
|
|
|
|
|
// ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message.
|
|
|
var ErrMissingExtension = errors.New("proto: missing extension")
|
|
|
|
|
|
-// ExtensionRange represents a range of message extensions for a protocol buffer.
|
|
|
-// Used in code generated by the protocol compiler.
|
|
|
-type ExtensionRange struct {
|
|
|
- Start, End int32 // both inclusive
|
|
|
-}
|
|
|
-
|
|
|
-// extendableProto is an interface implemented by any protocol buffer generated by the current
|
|
|
-// proto compiler that may be extended.
|
|
|
-type extendableProto interface {
|
|
|
- Message
|
|
|
- ExtensionRangeArray() []ExtensionRange
|
|
|
- extensionsWrite() map[int32]Extension
|
|
|
- extensionsRead() (map[int32]Extension, sync.Locker)
|
|
|
-}
|
|
|
-
|
|
|
-// extendableProtoV1 is an interface implemented by a protocol buffer generated by the previous
|
|
|
-// version of the proto compiler that may be extended.
|
|
|
-type extendableProtoV1 interface {
|
|
|
- Message
|
|
|
- ExtensionRangeArray() []ExtensionRange
|
|
|
- ExtensionMap() map[int32]Extension
|
|
|
-}
|
|
|
-
|
|
|
-// extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto.
|
|
|
-type extensionAdapter struct {
|
|
|
- extendableProtoV1
|
|
|
-}
|
|
|
-
|
|
|
-func (e extensionAdapter) extensionsWrite() map[int32]Extension {
|
|
|
- return e.ExtensionMap()
|
|
|
-}
|
|
|
-
|
|
|
-func (e extensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) {
|
|
|
- return e.ExtensionMap(), notLocker{}
|
|
|
-}
|
|
|
-
|
|
|
-// notLocker is a sync.Locker whose Lock and Unlock methods are nops.
|
|
|
-type notLocker struct{}
|
|
|
-
|
|
|
-func (n notLocker) Lock() {}
|
|
|
-func (n notLocker) Unlock() {}
|
|
|
-
|
|
|
-// extendable returns the extendableProto interface for the given generated proto message.
|
|
|
-// If the proto message has the old extension format, it returns a wrapper that implements
|
|
|
-// the extendableProto interface.
|
|
|
-func extendable(p interface{}) (extendableProto, error) {
|
|
|
- switch p := p.(type) {
|
|
|
- case extendableProto:
|
|
|
- if isNilPtr(p) {
|
|
|
- return nil, fmt.Errorf("proto: nil %T is not extendable", p)
|
|
|
- }
|
|
|
- return p, nil
|
|
|
- case extendableProtoV1:
|
|
|
- if isNilPtr(p) {
|
|
|
- return nil, fmt.Errorf("proto: nil %T is not extendable", p)
|
|
|
+func extendable(p interface{}) (protoapi.ExtensionFields, error) {
|
|
|
+ type extendableProto interface {
|
|
|
+ Message
|
|
|
+ ExtensionRangeArray() []ExtensionRange
|
|
|
+ }
|
|
|
+ if _, ok := p.(extendableProto); ok {
|
|
|
+ v := reflect.ValueOf(p)
|
|
|
+ if v.Kind() == reflect.Ptr && !v.IsNil() {
|
|
|
+ v = v.Elem()
|
|
|
+ if v := v.FieldByName("XXX_InternalExtensions"); v.IsValid() {
|
|
|
+ return protoapi.ExtensionFieldsOf(v.Addr().Interface()), nil
|
|
|
+ }
|
|
|
+ if v := v.FieldByName("XXX_extensions"); v.IsValid() {
|
|
|
+ return protoapi.ExtensionFieldsOf(v.Addr().Interface()), nil
|
|
|
+ }
|
|
|
}
|
|
|
- return extensionAdapter{p}, nil
|
|
|
}
|
|
|
// Don't allocate a specific error containing %T:
|
|
|
// this is the hot path for Clone and MarshalText.
|
|
|
@@ -85,129 +47,47 @@ func extendable(p interface{}) (extendableProto, error) {
|
|
|
|
|
|
var errNotExtendable = errors.New("proto: not an extendable proto.Message")
|
|
|
|
|
|
-func isNilPtr(x interface{}) bool {
|
|
|
- v := reflect.ValueOf(x)
|
|
|
- return v.Kind() == reflect.Ptr && v.IsNil()
|
|
|
-}
|
|
|
-
|
|
|
-// XXX_InternalExtensions is an internal representation of proto extensions.
|
|
|
-//
|
|
|
-// Each generated message struct type embeds an anonymous XXX_InternalExtensions field,
|
|
|
-// thus gaining the unexported 'extensions' method, which can be called only from the proto package.
|
|
|
-//
|
|
|
-// The methods of XXX_InternalExtensions are not concurrency safe in general,
|
|
|
-// but calls to logically read-only methods such as has and get may be executed concurrently.
|
|
|
-type XXX_InternalExtensions struct {
|
|
|
- // The struct must be indirect so that if a user inadvertently copies a
|
|
|
- // generated message and its embedded XXX_InternalExtensions, they
|
|
|
- // avoid the mayhem of a copied mutex.
|
|
|
- //
|
|
|
- // The mutex serializes all logically read-only operations to p.extensionMap.
|
|
|
- // It is up to the client to ensure that write operations to p.extensionMap are
|
|
|
- // mutually exclusive with other accesses.
|
|
|
- p *struct {
|
|
|
- mu sync.Mutex
|
|
|
- extensionMap map[int32]Extension
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-// extensionsWrite returns the extension map, creating it on first use.
|
|
|
-func (e *XXX_InternalExtensions) extensionsWrite() map[int32]Extension {
|
|
|
- if e.p == nil {
|
|
|
- e.p = new(struct {
|
|
|
- mu sync.Mutex
|
|
|
- extensionMap map[int32]Extension
|
|
|
- })
|
|
|
- e.p.extensionMap = make(map[int32]Extension)
|
|
|
- }
|
|
|
- return e.p.extensionMap
|
|
|
-}
|
|
|
-
|
|
|
-// extensionsRead returns the extensions map for read-only use. It may be nil.
|
|
|
-// The caller must hold the returned mutex's lock when accessing Elements within the map.
|
|
|
-func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Locker) {
|
|
|
- if e.p == nil {
|
|
|
- return nil, nil
|
|
|
- }
|
|
|
- return e.p.extensionMap, &e.p.mu
|
|
|
-}
|
|
|
-
|
|
|
-// ExtensionDesc represents an extension specification.
|
|
|
-// Used in generated code from the protocol compiler.
|
|
|
-type ExtensionDesc struct {
|
|
|
- ExtendedType Message // nil pointer to the type that is being extended
|
|
|
- ExtensionType interface{} // nil pointer to the extension type
|
|
|
- Field int32 // field number
|
|
|
- Name string // fully-qualified name of extension, for text formatting
|
|
|
- Tag string // protobuf tag style
|
|
|
- Filename string // name of the file in which the extension is defined
|
|
|
-}
|
|
|
+type (
|
|
|
+ ExtensionRange = protoapi.ExtensionRange
|
|
|
+ ExtensionDesc = protoapi.ExtensionDesc
|
|
|
+ Extension = protoapi.ExtensionField
|
|
|
+ XXX_InternalExtensions = protoapi.XXX_InternalExtensions
|
|
|
+)
|
|
|
|
|
|
-func (ed *ExtensionDesc) repeated() bool {
|
|
|
+func isRepeatedExtension(ed *ExtensionDesc) bool {
|
|
|
t := reflect.TypeOf(ed.ExtensionType)
|
|
|
return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8
|
|
|
}
|
|
|
|
|
|
-// Extension represents an extension in a message.
|
|
|
-type Extension struct {
|
|
|
- // When an extension is stored in a message using SetExtension
|
|
|
- // only desc and value are set. When the message is marshaled
|
|
|
- // enc will be set to the encoded form of the message.
|
|
|
- //
|
|
|
- // When a message is unmarshaled and contains extensions, each
|
|
|
- // extension will have only enc set. When such an extension is
|
|
|
- // accessed using GetExtension (or GetExtensions) desc and value
|
|
|
- // will be set.
|
|
|
- desc *ExtensionDesc
|
|
|
-
|
|
|
- // value is a concrete value for the extension field. Let the type of
|
|
|
- // desc.ExtensionType be the "API type" and the type of Extension.value
|
|
|
- // be the "storage type". The API type and storage type are the same except:
|
|
|
- // * For scalars (except []byte), the API type uses *T,
|
|
|
- // while the storage type uses T.
|
|
|
- // * For repeated fields, the API type uses []T, while the storage type
|
|
|
- // uses *[]T.
|
|
|
- //
|
|
|
- // The reason for the divergence is so that the storage type more naturally
|
|
|
- // matches what is expected of when retrieving the values through the
|
|
|
- // protobuf reflection APIs.
|
|
|
- //
|
|
|
- // The value may only be populated if desc is also populated.
|
|
|
- value interface{}
|
|
|
-
|
|
|
- // enc is the raw bytes for the extension field.
|
|
|
- enc []byte
|
|
|
-}
|
|
|
-
|
|
|
// SetRawExtension is for testing only.
|
|
|
func SetRawExtension(base Message, id int32, b []byte) {
|
|
|
epb, err := extendable(base)
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
- extmap := epb.extensionsWrite()
|
|
|
- extmap[id] = Extension{enc: b}
|
|
|
+ epb.Set(protoreflect.FieldNumber(id), Extension{Raw: b})
|
|
|
}
|
|
|
|
|
|
// isExtensionField returns true iff the given field number is in an extension range.
|
|
|
-func isExtensionField(pb extendableProto, field int32) bool {
|
|
|
- for _, er := range pb.ExtensionRangeArray() {
|
|
|
- if er.Start <= field && field <= er.End {
|
|
|
- return true
|
|
|
+func isExtensionField(pb Message, field int32) bool {
|
|
|
+ m, ok := pb.(interface{ ExtensionRangeArray() []ExtensionRange })
|
|
|
+ if ok {
|
|
|
+ for _, er := range m.ExtensionRangeArray() {
|
|
|
+ if er.Start <= field && field <= er.End {
|
|
|
+ return true
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
-// checkExtensionTypes checks that the given extension is valid for pb.
|
|
|
-func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error {
|
|
|
- var pbi interface{} = pb
|
|
|
+// checkExtensionTypeAndRanges checks that the given extension is valid for pb.
|
|
|
+func checkExtensionTypeAndRanges(pb Message, extension *ExtensionDesc) error {
|
|
|
// Check the extended type.
|
|
|
- if ea, ok := pbi.(extensionAdapter); ok {
|
|
|
- pbi = ea.extendableProtoV1
|
|
|
- }
|
|
|
- if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b {
|
|
|
- return fmt.Errorf("proto: bad extended type; %v does not extend %v", b, a)
|
|
|
+ if extension.ExtendedType != nil {
|
|
|
+ if a, b := reflect.TypeOf(pb), reflect.TypeOf(extension.ExtendedType); a != b {
|
|
|
+ return fmt.Errorf("proto: bad extended type; %v does not extend %v", b, a)
|
|
|
+ }
|
|
|
}
|
|
|
// Check the range.
|
|
|
if !isExtensionField(pb, extension.Field) {
|
|
|
@@ -229,8 +109,8 @@ var extProp = struct {
|
|
|
m: make(map[extPropKey]*Properties),
|
|
|
}
|
|
|
|
|
|
-func extensionProperties(ed *ExtensionDesc) *Properties {
|
|
|
- key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field}
|
|
|
+func extensionProperties(pb Message, ed *ExtensionDesc) *Properties {
|
|
|
+ key := extPropKey{base: reflect.TypeOf(pb), field: ed.Field}
|
|
|
|
|
|
extProp.RLock()
|
|
|
if prop, ok := extProp.m[key]; ok {
|
|
|
@@ -259,14 +139,12 @@ func HasExtension(pb Message, extension *ExtensionDesc) bool {
|
|
|
if err != nil {
|
|
|
return false
|
|
|
}
|
|
|
- extmap, mu := epb.extensionsRead()
|
|
|
- if extmap == nil {
|
|
|
+ if !epb.HasInit() {
|
|
|
return false
|
|
|
}
|
|
|
- mu.Lock()
|
|
|
- _, ok := extmap[extension.Field]
|
|
|
- mu.Unlock()
|
|
|
- return ok
|
|
|
+ epb.Lock()
|
|
|
+ defer epb.Unlock()
|
|
|
+ return epb.Has(protoreflect.FieldNumber(extension.Field))
|
|
|
}
|
|
|
|
|
|
// ClearExtension removes the given extension from pb.
|
|
|
@@ -276,8 +154,7 @@ func ClearExtension(pb Message, extension *ExtensionDesc) {
|
|
|
return
|
|
|
}
|
|
|
// TODO: Check types, field numbers, etc.?
|
|
|
- extmap := epb.extensionsWrite()
|
|
|
- delete(extmap, extension.Field)
|
|
|
+ epb.Clear(protoreflect.FieldNumber(extension.Field))
|
|
|
}
|
|
|
|
|
|
// GetExtension retrieves a proto2 extended field from pb.
|
|
|
@@ -295,66 +172,63 @@ func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
- if extension.ExtendedType != nil {
|
|
|
- // can only check type if this is a complete descriptor
|
|
|
- if err := checkExtensionTypes(epb, extension); err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
+ // can only check type if this is a complete descriptor
|
|
|
+ if err := checkExtensionTypeAndRanges(pb, extension); err != nil {
|
|
|
+ return nil, err
|
|
|
}
|
|
|
|
|
|
- emap, mu := epb.extensionsRead()
|
|
|
- if emap == nil {
|
|
|
- return defaultExtensionValue(extension)
|
|
|
+ if !epb.HasInit() {
|
|
|
+ return defaultExtensionValue(pb, extension)
|
|
|
}
|
|
|
- mu.Lock()
|
|
|
- defer mu.Unlock()
|
|
|
- e, ok := emap[extension.Field]
|
|
|
- if !ok {
|
|
|
+ epb.Lock()
|
|
|
+ defer epb.Unlock()
|
|
|
+ if !epb.Has(protoreflect.FieldNumber(extension.Field)) {
|
|
|
// defaultExtensionValue returns the default value or
|
|
|
// ErrMissingExtension if there is no default.
|
|
|
- return defaultExtensionValue(extension)
|
|
|
+ return defaultExtensionValue(pb, extension)
|
|
|
}
|
|
|
+ e := epb.Get(protoreflect.FieldNumber(extension.Field))
|
|
|
|
|
|
- if e.value != nil {
|
|
|
+ if e.Value != nil {
|
|
|
// Already decoded. Check the descriptor, though.
|
|
|
- if e.desc != extension {
|
|
|
+ if e.Desc != extension {
|
|
|
// This shouldn't happen. If it does, it means that
|
|
|
// GetExtension was called twice with two different
|
|
|
// descriptors with the same field number.
|
|
|
return nil, errors.New("proto: descriptor conflict")
|
|
|
}
|
|
|
- return extensionAsLegacyType(e.value), nil
|
|
|
+ return extensionAsLegacyType(e.Value), nil
|
|
|
}
|
|
|
|
|
|
if extension.ExtensionType == nil {
|
|
|
// incomplete descriptor
|
|
|
- return e.enc, nil
|
|
|
+ return e.Raw, nil
|
|
|
}
|
|
|
|
|
|
- v, err := decodeExtension(e.enc, extension)
|
|
|
+ v, err := decodeExtension(e.Raw, extension)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
// Remember the decoded version and drop the encoded version.
|
|
|
// That way it is safe to mutate what we return.
|
|
|
- e.value = extensionAsStorageType(v)
|
|
|
- e.desc = extension
|
|
|
- e.enc = nil
|
|
|
- emap[extension.Field] = e
|
|
|
- return extensionAsLegacyType(e.value), nil
|
|
|
+ e.Value = extensionAsStorageType(v)
|
|
|
+ e.Desc = extension
|
|
|
+ e.Raw = nil
|
|
|
+ epb.Set(protoreflect.FieldNumber(extension.Field), e)
|
|
|
+ return extensionAsLegacyType(e.Value), nil
|
|
|
}
|
|
|
|
|
|
// defaultExtensionValue returns the default value for extension.
|
|
|
// If no default for an extension is defined ErrMissingExtension is returned.
|
|
|
-func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) {
|
|
|
+func defaultExtensionValue(pb Message, extension *ExtensionDesc) (interface{}, error) {
|
|
|
if extension.ExtensionType == nil {
|
|
|
// incomplete descriptor, so no default
|
|
|
return nil, ErrMissingExtension
|
|
|
}
|
|
|
|
|
|
t := reflect.TypeOf(extension.ExtensionType)
|
|
|
- props := extensionProperties(extension)
|
|
|
+ props := extensionProperties(pb, extension)
|
|
|
|
|
|
sf, _, err := fieldDefault(t, props)
|
|
|
if err != nil {
|
|
|
@@ -376,7 +250,7 @@ func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) {
|
|
|
value.Set(reflect.New(value.Type().Elem()))
|
|
|
if sf.kind == reflect.Int32 {
|
|
|
// We may have an int32 or an enum, but the underlying data is int32.
|
|
|
- // Since we can't set an int32 into a non int32 reflect.value directly
|
|
|
+ // Since we can't set an int32 into a non int32 reflect.Value directly
|
|
|
// set it as a int32.
|
|
|
value.Elem().SetInt(int64(sf.value.(int32)))
|
|
|
} else {
|
|
|
@@ -418,13 +292,13 @@ func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
|
|
|
// GetExtensions returns a slice of the extensions present in pb that are also listed in es.
|
|
|
// The returned slice has the same length as es; missing extensions will appear as nil elements.
|
|
|
func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
|
|
|
- epb, err := extendable(pb)
|
|
|
+ _, err = extendable(pb)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
extensions = make([]interface{}, len(es))
|
|
|
for i, e := range es {
|
|
|
- extensions[i], err = GetExtension(epb, e)
|
|
|
+ extensions[i], err = GetExtension(pb, e)
|
|
|
if err == ErrMissingExtension {
|
|
|
err = nil
|
|
|
}
|
|
|
@@ -445,24 +319,24 @@ func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) {
|
|
|
}
|
|
|
registeredExtensions := RegisteredExtensions(pb)
|
|
|
|
|
|
- emap, mu := epb.extensionsRead()
|
|
|
- if emap == nil {
|
|
|
+ if !epb.HasInit() {
|
|
|
return nil, nil
|
|
|
}
|
|
|
- mu.Lock()
|
|
|
- defer mu.Unlock()
|
|
|
- extensions := make([]*ExtensionDesc, 0, len(emap))
|
|
|
- for extid, e := range emap {
|
|
|
- desc := e.desc
|
|
|
+ epb.Lock()
|
|
|
+ defer epb.Unlock()
|
|
|
+ extensions := make([]*ExtensionDesc, 0, epb.Len())
|
|
|
+ epb.Range(func(extid protoreflect.FieldNumber, e Extension) bool {
|
|
|
+ desc := e.Desc
|
|
|
if desc == nil {
|
|
|
- desc = registeredExtensions[extid]
|
|
|
+ desc = registeredExtensions[int32(extid)]
|
|
|
if desc == nil {
|
|
|
- desc = &ExtensionDesc{Field: extid}
|
|
|
+ desc = &ExtensionDesc{Field: int32(extid)}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
extensions = append(extensions, desc)
|
|
|
- }
|
|
|
+ return true
|
|
|
+ })
|
|
|
return extensions, nil
|
|
|
}
|
|
|
|
|
|
@@ -472,7 +346,7 @@ func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- if err := checkExtensionTypes(epb, extension); err != nil {
|
|
|
+ if err := checkExtensionTypeAndRanges(pb, extension); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
typ := reflect.TypeOf(extension.ExtensionType)
|
|
|
@@ -488,8 +362,10 @@ func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error
|
|
|
return fmt.Errorf("proto: SetExtension called with nil value of type %T", value)
|
|
|
}
|
|
|
|
|
|
- extmap := epb.extensionsWrite()
|
|
|
- extmap[extension.Field] = Extension{desc: extension, value: extensionAsStorageType(value)}
|
|
|
+ epb.Set(protoreflect.FieldNumber(extension.Field), Extension{
|
|
|
+ Desc: extension,
|
|
|
+ Value: extensionAsStorageType(value),
|
|
|
+ })
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
@@ -499,10 +375,10 @@ func ClearAllExtensions(pb Message) {
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
- m := epb.extensionsWrite()
|
|
|
- for k := range m {
|
|
|
- delete(m, k)
|
|
|
- }
|
|
|
+ epb.Range(func(k protoreflect.FieldNumber, _ Extension) bool {
|
|
|
+ epb.Clear(k)
|
|
|
+ return true
|
|
|
+ })
|
|
|
}
|
|
|
|
|
|
// A global registry of extensions.
|
|
|
@@ -532,7 +408,7 @@ func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc {
|
|
|
}
|
|
|
|
|
|
// extensionAsLegacyType converts an value in the storage type as the API type.
|
|
|
-// See Extension.value.
|
|
|
+// See Extension.Value.
|
|
|
func extensionAsLegacyType(v interface{}) interface{} {
|
|
|
switch rv := reflect.ValueOf(v); rv.Kind() {
|
|
|
case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
|
|
|
@@ -555,7 +431,7 @@ func extensionAsLegacyType(v interface{}) interface{} {
|
|
|
}
|
|
|
|
|
|
// extensionAsStorageType converts an value in the API type as the storage type.
|
|
|
-// See Extension.value.
|
|
|
+// See Extension.Value.
|
|
|
func extensionAsStorageType(v interface{}) interface{} {
|
|
|
switch rv := reflect.ValueOf(v); rv.Kind() {
|
|
|
case reflect.Ptr:
|