Browse Source

internal/impl: fix race in aberrant message logic

Previously, when aberrantLoadMessageDesc returned it was guaranteed
to have initialized the current message through the use of the done signal.
However, this does not guarantee that the descriptor for a cylic reference
has also finished initialization.

Rather than add more complicated logic to wait until all cyclic references
have finished initializing, just add a global lock for the entire
aberrantLoadMessageDesc function.

This slows down performance, but is easier to reason about.

Change-Id: I4cdae8b955f71ee40fa6979f5a8d548d9749042c
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/184657
Reviewed-by: Damien Neil <dneil@google.com>
Joe Tsai 6 years ago
parent
commit
32e8a52cbf
2 changed files with 57 additions and 34 deletions
  1. 34 0
      internal/impl/legacy_aberrant_test.go
  2. 23 34
      internal/impl/legacy_message.go

+ 34 - 0
internal/impl/legacy_aberrant_test.go

@@ -7,12 +7,14 @@ package impl_test
 import (
 import (
 	"io"
 	"io"
 	"reflect"
 	"reflect"
+	"sync"
 	"testing"
 	"testing"
 
 
 	"google.golang.org/protobuf/encoding/prototext"
 	"google.golang.org/protobuf/encoding/prototext"
 	"google.golang.org/protobuf/internal/impl"
 	"google.golang.org/protobuf/internal/impl"
 	"google.golang.org/protobuf/proto"
 	"google.golang.org/protobuf/proto"
 	"google.golang.org/protobuf/reflect/protodesc"
 	"google.golang.org/protobuf/reflect/protodesc"
+	"google.golang.org/protobuf/reflect/protoreflect"
 	"google.golang.org/protobuf/runtime/protoiface"
 	"google.golang.org/protobuf/runtime/protoiface"
 
 
 	"google.golang.org/protobuf/types/descriptorpb"
 	"google.golang.org/protobuf/types/descriptorpb"
@@ -286,3 +288,35 @@ func TestAberrant(t *testing.T) {
 		t.Errorf("mismatching descriptor:\ngot  %v\nwant %v", got, want)
 		t.Errorf("mismatching descriptor:\ngot  %v\nwant %v", got, want)
 	}
 	}
 }
 }
+
+type AberrantMessage1 struct {
+	M *AberrantMessage2 `protobuf:"bytes,1,opt,name=message"`
+}
+
+type AberrantMessage2 struct {
+	M *AberrantMessage1 `protobuf:"bytes,1,opt,name=message"`
+}
+
+func TestAberrantRace(t *testing.T) {
+	var gotMD1, wantMD1, gotMD2, wantMD2 protoreflect.MessageDescriptor
+
+	var wg sync.WaitGroup
+	wg.Add(2)
+	go func() {
+		defer wg.Done()
+		md := impl.LegacyLoadMessageDesc(reflect.TypeOf(&AberrantMessage1{}))
+		wantMD2 = md.Fields().Get(0).Message()
+		gotMD2 = wantMD2.Fields().Get(0).Message().Fields().Get(0).Message()
+	}()
+	go func() {
+		defer wg.Done()
+		md := impl.LegacyLoadMessageDesc(reflect.TypeOf(&AberrantMessage2{}))
+		wantMD1 = md.Fields().Get(0).Message()
+		gotMD1 = wantMD1.Fields().Get(0).Message().Fields().Get(0).Message()
+	}()
+	wg.Wait()
+
+	if gotMD1 != wantMD1 || gotMD2 != wantMD2 {
+		t.Errorf("mismatching exact message descriptors")
+	}
+}

+ 23 - 34
internal/impl/legacy_message.go

@@ -59,9 +59,6 @@ var legacyMessageDescCache sync.Map // map[reflect.Type]protoreflect.MessageDesc
 //
 //
 // This is exported for testing purposes.
 // This is exported for testing purposes.
 func LegacyLoadMessageDesc(t reflect.Type) pref.MessageDescriptor {
 func LegacyLoadMessageDesc(t reflect.Type) pref.MessageDescriptor {
-	return legacyLoadMessageDesc(t, true)
-}
-func legacyLoadMessageDesc(t reflect.Type, finalized bool) pref.MessageDescriptor {
 	// Fast-path: check if a MessageDescriptor is cached for this concrete type.
 	// Fast-path: check if a MessageDescriptor is cached for this concrete type.
 	if mi, ok := legacyMessageDescCache.Load(t); ok {
 	if mi, ok := legacyMessageDescCache.Load(t); ok {
 		return mi.(pref.MessageDescriptor)
 		return mi.(pref.MessageDescriptor)
@@ -74,7 +71,7 @@ func legacyLoadMessageDesc(t reflect.Type, finalized bool) pref.MessageDescripto
 	}
 	}
 	mdV1, ok := mv.(messageV1)
 	mdV1, ok := mv.(messageV1)
 	if !ok {
 	if !ok {
-		return aberrantLoadMessageDesc(t, finalized)
+		return aberrantLoadMessageDesc(t)
 	}
 	}
 	b, idxs := mdV1.Descriptor()
 	b, idxs := mdV1.Descriptor()
 
 
@@ -88,16 +85,10 @@ func legacyLoadMessageDesc(t reflect.Type, finalized bool) pref.MessageDescripto
 	return md
 	return md
 }
 }
 
 
