فهرست منبع

reflect/prototype: simplify Go type descriptor constructors

The Go type descriptors protoreflect.{Enum,Message,Extension}Type are simple
wrappers over protoreflect.{Enum,Message,Extension}Descriptor with a small
number of additional methods. It is very unlikely that more will be added in
the near future.

For this reason, construct the types directly using arguments to the constructor
function, as opposed to taking in another struct (which was originally done
to provide flexibility in-case we needed more fields).

Furthmore, rename GoNew and New.

Change-Id: Ic7fb5bc250cdb2761ae03b388b5147ff50f37d15
Reviewed-on: https://go-review.googlesource.com/c/148822
Reviewed-by: Herbie Ong <herbie@google.com>
Joe Tsai 7 سال پیش
والد
کامیت
1c40f4957d
4فایلهای تغییر یافته به همراه162 افزوده شده و 188 حذف شده
  1. 3 6
      internal/impl/message.go
  2. 2 2
      internal/value/convert.go
  3. 10 12
      reflect/protoreflect/type.go
  4. 147 168
      reflect/prototype/go_type.go

+ 3 - 6
internal/impl/message.go

@@ -66,12 +66,9 @@ func (mi *MessageType) init(p interface{}) {
 		//
 		// Generated code ensures that this property holds.
 		if _, ok := p.(pref.ProtoMessage); !ok {
-			mi.pbType = ptype.NewGoMessage(&ptype.GoMessage{
-				MessageDescriptor: mi.Desc,
-				New: func(pref.MessageType) pref.ProtoMessage {
-					p := reflect.New(t.Elem()).Interface()
-					return (*message)(mi.dataTypeOf(p))
-				},
+			mi.pbType = ptype.GoMessage(mi.Desc, func(pref.MessageType) pref.ProtoMessage {
+				p := reflect.New(t.Elem()).Interface()
+				return (*message)(mi.dataTypeOf(p))
 			})
 		}
 

+ 2 - 2
internal/value/convert.go

@@ -111,7 +111,7 @@ func NewLegacyConverter(t reflect.Type, k pref.Kind, wrapLegacyMessage func(refl
 					return pref.ValueOf(e.ProtoReflect().Number())
 				},
 				toGo: func(v pref.Value) reflect.Value {
-					rv := reflect.ValueOf(et.GoNew(v.Enum()))
+					rv := reflect.ValueOf(et.New(v.Enum()))
 					if rv.Type() != t {
 						panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), t))
 					}
@@ -153,7 +153,7 @@ func NewLegacyConverter(t reflect.Type, k pref.Kind, wrapLegacyMessage func(refl
 					return rv
 				},
 				newMessage: func() pref.Message {
-					return mt.GoNew().ProtoReflect()
+					return mt.New().ProtoReflect()
 				},
 			}
 		}

+ 10 - 12
reflect/protoreflect/type.go

@@ -6,8 +6,6 @@ package protoreflect
 
 import "reflect"
 
-// TODO: Rename GoNew as New for MessageType, EnumType, and ExtensionType?
-
 // TODO: For all ByX methods (e.g., ByName, ByJSONName, ByNumber, etc),
 // should they use a (v, ok) signature for the return value?
 
@@ -259,12 +257,12 @@ type isMessageDescriptor interface{ ProtoType(MessageDescriptor) }
 type MessageType interface {
 	MessageDescriptor
 
-	// GoNew returns a newly allocated empty message.
-	GoNew() ProtoMessage
+	// New returns a newly allocated empty message.
+	New() ProtoMessage
 
 	// GoType returns the Go type of the allocated message.
 	//
-	// Invariant: t.GoType() == reflect.TypeOf(t.GoNew())
+	// Invariant: t.GoType() == reflect.TypeOf(t.New())
 	GoType() reflect.Type
 }
 
@@ -437,15 +435,15 @@ type ExtensionDescriptors interface {
 type ExtensionType interface {
 	ExtensionDescriptor
 
-	// GoNew returns a new value for the field.
+	// New returns a new value for the field.
 	// For scalars, this returns the default value in native Go form.
-	GoNew() interface{}
+	New() interface{}
 
 	// GoType returns the Go type of the field value.
 	//
 	// Invariants:
-	//	t.GoType() == reflect.TypeOf(t.GoNew())
-	//	t.GoType() == reflect.TypeOf(t.InterfaceOf(t.ValueOf(t.GoNew())))
+	//	t.GoType() == reflect.TypeOf(t.New())
+	//	t.GoType() == reflect.TypeOf(t.InterfaceOf(t.ValueOf(t.New())))
 	GoType() reflect.Type
 
 	// TODO: How do we reconcile GoType with the existing extension API,
@@ -487,12 +485,12 @@ type isEnumDescriptor interface{ ProtoType(EnumDescriptor) }
 type EnumType interface {
 	EnumDescriptor
 
-	// GoNew returns an instance of this enum type with its value set to n.
-	GoNew(n EnumNumber) ProtoEnum
+	// New returns an instance of this enum type with its value set to n.
+	New(n EnumNumber) ProtoEnum
 
 	// GoType returns the Go type of the enum value.
 	//
-	// Invariants: t.GoType() == reflect.TypeOf(t.GoNew(0))
+	// Invariants: t.GoType() == reflect.TypeOf(t.New(0))
 	GoType() reflect.Type
 }
 

+ 147 - 168
reflect/prototype/go_type.go

@@ -13,125 +13,74 @@ import (
 	"github.com/golang/protobuf/v2/reflect/protoreflect"
 )
 
-// GoEnum is a constructor for a protoreflect.EnumType.
-type GoEnum struct {
-	protoreflect.EnumDescriptor
-
-	// New returns a concrete proto.Enum value with the given enum number.
-	// The constructor must return the same concrete type for each invocation.
-	New func(protoreflect.EnumType, protoreflect.EnumNumber) protoreflect.ProtoEnum
+// GoEnum creates a new protoreflect.EnumType by combining the provided
+// protoreflect.EnumDescriptor with the provided constructor function.
+func GoEnum(ed protoreflect.EnumDescriptor, fn func(protoreflect.EnumType, protoreflect.EnumNumber) protoreflect.ProtoEnum) protoreflect.EnumType {
+	if ed.IsPlaceholder() {
+		panic("enum descriptor must not be a placeholder")
+	}
+	t := &goEnum{EnumDescriptor: ed, new: fn}
+	t.typ = reflect.TypeOf(fn(t, 0))
+	return t
+}
 
-	once   sync.Once
-	goType reflect.Type
+type goEnum struct {
+	protoreflect.EnumDescriptor
+	typ reflect.Type
+	new func(protoreflect.EnumType, protoreflect.EnumNumber) protoreflect.ProtoEnum
 }
-type goEnum struct{ *GoEnum }
 
-// NewGoEnum creates a new protoreflect.EnumType.
-//
-// The caller must relinquish full ownership of the input t and must not
-// access or mutate any fields.
-func NewGoEnum(t *GoEnum) protoreflect.EnumType {
-	if t.IsPlaceholder() {
-		panic("enum descriptor must not be a placeholder")
-	}
-	if t.New == nil {
-		panic("invalid nil constructor for enum kind")
-	}
-	return goEnum{t}
+func (t *goEnum) GoType() reflect.Type {
+	return t.typ
 }
-func (p goEnum) GoNew(n protoreflect.EnumNumber) protoreflect.ProtoEnum {
-	e := p.New(p, n)
-	p.once.Do(func() { p.goType = reflect.TypeOf(e) })
-	if p.goType != reflect.TypeOf(e) {
-		panic(fmt.Sprintf("mismatching types for enum: got %T, want %v", e, p.goType))
+func (t *goEnum) New(n protoreflect.EnumNumber) protoreflect.ProtoEnum {
+	e := t.new(t, n)
+	if t.typ != reflect.TypeOf(e) {
+		panic(fmt.Sprintf("mismatching types for enum: got %T, want %v", e, t.typ))
 	}
 	return e
 }
-func (p goEnum) GoType() reflect.Type {
-	p.once.Do(func() { p.goType = reflect.TypeOf(p.New(p, 0)) })
-	return p.goType
+
+// GoMessage creates a new protoreflect.MessageType by combining the provided
+// protoreflect.MessageDescriptor with the provided constructor function.
+func GoMessage(md protoreflect.MessageDescriptor, fn func(protoreflect.MessageType) protoreflect.ProtoMessage) protoreflect.MessageType {
+	if md.IsPlaceholder() {
+		panic("message descriptor must not be a placeholder")
+	}
+	t := &goMessage{MessageDescriptor: md, new: fn}
+	t.typ = reflect.TypeOf(fn(t))
+	return t
 }
 
-// GoMessage is a constructor for a protoreflect.MessageType.
-type GoMessage struct {
+type goMessage struct {
 	protoreflect.MessageDescriptor
-
-	// New returns a new empty proto.Message instance.
-	// The constructor must return the same concrete type for each invocation.
-	New func(protoreflect.MessageType) protoreflect.ProtoMessage
-
-	once   sync.Once
-	goType reflect.Type
+	typ reflect.Type
+	new func(protoreflect.MessageType) protoreflect.ProtoMessage
 }
-type goMessage struct{ *GoMessage }
 
-// NewGoMessage creates a new protoreflect.MessageType.
-//
-// The caller must relinquish full ownership of the input t and must not
-// access or mutate any fields.
-func NewGoMessage(t *GoMessage) protoreflect.MessageType {
-	if t.IsPlaceholder() {
-		panic("message descriptor must not be a placeholder")
-	}
-	if t.New == nil {
-		panic("invalid nil constructor for message kind")
-	}
-	return goMessage{t}
+func (t *goMessage) GoType() reflect.Type {
+	return t.typ
 }
-func (p goMessage) GoNew() protoreflect.ProtoMessage {
-	m := p.New(p)
-	p.once.Do(func() { p.goType = reflect.TypeOf(m) })
-	if p.goType != reflect.TypeOf(m) {
-		panic(fmt.Sprintf("mismatching types for message: got %T, want %v", m, p.goType))
+func (t *goMessage) New() protoreflect.ProtoMessage {
+	m := t.new(t)
+	if t.typ != reflect.TypeOf(m) {
+		panic(fmt.Sprintf("mismatching types for message: got %T, want %v", m, t.typ))
 	}
 	return m
 }
-func (p goMessage) GoType() reflect.Type {
-	p.once.Do(func() { p.goType = reflect.TypeOf(p.New(p)) })
-	return p.goType
-}
-
-// GoExtension is a constructor for a protoreflect.ExtensionType.
-type GoExtension struct {
-	protoreflect.ExtensionDescriptor
-
-	// NewEnum returns a concrete proto.Enum value with the given enum number.
-	// The constructor must be provided if protoreflect.ExtensionDescriptor.Kind
-	// is protoreflect.EnumKind.
-	//
-	// The returned enum must represent an protoreflect.EnumDescriptor
-	// that matches protoreflect.ExtensionDescriptor.EnumType.
-	NewEnum func(protoreflect.EnumNumber) protoreflect.ProtoEnum
-
-	// NewMessage returns a new empty proto.Message instance.
-	// The constructor must be provided if protoreflect.ExtensionDescriptor.Kind
-	// is protoreflect.MessageKind or protoreflect.GroupKind.
-	//
-	// The returned message must represent an protoreflect.MessageDescriptor
-	// that matches protoreflect.ExtensionDescriptor.MessageType.
-	NewMessage func() protoreflect.ProtoMessage
-
-	// TODO: Separate NewEnum and NewMessage constructors make it possible for
-	// users to provide a constructor that returns a Go type does not match
-	// the corresponding protobuf descriptor in ExtensionDescriptor.
-	// Checking for correctness is hard since descriptors are not comparable.
-	//
-	// An alternative API is for ExtensionDescriptor.{EnumType,MessageType}
-	// to document that it must implement protoreflect.{EnumType,MessageType}.
-
-	once        sync.Once
-	new         func() interface{}
-	goType      reflect.Type
-	valueOf     func(v interface{}) protoreflect.Value
-	interfaceOf func(v protoreflect.Value) interface{}
-}
-type goExtension struct{ *GoExtension }
 
-// NewGoExtension creates a new protoreflect.ExtensionType.
+// GoExtension creates a new protoreflect.ExtensionType.
 //
-// The Go type is currently determined automatically (although custom Go types
-// may be supported in the future). The type is T for scalars and
-// *[]T for vectors. Maps are not valid in extension fields.
+// An enum type must be provided for enum extension fields if
+// ExtensionDescriptor.EnumType does not implement protoreflect.EnumType,
+// in which case it replaces the original enum in ExtensionDescriptor.
+//
+// Similarly, a message type must be provided for message extension fields if
+// ExtensionDescriptor.MessageType does not implement protoreflect.MessageType,
+// in which case it replaces the original message in ExtensionDescriptor.
+//
+// The Go type is currently determined automatically.
+// The type is T for scalars and *[]T for vectors (maps are not allowed).
 // The type T is determined as follows:
 //
 //	+------------+-------------------------------------+
@@ -154,121 +103,151 @@ type goExtension struct{ *GoExtension }
 // which is often, but not required to be, a named int32 type.
 // The type M is the concrete message type returned by NewMessage,
 // which is often, but not required to be, a pointer to a named struct type.
-//
-// The caller must relinquish full ownership of the input t and must not
-// access or mutate any fields.
-func NewGoExtension(t *GoExtension) protoreflect.ExtensionType {
-	if t.ExtendedType() == nil {
+func GoExtension(xd protoreflect.ExtensionDescriptor, et protoreflect.EnumType, mt protoreflect.MessageType) protoreflect.ExtensionType {
+	if xd.ExtendedType() == nil {
 		panic("field descriptor does not extend a message")
 	}
-	switch t.Kind() {
+	switch xd.Kind() {
 	case protoreflect.EnumKind:
-		if t.NewEnum == nil {
-			panic("enum constructor not provided for enum kind")
+		if et2, ok := xd.EnumType().(protoreflect.EnumType); ok && et == nil {
+			et = et2
 		}
-		if t.NewMessage != nil {
-			panic("message constructor provided for enum kind")
+		if et == nil {
+			panic("enum type not provided for enum kind")
+		}
+		if mt != nil {
+			panic("message type provided for enum kind")
 		}
 	case protoreflect.MessageKind, protoreflect.GroupKind:
-		if t.NewMessage == nil {
-			panic("message constructor not provided for message kind")
+		if mt2, ok := xd.MessageType().(protoreflect.MessageType); ok && mt == nil {
+			mt = mt2
+		}
+		if et != nil {
+			panic("enum type provided for message kind")
 		}
-		if t.NewEnum != nil {
-			panic("enum constructor provided for message kind")
+		if mt == nil {
+			panic("message type not provided for message kind")
 		}
 	default:
-		if t.NewMessage != nil || t.NewEnum != nil {
-			panic(fmt.Sprintf("enum or message constructor provided for %v kind", t.Kind()))
+		if et != nil || mt != nil {
+			panic(fmt.Sprintf("enum or message type provided for %v kind", xd.Kind()))
 		}
 	}
-	return goExtension{t}
+	return &goExtension{ExtensionDescriptor: xd, enumType: et, messageType: mt}
 }
-func (p goExtension) GoNew() interface{} {
-	p.lazyInit()
-	v := p.new()
-	if reflect.TypeOf(v) != p.goType {
-		panic(fmt.Sprintf("invalid type: got %T, want %v", v, p.goType))
+
+type goExtension struct {
+	protoreflect.ExtensionDescriptor
+	enumType    protoreflect.EnumType
+	messageType protoreflect.MessageType
+
+	once        sync.Once
+	typ         reflect.Type
+	new         func() interface{}
+	valueOf     func(v interface{}) protoreflect.Value
+	interfaceOf func(v protoreflect.Value) interface{}
+}
+
+func (t *goExtension) EnumType() protoreflect.EnumDescriptor {
+	return t.enumType
+}
+func (t *goExtension) MessageType() protoreflect.MessageDescriptor {
+	return t.messageType
+}
+func (t *goExtension) GoType() reflect.Type {
+	t.lazyInit()
+	return t.typ
+}
+func (t *goExtension) New() interface{} {
+	t.lazyInit()
+	v := t.new()
+	if reflect.TypeOf(v) != t.typ {
+		panic(fmt.Sprintf("invalid type: got %T, want %v", v, t.typ))
 	}
 	return v
 }
-func (p goExtension) GoType() reflect.Type {
-	p.lazyInit()
-	return p.goType
-}
-func (p goExtension) ValueOf(v interface{}) protoreflect.Value {
-	p.lazyInit()
-	if reflect.TypeOf(v) != p.goType {
-		panic(fmt.Sprintf("invalid type: got %T, want %v", v, p.goType))
+func (t *goExtension) ValueOf(v interface{}) protoreflect.Value {
+	t.lazyInit()
+	if reflect.TypeOf(v) != t.typ {
+		panic(fmt.Sprintf("invalid type: got %T, want %v", v, t.typ))
 	}
-	return p.valueOf(v)
+	return t.valueOf(v)
 }
-func (p goExtension) InterfaceOf(pv protoreflect.Value) interface{} {
-	p.lazyInit()
-	v := p.interfaceOf(pv)
-	if reflect.TypeOf(v) != p.goType {
-		panic(fmt.Sprintf("invalid type: got %T, want %v", v, p.goType))
+func (t *goExtension) InterfaceOf(pv protoreflect.Value) interface{} {
+	t.lazyInit()
+	v := t.interfaceOf(pv)
+	if reflect.TypeOf(v) != t.typ {
+		panic(fmt.Sprintf("invalid type: got %T, want %v", v, t.typ))
 	}
 	return v
 }
-func (p goExtension) lazyInit() {
-	p.once.Do(func() {
-		switch p.Cardinality() {
+func (t *goExtension) lazyInit() {
+	t.once.Do(func() {
+		switch t.Cardinality() {
 		case protoreflect.Optional:
-			switch p.Kind() {
+			switch t.Kind() {
 			case protoreflect.EnumKind:
-				p.goType = reflect.TypeOf(p.NewEnum(0))
-				p.new = func() interface{} { return p.NewEnum(p.Default().Enum()) }
-				p.valueOf = func(v interface{}) protoreflect.Value {
+				t.typ = t.enumType.GoType()
+				t.new = func() interface{} {
+					return t.enumType.New(t.Default().Enum())
+				}
+				t.valueOf = func(v interface{}) protoreflect.Value {
 					ev := v.(protoreflect.ProtoEnum).ProtoReflect()
 					return protoreflect.ValueOf(ev.Number())
 				}
-				p.interfaceOf = func(pv protoreflect.Value) interface{} {
-					return p.NewEnum(pv.Enum())
+				t.interfaceOf = func(pv protoreflect.Value) interface{} {
+					return t.enumType.New(pv.Enum())
 				}
 			case protoreflect.MessageKind, protoreflect.GroupKind:
-				p.goType = reflect.TypeOf(p.NewMessage())
-				p.new = func() interface{} { return p.NewMessage() }
-				p.valueOf = func(v interface{}) protoreflect.Value {
-					return protoreflect.ValueOf(v)
+				t.typ = t.messageType.GoType()
+				t.new = func() interface{} {
+					return t.messageType.New()
 				}
-				p.interfaceOf = func(pv protoreflect.Value) interface{} {
+				t.valueOf = func(v interface{}) protoreflect.Value {
+					mv := v.(protoreflect.ProtoMessage).ProtoReflect()
+					return protoreflect.ValueOf(mv)
+				}
+				t.interfaceOf = func(pv protoreflect.Value) interface{} {
 					return pv.Message().Interface()
 				}
 			default:
-				p.goType = goTypeForPBKind[p.Kind()]
-				p.new = func() interface{} { return p.Default().Interface() }
-				p.valueOf = func(v interface{}) protoreflect.Value {
+				t.typ = goTypeForPBKind[t.Kind()]
+				t.new = func() interface{} {
+					return t.Default().Interface()
+				}
+				t.valueOf = func(v interface{}) protoreflect.Value {
 					return protoreflect.ValueOf(v)
 				}
-				p.interfaceOf = func(pv protoreflect.Value) interface{} {
-					v := pv.Interface()
-					return v
+				t.interfaceOf = func(pv protoreflect.Value) interface{} {
+					return pv.Interface()
 				}
 			}
 		case protoreflect.Repeated:
-			var goType reflect.Type
-			switch p.Kind() {
+			var typ reflect.Type
+			switch t.Kind() {
 			case protoreflect.EnumKind:
-				goType = reflect.TypeOf(p.NewEnum(p.Default().Enum()))
+				typ = t.enumType.GoType()
 			case protoreflect.MessageKind, protoreflect.GroupKind:
-				goType = reflect.TypeOf(p.NewMessage())
+				typ = t.messageType.GoType()
 			default:
-				goType = goTypeForPBKind[p.Kind()]
+				typ = goTypeForPBKind[t.Kind()]
+			}
+			c := value.NewConverter(typ, t.Kind())
+			t.typ = reflect.PtrTo(reflect.SliceOf(typ))
+			t.new = func() interface{} {
+				return reflect.New(t.typ.Elem()).Interface()
 			}
-			c := value.NewConverter(goType, p.Kind())
-			p.goType = reflect.PtrTo(reflect.SliceOf(goType))
-			p.new = func() interface{} { return reflect.New(p.goType.Elem()).Interface() }
-			p.valueOf = func(v interface{}) protoreflect.Value {
+			t.valueOf = func(v interface{}) protoreflect.Value {
 				return protoreflect.ValueOf(value.VectorOf(v, c))
 			}
-			p.interfaceOf = func(v protoreflect.Value) interface{} {
+			t.interfaceOf = func(v protoreflect.Value) interface{} {
 				// TODO: Can we assume that Vector implementations know how
 				// to unwrap themselves?
 				// Should this be part of the public API in protoreflect?
 				return v.Vector().(value.Unwrapper).Unwrap()
 			}
 		default:
-			panic(fmt.Sprintf("invalid cardinality: %v", p.Cardinality()))
+			panic(fmt.Sprintf("invalid cardinality: %v", t.Cardinality()))
 		}
 	})
 }