Просмотр исходного кода

net/proto2: remove <message>.ExtensionMap() from generated messages

Turn generated message struct field XXX_Extensions map[int32]Extension
into an embedded proto.InternalExtensions  struct

InternalExtensions is a struct without any exported fields and methods.
This effectively makes the representation of the extension map private.
The proto package can access InternalExtensions by checking that the
generated struct has the method 'extmap() proto.InternalExtensions'.

Also lock accesses to the extension map.

This change bumps the Go protobuf generated code version number. Any
.pb.go files generated with this version of the proto package or later
will require this version or later of the proto package to compile.
matloob@google.com 9 лет назад
Родитель
Сommit
e51d002c61

+ 6 - 11
jsonpb/jsonpb.go

@@ -233,12 +233,14 @@ func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent, typeU
 	}
 
 	// Handle proto2 extensions.
-	if ep, ok := v.(extendableProto); ok {
+	if ep, ok := v.(proto.Message); ok {
 		extensions := proto.RegisteredExtensions(v)
-		extensionMap := ep.ExtensionMap()
 		// Sort extensions for stable output.
-		ids := make([]int32, 0, len(extensionMap))
-		for id := range extensionMap {
+		ids := make([]int32, 0, len(extensions))
+		for id, desc := range extensions {
+			if !proto.HasExtension(ep, desc) {
+				continue
+			}
 			ids = append(ids, id)
 		}
 		sort.Sort(int32Slice(ids))
@@ -767,13 +769,6 @@ func acceptedJSONFieldNames(prop *proto.Properties) fieldNames {
 	return opts
 }
 
-// extendableProto is an interface implemented by any protocol buffer that may be extended.
-type extendableProto interface {
-	proto.Message
-	ExtensionRangeArray() []proto.ExtensionRange
-	ExtensionMap() map[int32]proto.Extension
-}
-
 // Writer wrapper inspired by https://blog.golang.org/errors-are-values
 type errWriter struct {
 	writer io.Writer

+ 9 - 3
proto/clone.go

@@ -84,9 +84,15 @@ func mergeStruct(out, in reflect.Value) {
 		mergeAny(out.Field(i), in.Field(i), false, sprop.Prop[i])
 	}
 
-	if emIn, ok := in.Addr().Interface().(extendableProto); ok {
-		emOut := out.Addr().Interface().(extendableProto)
-		mergeExtension(emOut.ExtensionMap(), emIn.ExtensionMap())
+	if emIn, ok := extendable(in.Addr().Interface()); ok {
+		emOut, _ := extendable(out.Addr().Interface())
+		mIn, muIn := emIn.extensionsRead()
+		if mIn != nil {
+			mOut := emOut.extensionsWrite()
+			muIn.Lock()
+			mergeExtension(mOut, mIn)
+			muIn.Unlock()
+		}
 	}
 
 	uf := in.FieldByName("XXX_unrecognized")

+ 4 - 3
proto/decode.go

@@ -390,11 +390,12 @@ func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group
 		if !ok {
 			// Maybe it's an extension?
 			if prop.extendable {
-				if e := structPointer_Interface(base, st).(extendableProto); isExtensionField(e, int32(tag)) {
+				if e, _ := extendable(structPointer_Interface(base, st)); isExtensionField(e, int32(tag)) {
 					if err = o.skip(st, tag, wire); err == nil {
-						ext := e.ExtensionMap()[int32(tag)] // may be missing
+						extmap := e.extensionsWrite()
+						ext := extmap[int32(tag)] // may be missing
 						ext.enc = append(ext.enc, o.buf[oi:o.index]...)
-						e.ExtensionMap()[int32(tag)] = ext
+						extmap[int32(tag)] = ext
 					}
 					continue
 				}

+ 24 - 4
proto/encode.go

@@ -1073,10 +1073,25 @@ func size_slice_struct_group(p *Properties, base structPointer) (n int) {
 
 // Encode an extension map.
 func (o *Buffer) enc_map(p *Properties, base structPointer) error {
-	v := *structPointer_ExtMap(base, p.field)
-	if err := encodeExtensionMap(v); err != nil {
+	exts := structPointer_ExtMap(base, p.field)
+	if err := encodeExtensionsMap(*exts); err != nil {
 		return err
 	}
+
+	return o.enc_map_body(*exts)
+}
+
+func (o *Buffer) enc_exts(p *Properties, base structPointer) error {
+	exts := structPointer_Extensions(base, p.field)
+	if err := encodeExtensions(exts); err != nil {
+		return err
+	}
+	v, _ := exts.extensionsRead()
+
+	return o.enc_map_body(v)
+}
+
+func (o *Buffer) enc_map_body(v map[int32]Extension) error {
 	// Fast-path for common cases: zero or one extensions.
 	if len(v) <= 1 {
 		for _, e := range v {
@@ -1099,8 +1114,13 @@ func (o *Buffer) enc_map(p *Properties, base structPointer) error {
 }
 
 func size_map(p *Properties, base structPointer) int {
-	v := *structPointer_ExtMap(base, p.field)
-	return sizeExtensionMap(v)
+	v := structPointer_ExtMap(base, p.field)
+	return extensionsMapSize(*v)
+}
+
+func size_exts(p *Properties, base structPointer) int {
+	v := structPointer_Extensions(base, p.field)
+	return extensionsSize(v)
 }
 
 // Encode a map field.

+ 7 - 5
proto/equal.go

@@ -121,9 +121,9 @@ func equalStruct(v1, v2 reflect.Value) bool {
 		}
 	}
 
-	if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() {
-		em2 := v2.FieldByName("XXX_extensions")
-		if !equalExtensions(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) {
+	if em1 := v1.FieldByName("XXX_InternalExtensions"); em1.IsValid() {
+		em2 := v2.FieldByName("XXX_InternalExtensions")
+		if !equalExtensions(v1.Type(), em1.Interface().(XXX_InternalExtensions), em2.Interface().(XXX_InternalExtensions)) {
 			return false
 		}
 	}
@@ -223,8 +223,10 @@ func equalAny(v1, v2 reflect.Value, prop *Properties) bool {
 }
 
 // base is the struct type that the extensions are based on.
-// em1 and em2 are extension maps.
-func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool {
+// x1 and x2 are InternalExtensions.
+func equalExtensions(base reflect.Type, x1, x2 XXX_InternalExtensions) bool {
+	em1, _ := x1.extensionsRead()
+	em2, _ := x2.extensionsRead()
 	if len(em1) != len(em2) {
 		return false
 	}

+ 142 - 18
proto/extensions.go

@@ -52,14 +52,99 @@ type ExtensionRange struct {
 	Start, End int32 // both inclusive
 }
 
-// extendableProto is an interface implemented by any protocol buffer that may be extended.
+// extendableProto is an interface implemented by any protocol buffer generated by the current
+// proto compiler that may be extended.
 type extendableProto interface {
+	Message
+	ExtensionRangeArray() []ExtensionRange
+	extensionsWrite() map[int32]Extension
+	extensionsRead() (map[int32]Extension, sync.Locker)
+}
+
+// extendableProtoV1 is an interface implemented by a protocol buffer generated by the previous
+// version of the proto compiler that may be extended.
+type extendableProtoV1 interface {
 	Message
 	ExtensionRangeArray() []ExtensionRange
 	ExtensionMap() map[int32]Extension
 }
 
+// extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto.
+type extensionAdapter struct {
+	extendableProtoV1
+}
+
+func (e extensionAdapter) extensionsWrite() map[int32]Extension {
+	return e.ExtensionMap()
+}
+
+func (e extensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) {
+	return e.ExtensionMap(), notLocker{}
+}
+
+// notLocker is a sync.Locker whose Lock and Unlock methods are nops.
+type notLocker struct{}
+
+func (n notLocker) Lock()   {}
+func (n notLocker) Unlock() {}
+
+// extendable returns the extendableProto interface for the given generated proto message.
+// If the proto message has the old extension format, it returns a wrapper that implements
+// the extendableProto interface.
+func extendable(p interface{}) (extendableProto, bool) {
+	if ep, ok := p.(extendableProto); ok {
+		return ep, ok
+	}
+	if ep, ok := p.(extendableProtoV1); ok {
+		return extensionAdapter{ep}, ok
+	}
+	return nil, false
+}
+
+// XXX_InternalExtensions is an internal representation of proto extensions.
+//
+// Each generated message struct type embeds an anonymous XXX_InternalExtensions field,
+// thus gaining the unexported 'extensions' method, which can be called only from the proto package.
+//
+// The methods of XXX_InternalExtensions are not concurrency safe in general,
+// but calls to logically read-only methods such as has and get may be executed concurrently.
+type XXX_InternalExtensions struct {
+	// The struct must be indirect so that if a user inadvertently copies a
+	// generated message and its embedded XXX_InternalExtensions, they
+	// avoid the mayhem of a copied mutex.
+	//
+	// The mutex serializes all logically read-only operations to p.extensionMap.
+	// It is up to the client to ensure that write operations to p.extensionMap are
+	// mutually exclusive with other accesses.
+	p *struct {
+		mu           sync.Mutex
+		extensionMap map[int32]Extension
+	}
+}
+
+// extensionsWrite returns the extension map, creating it on first use.
+func (e *XXX_InternalExtensions) extensionsWrite() map[int32]Extension {
+	if e.p == nil {
+		e.p = new(struct {
+			mu           sync.Mutex
+			extensionMap map[int32]Extension
+		})
+		e.p.extensionMap = make(map[int32]Extension)
+	}
+	return e.p.extensionMap
+}
+
+// extensionsRead returns the extensions map for read-only use.  It may be nil.
+// The caller must hold the returned mutex's lock when accessing Elements within the map.
+func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Locker) {
+	if e.p == nil {
+		return nil, nil
+	}
+	return e.p.extensionMap, &e.p.mu
+}
+
 var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem()
+var extendableProtoV1Type = reflect.TypeOf((*extendableProtoV1)(nil)).Elem()
 
 // ExtensionDesc represents an extension specification.
 // Used in generated code from the protocol compiler.
@@ -93,11 +178,12 @@ type Extension struct {
 
 // SetRawExtension is for testing only.
 func SetRawExtension(base Message, id int32, b []byte) {
-	epb, ok := base.(extendableProto)
+	epb, ok := extendable(base)
 	if !ok {
 		return
 	}
-	epb.ExtensionMap()[id] = Extension{enc: b}
+	extmap := epb.extensionsWrite()
+	extmap[id] = Extension{enc: b}
 }
 
 // isExtensionField returns true iff the given field number is in an extension range.
@@ -112,8 +198,12 @@ func isExtensionField(pb extendableProto, field int32) bool {
 
 // checkExtensionTypes checks that the given extension is valid for pb.
 func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error {
+	var pbi interface{} = pb
 	// Check the extended type.
-	if a, b := reflect.TypeOf(pb), reflect.TypeOf(extension.ExtendedType); a != b {
+	if ea, ok := pbi.(extensionAdapter); ok {
+		pbi = ea.extendableProtoV1
+	}
+	if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b {
 		return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String())
 	}
 	// Check the range.
@@ -159,8 +249,19 @@ func extensionProperties(ed *ExtensionDesc) *Properties {
 	return prop
 }
 
-// encodeExtensionMap encodes any unmarshaled (unencoded) extensions in m.
-func encodeExtensionMap(m map[int32]Extension) error {
+// encode encodes any unmarshaled (unencoded) extensions in e.
+func encodeExtensions(e *XXX_InternalExtensions) error {
+	m, mu := e.extensionsRead()
+	if m == nil {
+		return nil // fast path
+	}
+	mu.Lock()
+	defer mu.Unlock()
+	return encodeExtensionsMap(m)
+}
+
+// encode encodes any unmarshaled (unencoded) extensions in e.
+func encodeExtensionsMap(m map[int32]Extension) error {
 	for k, e := range m {
 		if e.value == nil || e.desc == nil {
 			// Extension is only in its encoded form.
@@ -188,7 +289,17 @@ func encodeExtensionMap(m map[int32]Extension) error {
 	return nil
 }
 
-func sizeExtensionMap(m map[int32]Extension) (n int) {
+func extensionsSize(e *XXX_InternalExtensions) (n int) {
+	m, mu := e.extensionsRead()
+	if m == nil {
+		return 0
+	}
+	mu.Lock()
+	defer mu.Unlock()
+	return extensionsMapSize(m)
+}
+
+func extensionsMapSize(m map[int32]Extension) (n int) {
 	for _, e := range m {
 		if e.value == nil || e.desc == nil {
 			// Extension is only in its encoded form.
@@ -215,28 +326,35 @@ func sizeExtensionMap(m map[int32]Extension) (n int) {
 // HasExtension returns whether the given extension is present in pb.
 func HasExtension(pb Message, extension *ExtensionDesc) bool {
 	// TODO: Check types, field numbers, etc.?
-	epb, ok := pb.(extendableProto)
+	epb, ok := extendable(pb)
 	if !ok {
 		return false
 	}
-	_, ok = epb.ExtensionMap()[extension.Field]
+	extmap, mu := epb.extensionsRead()
+	if extmap == nil {
+		return false
+	}
+	mu.Lock()
+	_, ok = extmap[extension.Field]
+	mu.Unlock()
 	return ok
 }
 
 // ClearExtension removes the given extension from pb.
 func ClearExtension(pb Message, extension *ExtensionDesc) {
-	epb, ok := pb.(extendableProto)
+	epb, ok := extendable(pb)
 	if !ok {
 		return
 	}
 	// TODO: Check types, field numbers, etc.?
-	delete(epb.ExtensionMap(), extension.Field)
+	extmap := epb.extensionsWrite()
+	delete(extmap, extension.Field)
 }
 
 // GetExtension parses and returns the given extension of pb.
 // If the extension is not present and has no default value it returns ErrMissingExtension.
 func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
-	epb, ok := pb.(extendableProto)
+	epb, ok := extendable(pb)
 	if !ok {
 		return nil, errors.New("proto: not an extendable proto")
 	}
@@ -245,7 +363,12 @@ func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
 		return nil, err
 	}
 
-	emap := epb.ExtensionMap()
+	emap, mu := epb.extensionsRead()
+	if emap == nil {
+		return defaultExtensionValue(extension)
+	}
+	mu.Lock()
+	defer mu.Unlock()
 	e, ok := emap[extension.Field]
 	if !ok {
 		// defaultExtensionValue returns the default value or
@@ -349,7 +472,7 @@ func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
 // GetExtensions returns a slice of the extensions present in pb that are also listed in es.
 // The returned slice has the same length as es; missing extensions will appear as nil elements.
 func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
-	epb, ok := pb.(extendableProto)
+	epb, ok := extendable(pb)
 	if !ok {
 		return nil, errors.New("proto: not an extendable proto")
 	}
@@ -368,7 +491,7 @@ func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, e
 
 // SetExtension sets the specified extension of pb to the specified value.
 func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error {
-	epb, ok := pb.(extendableProto)
+	epb, ok := extendable(pb)
 	if !ok {
 		return errors.New("proto: not an extendable proto")
 	}
@@ -388,17 +511,18 @@ func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error
 		return fmt.Errorf("proto: SetExtension called with nil value of type %T", value)
 	}
 
-	epb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value}
+	extmap := epb.extensionsWrite()
+	extmap[extension.Field] = Extension{desc: extension, value: value}
 	return nil
 }
 
 // ClearAllExtensions clears all extensions from pb.
 func ClearAllExtensions(pb Message) {
-	epb, ok := pb.(extendableProto)
+	epb, ok := extendable(pb)
 	if !ok {
 		return
 	}
-	m := epb.ExtensionMap()
+	m := epb.extensionsWrite()
 	for k := range m {
 		delete(m, k)
 	}

+ 4 - 0
proto/lib.go

@@ -889,6 +889,10 @@ func isProto3Zero(v reflect.Value) bool {
 	return false
 }
 
+// ProtoPackageIsVersion2 is referenced from generated protocol buffer files
+// to assert that that code is compatible with this version of the proto package.
+const ProtoPackageIsVersion2 = true
+
 // ProtoPackageIsVersion1 is referenced from generated protocol buffer files
 // to assert that that code is compatible with this version of the proto package.
 const ProtoPackageIsVersion1 = true

+ 37 - 6
proto/message_set.go

@@ -149,9 +149,21 @@ func skipVarint(buf []byte) []byte {
 
 // MarshalMessageSet encodes the extension map represented by m in the message set wire format.
 // It is called by generated Marshal methods on protocol buffer messages with the message_set_wire_format option.
-func MarshalMessageSet(m map[int32]Extension) ([]byte, error) {
-	if err := encodeExtensionMap(m); err != nil {
-		return nil, err
+func MarshalMessageSet(exts interface{}) ([]byte, error) {
+	var m map[int32]Extension
+	switch exts := exts.(type) {
+	case *XXX_InternalExtensions:
+		if err := encodeExtensions(exts); err != nil {
+			return nil, err
+		}
+		m, _ = exts.extensionsRead()
+	case map[int32]Extension:
+		if err := encodeExtensionsMap(exts); err != nil {
+			return nil, err
+		}
+		m = exts
+	default:
+		return nil, errors.New("proto: not an extension map")
 	}
 
 	// Sort extension IDs to provide a deterministic encoding.
@@ -178,7 +190,17 @@ func MarshalMessageSet(m map[int32]Extension) ([]byte, error) {
 
 // UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format.
 // It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option.
-func UnmarshalMessageSet(buf []byte, m map[int32]Extension) error {
+func UnmarshalMessageSet(buf []byte, exts interface{}) error {
+	var m map[int32]Extension
+	switch exts := exts.(type) {
+	case *XXX_InternalExtensions:
+		m = exts.extensionsWrite()
+	case map[int32]Extension:
+		m = exts
+	default:
+		return errors.New("proto: not an extension map")
+	}
+
 	ms := new(messageSet)
 	if err := Unmarshal(buf, ms); err != nil {
 		return err
@@ -209,7 +231,16 @@ func UnmarshalMessageSet(buf []byte, m map[int32]Extension) error {
 
 // MarshalMessageSetJSON encodes the extension map represented by m in JSON format.
 // It is called by generated MarshalJSON methods on protocol buffer messages with the message_set_wire_format option.
-func MarshalMessageSetJSON(m map[int32]Extension) ([]byte, error) {
+func MarshalMessageSetJSON(exts interface{}) ([]byte, error) {
+	var m map[int32]Extension
+	switch exts := exts.(type) {
+	case *XXX_InternalExtensions:
+		m, _ = exts.extensionsRead()
+	case map[int32]Extension:
+		m = exts
+	default:
+		return nil, errors.New("proto: not an extension map")
+	}
 	var b bytes.Buffer
 	b.WriteByte('{')
 
@@ -252,7 +283,7 @@ func MarshalMessageSetJSON(m map[int32]Extension) ([]byte, error) {
 
 // UnmarshalMessageSetJSON decodes the extension map encoded in buf in JSON format.
 // It is called by generated UnmarshalJSON methods on protocol buffer messages with the message_set_wire_format option.
-func UnmarshalMessageSetJSON(buf []byte, m map[int32]Extension) error {
+func UnmarshalMessageSetJSON(buf []byte, exts interface{}) error {
 	// Common-case fast path.
 	if len(buf) == 0 || bytes.Equal(buf, []byte("{}")) {
 		return nil

+ 4 - 4
proto/message_set_test.go

@@ -50,13 +50,13 @@ func TestUnmarshalMessageSetWithDuplicate(t *testing.T) {
 	}
 	t.Logf("Marshaled bytes: %q", b)
 
-	m := make(map[int32]Extension)
-	if err := UnmarshalMessageSet(b, m); err != nil {
+	var extensions XXX_InternalExtensions
+	if err := UnmarshalMessageSet(b, &extensions); err != nil {
 		t.Fatalf("UnmarshalMessageSet: %v", err)
 	}
-	ext, ok := m[12345]
+	ext, ok := extensions.p.extensionMap[12345]
 	if !ok {
-		t.Fatalf("Didn't retrieve extension 12345; map is %v", m)
+		t.Fatalf("Didn't retrieve extension 12345; map is %v", extensions.p.extensionMap)
 	}
 	// Skip wire type/field number and length varints.
 	got := skipVarint(skipVarint(ext.enc))

+ 5 - 0
proto/pointer_reflect.go

@@ -139,6 +139,11 @@ func structPointer_StringSlice(p structPointer, f field) *[]string {
 	return structPointer_ifield(p, f).(*[]string)
 }
 
+// Extensions returns the address of an extension map field in the struct.
+func structPointer_Extensions(p structPointer, f field) *XXX_InternalExtensions {
+	return structPointer_ifield(p, f).(*XXX_InternalExtensions)
+}
+
 // ExtMap returns the address of an extension map field in the struct.
 func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
 	return structPointer_ifield(p, f).(*map[int32]Extension)

+ 4 - 0
proto/pointer_unsafe.go

@@ -126,6 +126,10 @@ func structPointer_StringSlice(p structPointer, f field) *[]string {
 }
 
 // ExtMap returns the address of an extension map field in the struct.
+func structPointer_Extensions(p structPointer, f field) *XXX_InternalExtensions {
+	return (*XXX_InternalExtensions)(unsafe.Pointer(uintptr(p) + uintptr(f)))
+}
+
 func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
 	return (*map[int32]Extension)(unsafe.Pointer(uintptr(p) + uintptr(f)))
 }

+ 8 - 4
proto/properties.go

@@ -682,7 +682,8 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
 	propertiesMap[t] = prop
 
 	// build properties
-	prop.extendable = reflect.PtrTo(t).Implements(extendableProtoType)
+	prop.extendable = reflect.PtrTo(t).Implements(extendableProtoType) ||
+		reflect.PtrTo(t).Implements(extendableProtoV1Type)
 	prop.unrecField = invalidField
 	prop.Prop = make([]*Properties, t.NumField())
 	prop.order = make([]int, t.NumField())
@@ -693,12 +694,15 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
 		name := f.Name
 		p.init(f.Type, name, f.Tag.Get("protobuf"), &f, false)
 
-		if f.Name == "XXX_extensions" { // special case
+		if f.Name == "XXX_InternalExtensions" { // special case
+			p.enc = (*Buffer).enc_exts
+			p.dec = nil // not needed
+			p.size = size_exts
+		} else if f.Name == "XXX_extensions" { // special case
 			p.enc = (*Buffer).enc_map
 			p.dec = nil // not needed
 			p.size = size_map
-		}
-		if f.Name == "XXX_unrecognized" { // special case
+		} else if f.Name == "XXX_unrecognized" { // special case
 			prop.unrecField = toField(&f)
 		}
 		oneof := f.Tag.Get("protobuf_oneof") // special case

+ 8 - 3
proto/text.go

@@ -455,7 +455,7 @@ func (tm *TextMarshaler) writeStruct(w *textWriter, sv reflect.Value) error {
 
 	// Extensions (the XXX_extensions field).
 	pv := sv.Addr()
-	if pv.Type().Implements(extendableProtoType) {
+	if _, ok := extendable(pv.Interface()); ok {
 		if err := tm.writeExtensions(w, pv); err != nil {
 			return err
 		}
@@ -689,17 +689,22 @@ func (s int32Slice) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
 // pv is assumed to be a pointer to a protocol message struct that is extendable.
 func (tm *TextMarshaler) writeExtensions(w *textWriter, pv reflect.Value) error {
 	emap := extensionMaps[pv.Type().Elem()]
-	ep := pv.Interface().(extendableProto)
+	ep, _ := extendable(pv.Interface())
 
 	// Order the extensions by ID.
 	// This isn't strictly necessary, but it will give us
 	// canonical output, which will also make testing easier.
-	m := ep.ExtensionMap()
+	m, mu := ep.extensionsRead()
+	if m == nil {
+		return nil
+	}
+	mu.Lock()
 	ids := make([]int32, 0, len(m))
 	for id := range m {
 		ids = append(ids, id)
 	}
 	sort.Sort(int32Slice(ids))
+	mu.Unlock()
 
 	for _, extNum := range ids {
 		ext := m[extNum]

+ 1 - 1
proto/text_parser.go

@@ -550,7 +550,7 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
 				}
 				reqFieldErr = err
 			}
-			ep := sv.Addr().Interface().(extendableProto)
+			ep := sv.Addr().Interface().(Message)
 			if !rep {
 				SetExtension(ep, desc, ext.Interface())
 			} else {

+ 11 - 20
protoc-gen-go/generator/generator.go

@@ -62,7 +62,7 @@ import (
 // It is incremented whenever an incompatibility between the generated code and
 // proto package is introduced; the generated code references
 // a constant, proto.ProtoPackageIsVersionN (where N is generatedCodeVersion).
-const generatedCodeVersion = 1
+const generatedCodeVersion = 2
 
 // A Plugin provides functionality to add to the output during Go code generation,
 // such as to produce RPC stubs.
@@ -363,8 +363,6 @@ func (ms *messageSymbol) GenerateAlias(g *Generator, pkg string) {
 	if ms.hasExtensions {
 		g.P("func (*", ms.sym, ") ExtensionRangeArray() []", g.Pkg["proto"], ".ExtensionRange ",
 			"{ return (*", remoteSym, ")(nil).ExtensionRangeArray() }")
-		g.P("func (m *", ms.sym, ") ExtensionMap() map[int32]", g.Pkg["proto"], ".Extension ",
-			"{ return (*", remoteSym, ")(m).ExtensionMap() }")
 		if ms.isMessageSet {
 			g.P("func (m *", ms.sym, ") Marshal() ([]byte, error) ",
 				"{ return (*", remoteSym, ")(m).Marshal() }")
@@ -1173,7 +1171,9 @@ func (g *Generator) generate(file *FileDescriptor) {
 		// For one file in the package, assert version compatibility.
 		g.P("// This is a compile-time assertion to ensure that this generated file")
 		g.P("// is compatible with the proto package it is being compiled against.")
-		g.P("const _ = ", g.Pkg["proto"], ".ProtoPackageIsVersion", generatedCodeVersion)
+		g.P("// A compilation error at this line likely means your copy of the")
+		g.P("// proto package needs to be updated.")
+		g.P("const _ = ", g.Pkg["proto"], ".ProtoPackageIsVersion", generatedCodeVersion, " // please upgrade the proto package")
 		g.P()
 	}
 
@@ -1684,7 +1684,8 @@ func (g *Generator) RecordTypeUse(t string) {
 }
 
 // Method names that may be generated.  Fields with these names get an
-// underscore appended.
+// underscore appended. Any change to this set is a potential incompatible
+// API change because it changes generated field names.
 var methodNames = [...]string{
 	"Reset",
 	"String",
@@ -1869,7 +1870,7 @@ func (g *Generator) generateMessage(message *Descriptor) {
 		g.RecordTypeUse(field.GetTypeName())
 	}
 	if len(message.ExtensionRange) > 0 {
-		g.P("XXX_extensions\t\tmap[int32]", g.Pkg["proto"], ".Extension `json:\"-\"`")
+		g.P(g.Pkg["proto"], ".XXX_InternalExtensions `json:\"-\"`")
 	}
 	if !message.proto3() {
 		g.P("XXX_unrecognized\t[]byte `json:\"-\"`")
@@ -1919,22 +1920,22 @@ func (g *Generator) generateMessage(message *Descriptor) {
 			g.P()
 			g.P("func (m *", ccTypeName, ") Marshal() ([]byte, error) {")
 			g.In()
-			g.P("return ", g.Pkg["proto"], ".MarshalMessageSet(m.ExtensionMap())")
+			g.P("return ", g.Pkg["proto"], ".MarshalMessageSet(&m.XXX_InternalExtensions)")
 			g.Out()
 			g.P("}")
 			g.P("func (m *", ccTypeName, ") Unmarshal(buf []byte) error {")
 			g.In()
-			g.P("return ", g.Pkg["proto"], ".UnmarshalMessageSet(buf, m.ExtensionMap())")
+			g.P("return ", g.Pkg["proto"], ".UnmarshalMessageSet(buf, &m.XXX_InternalExtensions)")
 			g.Out()
 			g.P("}")
 			g.P("func (m *", ccTypeName, ") MarshalJSON() ([]byte, error) {")
 			g.In()
-			g.P("return ", g.Pkg["proto"], ".MarshalMessageSetJSON(m.XXX_extensions)")
+			g.P("return ", g.Pkg["proto"], ".MarshalMessageSetJSON(&m.XXX_InternalExtensions)")
 			g.Out()
 			g.P("}")
 			g.P("func (m *", ccTypeName, ") UnmarshalJSON(buf []byte) error {")
 			g.In()
-			g.P("return ", g.Pkg["proto"], ".UnmarshalMessageSetJSON(buf, m.XXX_extensions)")
+			g.P("return ", g.Pkg["proto"], ".UnmarshalMessageSetJSON(buf, &m.XXX_InternalExtensions)")
 			g.Out()
 			g.P("}")
 			g.P("// ensure ", ccTypeName, " satisfies proto.Marshaler and proto.Unmarshaler")
@@ -1956,16 +1957,6 @@ func (g *Generator) generateMessage(message *Descriptor) {
 		g.P("return extRange_", ccTypeName)
 		g.Out()
 		g.P("}")
-		g.P("func (m *", ccTypeName, ") ExtensionMap() map[int32]", g.Pkg["proto"], ".Extension {")
-		g.In()
-		g.P("if m.XXX_extensions == nil {")
-		g.In()
-		g.P("m.XXX_extensions = make(map[int32]", g.Pkg["proto"], ".Extension)")
-		g.Out()
-		g.P("}")
-		g.P("return m.XXX_extensions")
-		g.Out()
-		g.P("}")
 	}
 
 	// Default constants