-var aberrantMessageDescCache sync.Map // map[reflect.Type]aberrantMessageDesc
-
-// aberrantMessageDesc is a tuple containing a MessageDescriptor and a channel
-// to signal whether the descriptor is initialized. For external lookups,
-// we must ensure that the descriptor is fully initialized. For internal lookups
-// to resolve cycles, we only need to obtain the descriptor reference.
-type aberrantMessageDesc struct {
-	desc protoreflect.MessageDescriptor
-	done chan struct{} // closed when desc is fully initialized
-}
+var (
+	aberrantMessageDescLock  sync.Mutex
+	aberrantMessageDescCache map[reflect.Type]protoreflect.MessageDescriptor
+)
 
 
 // aberrantLoadEnumDesc returns an EnumDescriptor derived from the Go type,
 // aberrantLoadEnumDesc returns an EnumDescriptor derived from the Go type,
 // which must not implement protoreflect.ProtoMessage or messageV1.
 // which must not implement protoreflect.ProtoMessage or messageV1.
@@ -107,31 +98,27 @@ type aberrantMessageDesc struct {
 //
 //
 // The finalized flag determines whether the returned message descriptor must
 // The finalized flag determines whether the returned message descriptor must
 // be fully initialized.
 // be fully initialized.
-func aberrantLoadMessageDesc(t reflect.Type, finalized bool) pref.MessageDescriptor {
-	// Fast-path: check if an MessageDescriptor is cached for this concrete type.
-	if mdi, ok := aberrantMessageDescCache.Load(t); ok {
-		if finalized {
-			<-mdi.(aberrantMessageDesc).done
-		}
-		return mdi.(aberrantMessageDesc).desc
+func aberrantLoadMessageDesc(t reflect.Type) pref.MessageDescriptor {
+	aberrantMessageDescLock.Lock()
+	defer aberrantMessageDescLock.Unlock()
+	if aberrantMessageDescCache == nil {
+		aberrantMessageDescCache = make(map[reflect.Type]protoreflect.MessageDescriptor)
 	}
 	}
-
-	// Medium-path: create an initial descriptor and cache it immediately,
-	// so that cyclic references can be resolved. Each descriptor is paired
-	// with a channel to signal when the descriptor is fully initialized.
-	md := &filedesc.Message{L2: new(filedesc.MessageL2)}
-	mdi := aberrantMessageDesc{desc: md, done: make(chan struct{})}
-	if mdi, ok := aberrantMessageDescCache.LoadOrStore(t, mdi); ok {
-		if finalized {
-			<-mdi.(aberrantMessageDesc).done
-		}
-		return mdi.(aberrantMessageDesc).desc
+	return aberrantLoadMessageDescReentrant(t)
+}
+func aberrantLoadMessageDescReentrant(t reflect.Type) pref.MessageDescriptor {
+	// Fast-path: check if an MessageDescriptor is cached for this concrete type.
+	if md, ok := aberrantMessageDescCache[t]; ok {
+		return md
 	}
 	}
-	defer func() { close(mdi.done) }()
 
 
 	// Slow-path: construct a descriptor from the Go struct type (best-effort).
 	// Slow-path: construct a descriptor from the Go struct type (best-effort).
+	// Cache the MessageDescriptor early on so that we can resolve internal
+	// cyclic references.
+	md := &filedesc.Message{L2: new(filedesc.MessageL2)}
 	md.L0.FullName = aberrantDeriveFullName(t.Elem())
 	md.L0.FullName = aberrantDeriveFullName(t.Elem())
 	md.L0.ParentFile = filedesc.SurrogateProto2
 	md.L0.ParentFile = filedesc.SurrogateProto2
+	aberrantMessageDescCache[t] = md
 
 
 	// Try to determine if the message is using proto3 by checking scalars.
 	// Try to determine if the message is using proto3 by checking scalars.
 	for i := 0; i < t.Elem().NumField(); i++ {
 	for i := 0; i < t.Elem().NumField(); i++ {
@@ -257,6 +244,8 @@ func aberrantAppendField(md *filedesc.Message, goType reflect.Type, tag, tagKey,
 		switch v := reflect.Zero(t).Interface().(type) {
 		switch v := reflect.Zero(t).Interface().(type) {
 		case pref.ProtoMessage:
 		case pref.ProtoMessage:
 			fd.L1.Message = v.ProtoReflect().Descriptor()
 			fd.L1.Message = v.ProtoReflect().Descriptor()
+		case messageV1:
+			fd.L1.Message = LegacyLoadMessageDesc(t)
 		default:
 		default:
 			if t.Kind() == reflect.Map {
 			if t.Kind() == reflect.Map {
 				n := len(md.L1.Messages.List)
 				n := len(md.L1.Messages.List)
@@ -280,7 +269,7 @@ func aberrantAppendField(md *filedesc.Message, goType reflect.Type, tag, tagKey,
 				fd.L1.Message = md2
 				fd.L1.Message = md2
 				break
 				break
 			}
 			}
-			fd.L1.Message = aberrantLoadMessageDesc(t, false)
+			fd.L1.Message = aberrantLoadMessageDescReentrant(t)
 		}
 		}
 	}
 	}
 }
 }