Sfoglia il codice sorgente

net/proto2: remove <message>.ExtensionMap() from generated messages

Turn generated message struct field XXX_Extensions map[int32]Extension
into an embedded proto.InternalExtensions  struct

InternalExtensions is a struct without any exported fields and methods.
This effectively makes the representation of the extension map private.
The proto package can access InternalExtensions by checking that the
generated struct has the method 'extmap() proto.InternalExtensions'.

Also lock accesses to the extension map.

This change bumps the Go protobuf generated code version number. Any
.pb.go files generated with this version of the proto package or later
will require this version or later of the proto package to compile.
matloob@google.com 9 anni fa
parent
commit
e51d002c61

+ 6 - 11
jsonpb/jsonpb.go

@@ -233,12 +233,14 @@ func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent, typeU
 	}
 	}
 
 
 	// Handle proto2 extensions.
 	// Handle proto2 extensions.
-	if ep, ok := v.(extendableProto); ok {
+	if ep, ok := v.(proto.Message); ok {
 		extensions := proto.RegisteredExtensions(v)
 		extensions := proto.RegisteredExtensions(v)
-		extensionMap := ep.ExtensionMap()
 		// Sort extensions for stable output.
 		// Sort extensions for stable output.
-		ids := make([]int32, 0, len(extensionMap))
-		for id := range extensionMap {
+		ids := make([]int32, 0, len(extensions))
+		for id, desc := range extensions {
+			if !proto.HasExtension(ep, desc) {
+				continue
+			}
 			ids = append(ids, id)
 			ids = append(ids, id)
 		}
 		}
 		sort.Sort(int32Slice(ids))
 		sort.Sort(int32Slice(ids))
@@ -767,13 +769,6 @@ func acceptedJSONFieldNames(prop *proto.Properties) fieldNames {
 	return opts
 	return opts
 }
 }
 
 
-// extendableProto is an interface implemented by any protocol buffer that may be extended.
-type extendableProto interface {
-	proto.Message
-	ExtensionRangeArray() []proto.ExtensionRange
-	ExtensionMap() map[int32]proto.Extension
-}
-
 // Writer wrapper inspired by https://blog.golang.org/errors-are-values
 // Writer wrapper inspired by https://blog.golang.org/errors-are-values
 type errWriter struct {
 type errWriter struct {
 	writer io.Writer
 	writer io.Writer

+ 9 - 3
proto/clone.go

@@ -84,9 +84,15 @@ func mergeStruct(out, in reflect.Value) {
 		mergeAny(out.Field(i), in.Field(i), false, sprop.Prop[i])
 		mergeAny(out.Field(i), in.Field(i), false, sprop.Prop[i])
 	}
 	}
 
 
-	if emIn, ok := in.Addr().Interface().(extendableProto); ok {
-		emOut := out.Addr().Interface().(extendableProto)
-		mergeExtension(emOut.ExtensionMap(), emIn.ExtensionMap())
+	if emIn, ok := extendable(in.Addr().Interface()); ok {
+		emOut, _ := extendable(out.Addr().Interface())
+		mIn, muIn := emIn.extensionsRead()
+		if mIn != nil {
+			mOut := emOut.extensionsWrite()
+			muIn.Lock()
+			mergeExtension(mOut, mIn)
+			muIn.Unlock()
+		}
 	}
 	}
 
 
 	uf := in.FieldByName("XXX_unrecognized")
 	uf := in.FieldByName("XXX_unrecognized")

+ 4 - 3
proto/decode.go

@@ -390,11 +390,12 @@ func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group
 		if !ok {
 		if !ok {
 			// Maybe it's an extension?
 			// Maybe it's an extension?
 			if prop.extendable {
 			if prop.extendable {
-				if e := structPointer_Interface(base, st).(extendableProto); isExtensionField(e, int32(tag)) {
+				if e, _ := extendable(structPointer_Interface(base, st)); isExtensionField(e, int32(tag)) {
 					if err = o.skip(st, tag, wire); err == nil {
 					if err = o.skip(st, tag, wire); err == nil {
-						ext := e.ExtensionMap()[int32(tag)] // may be missing
+						extmap := e.extensionsWrite()
+						ext := extmap[int32(tag)] // may be missing
 						ext.enc = append(ext.enc, o.buf[oi:o.index]...)
 						ext.enc = append(ext.enc, o.buf[oi:o.index]...)
-						e.ExtensionMap()[int32(tag)] = ext
+						extmap[int32(tag)] = ext
 					}
 					}
 					continue
 					continue
 				}
 				}

+ 24 - 4
proto/encode.go

@@ -1073,10 +1073,25 @@ func size_slice_struct_group(p *Properties, base structPointer) (n int) {
 
 
 // Encode an extension map.
 // Encode an extension map.
 func (o *Buffer) enc_map(p *Properties, base structPointer) error {
 func (o *Buffer) enc_map(p *Properties, base structPointer) error {
-	v := *structPointer_ExtMap(base, p.field)
-	if err := encodeExtensionMap(v); err != nil {
+	exts := structPointer_ExtMap(base, p.field)
+	if err := encodeExtensionsMap(*exts); err != nil {
 		return err
 		return err
 	}
 	}
+
+	return o.enc_map_body(*exts)
+}
+
+func (o *Buffer) enc_exts(p *Properties, base structPointer) error {
+	exts := structPointer_Extensions(base, p.field)
+	if err := encodeExtensions(exts); err != nil {
+		return err
+	}
+	v, _ := exts.extensionsRead()
+
+	return o.enc_map_body(v)
+}
+
+func (o *Buffer) enc_map_body(v map[int32]Extension) error {
 	// Fast-path for common cases: zero or one extensions.
 	// Fast-path for common cases: zero or one extensions.
 	if len(v) <= 1 {
 	if len(v) <= 1 {
 		for _, e := range v {
 		for _, e := range v {
@@ -1099,8 +1114,13 @@ func (o *Buffer) enc_map(p *Properties, base structPointer) error {
 }
 }
 
 
 func size_map(p *Properties, base structPointer) int {
 func size_map(p *Properties, base structPointer) int {
-	v := *structPointer_ExtMap(base, p.field)
-	return sizeExtensionMap(v)
+	v := structPointer_ExtMap(base, p.field)
+	return extensionsMapSize(*v)
+}
+
+func size_exts(p *Properties, base structPointer) int {
+	v := structPointer_Extensions(base, p.field)
+	return extensionsSize(v)
 }
 }
 
 
 // Encode a map field.
 // Encode a map field.

+ 7 - 5
proto/equal.go

@@ -121,9 +121,9 @@ func equalStruct(v1, v2 reflect.Value) bool {
 		}
 		}
 	}
 	}
 
 
-	if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() {
-		em2 := v2.FieldByName("XXX_extensions")
-		if !equalExtensions(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) {
+	if em1 := v1.FieldByName("XXX_InternalExtensions"); em1.IsValid() {
+		em2 := v2.FieldByName("XXX_InternalExtensions")
+		if !equalExtensions(v1.Type(), em1.Interface().(XXX_InternalExtensions), em2.Interface().(XXX_InternalExtensions)) {
 			return false
 			return false
 		}
 		}
 	}
 	}
@@ -223,8 +223,10 @@ func equalAny(v1, v2 reflect.Value, prop *Properties) bool {
 }
 }
 
 
 // base is the struct type that the extensions are based on.
 // base is the struct type that the extensions are based on.
-// em1 and em2 are extension maps.
-func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool {
+// x1 and x2 are InternalExtensions.
+func equalExtensions(base reflect.Type, x1, x2 XXX_InternalExtensions) bool {
+	em1, _ := x1.extensionsRead()
+	em2, _ := x2.extensionsRead()
 	if len(em1) != len(em2) {
 	if len(em1) != len(em2) {
 		return false
 		return false
 	}
 	}

+ 142 - 18
proto/extensions.go

@@ -52,14 +52,99 @@ type ExtensionRange struct {
 	Start, End int32 // both inclusive
 	Start, End int32 // both inclusive
 }
 }
 
 
-// extendableProto is an interface implemented by any protocol buffer that may be extended.
+// extendableProto is an interface implemented by any protocol buffer generated by the current
+// proto compiler that may be extended.
 type extendableProto interface {
 type extendableProto interface {
+	Message
+	ExtensionRangeArray() []ExtensionRange
+	extensionsWrite() map[int32]Extension
+	extensionsRead() (map[int32]Extension, sync.Locker)
+}
+
+// extendableProtoV1 is an interface implemented by a protocol buffer generated by the previous
+// version of the proto compiler that may be extended.
+type extendableProtoV1 interface {
 	Message
 	Message
 	ExtensionRangeArray() []ExtensionRange
 	ExtensionRangeArray() []ExtensionRange
 	ExtensionMap() map[int32]Extension
 	ExtensionMap() map[int32]Extension
 }
 }
 
 
+// extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto.
+type extensionAdapter struct {
+	extendableProtoV1
+}
+
+func (e extensionAdapter) extensionsWrite() map[int32]Extension {
+	return e.ExtensionMap()
+}
+
+func (e extensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) {
+	return e.ExtensionMap(), notLocker{}
+}
+
+// notLocker is a sync.Locker whose Lock and Unlock methods are nops.
+type notLocker struct{}
+
+func (n notLocker) Lock()   {}
+func (n notLocker) Unlock() {}
+
+// extendable returns the extendableProto interface for the given generated proto message.
+// If the proto message has the old extension format, it returns a wrapper that implements
+// the extendableProto interface.
+func extendable(p interface{}) (extendableProto, bool) {
+	if ep, ok := p.(extendableProto); ok {
+		return ep, ok
+	}
+	if ep, ok := p.(extendableProtoV1); ok {
+		return extensionAdapter{ep}, ok
+	}
+	return nil, false
+}
+
+// XXX_InternalExtensions is an internal representation of proto extensions.
+//
+// Each generated message struct type embeds an anonymous XXX_InternalExtensions field,
+// thus gaining the unexported 'extensions' method, which can be called only from the proto package.
+//
+// The methods of XXX_InternalExtensions are not concurrency safe in general,
+// but calls to logically read-only methods such as has and get may be executed concurrently.
+type XXX_InternalExtensions struct {
+	// The struct must be indirect so that if a user inadvertently copies a
+	// generated message and its embedded XXX_InternalExtensions, they
+	// avoid the mayhem of a copied mutex.
+	//
+	// The mutex serializes all logically read-only operations to p.extensionMap.
+	// It is up to the client to ensure that write operations to p.extensionMap are
+	// mutually exclusive with other accesses.
+	p *struct {
+		mu           sync.Mutex
+		extensionMap map[int32]Extension
+	}
+}
+
+// extensionsWrite returns the extension map, creating it on first use.
+func (e *XXX_InternalExtensions) extensionsWrite() map[int32]Extension {
+	if e.p == nil {
+		e.p = new(struct {
+			mu           sync.Mutex
+			extensionMap map[int32]Extension
+		})
+		e.p.extensionMap = make(map[int32]Extension)
+	}
+	return e.p.extensionMap
+}
+
+// extensionsRead returns the extensions map for read-only use.  It may be nil.
+// The caller must hold the returned mutex's lock when accessing Elements within the map.
+func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Locker) {
+	if e.p == nil {
+		return nil, nil
+	}
+	return e.p.extensionMap, &e.p.mu
+}
+
 var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem()
 var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem()
+var extendableProtoV1Type = reflect.TypeOf((*extendableProtoV1)(nil)).Elem()
 
 
 // ExtensionDesc represents an extension specification.
 // ExtensionDesc represents an extension specification.
 // Used in generated code from the protocol compiler.
 // Used in generated code from the protocol compiler.
@@ -93,11 +178,12 @@ type Extension struct {
 
 
 // SetRawExtension is for testing only.
 // SetRawExtension is for testing only.
 func SetRawExtension(base Message, id int32, b []byte) {
 func SetRawExtension(base Message, id int32, b []byte) {
-	epb, ok := base.(extendableProto)
+	epb, ok := extendable(base)
 	if !ok {
 	if !ok {
 		return
 		return
 	}
 	}
-	epb.ExtensionMap()[id] = Extension{enc: b}
+	extmap := epb.extensionsWrite()
+	extmap[id] = Extension{enc: b}
 }
 }
 
 
 // isExtensionField returns true iff the given field number is in an extension range.
 // isExtensionField returns true iff the given field number is in an extension range.
@@ -112,8 +198,12 @@ func isExtensionField(pb extendableProto, field int32) bool {
 
 
 // checkExtensionTypes checks that the given extension is valid for pb.
 // checkExtensionTypes checks that the given extension is valid for pb.
 func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error {
 func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error {
+	var pbi interface{} = pb
 	// Check the extended type.
 	// Check the extended type.
-	if a, b := reflect.TypeOf(pb), reflect.TypeOf(extension.ExtendedType); a != b {
+	if ea, ok := pbi.(extensionAdapter); ok {
+		pbi = ea.extendableProtoV1
+	}
+	if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b {
 		return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String())
 		return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String())
 	}
 	}
 	// Check the range.
 	// Check the range.
@@ -159,8 +249,19 @@ func extensionProperties(ed *ExtensionDesc) *Properties {
 	return prop
 	return prop
 }
 }
 
 
-// encodeExtensionMap encodes any unmarshaled (unencoded) extensions in m.
-func encodeExtensionMap(m map[int32]Extension) error {
+// encode encodes any unmarshaled (unencoded) extensions in e.
+func encodeExtensions(e *XXX_InternalExtensions) error {
+	m, mu := e.extensionsRead()
+	if m == nil {
+		return nil // fast path
+	}
+	mu.Lock()
+	defer mu.Unlock()
+	return encodeExtensionsMap(m)
+}
+
+// encode encodes any unmarshaled (unencoded) extensions in e.
+func encodeExtensionsMap(m map[int32]Extension) error {
 	for k, e := range m {
 	for k, e := range m {
 		if e.value == nil || e.desc == nil {
 		if e.value == nil || e.desc == nil {
 			// Extension is only in its encoded form.
 			// Extension is only in its encoded form.
@@ -188,7 +289,17 @@ func encodeExtensionMap(m map[int32]Extension) error {
 	return nil
 	return nil
 }
 }
 
 
-func sizeExtensionMap(m map[int32]Extension) (n int) {
+func extensionsSize(e *XXX_InternalExtensions) (n int) {
+	m, mu := e.extensionsRead()
+	if m == nil {
+		return 0
+	}
+	mu.Lock()
+	defer mu.Unlock()
+	return extensionsMapSize(m)
+}
+
+func extensionsMapSize(m map[int32]Extension) (n int) {
 	for _, e := range m {
 	for _, e := range m {
 		if e.value == nil || e.desc == nil {
 		if e.value == nil || e.desc == nil {
 			// Extension is only in its encoded form.
 			// Extension is only in its encoded form.
@@ -215,28 +326,35 @@ func sizeExtensionMap(m map[int32]Extension) (n int) {
 // HasExtension returns whether the given extension is present in pb.
 // HasExtension returns whether the given extension is present in pb.
 func HasExtension(pb Message, extension *ExtensionDesc) bool {
 func HasExtension(pb Message, extension *ExtensionDesc) bool {
 	// TODO: Check types, field numbers, etc.?
 	// TODO: Check types, field numbers, etc.?
-	epb, ok := pb.(extendableProto)
+	epb, ok := extendable(pb)
 	if !ok {
 	if !ok {
 		return false
 		return false
 	}
 	}
-	_, ok = epb.ExtensionMap()[extension.Field]
+	extmap, mu := epb.extensionsRead()
+	if extmap == nil {
+		return false
+	}
+	mu.Lock()
+	_, ok = extmap[extension.Field]
+	mu.Unlock()
 	return ok
 	return ok
 }
 }
 
 
 // ClearExtension removes the given extension from pb.
 // ClearExtension removes the given extension from pb.
 func ClearExtension(pb Message, extension *ExtensionDesc) {
 func ClearExtension(pb Message, extension *ExtensionDesc) {
-	epb, ok := pb.(extendableProto)
+	epb, ok := extendable(pb)
 	if !ok {
 	if !ok {
 		return
 		return
 	}
 	}
 	// TODO: Check types, field numbers, etc.?
 	// TODO: Check types, field numbers, etc.?
-	delete(epb.ExtensionMap(), extension.Field)
+	extmap := epb.extensionsWrite()
+	delete(extmap, extension.Field)
 }
 }
 
 
 // GetExtension parses and returns the given extension of pb.
 // GetExtension parses and returns the given extension of pb.
 // If the extension is not present and has no default value it returns ErrMissingExtension.
 // If the extension is not present and has no default value it returns ErrMissingExtension.
 func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
 func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
-	epb, ok := pb.(extendableProto)
+	epb, ok := extendable(pb)
 	if !ok {
 	if !ok {
 		return nil, errors.New("proto: not an extendable proto")
 		return nil, errors.New("proto: not an extendable proto")
 	}
 	}
@@ -245,7 +363,12 @@ func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	emap := epb.ExtensionMap()
+	emap, mu := epb.extensionsRead()
+	if emap == nil {
+		return defaultExtensionValue(extension)
+	}
+	mu.Lock()
+	defer mu.Unlock()
 	e, ok := emap[extension.Field]
 	e, ok := emap[extension.Field]
 	if !ok {
 	if !ok {
 		// defaultExtensionValue returns the default value or
 		// defaultExtensionValue returns the default value or
@@ -349,7 +472,7 @@ func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
 // GetExtensions returns a slice of the extensions present in pb that are also listed in es.
 // GetExtensions returns a slice of the extensions present in pb that are also listed in es.
 // The returned slice has the same length as es; missing extensions will appear as nil elements.
 // The returned slice has the same length as es; missing extensions will appear as nil elements.
 func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
 func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
-	epb, ok := pb.(extendableProto)
+	epb, ok := extendable(pb)
 	if !ok {
 	if !ok {
 		return nil, errors.New("proto: not an extendable proto")
 		return nil, errors.New("proto: not an extendable proto")
 	}
 	}
@@ -368,7 +491,7 @@ func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, e
 
 
 // SetExtension sets the specified extension of pb to the specified value.
 // SetExtension sets the specified extension of pb to the specified value.
 func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error {
 func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error {
-	epb, ok := pb.(extendableProto)
+	epb, ok := extendable(pb)
 	if !ok {
 	if !ok {
 		return errors.New("proto: not an extendable proto")
 		return errors.New("proto: not an extendable proto")
 	}
 	}
@@ -388,17 +511,18 @@ func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error
 		return fmt.Errorf("proto: SetExtension called with nil value of type %T", value)
 		return fmt.Errorf("proto: SetExtension called with nil value of type %T", value)
 	}
 	}
 
 
-	epb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value}
+	extmap := epb.extensionsWrite()
+	extmap[extension.Field] = Extension{desc: extension, value: value}
 	return nil
 	return nil
 }
 }
 
 
 // ClearAllExtensions clears all extensions from pb.
 // ClearAllExtensions clears all extensions from pb.
 func ClearAllExtensions(pb Message) {
 func ClearAllExtensions(pb Message) {
-	epb, ok := pb.(extendableProto)
+	epb, ok := extendable(pb)
 	if !ok {
 	if !ok {
 		return
 		return
 	}
 	}
-	m := epb.ExtensionMap()
+	m := epb.extensionsWrite()
 	for k := range m {
 	for k := range m {
 		delete(m, k)
 		delete(m, k)
 	}
 	}

+ 4 - 0
proto/lib.go

@@ -889,6 +889,10 @@ func isProto3Zero(v reflect.Value) bool {
 	return false
 	return false
 }
 }
 
 
+// ProtoPackageIsVersion2 is referenced from generated protocol buffer files
+// to assert that that code is compatible with this version of the proto package.
+const ProtoPackageIsVersion2 = true
+
 // ProtoPackageIsVersion1 is referenced from generated protocol buffer files
 // ProtoPackageIsVersion1 is referenced from generated protocol buffer files
 // to assert that that code is compatible with this version of the proto package.
 // to assert that that code is compatible with this version of the proto package.
 const ProtoPackageIsVersion1 = true
 const ProtoPackageIsVersion1 = true

+ 37 - 6
proto/message_set.go

@@ -149,9 +149,21 @@ func skipVarint(buf []byte) []byte {
 
 
 // MarshalMessageSet encodes the extension map represented by m in the message set wire format.
 // MarshalMessageSet encodes the extension map represented by m in the message set wire format.
 // It is called by generated Marshal methods on protocol buffer messages with the message_set_wire_format option.
 // It is called by generated Marshal methods on protocol buffer messages with the message_set_wire_format option.
-func MarshalMessageSet(m map[int32]Extension) ([]byte, error) {
-	if err := encodeExtensionMap(m); err != nil {
-		return nil, err
+func MarshalMessageSet(exts interface{}) ([]byte, error) {
+	var m map[int32]Extension
+	switch exts := exts.(type) {
+	case *XXX_InternalExtensions:
+		if err := encodeExtensions(exts); err != nil {
+			return nil, err
+		}
+		m, _ = exts.extensionsRead()
+	case map[int32]Extension:
+		if err := encodeExtensionsMap(exts); err != nil {
+			return nil, err
+		}
+		m = exts
+	default:
+		return nil, errors.New("proto: not an extension map")
 	}
 	}
 
 
 	// Sort extension IDs to provide a deterministic encoding.
 	// Sort extension IDs to provide a deterministic encoding.
@@ -178,7 +190,17 @@ func MarshalMessageSet(m map[int32]Extension) ([]byte, error) {
 
 
 // UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format.
 // UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format.
 // It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option.
 // It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option.
-func UnmarshalMessageSet(buf []byte, m map[int32]Extension) error {
+func UnmarshalMessageSet(buf []byte, exts interface{}) error {
+	var m map[int32]Extension
+	switch exts := exts.(type) {
+	case *XXX_InternalExtensions:
+		m = exts.extensionsWrite()
+	case map[int32]Extension:
+		m = exts
+	default:
+		return errors.New("proto: not an extension map")
+	}
+
 	ms := new(messageSet)
 	ms := new(messageSet)
 	if err := Unmarshal(buf, ms); err != nil {
 	if err := Unmarshal(buf, ms); err != nil {
 		return err
 		return err
@@ -209,7 +231,16 @@ func UnmarshalMessageSet(buf []byte, m map[int32]Extension) error {
 
 
 // MarshalMessageSetJSON encodes the extension map represented by m in JSON format.
 // MarshalMessageSetJSON encodes the extension map represented by m in JSON format.
 // It is called by generated MarshalJSON methods on protocol buffer messages with the message_set_wire_format option.
 // It is called by generated MarshalJSON methods on protocol buffer messages with the message_set_wire_format option.
-func MarshalMessageSetJSON(m map[int32]Extension) ([]byte, error) {
+func MarshalMessageSetJSON(exts interface{}) ([]byte, error) {
+	var m map[int32]Extension
+	switch exts := exts.(type) {
+	case *XXX_InternalExtensions:
+		m, _ = exts.extensionsRead()
+	case map[int32]Extension:
+		m = exts
+	default:
+		return nil, errors.New("proto: not an extension map")
+	}
 	var b bytes.Buffer
 	var b bytes.Buffer
 	b.WriteByte('{')
 	b.WriteByte('{')
 
 
@@ -252,7 +283,7 @@ func MarshalMessageSetJSON(m map[int32]Extension) ([]byte, error) {
 
 
 // UnmarshalMessageSetJSON decodes the extension map encoded in buf in JSON format.
 // UnmarshalMessageSetJSON decodes the extension map encoded in buf in JSON format.
 // It is called by generated UnmarshalJSON methods on protocol buffer messages with the message_set_wire_format option.
 // It is called by generated UnmarshalJSON methods on protocol buffer messages with the message_set_wire_format option.
-func UnmarshalMessageSetJSON(buf []byte, m map[int32]Extension) error {
+func UnmarshalMessageSetJSON(buf []byte, exts interface{}) error {
 	// Common-case fast path.
 	// Common-case fast path.
 	if len(buf) == 0 || bytes.Equal(buf, []byte("{}")) {
 	if len(buf) == 0 || bytes.Equal(buf, []byte("{}")) {
 		return nil
 		return nil

+ 4 - 4
proto/message_set_test.go

@@ -50,13 +50,13 @@ func TestUnmarshalMessageSetWithDuplicate(t *testing.T) {
 	}
 	}
 	t.Logf("Marshaled bytes: %q", b)
 	t.Logf("Marshaled bytes: %q", b)
 
 
-	m := make(map[int32]Extension)
-	if err := UnmarshalMessageSet(b, m); err != nil {
+	var extensions XXX_InternalExtensions
+	if err := UnmarshalMessageSet(b, &extensions); err != nil {
 		t.Fatalf("UnmarshalMessageSet: %v", err)
 		t.Fatalf("UnmarshalMessageSet: %v", err)
 	}
 	}
-	ext, ok := m[12345]
+	ext, ok := extensions.p.extensionMap[12345]
 	if !ok {
 	if !ok {
-		t.Fatalf("Didn't retrieve extension 12345; map is %v", m)
+		t.Fatalf("Didn't retrieve extension 12345; map is %v", extensions.p.extensionMap)
 	}
 	}
 	// Skip wire type/field number and length varints.
 	// Skip wire type/field number and length varints.
 	got := skipVarint(skipVarint(ext.enc))
 	got := skipVarint(skipVarint(ext.enc))

+ 5 - 0
proto/pointer_reflect.go

@@ -139,6 +139,11 @@ func structPointer_StringSlice(p structPointer, f field) *[]string {
 	return structPointer_ifield(p, f).(*[]string)
 	return structPointer_ifield(p, f).(*[]string)
 }
 }
 
 
+// Extensions returns the address of an extension map field in the struct.
+func structPointer_Extensions(p structPointer, f field) *XXX_InternalExtensions {
+	return structPointer_ifield(p, f).(*XXX_InternalExtensions)
+}
+
 // ExtMap returns the address of an extension map field in the struct.
 // ExtMap returns the address of an extension map field in the struct.
 func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
 func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
 	return structPointer_ifield(p, f).(*map[int32]Extension)
 	return structPointer_ifield(p, f).(*map[int32]Extension)

+ 4 - 0
proto/pointer_unsafe.go

@@ -126,6 +126,10 @@ func structPointer_StringSlice(p structPointer, f field) *[]string {
 }
 }
 
 
 // ExtMap returns the address of an extension map field in the struct.
 // ExtMap returns the address of an extension map field in the struct.
+func structPointer_Extensions(p structPointer, f field) *XXX_InternalExtensions {
+	return (*XXX_InternalExtensions)(unsafe.Pointer(uintptr(p) + uintptr(f)))
+}
+
 func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
 func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
 	return (*map[int32]Extension)(unsafe.Pointer(uintptr(p) + uintptr(f)))
 	return (*map[int32]Extension)(unsafe.Pointer(uintptr(p) + uintptr(f)))
 }
 }

+ 8 - 4
proto/properties.go

@@ -682,7 +682,8 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
 	propertiesMap[t] = prop
 	propertiesMap[t] = prop
 
 
 	// build properties
 	// build properties
-	prop.extendable = reflect.PtrTo(t).Implements(extendableProtoType)
+	prop.extendable = reflect.PtrTo(t).Implements(extendableProtoType) ||
+		reflect.PtrTo(t).Implements(extendableProtoV1Type)
 	prop.unrecField = invalidField
 	prop.unrecField = invalidField
 	prop.Prop = make([]*Properties, t.NumField())
 	prop.Prop = make([]*Properties, t.NumField())
 	prop.order = make([]int, t.NumField())
 	prop.order = make([]int, t.NumField())
@@ -693,12 +694,15 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
 		name := f.Name
 		name := f.Name
 		p.init(f.Type, name, f.Tag.Get("protobuf"), &f, false)
 		p.init(f.Type, name, f.Tag.Get("protobuf"), &f, false)
 
 
-		if f.Name == "XXX_extensions" { // special case
+		if f.Name == "XXX_InternalExtensions" { // special case
+			p.enc = (*Buffer).enc_exts
+			p.dec = nil // not needed
+			p.size = size_exts
+		} else if f.Name == "XXX_extensions" { // special case
 			p.enc = (*Buffer).enc_map
 			p.enc = (*Buffer).enc_map
 			p.dec = nil // not needed
 			p.dec = nil // not needed
 			p.size = size_map
 			p.size = size_map
-		}
-		if f.Name == "XXX_unrecognized" { // special case
+		} else if f.Name == "XXX_unrecognized" { // special case
 			prop.unrecField = toField(&f)
 			prop.unrecField = toField(&f)
 		}
 		}
 		oneof := f.Tag.Get("protobuf_oneof") // special case
 		oneof := f.Tag.Get("protobuf_oneof") // special case

+ 8 - 3
proto/text.go

@@ -455,7 +455,7 @@ func (tm *TextMarshaler) writeStruct(w *textWriter, sv reflect.Value) error {
 
 
 	// Extensions (the XXX_extensions field).
 	// Extensions (the XXX_extensions field).
 	pv := sv.Addr()
 	pv := sv.Addr()
-	if pv.Type().Implements(extendableProtoType) {
+	if _, ok := extendable(pv.Interface()); ok {
 		if err := tm.writeExtensions(w, pv); err != nil {
 		if err := tm.writeExtensions(w, pv); err != nil {
 			return err
 			return err
 		}
 		}
@@ -689,17 +689,22 @@ func (s int32Slice) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
 // pv is assumed to be a pointer to a protocol message struct that is extendable.
 // pv is assumed to be a pointer to a protocol message struct that is extendable.
 func (tm *TextMarshaler) writeExtensions(w *textWriter, pv reflect.Value) error {
 func (tm *TextMarshaler) writeExtensions(w *textWriter, pv reflect.Value) error {
 	emap := extensionMaps[pv.Type().Elem()]
 	emap := extensionMaps[pv.Type().Elem()]
-	ep := pv.Interface().(extendableProto)
+	ep, _ := extendable(pv.Interface())
 
 
 	// Order the extensions by ID.
 	// Order the extensions by ID.
 	// This isn't strictly necessary, but it will give us
 	// This isn't strictly necessary, but it will give us
 	// canonical output, which will also make testing easier.
 	// canonical output, which will also make testing easier.
-	m := ep.ExtensionMap()
+	m, mu := ep.extensionsRead()
+	if m == nil {
+		return nil
+	}
+	mu.Lock()
 	ids := make([]int32, 0, len(m))
 	ids := make([]int32, 0, len(m))
 	for id := range m {
 	for id := range m {
 		ids = append(ids, id)
 		ids = append(ids, id)
 	}
 	}
 	sort.Sort(int32Slice(ids))
 	sort.Sort(int32Slice(ids))
+	mu.Unlock()
 
 
 	for _, extNum := range ids {
 	for _, extNum := range ids {
 		ext := m[extNum]
 		ext := m[extNum]

+ 1 - 1
proto/text_parser.go

@@ -550,7 +550,7 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
 				}
 				}
 				reqFieldErr = err
 				reqFieldErr = err
 			}
 			}
-			ep := sv.Addr().Interface().(extendableProto)
+			ep := sv.Addr().Interface().(Message)
 			if !rep {
 			if !rep {
 				SetExtension(ep, desc, ext.Interface())
 				SetExtension(ep, desc, ext.Interface())
 			} else {
 			} else {

+ 11 - 20
protoc-gen-go/generator/generator.go

@@ -62,7 +62,7 @@ import (
 // It is incremented whenever an incompatibility between the generated code and
 // It is incremented whenever an incompatibility between the generated code and
 // proto package is introduced; the generated code references
 // proto package is introduced; the generated code references
 // a constant, proto.ProtoPackageIsVersionN (where N is generatedCodeVersion).
 // a constant, proto.ProtoPackageIsVersionN (where N is generatedCodeVersion).
-const generatedCodeVersion = 1
+const generatedCodeVersion = 2
 
 
 // A Plugin provides functionality to add to the output during Go code generation,
 // A Plugin provides functionality to add to the output during Go code generation,
 // such as to produce RPC stubs.
 // such as to produce RPC stubs.
@@ -363,8 +363,6 @@ func (ms *messageSymbol) GenerateAlias(g *Generator, pkg string) {
 	if ms.hasExtensions {
 	if ms.hasExtensions {
 		g.P("func (*", ms.sym, ") ExtensionRangeArray() []", g.Pkg["proto"], ".ExtensionRange ",
 		g.P("func (*", ms.sym, ") ExtensionRangeArray() []", g.Pkg["proto"], ".ExtensionRange ",
 			"{ return (*", remoteSym, ")(nil).ExtensionRangeArray() }")
 			"{ return (*", remoteSym, ")(nil).ExtensionRangeArray() }")
-		g.P("func (m *", ms.sym, ") ExtensionMap() map[int32]", g.Pkg["proto"], ".Extension ",
-			"{ return (*", remoteSym, ")(m).ExtensionMap() }")
 		if ms.isMessageSet {
 		if ms.isMessageSet {
 			g.P("func (m *", ms.sym, ") Marshal() ([]byte, error) ",
 			g.P("func (m *", ms.sym, ") Marshal() ([]byte, error) ",
 				"{ return (*", remoteSym, ")(m).Marshal() }")
 				"{ return (*", remoteSym, ")(m).Marshal() }")
@@ -1173,7 +1171,9 @@ func (g *Generator) generate(file *FileDescriptor) {
 		// For one file in the package, assert version compatibility.
 		// For one file in the package, assert version compatibility.
 		g.P("// This is a compile-time assertion to ensure that this generated file")
 		g.P("// This is a compile-time assertion to ensure that this generated file")
 		g.P("// is compatible with the proto package it is being compiled against.")
 		g.P("// is compatible with the proto package it is being compiled against.")
-		g.P("const _ = ", g.Pkg["proto"], ".ProtoPackageIsVersion", generatedCodeVersion)
+		g.P("// A compilation error at this line likely means your copy of the")
+		g.P("// proto package needs to be updated.")
+		g.P("const _ = ", g.Pkg["proto"], ".ProtoPackageIsVersion", generatedCodeVersion, " // please upgrade the proto package")
 		g.P()
 		g.P()
 	}
 	}
 
 
@@ -1684,7 +1684,8 @@ func (g *Generator) RecordTypeUse(t string) {
 }
 }
 
 
 // Method names that may be generated.  Fields with these names get an
 // Method names that may be generated.  Fields with these names get an
-// underscore appended.
+// underscore appended. Any change to this set is a potential incompatible
+// API change because it changes generated field names.
 var methodNames = [...]string{
 var methodNames = [...]string{
 	"Reset",
 	"Reset",
 	"String",
 	"String",
@@ -1869,7 +1870,7 @@ func (g *Generator) generateMessage(message *Descriptor) {
 		g.RecordTypeUse(field.GetTypeName())
 		g.RecordTypeUse(field.GetTypeName())
 	}
 	}
 	if len(message.ExtensionRange) > 0 {
 	if len(message.ExtensionRange) > 0 {
-		g.P("XXX_extensions\t\tmap[int32]", g.Pkg["proto"], ".Extension `json:\"-\"`")
+		g.P(g.Pkg["proto"], ".XXX_InternalExtensions `json:\"-\"`")
 	}
 	}
 	if !message.proto3() {
 	if !message.proto3() {
 		g.P("XXX_unrecognized\t[]byte `json:\"-\"`")
 		g.P("XXX_unrecognized\t[]byte `json:\"-\"`")
@@ -1919,22 +1920,22 @@ func (g *Generator) generateMessage(message *Descriptor) {
 			g.P()
 			g.P()
 			g.P("func (m *", ccTypeName, ") Marshal() ([]byte, error) {")
 			g.P("func (m *", ccTypeName, ") Marshal() ([]byte, error) {")
 			g.In()
 			g.In()
-			g.P("return ", g.Pkg["proto"], ".MarshalMessageSet(m.ExtensionMap())")
+			g.P("return ", g.Pkg["proto"], ".MarshalMessageSet(&m.XXX_InternalExtensions)")
 			g.Out()
 			g.Out()
 			g.P("}")
 			g.P("}")
 			g.P("func (m *", ccTypeName, ") Unmarshal(buf []byte) error {")
 			g.P("func (m *", ccTypeName, ") Unmarshal(buf []byte) error {")
 			g.In()
 			g.In()
-			g.P("return ", g.Pkg["proto"], ".UnmarshalMessageSet(buf, m.ExtensionMap())")
+			g.P("return ", g.Pkg["proto"], ".UnmarshalMessageSet(buf, &m.XXX_InternalExtensions)")
 			g.Out()
 			g.Out()
 			g.P("}")
 			g.P("}")
 			g.P("func (m *", ccTypeName, ") MarshalJSON() ([]byte, error) {")
 			g.P("func (m *", ccTypeName, ") MarshalJSON() ([]byte, error) {")
 			g.In()
 			g.In()
-			g.P("return ", g.Pkg["proto"], ".MarshalMessageSetJSON(m.XXX_extensions)")
+			g.P("return ", g.Pkg["proto"], ".MarshalMessageSetJSON(&m.XXX_InternalExtensions)")
 			g.Out()
 			g.Out()
 			g.P("}")
 			g.P("}")
 			g.P("func (m *", ccTypeName, ") UnmarshalJSON(buf []byte) error {")
 			g.P("func (m *", ccTypeName, ") UnmarshalJSON(buf []byte) error {")
 			g.In()
 			g.In()
-			g.P("return ", g.Pkg["proto"], ".UnmarshalMessageSetJSON(buf, m.XXX_extensions)")
+			g.P("return ", g.Pkg["proto"], ".UnmarshalMessageSetJSON(buf, &m.XXX_InternalExtensions)")
 			g.Out()
 			g.Out()
 			g.P("}")
 			g.P("}")
 			g.P("// ensure ", ccTypeName, " satisfies proto.Marshaler and proto.Unmarshaler")
 			g.P("// ensure ", ccTypeName, " satisfies proto.Marshaler and proto.Unmarshaler")
@@ -1956,16 +1957,6 @@ func (g *Generator) generateMessage(message *Descriptor) {
 		g.P("return extRange_", ccTypeName)
 		g.P("return extRange_", ccTypeName)
 		g.Out()
 		g.Out()
 		g.P("}")
 		g.P("}")
-		g.P("func (m *", ccTypeName, ") ExtensionMap() map[int32]", g.Pkg["proto"], ".Extension {")
-		g.In()
-		g.P("if m.XXX_extensions == nil {")
-		g.In()
-		g.P("m.XXX_extensions = make(map[int32]", g.Pkg["proto"], ".Extension)")
-		g.Out()
-		g.P("}")
-		g.P("return m.XXX_extensions")
-		g.Out()
-		g.P("}")
 	}
 	}
 
 
 	// Default constants
 	// Default constants