Переглянути джерело

protoapi: new package for v1 types that generated messages depend on

Create a new package protoapi that contains the minimum set of types depended
upon by generated messages so that a generated message implemented the v2
APIs do not need to import the v1 proto runtime.

Some types like proto.Buffer are not included here since those dependencies
are going away with PR#760.

Explicitly document that no one should import this package directly,
so we have the flexibility to remove this package if necessary.

Change-Id: Iddd0c697c9170b809a587bee626347e4ffdddbc7
Reviewed-on: https://go-review.googlesource.com/c/151347
Reviewed-by: Damien Neil <dneil@google.com>
Joe Tsai 7 роки тому
батько
коміт
87f1426e53

+ 20 - 18
proto/clone.go

@@ -12,6 +12,9 @@ import (
 	"log"
 	"reflect"
 	"strings"
+
+	"github.com/golang/protobuf/protoapi"
+	"github.com/golang/protobuf/v2/reflect/protoreflect"
 )
 
 // Clone returns a deep copy of a protocol buffer.
@@ -83,12 +86,10 @@ func mergeStruct(out, in reflect.Value) {
 
 	if emIn, err := extendable(in.Addr().Interface()); err == nil {
 		emOut, _ := extendable(out.Addr().Interface())
-		mIn, muIn := emIn.extensionsRead()
-		if mIn != nil {
-			mOut := emOut.extensionsWrite()
-			muIn.Lock()
-			mergeExtension(mOut, mIn)
-			muIn.Unlock()
+		if emIn.HasInit() {
+			emIn.Lock()
+			mergeExtension(emOut, emIn)
+			emIn.Unlock()
 		}
 	}
 
@@ -208,19 +209,20 @@ func mergeAny(out, in reflect.Value, viaPtr bool, prop *Properties) {
 	}
 }
 
-func mergeExtension(out, in map[int32]Extension) {
-	for extNum, eIn := range in {
-		eOut := Extension{desc: eIn.desc}
-		if eIn.value != nil {
-			v := reflect.New(reflect.TypeOf(eIn.value)).Elem()
-			mergeAny(v, reflect.ValueOf(eIn.value), false, nil)
-			eOut.value = v.Interface()
+func mergeExtension(out, in protoapi.ExtensionFields) {
+	in.Range(func(extNum protoreflect.FieldNumber, eIn Extension) bool {
+		eOut := Extension{Desc: eIn.Desc}
+		if eIn.Value != nil {
+			v := reflect.New(reflect.TypeOf(eIn.Value)).Elem()
+			mergeAny(v, reflect.ValueOf(eIn.Value), false, nil)
+			eOut.Value = v.Interface()
 		}
-		if eIn.enc != nil {
-			eOut.enc = make([]byte, len(eIn.enc))
-			copy(eOut.enc, eIn.enc)
+		if eIn.Raw != nil {
+			eOut.Raw = make([]byte, len(eIn.Raw))
+			copy(eOut.Raw, eIn.Raw)
 		}
 
-		out[extNum] = eOut
-	}
+		out.Set(extNum, eOut)
+		return true
+	})
 }

+ 10 - 10
proto/discard.go

@@ -10,6 +10,8 @@ import (
 	"strings"
 	"sync"
 	"sync/atomic"
+
+	"github.com/golang/protobuf/v2/reflect/protoreflect"
 )
 
 type generatedDiscarder interface {
@@ -96,13 +98,12 @@ func (di *discardInfo) discard(src pointer) {
 	// For proto2 messages, only discard unknown fields in message extensions
 	// that have been accessed via GetExtension.
 	if em, err := extendable(src.asPointerTo(di.typ).Interface()); err == nil {
-		// Ignore lock since DiscardUnknown is not concurrency safe.
-		emm, _ := em.extensionsRead()
-		for _, mx := range emm {
-			if m, ok := mx.value.(Message); ok {
+		em.Range(func(_ protoreflect.FieldNumber, mx Extension) bool {
+			if m, ok := mx.Value.(Message); ok {
 				DiscardUnknown(m)
 			}
-		}
+			return true
+		})
 	}
 
 	if di.unrecognized.IsValid() {
@@ -312,12 +313,11 @@ func discardLegacy(m Message) {
 	// For proto2 messages, only discard unknown fields in message extensions
 	// that have been accessed via GetExtension.
 	if em, err := extendable(m); err == nil {
-		// Ignore lock since discardLegacy is not concurrency safe.
-		emm, _ := em.extensionsRead()
-		for _, mx := range emm {
-			if m, ok := mx.value.(Message); ok {
+		em.Range(func(_ protoreflect.FieldNumber, mx Extension) bool {
+			if m, ok := mx.Value.(Message); ok {
 				discardLegacy(m)
 			}
-		}
+			return true
+		})
 	}
 }

+ 31 - 25
proto/equal.go

@@ -11,6 +11,9 @@ import (
 	"log"
 	"reflect"
 	"strings"
+
+	"github.com/golang/protobuf/protoapi"
+	"github.com/golang/protobuf/v2/reflect/protoreflect"
 )
 
 /*
@@ -91,14 +94,18 @@ func equalStruct(v1, v2 reflect.Value) bool {
 
 	if em1 := v1.FieldByName("XXX_InternalExtensions"); em1.IsValid() {
 		em2 := v2.FieldByName("XXX_InternalExtensions")
-		if !equalExtensions(v1.Type(), em1.Interface().(XXX_InternalExtensions), em2.Interface().(XXX_InternalExtensions)) {
+		m1 := protoapi.ExtensionFieldsOf(em1.Addr().Interface())
+		m2 := protoapi.ExtensionFieldsOf(em2.Addr().Interface())
+		if !equalExtensions(v1.Type(), m1, m2) {
 			return false
 		}
 	}
 
 	if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() {
 		em2 := v2.FieldByName("XXX_extensions")
-		if !equalExtMap(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) {
+		m1 := protoapi.ExtensionFieldsOf(em1.Addr().Interface())
+		m2 := protoapi.ExtensionFieldsOf(em2.Addr().Interface())
+		if !equalExtensions(v1.Type(), m1, m2) {
 			return false
 		}
 	}
@@ -200,32 +207,26 @@ func equalAny(v1, v2 reflect.Value, prop *Properties) bool {
 	return false
 }
 
-// base is the struct type that the extensions are based on.
-// x1 and x2 are InternalExtensions.
-func equalExtensions(base reflect.Type, x1, x2 XXX_InternalExtensions) bool {
-	em1, _ := x1.extensionsRead()
-	em2, _ := x2.extensionsRead()
-	return equalExtMap(base, em1, em2)
-}
-
-func equalExtMap(base reflect.Type, em1, em2 map[int32]Extension) bool {
-	if len(em1) != len(em2) {
+func equalExtensions(base reflect.Type, em1, em2 protoapi.ExtensionFields) bool {
+	if em1.Len() != em2.Len() {
 		return false
 	}
 
-	for extNum, e1 := range em1 {
-		e2, ok := em2[extNum]
-		if !ok {
+	equal := true
+	em1.Range(func(extNum protoreflect.FieldNumber, e1 Extension) bool {
+		if !em2.Has(extNum) {
+			equal = false
 			return false
 		}
+		e2 := em2.Get(extNum)
 
-		m1 := extensionAsLegacyType(e1.value)
-		m2 := extensionAsLegacyType(e2.value)
+		m1 := extensionAsLegacyType(e1.Value)
+		m2 := extensionAsLegacyType(e2.Value)
 
 		if m1 == nil && m2 == nil {
 			// Both have only encoded form.
-			if bytes.Equal(e1.enc, e2.enc) {
-				continue
+			if bytes.Equal(e1.Raw, e2.Raw) {
+				return true
 			}
 			// The bytes are different, but the extensions might still be
 			// equal. We need to decode them to compare.
@@ -234,16 +235,17 @@ func equalExtMap(base reflect.Type, em1, em2 map[int32]Extension) bool {
 		if m1 != nil && m2 != nil {
 			// Both are unencoded.
 			if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2), nil) {
+				equal = false
 				return false
 			}
-			continue
+			return true
 		}
 
 		// At least one is encoded. To do a semantically correct comparison
 		// we need to unmarshal them first.
 		var desc *ExtensionDesc
 		if m := extensionMaps[base]; m != nil {
-			desc = m[extNum]
+			desc = m[int32(extNum)]
 		}
 		if desc == nil {
 			// If both have only encoded form and the bytes are the same,
@@ -251,24 +253,28 @@ func equalExtMap(base reflect.Type, em1, em2 map[int32]Extension) bool {
 			// We don't know how to decode it, so just compare them as byte
 			// slices.
 			log.Printf("proto: don't know how to compare extension %d of %v", extNum, base)
+			equal = false
 			return false
 		}
 		var err error
 		if m1 == nil {
-			m1, err = decodeExtension(e1.enc, desc)
+			m1, err = decodeExtension(e1.Raw, desc)
 		}
 		if m2 == nil && err == nil {
-			m2, err = decodeExtension(e2.enc, desc)
+			m2, err = decodeExtension(e2.Raw, desc)
 		}
 		if err != nil {
 			// The encoded form is invalid.
 			log.Printf("proto: badly encoded extension %d of %v: %v", extNum, base, err)
+			equal = false
 			return false
 		}
 		if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2), nil) {
+			equal = false
 			return false
 		}
-	}
+		return true
+	})
 
-	return true
+	return equal
 }

+ 92 - 216
proto/extensions.go

@@ -15,68 +15,30 @@ import (
 	"reflect"
 	"strconv"
 	"sync"
+
+	"github.com/golang/protobuf/protoapi"
+	"github.com/golang/protobuf/v2/reflect/protoreflect"
 )
 
 // ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message.
 var ErrMissingExtension = errors.New("proto: missing extension")
 
-// ExtensionRange represents a range of message extensions for a protocol buffer.
-// Used in code generated by the protocol compiler.
-type ExtensionRange struct {
-	Start, End int32 // both inclusive
-}
-
-// extendableProto is an interface implemented by any protocol buffer generated by the current
-// proto compiler that may be extended.
-type extendableProto interface {
-	Message
-	ExtensionRangeArray() []ExtensionRange
-	extensionsWrite() map[int32]Extension
-	extensionsRead() (map[int32]Extension, sync.Locker)
-}
-
-// extendableProtoV1 is an interface implemented by a protocol buffer generated by the previous
-// version of the proto compiler that may be extended.
-type extendableProtoV1 interface {
-	Message
-	ExtensionRangeArray() []ExtensionRange
-	ExtensionMap() map[int32]Extension
-}
-
-// extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto.
-type extensionAdapter struct {
-	extendableProtoV1
-}
-
-func (e extensionAdapter) extensionsWrite() map[int32]Extension {
-	return e.ExtensionMap()
-}
-
-func (e extensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) {
-	return e.ExtensionMap(), notLocker{}
-}
-
-// notLocker is a sync.Locker whose Lock and Unlock methods are nops.
-type notLocker struct{}
-
-func (n notLocker) Lock()   {}
-func (n notLocker) Unlock() {}
-
-// extendable returns the extendableProto interface for the given generated proto message.
-// If the proto message has the old extension format, it returns a wrapper that implements
-// the extendableProto interface.
-func extendable(p interface{}) (extendableProto, error) {
-	switch p := p.(type) {
-	case extendableProto:
-		if isNilPtr(p) {
-			return nil, fmt.Errorf("proto: nil %T is not extendable", p)
-		}
-		return p, nil
-	case extendableProtoV1:
-		if isNilPtr(p) {
-			return nil, fmt.Errorf("proto: nil %T is not extendable", p)
+func extendable(p interface{}) (protoapi.ExtensionFields, error) {
+	type extendableProto interface {
+		Message
+		ExtensionRangeArray() []ExtensionRange
+	}
+	if _, ok := p.(extendableProto); ok {
+		v := reflect.ValueOf(p)
+		if v.Kind() == reflect.Ptr && !v.IsNil() {
+			v = v.Elem()
+			if v := v.FieldByName("XXX_InternalExtensions"); v.IsValid() {
+				return protoapi.ExtensionFieldsOf(v.Addr().Interface()), nil
+			}
+			if v := v.FieldByName("XXX_extensions"); v.IsValid() {
+				return protoapi.ExtensionFieldsOf(v.Addr().Interface()), nil
+			}
 		}
-		return extensionAdapter{p}, nil
 	}
 	// Don't allocate a specific error containing %T:
 	// this is the hot path for Clone and MarshalText.
@@ -85,129 +47,47 @@ func extendable(p interface{}) (extendableProto, error) {
 
 var errNotExtendable = errors.New("proto: not an extendable proto.Message")
 
-func isNilPtr(x interface{}) bool {
-	v := reflect.ValueOf(x)
-	return v.Kind() == reflect.Ptr && v.IsNil()
-}
-
-// XXX_InternalExtensions is an internal representation of proto extensions.
-//
-// Each generated message struct type embeds an anonymous XXX_InternalExtensions field,
-// thus gaining the unexported 'extensions' method, which can be called only from the proto package.
-//
-// The methods of XXX_InternalExtensions are not concurrency safe in general,
-// but calls to logically read-only methods such as has and get may be executed concurrently.
-type XXX_InternalExtensions struct {
-	// The struct must be indirect so that if a user inadvertently copies a
-	// generated message and its embedded XXX_InternalExtensions, they
-	// avoid the mayhem of a copied mutex.
-	//
-	// The mutex serializes all logically read-only operations to p.extensionMap.
-	// It is up to the client to ensure that write operations to p.extensionMap are
-	// mutually exclusive with other accesses.
-	p *struct {
-		mu           sync.Mutex
-		extensionMap map[int32]Extension
-	}
-}
-
-// extensionsWrite returns the extension map, creating it on first use.
-func (e *XXX_InternalExtensions) extensionsWrite() map[int32]Extension {
-	if e.p == nil {
-		e.p = new(struct {
-			mu           sync.Mutex
-			extensionMap map[int32]Extension
-		})
-		e.p.extensionMap = make(map[int32]Extension)
-	}
-	return e.p.extensionMap
-}
-
-// extensionsRead returns the extensions map for read-only use.  It may be nil.
-// The caller must hold the returned mutex's lock when accessing Elements within the map.
-func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Locker) {
-	if e.p == nil {
-		return nil, nil
-	}
-	return e.p.extensionMap, &e.p.mu
-}
-
-// ExtensionDesc represents an extension specification.
-// Used in generated code from the protocol compiler.
-type ExtensionDesc struct {
-	ExtendedType  Message     // nil pointer to the type that is being extended
-	ExtensionType interface{} // nil pointer to the extension type
-	Field         int32       // field number
-	Name          string      // fully-qualified name of extension, for text formatting
-	Tag           string      // protobuf tag style
-	Filename      string      // name of the file in which the extension is defined
-}
+type (
+	ExtensionRange         = protoapi.ExtensionRange
+	ExtensionDesc          = protoapi.ExtensionDesc
+	Extension              = protoapi.ExtensionField
+	XXX_InternalExtensions = protoapi.XXX_InternalExtensions
+)
 
-func (ed *ExtensionDesc) repeated() bool {
+func isRepeatedExtension(ed *ExtensionDesc) bool {
 	t := reflect.TypeOf(ed.ExtensionType)
 	return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8
 }
 
-// Extension represents an extension in a message.
-type Extension struct {
-	// When an extension is stored in a message using SetExtension
-	// only desc and value are set. When the message is marshaled
-	// enc will be set to the encoded form of the message.
-	//
-	// When a message is unmarshaled and contains extensions, each
-	// extension will have only enc set. When such an extension is
-	// accessed using GetExtension (or GetExtensions) desc and value
-	// will be set.
-	desc *ExtensionDesc
-
-	// value is a concrete value for the extension field. Let the type of
-	// desc.ExtensionType be the "API type" and the type of Extension.value
-	// be the "storage type". The API type and storage type are the same except:
-	//	* For scalars (except []byte), the API type uses *T,
-	//	while the storage type uses T.
-	//	* For repeated fields, the API type uses []T, while the storage type
-	//	uses *[]T.
-	//
-	// The reason for the divergence is so that the storage type more naturally
-	// matches what is expected of when retrieving the values through the
-	// protobuf reflection APIs.
-	//
-	// The value may only be populated if desc is also populated.
-	value interface{}
-
-	// enc is the raw bytes for the extension field.
-	enc []byte
-}
-
 // SetRawExtension is for testing only.
 func SetRawExtension(base Message, id int32, b []byte) {
 	epb, err := extendable(base)
 	if err != nil {
 		return
 	}
-	extmap := epb.extensionsWrite()
-	extmap[id] = Extension{enc: b}
+	epb.Set(protoreflect.FieldNumber(id), Extension{Raw: b})
 }
 
 // isExtensionField returns true iff the given field number is in an extension range.
-func isExtensionField(pb extendableProto, field int32) bool {
-	for _, er := range pb.ExtensionRangeArray() {
-		if er.Start <= field && field <= er.End {
-			return true
+func isExtensionField(pb Message, field int32) bool {
+	m, ok := pb.(interface{ ExtensionRangeArray() []ExtensionRange })
+	if ok {
+		for _, er := range m.ExtensionRangeArray() {
+			if er.Start <= field && field <= er.End {
+				return true
+			}
 		}
 	}
 	return false
 }
 
-// checkExtensionTypes checks that the given extension is valid for pb.
-func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error {
-	var pbi interface{} = pb
+// checkExtensionTypeAndRanges checks that the given extension is valid for pb.
+func checkExtensionTypeAndRanges(pb Message, extension *ExtensionDesc) error {
 	// Check the extended type.
-	if ea, ok := pbi.(extensionAdapter); ok {
-		pbi = ea.extendableProtoV1
-	}
-	if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b {
-		return fmt.Errorf("proto: bad extended type; %v does not extend %v", b, a)
+	if extension.ExtendedType != nil {
+		if a, b := reflect.TypeOf(pb), reflect.TypeOf(extension.ExtendedType); a != b {
+			return fmt.Errorf("proto: bad extended type; %v does not extend %v", b, a)
+		}
 	}
 	// Check the range.
 	if !isExtensionField(pb, extension.Field) {
@@ -229,8 +109,8 @@ var extProp = struct {
 	m: make(map[extPropKey]*Properties),
 }
 
-func extensionProperties(ed *ExtensionDesc) *Properties {
-	key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field}
+func extensionProperties(pb Message, ed *ExtensionDesc) *Properties {
+	key := extPropKey{base: reflect.TypeOf(pb), field: ed.Field}
 
 	extProp.RLock()
 	if prop, ok := extProp.m[key]; ok {
@@ -259,14 +139,12 @@ func HasExtension(pb Message, extension *ExtensionDesc) bool {
 	if err != nil {
 		return false
 	}
-	extmap, mu := epb.extensionsRead()
-	if extmap == nil {
+	if !epb.HasInit() {
 		return false
 	}
-	mu.Lock()
-	_, ok := extmap[extension.Field]
-	mu.Unlock()
-	return ok
+	epb.Lock()
+	defer epb.Unlock()
+	return epb.Has(protoreflect.FieldNumber(extension.Field))
 }
 
 // ClearExtension removes the given extension from pb.
@@ -276,8 +154,7 @@ func ClearExtension(pb Message, extension *ExtensionDesc) {
 		return
 	}
 	// TODO: Check types, field numbers, etc.?
-	extmap := epb.extensionsWrite()
-	delete(extmap, extension.Field)
+	epb.Clear(protoreflect.FieldNumber(extension.Field))
 }
 
 // GetExtension retrieves a proto2 extended field from pb.
@@ -295,66 +172,63 @@ func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
 		return nil, err
 	}
 
-	if extension.ExtendedType != nil {
-		// can only check type if this is a complete descriptor
-		if err := checkExtensionTypes(epb, extension); err != nil {
-			return nil, err
-		}
+	// can only check type if this is a complete descriptor
+	if err := checkExtensionTypeAndRanges(pb, extension); err != nil {
+		return nil, err
 	}
 
-	emap, mu := epb.extensionsRead()
-	if emap == nil {
-		return defaultExtensionValue(extension)
+	if !epb.HasInit() {
+		return defaultExtensionValue(pb, extension)
 	}
-	mu.Lock()
-	defer mu.Unlock()
-	e, ok := emap[extension.Field]
-	if !ok {
+	epb.Lock()
+	defer epb.Unlock()
+	if !epb.Has(protoreflect.FieldNumber(extension.Field)) {
 		// defaultExtensionValue returns the default value or
 		// ErrMissingExtension if there is no default.
-		return defaultExtensionValue(extension)
+		return defaultExtensionValue(pb, extension)
 	}
+	e := epb.Get(protoreflect.FieldNumber(extension.Field))
 
-	if e.value != nil {
+	if e.Value != nil {
 		// Already decoded. Check the descriptor, though.
-		if e.desc != extension {
+		if e.Desc != extension {
 			// This shouldn't happen. If it does, it means that
 			// GetExtension was called twice with two different
 			// descriptors with the same field number.
 			return nil, errors.New("proto: descriptor conflict")
 		}
-		return extensionAsLegacyType(e.value), nil
+		return extensionAsLegacyType(e.Value), nil
 	}
 
 	if extension.ExtensionType == nil {
 		// incomplete descriptor
-		return e.enc, nil
+		return e.Raw, nil
 	}
 
-	v, err := decodeExtension(e.enc, extension)
+	v, err := decodeExtension(e.Raw, extension)
 	if err != nil {
 		return nil, err
 	}
 
 	// Remember the decoded version and drop the encoded version.
 	// That way it is safe to mutate what we return.
-	e.value = extensionAsStorageType(v)
-	e.desc = extension
-	e.enc = nil
-	emap[extension.Field] = e
-	return extensionAsLegacyType(e.value), nil
+	e.Value = extensionAsStorageType(v)
+	e.Desc = extension
+	e.Raw = nil
+	epb.Set(protoreflect.FieldNumber(extension.Field), e)
+	return extensionAsLegacyType(e.Value), nil
 }
 
 // defaultExtensionValue returns the default value for extension.
 // If no default for an extension is defined ErrMissingExtension is returned.
-func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) {
+func defaultExtensionValue(pb Message, extension *ExtensionDesc) (interface{}, error) {
 	if extension.ExtensionType == nil {
 		// incomplete descriptor, so no default
 		return nil, ErrMissingExtension
 	}
 
 	t := reflect.TypeOf(extension.ExtensionType)
-	props := extensionProperties(extension)
+	props := extensionProperties(pb, extension)
 
 	sf, _, err := fieldDefault(t, props)
 	if err != nil {
@@ -376,7 +250,7 @@ func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) {
 	value.Set(reflect.New(value.Type().Elem()))
 	if sf.kind == reflect.Int32 {
 		// We may have an int32 or an enum, but the underlying data is int32.
-		// Since we can't set an int32 into a non int32 reflect.value directly
+		// Since we can't set an int32 into a non int32 reflect.Value directly
 		// set it as a int32.
 		value.Elem().SetInt(int64(sf.value.(int32)))
 	} else {
@@ -418,13 +292,13 @@ func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
 // GetExtensions returns a slice of the extensions present in pb that are also listed in es.
 // The returned slice has the same length as es; missing extensions will appear as nil elements.
 func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
-	epb, err := extendable(pb)
+	_, err = extendable(pb)
 	if err != nil {
 		return nil, err
 	}
 	extensions = make([]interface{}, len(es))
 	for i, e := range es {
-		extensions[i], err = GetExtension(epb, e)
+		extensions[i], err = GetExtension(pb, e)
 		if err == ErrMissingExtension {
 			err = nil
 		}
@@ -445,24 +319,24 @@ func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) {
 	}
 	registeredExtensions := RegisteredExtensions(pb)
 
-	emap, mu := epb.extensionsRead()
-	if emap == nil {
+	if !epb.HasInit() {
 		return nil, nil
 	}
-	mu.Lock()
-	defer mu.Unlock()
-	extensions := make([]*ExtensionDesc, 0, len(emap))
-	for extid, e := range emap {
-		desc := e.desc
+	epb.Lock()
+	defer epb.Unlock()
+	extensions := make([]*ExtensionDesc, 0, epb.Len())
+	epb.Range(func(extid protoreflect.FieldNumber, e Extension) bool {
+		desc := e.Desc
 		if desc == nil {
-			desc = registeredExtensions[extid]
+			desc = registeredExtensions[int32(extid)]
 			if desc == nil {
-				desc = &ExtensionDesc{Field: extid}
+				desc = &ExtensionDesc{Field: int32(extid)}
 			}
 		}
 
 		extensions = append(extensions, desc)
-	}
+		return true
+	})
 	return extensions, nil
 }
 
@@ -472,7 +346,7 @@ func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error
 	if err != nil {
 		return err
 	}
-	if err := checkExtensionTypes(epb, extension); err != nil {
+	if err := checkExtensionTypeAndRanges(pb, extension); err != nil {
 		return err
 	}
 	typ := reflect.TypeOf(extension.ExtensionType)
@@ -488,8 +362,10 @@ func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error
 		return fmt.Errorf("proto: SetExtension called with nil value of type %T", value)
 	}
 
-	extmap := epb.extensionsWrite()
-	extmap[extension.Field] = Extension{desc: extension, value: extensionAsStorageType(value)}
+	epb.Set(protoreflect.FieldNumber(extension.Field), Extension{
+		Desc:  extension,
+		Value: extensionAsStorageType(value),
+	})
 	return nil
 }
 
@@ -499,10 +375,10 @@ func ClearAllExtensions(pb Message) {
 	if err != nil {
 		return
 	}
-	m := epb.extensionsWrite()
-	for k := range m {
-		delete(m, k)
-	}
+	epb.Range(func(k protoreflect.FieldNumber, _ Extension) bool {
+		epb.Clear(k)
+		return true
+	})
 }
 
 // A global registry of extensions.
@@ -532,7 +408,7 @@ func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc {
 }
 
 // extensionAsLegacyType converts an value in the storage type as the API type.
-// See Extension.value.
+// See Extension.Value.
 func extensionAsLegacyType(v interface{}) interface{} {
 	switch rv := reflect.ValueOf(v); rv.Kind() {
 	case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
@@ -555,7 +431,7 @@ func extensionAsLegacyType(v interface{}) interface{} {
 }
 
 // extensionAsStorageType converts an value in the API type as the storage type.
-// See Extension.value.
+// See Extension.Value.
 func extensionAsStorageType(v interface{}) interface{} {
 	switch rv := reflect.ValueOf(v); rv.Kind() {
 	case reflect.Ptr:

+ 1 - 1
proto/extensions_test.go

@@ -377,7 +377,7 @@ func TestNilMessage(t *testing.T) {
 	desc := pb.E_Ext_More
 
 	isNotExtendable := func(err error) bool {
-		return strings.Contains(fmt.Sprint(err), "not extendable")
+		return strings.Contains(fmt.Sprint(err), "not an extendable")
 	}
 
 	if proto.HasExtension(nilMsg, desc) {

+ 3 - 5
proto/lib.go

@@ -248,6 +248,8 @@ import (
 	// Add a bogus dependency on the v2 API to ensure the Go toolchain does not
 	// remove our dependency from the go.mod file.
 	_ "github.com/golang/protobuf/v2/reflect/protoreflect"
+
+	"github.com/golang/protobuf/protoapi"
 )
 
 // RequiredNotSetError is an error type returned by either Marshal or Unmarshal.
@@ -312,11 +314,7 @@ func (nf *nonFatal) Merge(err error) (ok bool) {
 }
 
 // Message is implemented by generated protocol buffer messages.
-type Message interface {
-	Reset()
-	String() string
-	ProtoMessage()
-}
+type Message = protoapi.Message
 
 // A Buffer is a buffer manager for marshaling and unmarshaling
 // protocol buffers.  It may be reused between invocations to

+ 11 - 14
proto/message_set.go

@@ -10,6 +10,9 @@ package proto
 
 import (
 	"errors"
+
+	"github.com/golang/protobuf/protoapi"
+	"github.com/golang/protobuf/v2/reflect/protoreflect"
 )
 
 // errNoMessageTypeID occurs when a protocol buffer does not have a message type ID.
@@ -115,32 +118,26 @@ func skipVarint(buf []byte) []byte {
 // unmarshalMessageSet decodes the extension map encoded in buf in the message set wire format.
 // It is called by Unmarshal methods on protocol buffer messages with the message_set_wire_format option.
 func unmarshalMessageSet(buf []byte, exts interface{}) error {
-	var m map[int32]Extension
-	switch exts := exts.(type) {
-	case *XXX_InternalExtensions:
-		m = exts.extensionsWrite()
-	case map[int32]Extension:
-		m = exts
-	default:
-		return errors.New("proto: not an extension map")
-	}
+	m := protoapi.ExtensionFieldsOf(exts)
 
 	ms := new(messageSet)
 	if err := Unmarshal(buf, ms); err != nil {
 		return err
 	}
 	for _, item := range ms.Item {
-		id := *item.TypeId
+		id := protoreflect.FieldNumber(*item.TypeId)
 		msg := item.Message
 
 		// Restore wire type and field number varint, plus length varint.
 		// Be careful to preserve duplicate items.
 		b := EncodeVarint(uint64(id)<<3 | WireBytes)
-		if ext, ok := m[id]; ok {
+		if m.Has(id) {
+			ext := m.Get(id)
+
 			// Existing data; rip off the tag and length varint
 			// so we join the new data correctly.
-			// We can assume that ext.enc is set because we are unmarshaling.
-			o := ext.enc[len(b):]   // skip wire type and field number
+			// We can assume that ext.Raw is set because we are unmarshaling.
+			o := ext.Raw[len(b):]   // skip wire type and field number
 			_, n := DecodeVarint(o) // calculate length of length varint
 			o = o[n:]               // skip length varint
 			msg = append(o, msg...) // join old data and new data
@@ -148,7 +145,7 @@ func unmarshalMessageSet(buf []byte, exts interface{}) error {
 		b = append(b, EncodeVarint(uint64(len(msg)))...)
 		b = append(b, msg...)
 
-		m[id] = Extension{enc: b}
+		m.Set(id, Extension{Raw: b})
 	}
 	return nil
 }

+ 85 - 74
proto/table_marshal.go

@@ -15,6 +15,9 @@ import (
 	"sync"
 	"sync/atomic"
 	"unicode/utf8"
+
+	"github.com/golang/protobuf/protoapi"
+	"github.com/golang/protobuf/v2/reflect/protoreflect"
 )
 
 // a sizer takes a pointer to a field and the size of its tag, computes the size of
@@ -2362,82 +2365,86 @@ func makeOneOfMarshaler(fi *marshalFieldInfo, f *reflect.StructField) (sizer, ma
 
 // sizeExtensions computes the size of encoded data for a XXX_InternalExtensions field.
 func (u *marshalInfo) sizeExtensions(ext *XXX_InternalExtensions) int {
-	m, mu := ext.extensionsRead()
-	if m == nil {
+	m := protoapi.ExtensionFieldsOf(ext)
+	if !m.HasInit() {
 		return 0
 	}
-	mu.Lock()
+	m.Lock()
+	defer m.Unlock()
 
 	n := 0
-	for _, e := range m {
-		if e.value == nil || e.desc == nil {
+	m.Range(func(_ protoreflect.FieldNumber, e Extension) bool {
+		if e.Value == nil || e.Desc == nil {
 			// Extension is only in its encoded form.
-			n += len(e.enc)
-			continue
+			n += len(e.Raw)
+			return true
 		}
 
 		// We don't skip extensions that have an encoded form set,
 		// because the extension value may have been mutated after
 		// the last time this function was called.
-		ei := u.getExtElemInfo(e.desc)
-		v := e.value
+		ei := u.getExtElemInfo(e.Desc)
+		v := e.Value
 		p := toAddrPointer(&v, ei.isptr, ei.deref)
 		n += ei.sizer(p, ei.tagsize)
-	}
-	mu.Unlock()
+		return true
+	})
 	return n
 }
 
 // appendExtensions marshals a XXX_InternalExtensions field to the end of byte slice b.
 func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, deterministic bool) ([]byte, error) {
-	m, mu := ext.extensionsRead()
-	if m == nil {
+	m := protoapi.ExtensionFieldsOf(ext)
+	if !m.HasInit() {
 		return b, nil
 	}
-	mu.Lock()
-	defer mu.Unlock()
+	m.Lock()
+	defer m.Unlock()
 
 	var err error
 	var nerr nonFatal
 
 	// Fast-path for common cases: zero or one extensions.
 	// Don't bother sorting the keys.
-	if len(m) <= 1 {
-		for _, e := range m {
-			if e.value == nil || e.desc == nil {
+	if m.Len() <= 1 {
+		m.Range(func(_ protoreflect.FieldNumber, e Extension) bool {
+			if e.Value == nil || e.Desc == nil {
 				// Extension is only in its encoded form.
-				b = append(b, e.enc...)
-				continue
+				b = append(b, e.Raw...)
+				return true
 			}
 
 			// We don't skip extensions that have an encoded form set,
 			// because the extension value may have been mutated after
 			// the last time this function was called.
 
-			ei := u.getExtElemInfo(e.desc)
-			v := e.value
+			ei := u.getExtElemInfo(e.Desc)
+			v := e.Value
 			p := toAddrPointer(&v, ei.isptr, ei.deref)
 			b, err = ei.marshaler(b, p, ei.wiretag, deterministic)
 			if !nerr.Merge(err) {
-				return b, err
+				return false
 			}
-		}
-		return b, nerr.E
+			err = nerr.E
+			return true
+		})
+		return b, err
 	}
 
 	// Sort the keys to provide a deterministic encoding.
 	// Not sure this is required, but the old code does it.
-	keys := make([]int, 0, len(m))
-	for k := range m {
+	keys := make([]int, 0, m.Len())
+	m.Range(func(k protoreflect.FieldNumber, _ Extension) bool {
 		keys = append(keys, int(k))
-	}
+		return true
+	})
 	sort.Ints(keys)
 
 	for _, k := range keys {
-		e := m[int32(k)]
-		if e.value == nil || e.desc == nil {
+		e := m.Get(protoreflect.FieldNumber(k))
+		if e.Value == nil || e.Desc == nil {
 			// Extension is only in its encoded form.
-			b = append(b, e.enc...)
+			b = append(b, e.Raw...)
 			continue
 		}
 
@@ -2445,8 +2452,8 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de
 		// because the extension value may have been mutated after
 		// the last time this function was called.
 
-		ei := u.getExtElemInfo(e.desc)
-		v := e.value
+		ei := u.getExtElemInfo(e.Desc)
+		v := e.Value
 		p := toAddrPointer(&v, ei.isptr, ei.deref)
 		b, err = ei.marshaler(b, p, ei.wiretag, deterministic)
 		if !nerr.Merge(err) {
@@ -2467,100 +2474,104 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de
 // sizeMessageSet computes the size of encoded data for a XXX_InternalExtensions field
 // in message set format (above).
 func (u *marshalInfo) sizeMessageSet(ext *XXX_InternalExtensions) int {
-	m, mu := ext.extensionsRead()
-	if m == nil {
+	m := protoapi.ExtensionFieldsOf(ext)
+	if !m.HasInit() {
 		return 0
 	}
-	mu.Lock()
+	m.Lock()
+	defer m.Unlock()
 
 	n := 0
-	for id, e := range m {
+	m.Range(func(id protoreflect.FieldNumber, e Extension) bool {
 		n += 2                          // start group, end group. tag = 1 (size=1)
 		n += SizeVarint(uint64(id)) + 1 // type_id, tag = 2 (size=1)
 
-		if e.value == nil || e.desc == nil {
+		if e.Value == nil || e.Desc == nil {
 			// Extension is only in its encoded form.
-			msgWithLen := skipVarint(e.enc) // skip old tag, but leave the length varint
+			msgWithLen := skipVarint(e.Raw) // skip old tag, but leave the length varint
 			siz := len(msgWithLen)
 			n += siz + 1 // message, tag = 3 (size=1)
-			continue
+			return true
 		}
 
 		// We don't skip extensions that have an encoded form set,
 		// because the extension value may have been mutated after
 		// the last time this function was called.
 
-		ei := u.getExtElemInfo(e.desc)
-		v := e.value
+		ei := u.getExtElemInfo(e.Desc)
+		v := e.Value
 		p := toAddrPointer(&v, ei.isptr, ei.deref)
 		n += ei.sizer(p, 1) // message, tag = 3 (size=1)
-	}
-	mu.Unlock()
+		return true
+	})
 	return n
 }
 
 // appendMessageSet marshals a XXX_InternalExtensions field in message set format (above)
 // to the end of byte slice b.
 func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, deterministic bool) ([]byte, error) {
-	m, mu := ext.extensionsRead()
-	if m == nil {
+	m := protoapi.ExtensionFieldsOf(ext)
+	if !m.HasInit() {
 		return b, nil
 	}
-	mu.Lock()
-	defer mu.Unlock()
+	m.Lock()
+	defer m.Unlock()
 
 	var err error
 	var nerr nonFatal
 
 	// Fast-path for common cases: zero or one extensions.
 	// Don't bother sorting the keys.
-	if len(m) <= 1 {
-		for id, e := range m {
+	if m.Len() <= 1 {
+		m.Range(func(id protoreflect.FieldNumber, e Extension) bool {
 			b = append(b, 1<<3|WireStartGroup)
 			b = append(b, 2<<3|WireVarint)
 			b = appendVarint(b, uint64(id))
 
-			if e.value == nil || e.desc == nil {
+			if e.Value == nil || e.Desc == nil {
 				// Extension is only in its encoded form.
-				msgWithLen := skipVarint(e.enc) // skip old tag, but leave the length varint
+				msgWithLen := skipVarint(e.Raw) // skip old tag, but leave the length varint
 				b = append(b, 3<<3|WireBytes)
 				b = append(b, msgWithLen...)
 				b = append(b, 1<<3|WireEndGroup)
-				continue
+				return true
 			}
 
 			// We don't skip extensions that have an encoded form set,
 			// because the extension value may have been mutated after
 			// the last time this function was called.
 
-			ei := u.getExtElemInfo(e.desc)
-			v := e.value
+			ei := u.getExtElemInfo(e.Desc)
+			v := e.Value
 			p := toAddrPointer(&v, ei.isptr, ei.deref)
 			b, err = ei.marshaler(b, p, 3<<3|WireBytes, deterministic)
 			if !nerr.Merge(err) {
-				return b, err
+				return false
 			}
 			b = append(b, 1<<3|WireEndGroup)
-		}
-		return b, nerr.E
+			err = nerr.E
+			return true
+		})
+		return b, err
 	}
 
 	// Sort the keys to provide a deterministic encoding.
-	keys := make([]int, 0, len(m))
-	for k := range m {
+	keys := make([]int, 0, m.Len())
+	m.Range(func(k protoreflect.FieldNumber, _ Extension) bool {
 		keys = append(keys, int(k))
-	}
+		return true
+	})
 	sort.Ints(keys)
 
 	for _, id := range keys {
-		e := m[int32(id)]
+		e := m.Get(protoreflect.FieldNumber(id))
 		b = append(b, 1<<3|WireStartGroup)
 		b = append(b, 2<<3|WireVarint)
 		b = appendVarint(b, uint64(id))
 
-		if e.value == nil || e.desc == nil {
+		if e.Value == nil || e.Desc == nil {
 			// Extension is only in its encoded form.
-			msgWithLen := skipVarint(e.enc) // skip old tag, but leave the length varint
+			msgWithLen := skipVarint(e.Raw) // skip old tag, but leave the length varint
 			b = append(b, 3<<3|WireBytes)
 			b = append(b, msgWithLen...)
 			b = append(b, 1<<3|WireEndGroup)
@@ -2571,8 +2582,8 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de
 		// because the extension value may have been mutated after
 		// the last time this function was called.
 
-		ei := u.getExtElemInfo(e.desc)
-		v := e.value
+		ei := u.getExtElemInfo(e.Desc)
+		v := e.Value
 		p := toAddrPointer(&v, ei.isptr, ei.deref)
 		b, err = ei.marshaler(b, p, 3<<3|WireBytes, deterministic)
 		b = append(b, 1<<3|WireEndGroup)
@@ -2591,9 +2602,9 @@ func (u *marshalInfo) sizeV1Extensions(m map[int32]Extension) int {
 
 	n := 0
 	for _, e := range m {
-		if e.value == nil || e.desc == nil {
+		if e.Value == nil || e.Desc == nil {
 			// Extension is only in its encoded form.
-			n += len(e.enc)
+			n += len(e.Raw)
 			continue
 		}
 
@@ -2601,8 +2612,8 @@ func (u *marshalInfo) sizeV1Extensions(m map[int32]Extension) int {
 		// because the extension value may have been mutated after
 		// the last time this function was called.
 
-		ei := u.getExtElemInfo(e.desc)
-		v := e.value
+		ei := u.getExtElemInfo(e.Desc)
+		v := e.Value
 		p := toAddrPointer(&v, ei.isptr, ei.deref)
 		n += ei.sizer(p, ei.tagsize)
 	}
@@ -2626,9 +2637,9 @@ func (u *marshalInfo) appendV1Extensions(b []byte, m map[int32]Extension, determ
 	var nerr nonFatal
 	for _, k := range keys {
 		e := m[int32(k)]
-		if e.value == nil || e.desc == nil {
+		if e.Value == nil || e.Desc == nil {
 			// Extension is only in its encoded form.
-			b = append(b, e.enc...)
+			b = append(b, e.Raw...)
 			continue
 		}
 
@@ -2636,8 +2647,8 @@ func (u *marshalInfo) appendV1Extensions(b []byte, m map[int32]Extension, determ
 		// because the extension value may have been mutated after
 		// the last time this function was called.
 
-		ei := u.getExtElemInfo(e.desc)
-		v := e.value
+		ei := u.getExtElemInfo(e.Desc)
+		v := e.Value
 		p := toAddrPointer(&v, ei.isptr, ei.deref)
 		b, err = ei.marshaler(b, p, ei.wiretag, deterministic)
 		if !nerr.Merge(err) {

+ 4 - 6
proto/table_merge.go

@@ -117,12 +117,10 @@ func (mi *mergeInfo) merge(dst, src pointer) {
 	in := src.asPointerTo(mi.typ).Elem()
 	if emIn, err := extendable(in.Addr().Interface()); err == nil {
 		emOut, _ := extendable(out.Addr().Interface())
-		mIn, muIn := emIn.extensionsRead()
-		if mIn != nil {
-			mOut := emOut.extensionsWrite()
-			muIn.Lock()
-			mergeExtension(mOut, mIn)
-			muIn.Unlock()
+		if emIn.HasInit() {
+			emIn.Lock()
+			mergeExtension(emOut, emIn)
+			emIn.Unlock()
 		}
 	}
 

+ 11 - 12
proto/table_unmarshal.go

@@ -15,6 +15,9 @@ import (
 	"sync"
 	"sync/atomic"
 	"unicode/utf8"
+
+	"github.com/golang/protobuf/protoapi"
+	"github.com/golang/protobuf/v2/reflect/protoreflect"
 )
 
 // Unmarshal is the entry point from the generated .pb.go files.
@@ -183,26 +186,22 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error {
 		// Keep unrecognized data around.
 		// maybe in extensions, maybe in the unrecognized field.
 		z := m.offset(u.unrecognized).toBytes()
-		var emap map[int32]Extension
+		var emap protoapi.ExtensionFields
 		var e Extension
 		for _, r := range u.extensionRanges {
 			if uint64(r.Start) <= tag && tag <= uint64(r.End) {
 				if u.extensions.IsValid() {
 					mp := m.offset(u.extensions).toExtensions()
-					emap = mp.extensionsWrite()
-					e = emap[int32(tag)]
-					z = &e.enc
+					emap = protoapi.ExtensionFieldsOf(mp)
+					e = emap.Get(protoreflect.FieldNumber(tag))
+					z = &e.Raw
 					break
 				}
 				if u.oldExtensions.IsValid() {
 					p := m.offset(u.oldExtensions).toOldExtensions()
-					emap = *p
-					if emap == nil {
-						emap = map[int32]Extension{}
-						*p = emap
-					}
-					e = emap[int32(tag)]
-					z = &e.enc
+					emap = protoapi.ExtensionFieldsOf(p)
+					e = emap.Get(protoreflect.FieldNumber(tag))
+					z = &e.Raw
 					break
 				}
 				panic("no extensions field available")
@@ -220,7 +219,7 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error {
 		*z = append(*z, b0[:len(b0)-len(b)]...)
 
 		if emap != nil {
-			emap[int32(tag)] = e
+			emap.Set(protoreflect.FieldNumber(tag), e)
 		}
 	}
 	if reqMask != u.reqMask && errLater == nil {

+ 19 - 17
proto/text.go

@@ -18,6 +18,8 @@ import (
 	"reflect"
 	"sort"
 	"strings"
+
+	"github.com/golang/protobuf/v2/reflect/protoreflect"
 )
 
 var (
@@ -641,11 +643,11 @@ func writeUnknownInt(w *textWriter, x uint64, err error) error {
 	return err
 }
 
-type int32Slice []int32
+type fieldNumSlice []protoreflect.FieldNumber
 
-func (s int32Slice) Len() int           { return len(s) }
-func (s int32Slice) Less(i, j int) bool { return s[i] < s[j] }
-func (s int32Slice) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
+func (s fieldNumSlice) Len() int           { return len(s) }
+func (s fieldNumSlice) Less(i, j int) bool { return s[i] < s[j] }
+func (s fieldNumSlice) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
 
 // writeExtensions writes all the extensions in pv.
 // pv is assumed to be a pointer to a protocol message struct that is extendable.
@@ -656,39 +658,39 @@ func (tm *TextMarshaler) writeExtensions(w *textWriter, pv reflect.Value) error
 	// Order the extensions by ID.
 	// This isn't strictly necessary, but it will give us
 	// canonical output, which will also make testing easier.
-	m, mu := ep.extensionsRead()
-	if m == nil {
+	if !ep.HasInit() {
 		return nil
 	}
-	mu.Lock()
-	ids := make([]int32, 0, len(m))
-	for id := range m {
+	ep.Lock()
+	ids := make([]protoreflect.FieldNumber, 0, ep.Len())
+	ep.Range(func(id protoreflect.FieldNumber, _ Extension) bool {
 		ids = append(ids, id)
-	}
-	sort.Sort(int32Slice(ids))
-	mu.Unlock()
+		return true
+	})
+	sort.Sort(fieldNumSlice(ids))
+	ep.Unlock()
 
 	for _, extNum := range ids {
-		ext := m[extNum]
+		ext := ep.Get(extNum)
 		var desc *ExtensionDesc
 		if emap != nil {
-			desc = emap[extNum]
+			desc = emap[int32(extNum)]
 		}
 		if desc == nil {
 			// Unknown extension.
-			if err := writeUnknownStruct(w, ext.enc); err != nil {
+			if err := writeUnknownStruct(w, ext.Raw); err != nil {
 				return err
 			}
 			continue
 		}
 
-		pb, err := GetExtension(ep, desc)
+		pb, err := GetExtension(pv.Interface().(Message), desc)
 		if err != nil {
 			return fmt.Errorf("failed getting extension: %v", err)
 		}
 
 		// Repeated extensions will appear as a slice.
-		if !desc.repeated() {
+		if !isRepeatedExtension(desc) {
 			if err := tm.writeExtension(w, desc.Name, pb); err != nil {
 				return err
 			}

+ 1 - 1
proto/text_parser.go

@@ -504,7 +504,7 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
 				return err
 			}
 
-			rep := desc.repeated()
+			rep := isRepeatedExtension(desc)
 
 			// Read the extension structure, and set it in
 			// the value we're constructing.

+ 242 - 0
protoapi/api.go

@@ -0,0 +1,242 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package protoapi contains the set of types referenced by generated messages.
+//
+// WARNING: This package should only ever be imported by generated messages.
+// The compatibility agreement covers nothing except for functionality needed
+// to keep existing generated messages operational.
+package protoapi
+
+import (
+	"fmt"
+	"sync"
+
+	"github.com/golang/protobuf/v2/reflect/protoreflect"
+)
+
+// TODO: How to handle Registration during the v1 to v2 switchover?
+
+type (
+	Message interface {
+		Reset()
+		String() string
+		ProtoMessage()
+	}
+
+	ExtensionRange struct {
+		Start, End int32 // both inclusive
+	}
+
+	ExtensionDesc struct {
+		// Type is the descriptor type for the extension field using the v2 API.
+		// If populated, the information in this field takes precedence over
+		// all other fields in ExtensionDesc.
+		Type protoreflect.ExtensionType
+
+		// ExtendedType is a typed nil-pointer to the parent message type that
+		// is being extended. It is possible for this to be unpopulated in v2
+		// since the message may no longer implement the v1 Message interface.
+		//
+		// Deprecated: Use Type.ExtendedType instead.
+		ExtendedType Message
+
+		// ExtensionType is zero value of the extension type.
+		//
+		// For historical reasons, reflect.TypeOf(ExtensionType) and Type.GoType
+		// may not be identical:
+		//	* for scalars (except []byte), where ExtensionType uses *T,
+		//	while Type.GoType uses T.
+		//	* for repeated fields, where ExtensionType uses []T,
+		//	while Type.GoType uses *[]T.
+		//
+		// Deprecated: Use Type.GoType instead.
+		ExtensionType interface{}
+
+		// Field is the field number of the extension.
+		//
+		// Deprecated: Use Type.Number instead.
+		Field int32 // field number
+
+		// Name is the fully qualified name of extension.
+		//
+		// Deprecated: Use Type.FullName instead.
+		Name string
+
+		// Tag is the protobuf struct tag used in the v1 API.
+		//
+		// Deprecated: Do not use.
+		Tag string
+
+		// Filename is the proto filename in which the extension is defined.
+		//
+		// Deprecated: Use Type.Parent to ascend to the top-most parent and use
+		// protoreflect.FileDescriptor.Path.
+		Filename string
+	}
+
+	ExtensionFields        extensionFields
+	ExtensionField         extensionField
+	XXX_InternalExtensions extensionSyncMap
+)
+
+// ExtensionFieldsOf returns an ExtensionFields abstraction over various
+// internal representations of extension fields.
+func ExtensionFieldsOf(p interface{}) ExtensionFields {
+	switch p := p.(type) {
+	case *map[int32]ExtensionField:
+		return (*extensionMap)(p)
+	case *XXX_InternalExtensions:
+		return (*extensionSyncMap)(p)
+	default:
+		panic(fmt.Sprintf("invalid extension fields type: %T", p))
+	}
+}
+
+type extensionFields interface {
+	Len() int
+	Has(protoreflect.FieldNumber) bool
+	Get(protoreflect.FieldNumber) ExtensionField
+	Set(protoreflect.FieldNumber, ExtensionField)
+	Clear(protoreflect.FieldNumber)
+	Range(f func(protoreflect.FieldNumber, ExtensionField) bool)
+
+	// HasInit and Locker are used by v1 GetExtension to provide
+	// an artificial degree of concurrent safety.
+	HasInit() bool
+	sync.Locker
+}
+
+type extensionField struct {
+	// When an extension is stored in a message using SetExtension
+	// only desc and value are set. When the message is marshaled
+	// Raw will be set to the encoded form of the message.
+	//
+	// When a message is unmarshaled and contains extensions, each
+	// extension will have only Raw set. When such an extension is
+	// accessed using GetExtension (or GetExtensions) desc and value
+	// will be set.
+	Desc *ExtensionDesc // TODO: switch to protoreflect.ExtensionType
+
+	// Value is a concrete value for the extension field. Let the type of
+	// Desc.ExtensionType be the "API type" and the type of Value be the
+	// "storage type". The API type and storage type are the same except:
+	//	* for scalars (except []byte), where the API type uses *T,
+	//	while the storage type uses T.
+	//	* for repeated fields, where the API type uses []T,
+	//	while the storage type uses *[]T.
+	//
+	// The reason for the divergence is so that the storage type more naturally
+	// matches what is expected of when retrieving the values through the
+	// protobuf reflection APIs.
+	//
+	// The Value may only be populated if Desc is also populated.
+	Value interface{} // TODO: switch to protoreflect.Value
+
+	// Raw is the raw encoded bytes for the extension field.
+	// It is possible for Raw to be populated irrespective of whether the
+	// other fields are populated.
+	Raw []byte // TODO: switch to protoreflect.RawFields
+}
+
+type extensionSyncMap struct {
+	p *struct {
+		mu sync.Mutex
+		m  extensionMap
+	}
+}
+
+func (m extensionSyncMap) Len() int {
+	if m.p == nil {
+		return 0
+	}
+	return m.p.m.Len()
+}
+func (m extensionSyncMap) Has(n protoreflect.FieldNumber) bool {
+	if m.p == nil {
+		return false
+	}
+	return m.p.m.Has(n)
+}
+func (m extensionSyncMap) Get(n protoreflect.FieldNumber) ExtensionField {
+	if m.p == nil {
+		return ExtensionField{}
+	}
+	return m.p.m.Get(n)
+}
+func (m *extensionSyncMap) Set(n protoreflect.FieldNumber, x ExtensionField) {
+	if m.p == nil {
+		m.p = new(struct {
+			mu sync.Mutex
+			m  extensionMap
+		})
+	}
+	m.p.m.Set(n, x)
+}
+func (m extensionSyncMap) Clear(n protoreflect.FieldNumber) {
+	if m.p == nil {
+		return
+	}
+	m.p.m.Clear(n)
+}
+func (m extensionSyncMap) Range(f func(protoreflect.FieldNumber, ExtensionField) bool) {
+	if m.p == nil {
+		return
+	}
+	m.p.m.Range(f)
+}
+
+func (m extensionSyncMap) HasInit() bool {
+	return m.p != nil
+}
+func (m extensionSyncMap) Lock() {
+	m.p.mu.Lock()
+}
+func (m extensionSyncMap) Unlock() {
+	m.p.mu.Unlock()
+}
+
+type extensionMap map[int32]ExtensionField
+
+func (m extensionMap) Len() int {
+	return len(m)
+}
+func (m extensionMap) Has(n protoreflect.FieldNumber) bool {
+	_, ok := m[int32(n)]
+	return ok
+}
+func (m extensionMap) Get(n protoreflect.FieldNumber) ExtensionField {
+	return m[int32(n)]
+}
+func (m *extensionMap) Set(n protoreflect.FieldNumber, x ExtensionField) {
+	if *m == nil {
+		*m = make(map[int32]ExtensionField)
+	}
+	(*m)[int32(n)] = x
+}
+func (m *extensionMap) Clear(n protoreflect.FieldNumber) {
+	delete(*m, int32(n))
+}
+func (m extensionMap) Range(f func(protoreflect.FieldNumber, ExtensionField) bool) {
+	for n, x := range m {
+		if !f(protoreflect.FieldNumber(n), x) {
+			return
+		}
+	}
+}
+
+var globalLock sync.Mutex
+
+func (m extensionMap) HasInit() bool {
+	return m != nil
+}
+func (m extensionMap) Lock() {
+	if !m.HasInit() {
+		panic("cannot lock an uninitialized map")
+	}
+	globalLock.Lock()
+}
+func (m extensionMap) Unlock() {
+	globalLock.Lock()
+}