Przeglądaj źródła

reflect/protoregistry: assume unique proto file path

Previously, we liberally permitted mutiple files to be registered that
have the same path. However, doing so causes complexity in various places
that need to assume that file paths are unique. Since unique paths are
the intention of the proto language, we strictly enforce that now.

Change-Id: Ie8fdd57c824c9809a51859cf20c4bc477b6871be
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/182497
Reviewed-by: Damien Neil <dneil@google.com>
Joe Tsai 6 lat temu
rodzic
commit
bd7b7a9e0c

+ 4 - 11
reflect/protodesc/protodesc.go

@@ -73,19 +73,12 @@ func NewFile(fd *descriptorpb.FileDescriptorProto, r *protoregistry.Files) (prot
 		f.Imports[i].IsWeak = true
 	}
 	for i, path := range fd.GetDependency() {
-		var n int
 		imp := &f.Imports[i]
-		r.RangeFilesByPath(path, func(fd protoreflect.FileDescriptor) bool {
-			imp.FileDescriptor = fd
-			n++
-			return true
-		})
-		if n > 1 {
-			return nil, errors.New("duplicate files for import %q", path)
-		}
-		if imp.IsWeak || imp.FileDescriptor == nil {
-			imp.FileDescriptor = prototype.PlaceholderFile(path, "")
+		fd, err := r.FindFileByPath(path)
+		if err != nil {
+			fd = prototype.PlaceholderFile(path, "")
 		}
+		imp.FileDescriptor = fd
 	}
 
 	imps := importedFiles(f.Imports)

+ 25 - 9
reflect/protoregistry/registry.go

@@ -53,7 +53,7 @@ type Files struct {
 	// Note that files are stored as a slice, since a package may contain
 	// multiple files.
 	descs       map[protoreflect.FullName]interface{}
-	filesByPath map[string][]protoreflect.FileDescriptor
+	filesByPath map[string]protoreflect.FileDescriptor
 }
 
 type packageDescriptor struct {
@@ -88,7 +88,7 @@ func (r *Files) Register(files ...protoreflect.FileDescriptor) error {
 		r.descs = map[protoreflect.FullName]interface{}{
 			"": &packageDescriptor{},
 		}
-		r.filesByPath = make(map[string][]protoreflect.FileDescriptor)
+		r.filesByPath = make(map[string]protoreflect.FileDescriptor)
 	}
 	var firstErr error
 	for _, file := range files {
@@ -99,6 +99,11 @@ func (r *Files) Register(files ...protoreflect.FileDescriptor) error {
 	return firstErr
 }
 func (r *Files) registerFile(file protoreflect.FileDescriptor) error {
+	path := file.Path()
+	if r.filesByPath[path] != nil {
+		return errors.New("file %q is already registered", file.Path())
+	}
+
 	for name := file.Package(); name != ""; name = name.Parent() {
 		switch r.descs[name].(type) {
 		case nil, *packageDescriptor:
@@ -116,9 +121,6 @@ func (r *Files) registerFile(file protoreflect.FileDescriptor) error {
 		return err
 	}
 
-	path := file.Path()
-	r.filesByPath[path] = append(r.filesByPath[path], file)
-
 	for name := file.Package(); name != ""; name = name.Parent() {
 		if r.descs[name] == nil {
 			r.descs[name] = &packageDescriptor{}
@@ -129,6 +131,7 @@ func (r *Files) registerFile(file protoreflect.FileDescriptor) error {
 	rangeRegisteredDescriptors(file, func(desc protoreflect.Descriptor) {
 		r.descs[desc.FullName()] = desc
 	})
+	r.filesByPath[path] = file
 	return nil
 }
 
@@ -187,6 +190,19 @@ func (r *Files) FindServiceByName(name protoreflect.FullName) (protoreflect.Serv
 	return nil, NotFound
 }
 
+// FindFileByPath looks up a file by the path.
+//
+// This returns (nil, NotFound) if not found.
+func (r *Files) FindFileByPath(path string) (protoreflect.FileDescriptor, error) {
+	if r == nil {
+		return nil, NotFound
+	}
+	if fd, ok := r.filesByPath[path]; ok {
+		return fd, nil
+	}
+	return nil, NotFound
+}
+
 // RangeFiles iterates over all registered files.
 // The iteration order is undefined.
 func (r *Files) RangeFiles(f func(protoreflect.FileDescriptor) bool) {
@@ -206,14 +222,14 @@ func (r *Files) RangeFiles(f func(protoreflect.FileDescriptor) bool) {
 
 // RangeFilesByPath iterates over all registered files filtered by
 // the given proto path. The iteration order is undefined.
+//
+// Deprecated: Use FindFileByPath instead.
 func (r *Files) RangeFilesByPath(path string, f func(protoreflect.FileDescriptor) bool) {
 	if r == nil {
 		return
 	}
-	for _, file := range r.filesByPath[path] {
-		if !f(file) {
-			return
-		}
+	if fd, ok := r.filesByPath[path]; ok {
+		f(fd)
 	}
 }
 

+ 20 - 21
reflect/protoregistry/registry_test.go

@@ -35,23 +35,23 @@ func TestFiles(t *testing.T) {
 			inPkg     pref.FullName
 			wantFiles []file
 		}
-		testRangePath struct {
+		testFindPath struct {
 			inPath    string
 			wantFiles []file
 		}
 	)
 
 	tests := []struct {
-		files      []testFile
-		rangePkgs  []testRangePkg
-		rangePaths []testRangePath
+		files     []testFile
+		rangePkgs []testRangePkg
+		findPaths []testFindPath
 	}{{
 		// Test that overlapping packages and files are permitted.
 		files: []testFile{
-			{inFile: &ptype.File{Syntax: pref.Proto2, Package: "foo.bar"}},
+			{inFile: &ptype.File{Syntax: pref.Proto2, Path: "test1.proto", Package: "foo.bar"}},
 			{inFile: &ptype.File{Syntax: pref.Proto2, Path: "foo/bar/test.proto", Package: "my.test"}},
-			{inFile: &ptype.File{Syntax: pref.Proto2, Path: "foo/bar/test.proto", Package: "foo.bar.baz"}},
-			{inFile: &ptype.File{Syntax: pref.Proto2, Package: "my.test.package"}},
+			{inFile: &ptype.File{Syntax: pref.Proto2, Path: "foo/bar/test.proto", Package: "foo.bar.baz"}, wantErr: "already registered"},
+			{inFile: &ptype.File{Syntax: pref.Proto2, Path: "test2.proto", Package: "my.test.package"}},
 			{inFile: &ptype.File{Syntax: pref.Proto2, Package: "foo.bar"}},
 			{inFile: &ptype.File{Syntax: pref.Proto2, Path: "foo/bar/baz/../test.proto", Package: "my.test"}},
 		},
@@ -71,31 +71,29 @@ func TestFiles(t *testing.T) {
 		}, {
 			inPkg: "foo.bar",
 			wantFiles: []file{
-				{"", "foo.bar"},
+				{"test1.proto", "foo.bar"},
 				{"", "foo.bar"},
 			},
 		}, {
-			inPkg: "foo.bar.baz",
+			inPkg: "my.test",
 			wantFiles: []file{
-				{"foo/bar/test.proto", "foo.bar.baz"},
+				{"foo/bar/baz/../test.proto", "my.test"},
+				{"foo/bar/test.proto", "my.test"},
 			},
 		}, {
 			inPkg: "fo",
 		}},
 
-		rangePaths: []testRangePath{{
+		findPaths: []testFindPath{{
 			inPath: "nothing",
 		}, {
 			inPath: "",
 			wantFiles: []file{
 				{"", "foo.bar"},
-				{"", "foo.bar"},
-				{"", "my.test.package"},
 			},
 		}, {
 			inPath: "foo/bar/test.proto",
 			wantFiles: []file{
-				{"foo/bar/test.proto", "foo.bar.baz"},
 				{"foo/bar/test.proto", "my.test"},
 			},
 		}},
@@ -127,6 +125,7 @@ func TestFiles(t *testing.T) {
 		files: []testFile{{
 			inFile: &ptype.File{
 				Syntax:  pref.Proto2,
+				Path:    "test1.proto",
 				Package: "fizz.buzz",
 				Messages: []ptype.Message{{
 					Name: "Message",
@@ -165,6 +164,7 @@ func TestFiles(t *testing.T) {
 		}, {
 			inFile: &ptype.File{
 				Syntax:  pref.Proto2,
+				Path:    "test2.proto",
 				Package: "fizz.buzz.gazz",
 				Enums: []ptype.Enum{{
 					Name:   "Enum",
@@ -172,9 +172,9 @@ func TestFiles(t *testing.T) {
 				}},
 			},
 		}, {
-			// Previously failed registration should not pollute the namespace.
 			inFile: &ptype.File{
 				Syntax:  pref.Proto2,
+				Path:    "test3.proto",
 				Package: "fizz.buzz",
 				Enums: []ptype.Enum{{
 					Name:   "Enum1",
@@ -213,7 +213,7 @@ func TestFiles(t *testing.T) {
 					t.Fatalf("file %d, prototype.NewFile() error: %v", i, err)
 				}
 				gotErr := files.Register(fd)
-				if (gotErr == nil && tc.wantErr != "") || !strings.Contains(fmt.Sprint(gotErr), tc.wantErr) {
+				if ((gotErr == nil) != (tc.wantErr == "")) || !strings.Contains(fmt.Sprint(gotErr), tc.wantErr) {
 					t.Errorf("file %d, Register() = %v, want %v", i, gotErr, tc.wantErr)
 				}
 			}
@@ -229,14 +229,13 @@ func TestFiles(t *testing.T) {
 				}
 			}
 
-			for _, tc := range tt.rangePaths {
+			for _, tc := range tt.findPaths {
 				var gotFiles []file
-				files.RangeFilesByPath(tc.inPath, func(fd pref.FileDescriptor) bool {
+				if fd, err := files.FindFileByPath(tc.inPath); err == nil {
 					gotFiles = append(gotFiles, file{fd.Path(), fd.Package()})
-					return true
-				})
+				}
 				if diff := cmp.Diff(tc.wantFiles, gotFiles, sortFiles); diff != "" {
-					t.Errorf("RangeFilesByPath(%v) mismatch (-want +got):\n%v", tc.inPath, diff)
+					t.Errorf("FindFileByPath(%v) mismatch (-want +got):\n%v", tc.inPath, diff)
 				}
 			}
 		})