Quellcode durchsuchen

all: fix reflect.Value.Interface races (#913)

The reflect.Value.Interface method shallow copies the underlying value,
which may copy mutexes and atomically-accessed fields.
Some usages of the Interface method is only to check if the interface value
implements an interface. In which case the shallow copy was unnecessary.
Change those usages to use the reflect.Value.Implements method instead.

Fixes #838
Joe Tsai vor 6 Jahren
Ursprung
Commit
4c88cc3f1a
4 geänderte Dateien mit 47 neuen und 5 gelöschten Zeilen
  1. 2 0
      go.mod
  2. 9 3
      jsonpb/jsonpb.go
  3. 32 0
      proto/all_test.go
  4. 4 2
      proto/text.go

+ 2 - 0
go.mod

@@ -1 +1,3 @@
 module github.com/golang/protobuf
+
+go 1.12

+ 9 - 3
jsonpb/jsonpb.go

@@ -165,6 +165,11 @@ type wkt interface {
 	XXX_WellKnownType() string
 }
 
+var (
+	wktType     = reflect.TypeOf((*wkt)(nil)).Elem()
+	messageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
+)
+
 // marshalObject writes a struct to the Writer.
 func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent, typeURL string) error {
 	if jsm, ok := v.(JSONPBMarshaler); ok {
@@ -531,7 +536,8 @@ func (m *Marshaler) marshalValue(out *errWriter, prop *proto.Properties, v refle
 
 	// Handle well-known types.
 	// Most are handled up in marshalObject (because 99% are messages).
-	if wkt, ok := v.Interface().(wkt); ok {
+	if v.Type().Implements(wktType) {
+		wkt := v.Interface().(wkt)
 		switch wkt.XXX_WellKnownType() {
 		case "NullValue":
 			out.write("null")
@@ -1277,8 +1283,8 @@ func checkRequiredFields(pb proto.Message) error {
 }
 
 func checkRequiredFieldsInValue(v reflect.Value) error {
-	if pm, ok := v.Interface().(proto.Message); ok {
-		return checkRequiredFields(pm)
+	if v.Type().Implements(messageType) {
+		return checkRequiredFields(v.Interface().(proto.Message))
 	}
 	return nil
 }

+ 32 - 0
proto/all_test.go

@@ -45,9 +45,11 @@ import (
 	"testing"
 	"time"
 
+	"github.com/golang/protobuf/jsonpb"
 	. "github.com/golang/protobuf/proto"
 	pb3 "github.com/golang/protobuf/proto/proto3_proto"
 	. "github.com/golang/protobuf/proto/test_proto"
+	descriptorpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
 )
 
 var globalO *Buffer
@@ -2490,3 +2492,33 @@ func BenchmarkUnmarshalUnrecognizedFields(b *testing.B) {
 		p2.Unmarshal(pbd)
 	}
 }
+
+// TestRace tests whether there are races among the different marshalers.
+func TestRace(t *testing.T) {
+	m := &descriptorpb.FileDescriptorProto{
+		Options: &descriptorpb.FileOptions{
+			GoPackage: String("path/to/my/package"),
+		},
+	}
+
+	wg := &sync.WaitGroup{}
+	defer wg.Wait()
+
+	wg.Add(1)
+	go func() {
+		defer wg.Done()
+		Marshal(m)
+	}()
+
+	wg.Add(1)
+	go func() {
+		defer wg.Done()
+		(&jsonpb.Marshaler{}).MarshalToString(m)
+	}()
+
+	wg.Add(1)
+	go func() {
+		defer wg.Done()
+		m.String()
+	}()
+}

+ 4 - 2
proto/text.go

@@ -456,6 +456,8 @@ func (tm *TextMarshaler) writeStruct(w *textWriter, sv reflect.Value) error {
 	return nil
 }
 
+var textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
+
 // writeAny writes an arbitrary field.
 func (tm *TextMarshaler) writeAny(w *textWriter, v reflect.Value, props *Properties) error {
 	v = reflect.Indirect(v)
@@ -519,8 +521,8 @@ func (tm *TextMarshaler) writeAny(w *textWriter, v reflect.Value, props *Propert
 			// mutating this value.
 			v = v.Addr()
 		}
-		if etm, ok := v.Interface().(encoding.TextMarshaler); ok {
-			text, err := etm.MarshalText()
+		if v.Type().Implements(textMarshalerType) {
+			text, err := v.Interface().(encoding.TextMarshaler).MarshalText()
 			if err != nil {
 				return err
 			}