Sfoglia il codice sorgente

internal/fileinit: prevent map entry descriptors from implementing MessageType

The protobuf type system hacks the representation of map entries into that
of a pseudo-message descriptor.

Previously, we made all message descriptors implement MessageType
where type descriptors had a GoType method that simply returned nil.
Unfortunately, this violates a nice property in the Go type system
where being able to assert to a MessageType guarantees that Go type
information is truly associated with that descriptor.

This CL makes it such that message descriptors for map entries
do not implement MessageType.

Change-Id: I23873cb71fe0ab3c0befd8052830ea6e53c97ca9
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/168399
Reviewed-by: Damien Neil <dneil@google.com>
Joe Tsai 7 anni fa
parent
commit
4532dd7969

+ 8 - 0
internal/cmd/generate-types/main.go

@@ -227,11 +227,19 @@ var fileinitDescListTemplate = template.Must(template.New("").Funcs(template.Fun
 		return len(p.list)
 	}
 	func (p *{{$nameList}}) Get(i int) {{.Expr}} {
+		{{- if (eq . "Message")}}
+		return p.list[i].asDesc()
+		{{- else}}
 		return &p.list[i]
+		{{- end}}
 	}
 	func (p *{{$nameList}}) ByName(s protoreflect.Name) {{.Expr}} {
 		if d := p.lazyInit().byName[s]; d != nil {
+			{{- if (eq . "Message")}}
+			return d.asDesc()
+			{{- else}}
 			return d
+			{{- end}}
 		}
 		return nil
 	}

+ 32 - 11
internal/fileinit/desc.go

@@ -108,7 +108,8 @@ type FileBuilder struct {
 	// in "flattened ordering".
 	EnumOutputTypes []pref.EnumType
 	// MessageOutputTypes is where Init stores all initialized message types
-	// in "flattened ordering"; this includes map entry types.
+	// in "flattened ordering". This includes slots for map entry messages,
+	// which are skipped over.
 	MessageOutputTypes []pref.MessageType
 	// ExtensionOutputTypes is where Init stores all initialized extension types
 	// in "flattened ordering".
@@ -141,7 +142,9 @@ func (fb FileBuilder) Init() pref.FileDescriptor {
 		fb.EnumOutputTypes[i] = &fd.allEnums[i]
 	}
 	for i := range fd.allMessages {
-		fb.MessageOutputTypes[i] = &fd.allMessages[i]
+		if mt, _ := fd.allMessages[i].asDesc().(pref.MessageType); mt != nil {
+			fb.MessageOutputTypes[i] = mt
+		}
 	}
 	for i := range fd.allExtensions {
 		fb.ExtensionOutputTypes[i] = &fd.allExtensions[i]
@@ -160,8 +163,10 @@ func (fb FileBuilder) Init() pref.FileDescriptor {
 			}
 		}
 		for i := range fd.allMessages {
-			if err := fb.TypesRegistry.Register(&fd.allMessages[i]); err != nil {
-				panic(err)
+			if mt, _ := fd.allMessages[i].asDesc().(pref.MessageType); mt != nil {
+				if err := fb.TypesRegistry.Register(mt); err != nil {
+					panic(err)
+				}
 			}
 		}
 		for i := range fd.allExtensions {
@@ -278,6 +283,11 @@ func (ed *enumValueDesc) Format(s fmt.State, r rune)         { pfmt.FormatDesc(s
 func (ed *enumValueDesc) ProtoType(pref.EnumValueDescriptor) {}
 
 type (
+	messageType       struct{ *messageDesc }
+	messageDescriptor struct{ *messageDesc }
+
+	// messageDesc does not implement protoreflect.Descriptor to avoid
+	// accidental usages of it as such. Use the asDesc method to retrieve one.
 	messageDesc struct {
 		baseDesc
 
@@ -285,13 +295,13 @@ type (
 		messages   messageDescs
 		extensions extensionDescs
 
-		lazy *messageLazy // protected by fileDesc.once
+		isMapEntry bool
+		lazy       *messageLazy // protected by fileDesc.once
 	}
 	messageLazy struct {
 		typ reflect.Type
 		new func() pref.Message
 
-		isMapEntry      bool
 		isMessageSet    bool
 		fields          fieldDescs
 		oneofs          oneofDescs
@@ -328,12 +338,10 @@ type (
 	}
 )
 
-func (md *messageDesc) GoType() reflect.Type { return md.lazyInit().typ }
-func (md *messageDesc) New() pref.Message    { return md.lazyInit().new() }
-func (md *messageDesc) Options() pref.OptionsMessage {
+func (md *messageDesc) options() pref.OptionsMessage {
 	return unmarshalOptions(ptype.X.MessageOptions(), md.lazyInit().options)
 }
-func (md *messageDesc) IsMapEntry() bool                   { return md.lazyInit().isMapEntry }
+func (md *messageDesc) IsMapEntry() bool                   { return md.isMapEntry }
 func (md *messageDesc) Fields() pref.FieldDescriptors      { return &md.lazyInit().fields }
 func (md *messageDesc) Oneofs() pref.OneofDescriptors      { return &md.lazyInit().oneofs }
 func (md *messageDesc) ReservedNames() pref.Names          { return &md.lazyInit().resvNames }
@@ -346,8 +354,8 @@ func (md *messageDesc) ExtensionRangeOptions(i int) pref.OptionsMessage {
 func (md *messageDesc) Enums() pref.EnumDescriptors           { return &md.enums }
 func (md *messageDesc) Messages() pref.MessageDescriptors     { return &md.messages }
 func (md *messageDesc) Extensions() pref.ExtensionDescriptors { return &md.extensions }
-func (md *messageDesc) Format(s fmt.State, r rune)            { pfmt.FormatDesc(s, r, md) }
 func (md *messageDesc) ProtoType(pref.MessageDescriptor)      {}
+func (md *messageDesc) Format(s fmt.State, r rune)            { pfmt.FormatDesc(s, r, md.asDesc()) }
 func (md *messageDesc) lazyInit() *messageLazy {
 	md.parentFile.lazyInit() // implicitly initializes messageLazy
 	return md.lazy
@@ -359,6 +367,19 @@ func (md *messageDesc) IsMessageSet() bool {
 	return md.lazyInit().isMessageSet
 }
 
+// asDesc returns a protoreflect.MessageDescriptor or protoreflect.MessageType
+// depending on whether the message is a map entry or not.
+func (mb *messageDesc) asDesc() pref.MessageDescriptor {
+	if !mb.isMapEntry {
+		return messageType{mb}
+	}
+	return messageDescriptor{mb}
+}
+func (mt messageType) GoType() reflect.Type               { return mt.lazyInit().typ }
+func (mt messageType) New() pref.Message                  { return mt.lazyInit().new() }
+func (mt messageType) Options() pref.OptionsMessage       { return mt.options() }
+func (md messageDescriptor) Options() pref.OptionsMessage { return md.options() }
+
 func (fd *fieldDesc) Options() pref.OptionsMessage {
 	return unmarshalOptions(ptype.X.FieldOptions(), fd.options)
 }

+ 11 - 4
internal/fileinit/desc_init.go

@@ -19,6 +19,13 @@ func newFileDesc(fb FileBuilder) *fileDesc {
 	file.initDecls(len(fb.EnumOutputTypes), len(fb.MessageOutputTypes), len(fb.ExtensionOutputTypes))
 	file.unmarshalSeed(fb.RawDescriptor)
 
+	// Determine which message descriptors represent map entries based on the
+	// lack of an associated Go type.
+	messageDecls := file.GoTypes[len(file.allEnums):]
+	for i := range file.allMessages {
+		file.allMessages[i].isMapEntry = messageDecls[i] == nil
+	}
+
 	// Extended message dependencies are eagerly handled since registration
 	// needs this information at program init time.
 	for i := range file.allExtensions {
@@ -31,7 +38,7 @@ func newFileDesc(fb FileBuilder) *fileDesc {
 }
 
 // initDecls pre-allocates slices for the exact number of enums, messages
-// (excluding map entries), and extensions declared in the proto file.
+// (including map entries), and extensions declared in the proto file.
 // This is done to avoid regrowing the slice, which would change the address
 // for any previously seen declaration.
 //
@@ -279,7 +286,7 @@ func (md *messageDesc) unmarshalSeed(b []byte, nb *nameBuilder, pf *fileDesc, pd
 		for i := range md.enums.list {
 			_, n := wire.ConsumeVarint(b)
 			v, m := wire.ConsumeBytes(b[n:])
-			md.enums.list[i].unmarshalSeed(v, nb, pf, md, i)
+			md.enums.list[i].unmarshalSeed(v, nb, pf, md.asDesc(), i)
 			b = b[n+m:]
 		}
 	}
@@ -288,7 +295,7 @@ func (md *messageDesc) unmarshalSeed(b []byte, nb *nameBuilder, pf *fileDesc, pd
 		for i := range md.messages.list {
 			_, n := wire.ConsumeVarint(b)
 			v, m := wire.ConsumeBytes(b[n:])
-			md.messages.list[i].unmarshalSeed(v, nb, pf, md, i)
+			md.messages.list[i].unmarshalSeed(v, nb, pf, md.asDesc(), i)
 			b = b[n+m:]
 		}
 	}
@@ -297,7 +304,7 @@ func (md *messageDesc) unmarshalSeed(b []byte, nb *nameBuilder, pf *fileDesc, pd
 		for i := range md.extensions.list {
 			_, n := wire.ConsumeVarint(b)
 			v, m := wire.ConsumeBytes(b[n:])
-			md.extensions.list[i].unmarshalSeed(v, nb, pf, md, i)
+			md.extensions.list[i].unmarshalSeed(v, nb, pf, md.asDesc(), i)
 			b = b[n+m:]
 		}
 	}

+ 25 - 23
internal/fileinit/desc_lazy.go

@@ -64,15 +64,12 @@ func (file *fileDesc) resolveMessages() {
 		md := &file.allMessages[i]
 
 		// Associate the MessageType with a concrete Go type.
-		//
-		// Note that descriptors for map entries, which have no associated
-		// Go type, also implement the protoreflect.MessageType interface,
-		// but have a GoType accessor that reports nil. Calling New results
-		// in a panic, which is sensible behavior.
-		md.lazy.typ = reflect.TypeOf(messageDecls[i])
-		md.lazy.new = func() pref.Message {
-			t := md.lazy.typ.Elem()
-			return reflect.New(t).Interface().(pref.ProtoMessage).ProtoReflect()
+		if !md.isMapEntry {
+			md.lazy.typ = reflect.TypeOf(messageDecls[i])
+			md.lazy.new = func() pref.Message {
+				t := md.lazy.typ.Elem()
+				return reflect.New(t).Interface().(pref.ProtoMessage).ProtoReflect()
+			}
 		}
 
 		// Resolve message field dependencies.
@@ -173,9 +170,9 @@ func (file *fileDesc) resolveExtensions() {
 		// Resolve extension field dependency.
 		switch xd.lazy.kind {
 		case pref.EnumKind:
-			xd.lazy.enumType = file.popEnumDependency()
+			xd.lazy.enumType = file.popEnumDependency().(pref.EnumType)
 		case pref.MessageKind, pref.GroupKind:
-			xd.lazy.messageType = file.popMessageDependency()
+			xd.lazy.messageType = file.popMessageDependency().(pref.MessageType)
 		}
 		xd.lazy.defVal.lazyInit(xd.lazy.kind, file.enumValuesOf(xd.lazy.enumType))
 	}
@@ -219,8 +216,8 @@ func (fd *fileDesc) isMapEntry(md pref.MessageDescriptor) bool {
 	if md == nil {
 		return false
 	}
-	if md, ok := md.(*messageDesc); ok && md.parentFile == fd {
-		return md.lazy.isMapEntry
+	if md, ok := md.(*messageDescriptor); ok && md.parentFile == fd {
+		return md.isMapEntry
 	}
 	return md.IsMapEntry()
 }
@@ -238,7 +235,7 @@ func (fd *fileDesc) enumValuesOf(ed pref.EnumDescriptor) pref.EnumValueDescripto
 	return ed.Values()
 }
 
-func (fd *fileDesc) popEnumDependency() pref.EnumType {
+func (fd *fileDesc) popEnumDependency() pref.EnumDescriptor {
 	depIdx := fd.popDependencyIndex()
 	if depIdx < len(fd.allEnums)+len(fd.allMessages) {
 		return &fd.allEnums[depIdx]
@@ -247,10 +244,10 @@ func (fd *fileDesc) popEnumDependency() pref.EnumType {
 	}
 }
 
-func (fd *fileDesc) popMessageDependency() pref.MessageType {
+func (fd *fileDesc) popMessageDependency() pref.MessageDescriptor {
 	depIdx := fd.popDependencyIndex()
 	if depIdx < len(fd.allEnums)+len(fd.allMessages) {
-		return &fd.allMessages[depIdx-len(fd.allEnums)]
+		return fd.allMessages[depIdx-len(fd.allEnums)].asDesc()
 	} else {
 		return pimpl.Export{}.MessageTypeOf(fd.GoTypes[depIdx])
 	}
@@ -490,6 +487,7 @@ func (vd *enumValueDesc) unmarshalFull(b []byte, nb *nameBuilder, pf *fileDesc,
 func (md *messageDesc) unmarshalFull(b []byte, nb *nameBuilder) {
 	var rawFields, rawOneofs [][]byte
 	var enumIdx, messageIdx, extensionIdx int
+	var isMapEntry bool
 	md.lazy = new(messageLazy)
 	for len(b) > 0 {
 		num, typ, n := wire.ConsumeTag(b)
@@ -521,7 +519,7 @@ func (md *messageDesc) unmarshalFull(b []byte, nb *nameBuilder) {
 				md.extensions.list[extensionIdx].unmarshalFull(v, nb)
 				extensionIdx++
 			case fieldnum.DescriptorProto_Options:
-				md.unmarshalOptions(v)
+				md.unmarshalOptions(v, &isMapEntry)
 			}
 		default:
 			m := wire.ConsumeFieldValue(num, typ, b)
@@ -534,21 +532,25 @@ func (md *messageDesc) unmarshalFull(b []byte, nb *nameBuilder) {
 		md.lazy.oneofs.list = make([]oneofDesc, len(rawOneofs))
 		for i, b := range rawFields {
 			fd := &md.lazy.fields.list[i]
-			fd.unmarshalFull(b, nb, md.parentFile, md, i)
+			fd.unmarshalFull(b, nb, md.parentFile, md.asDesc(), i)
 			if fd.cardinality == pref.Required {
 				md.lazy.reqNumbers.list = append(md.lazy.reqNumbers.list, fd.number)
 			}
 		}
 		for i, b := range rawOneofs {
 			od := &md.lazy.oneofs.list[i]
-			od.unmarshalFull(b, nb, md.parentFile, md, i)
+			od.unmarshalFull(b, nb, md.parentFile, md.asDesc(), i)
 		}
 	}
 
-	md.parentFile.lazy.byName[md.FullName()] = md
+	if isMapEntry != md.isMapEntry {
+		panic("mismatching map entry property")
+	}
+
+	md.parentFile.lazy.byName[md.FullName()] = md.asDesc()
 }
 
-func (md *messageDesc) unmarshalOptions(b []byte) {
+func (md *messageDesc) unmarshalOptions(b []byte, isMapEntry *bool) {
 	md.lazy.options = append(md.lazy.options, b...)
 	for len(b) > 0 {
 		num, typ, n := wire.ConsumeTag(b)
@@ -559,7 +561,7 @@ func (md *messageDesc) unmarshalOptions(b []byte) {
 			b = b[m:]
 			switch num {
 			case fieldnum.MessageOptions_MapEntry:
-				md.lazy.isMapEntry = wire.DecodeBool(v)
+				*isMapEntry = wire.DecodeBool(v)
 			case fieldnum.MessageOptions_MessageSetWireFormat:
 				md.lazy.isMessageSet = wire.DecodeBool(v)
 			}
@@ -646,7 +648,7 @@ func (fd *fieldDesc) unmarshalFull(b []byte, nb *nameBuilder, pf *fileDesc, pd p
 				// In messageDesc.UnmarshalFull, we allocate slices for both
 				// the field and oneof descriptors before unmarshaling either
 				// of them. This ensures pointers to slice elements are stable.
-				od := &pd.(*messageDesc).lazy.oneofs.list[v]
+				od := &pd.(messageType).lazy.oneofs.list[v]
 				od.fields.list = append(od.fields.list, fd)
 				if fd.oneofType != nil {
 					panic("oneof type already set")

+ 2 - 2
internal/fileinit/desc_list_gen.go

@@ -110,11 +110,11 @@ func (p *messageDescs) Len() int {
 	return len(p.list)
 }
 func (p *messageDescs) Get(i int) protoreflect.MessageDescriptor {
-	return &p.list[i]
+	return p.list[i].asDesc()
 }
 func (p *messageDescs) ByName(s protoreflect.Name) protoreflect.MessageDescriptor {
 	if d := p.lazyInit().byName[s]; d != nil {
-		return d
+		return d.asDesc()
 	}
 	return nil
 }

+ 9 - 0
internal/fileinit/fileinit_test.go

@@ -68,6 +68,15 @@ func TestInit(t *testing.T) {
 		}
 	}
 
+	// Verify that message descriptors for map entries have no Go type info.
+	mapEntryName := protoreflect.FullName("goproto.proto.test.TestAllTypes.MapInt32Int32Entry")
+	d := testpb.File_test_test_proto.DescriptorByName(mapEntryName)
+	if _, ok := d.(protoreflect.MessageDescriptor); !ok {
+		t.Errorf("message descriptor for %v not found", mapEntryName)
+	}
+	if _, ok := d.(protoreflect.MessageType); ok {
+		t.Errorf("message descriptor for %v must not implement protoreflect.MessageType", mapEntryName)
+	}
 }
 
 // visitFields calls f for every field set in m and its children.