Procházet zdrojové kódy

reflect/protoregistry: add Num methods for every Range method

The Num methods provide an O(1) lookup for the number of entries that Range
would return. This is needed to implement efficient cache invalidation logic
for caches that wrap the global registry.

Change-Id: I7c4ff97f674c4e9e4caae291f017cfad7294856c
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/193599
Reviewed-by: Damien Neil <dneil@google.com>
Joe Tsai před 6 roky
rodič
revize
72980ee410

+ 66 - 8
reflect/protoregistry/registry.go

@@ -255,23 +255,39 @@ func (r *Files) FindFileByPath(path string) (protoreflect.FileDescriptor, error)
 	return nil, NotFound
 }
 
+// NumFiles reports the number of registered files.
+func (r *Files) NumFiles() int {
+	if r == nil {
+		return 0
+	}
+	return len(r.filesByPath)
+}
+
 // RangeFiles iterates over all registered files.
 // The iteration order is undefined.
 func (r *Files) RangeFiles(f func(protoreflect.FileDescriptor) bool) {
 	if r == nil {
 		return
 	}
-	for _, d := range r.descsByName {
-		if p, ok := d.(*packageDescriptor); ok {
-			for _, file := range p.files {
-				if !f(file) {
-					return
-				}
-			}
+	for _, file := range r.filesByPath {
+		if !f(file) {
+			return
 		}
 	}
 }
 
+// NumFilesByPackage reports the number of registered files in a proto package.
+func (r *Files) NumFilesByPackage(name protoreflect.FullName) int {
+	if r == nil {
+		return 0
+	}
+	p, ok := r.descsByName[name].(*packageDescriptor)
+	if !ok {
+		return 0
+	}
+	return len(p.files)
+}
+
 // RangeFilesByPackage iterates over all registered files in a give proto package.
 // The iteration order is undefined.
 func (r *Files) RangeFilesByPackage(name protoreflect.FullName, f func(protoreflect.FileDescriptor) bool) {
@@ -399,6 +415,10 @@ type Types struct {
 
 	typesByName         typesByName
 	extensionsByMessage extensionsByMessage
+
+	numEnums      int
+	numMessages   int
+	numExtensions int
 }
 
 type (
@@ -428,13 +448,17 @@ typeLoop:
 		case protoreflect.EnumType, protoreflect.MessageType, protoreflect.ExtensionType:
 			// Check for conflicts in typesByName.
 			var desc protoreflect.Descriptor
+			var pcnt *int
 			switch t := typ.(type) {
 			case protoreflect.EnumType:
 				desc = t.Descriptor()
+				pcnt = &r.numEnums
 			case protoreflect.MessageType:
 				desc = t.Descriptor()
+				pcnt = &r.numMessages
 			case protoreflect.ExtensionType:
 				desc = t.TypeDescriptor()
+				pcnt = &r.numExtensions
 			default:
 				panic(fmt.Sprintf("invalid type: %T", t))
 			}
@@ -478,11 +502,12 @@ typeLoop:
 				r.extensionsByMessage[message][field] = xt
 			}
 
-			// Update typesByName.
+			// Update typesByName and the count.
 			if r.typesByName == nil {
 				r.typesByName = make(typesByName)
 			}
 			r.typesByName[name] = typ
+			(*pcnt)++
 		default:
 			if firstErr == nil {
 				firstErr = errors.New("invalid type: %v", typeName(typ))
@@ -573,6 +598,14 @@ func (r *Types) FindExtensionByNumber(message protoreflect.FullName, field proto
 	return nil, NotFound
 }
 
+// NumEnums reports the number of registered enums.
+func (r *Types) NumEnums() int {
+	if r == nil {
+		return 0
+	}
+	return r.numEnums
+}
+
 // RangeEnums iterates over all registered enums.
 // Iteration order is undefined.
 func (r *Types) RangeEnums(f func(protoreflect.EnumType) bool) {
@@ -588,6 +621,14 @@ func (r *Types) RangeEnums(f func(protoreflect.EnumType) bool) {
 	}
 }
 
+// NumMessages reports the number of registered messages.
+func (r *Types) NumMessages() int {
+	if r == nil {
+		return 0
+	}
+	return r.numMessages
+}
+
 // RangeMessages iterates over all registered messages.
 // Iteration order is undefined.
 func (r *Types) RangeMessages(f func(protoreflect.MessageType) bool) {
@@ -603,6 +644,14 @@ func (r *Types) RangeMessages(f func(protoreflect.MessageType) bool) {
 	}
 }
 
+// NumExtensions reports the number of registered extensions.
+func (r *Types) NumExtensions() int {
+	if r == nil {
+		return 0
+	}
+	return r.numExtensions
+}
+
 // RangeExtensions iterates over all registered extensions.
 // Iteration order is undefined.
 func (r *Types) RangeExtensions(f func(protoreflect.ExtensionType) bool) {
@@ -618,6 +667,15 @@ func (r *Types) RangeExtensions(f func(protoreflect.ExtensionType) bool) {
 	}
 }
 
+// NumExtensionsByMessage reports the number of registered extensions for
+// a given message type.
+func (r *Types) NumExtensionsByMessage(message protoreflect.FullName) int {
+	if r == nil {
+		return 0
+	}
+	return len(r.extensionsByMessage[message])
+}
+
 // RangeExtensionsByMessage iterates over all registered extensions filtered
 // by a given message type. Iteration order is undefined.
 func (r *Types) RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool) {

+ 45 - 19
reflect/protoregistry/registry_test.go

@@ -298,10 +298,16 @@ func TestFiles(t *testing.T) {
 
 			for _, tc := range tt.rangePkgs {
 				var gotFiles []file
+				var gotCnt int
+				wantCnt := files.NumFilesByPackage(tc.inPkg)
 				files.RangeFilesByPackage(tc.inPkg, func(fd pref.FileDescriptor) bool {
 					gotFiles = append(gotFiles, file{fd.Path(), fd.Package()})
+					gotCnt++
 					return true
 				})
+				if gotCnt != wantCnt {
+					t.Errorf("NumFilesByPackage(%v) = %v, want %v", tc.inPkg, gotCnt, wantCnt)
+				}
 				if diff := cmp.Diff(tc.wantFiles, gotFiles, sortFiles); diff != "" {
 					t.Errorf("RangeFilesByPackage(%v) mismatch (-want +got):\n%v", tc.inPkg, diff)
 				}
@@ -552,44 +558,59 @@ func TestTypes(t *testing.T) {
 		return x == y
 	})
 
-	t.Run("RangeMessages", func(t *testing.T) {
-		want := []preg.Type{mt1}
+	t.Run("RangeEnums", func(t *testing.T) {
+		want := []preg.Type{et1}
 		var got []preg.Type
-		registry.RangeMessages(func(mt pref.MessageType) bool {
-			got = append(got, mt)
+		var gotCnt int
+		wantCnt := registry.NumEnums()
+		registry.RangeEnums(func(et pref.EnumType) bool {
+			got = append(got, et)
+			gotCnt++
 			return true
 		})
 
-		diff := cmp.Diff(want, got, sortTypes, compare)
-		if diff != "" {
-			t.Errorf("RangeMessages() mismatch (-want +got):\n%v", diff)
+		if gotCnt != wantCnt {
+			t.Errorf("NumEnums() = %v, want %v", gotCnt, wantCnt)
+		}
+		if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
+			t.Errorf("RangeEnums() mismatch (-want +got):\n%v", diff)
 		}
 	})
 
-	t.Run("RangeEnums", func(t *testing.T) {
-		want := []preg.Type{et1}
+	t.Run("RangeMessages", func(t *testing.T) {
+		want := []preg.Type{mt1}
 		var got []preg.Type
-		registry.RangeEnums(func(et pref.EnumType) bool {
-			got = append(got, et)
+		var gotCnt int
+		wantCnt := registry.NumMessages()
+		registry.RangeMessages(func(mt pref.MessageType) bool {
+			got = append(got, mt)
+			gotCnt++
 			return true
 		})
 
-		diff := cmp.Diff(want, got, sortTypes, compare)
-		if diff != "" {
-			t.Errorf("RangeEnums() mismatch (-want +got):\n%v", diff)
+		if gotCnt != wantCnt {
+			t.Errorf("NumMessages() = %v, want %v", gotCnt, wantCnt)
+		}
+		if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
+			t.Errorf("RangeMessages() mismatch (-want +got):\n%v", diff)
 		}
 	})
 
 	t.Run("RangeExtensions", func(t *testing.T) {
 		want := []preg.Type{xt1, xt2}
 		var got []preg.Type
+		var gotCnt int
+		wantCnt := registry.NumExtensions()
 		registry.RangeExtensions(func(xt pref.ExtensionType) bool {
 			got = append(got, xt)
+			gotCnt++
 			return true
 		})
 
-		diff := cmp.Diff(want, got, sortTypes, compare)
-		if diff != "" {
+		if gotCnt != wantCnt {
+			t.Errorf("NumExtensions() = %v, want %v", gotCnt, wantCnt)
+		}
+		if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
 			t.Errorf("RangeExtensions() mismatch (-want +got):\n%v", diff)
 		}
 	})
@@ -597,13 +618,18 @@ func TestTypes(t *testing.T) {
 	t.Run("RangeExtensionsByMessage", func(t *testing.T) {
 		want := []preg.Type{xt1, xt2}
 		var got []preg.Type
-		registry.RangeExtensionsByMessage(pref.FullName("testprotos.Message1"), func(xt pref.ExtensionType) bool {
+		var gotCnt int
+		wantCnt := registry.NumExtensionsByMessage("testprotos.Message1")
+		registry.RangeExtensionsByMessage("testprotos.Message1", func(xt pref.ExtensionType) bool {
 			got = append(got, xt)
+			gotCnt++
 			return true
 		})
 
-		diff := cmp.Diff(want, got, sortTypes, compare)
-		if diff != "" {
+		if gotCnt != wantCnt {
+			t.Errorf("NumExtensionsByMessage() = %v, want %v", gotCnt, wantCnt)
+		}
+		if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
 			t.Errorf("RangeExtensionsByMessage() mismatch (-want +got):\n%v", diff)
 		}
 	})