Просмотр исходного кода

internal/impl: simplify getMessageInfo

Our specific protoreflect.Message implementations have a special
ProtoMessageInfo method to obtain the *MessageInfo for v1 compatibility.
Use that instead to implement getMessageInfo.

Change-Id: I6cab9aeaa93714be73bd812c3d9a3be0ec86dd52
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/187777
Reviewed-by: Damien Neil <dneil@google.com>
Joe Tsai 6 лет назад
Родитель
Сommit
070c1010d9
2 измененных файлов с 29 добавлено и 36 удалено
  1. 20 20
      internal/impl/codec_field.go
  2. 9 16
      internal/impl/message.go

+ 20 - 20
internal/impl/codec_field.go

@@ -79,19 +79,19 @@ func makeOneofFieldCoder(si structInfo, fd pref.FieldDescriptor) pointerCoderFun
 }
 
 func makeMessageFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
-	if fi, ok := getMessageInfo(ft); ok {
+	if mi := getMessageInfo(ft); mi != nil {
 		return pointerCoderFuncs{
 			size: func(p pointer, tagsize int, opts marshalOptions) int {
-				return sizeMessageInfo(p, fi, tagsize, opts)
+				return sizeMessageInfo(p, mi, tagsize, opts)
 			},
 			marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
-				return appendMessageInfo(b, p, wiretag, fi, opts)
+				return appendMessageInfo(b, p, wiretag, mi, opts)
 			},
 			unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
-				return consumeMessageInfo(b, p, fi, wtyp, opts)
+				return consumeMessageInfo(b, p, mi, wtyp, opts)
 			},
 			isInit: func(p pointer) error {
-				return fi.isInitializedPointer(p.Elem())
+				return mi.isInitializedPointer(p.Elem())
 			},
 		}
 	} else {
@@ -200,19 +200,19 @@ var coderMessageIface = ifaceCoderFuncs{
 
 func makeGroupFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
 	num := fd.Number()
-	if fi, ok := getMessageInfo(ft); ok {
+	if mi := getMessageInfo(ft); mi != nil {
 		return pointerCoderFuncs{
 			size: func(p pointer, tagsize int, opts marshalOptions) int {
-				return sizeGroupType(p, fi, tagsize, opts)
+				return sizeGroupType(p, mi, tagsize, opts)
 			},
 			marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
-				return appendGroupType(b, p, wiretag, fi, opts)
+				return appendGroupType(b, p, wiretag, mi, opts)
 			},
 			unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
-				return consumeGroupType(b, p, fi, num, wtyp, opts)
+				return consumeGroupType(b, p, mi, num, wtyp, opts)
 			},
 			isInit: func(p pointer) error {
-				return fi.isInitializedPointer(p.Elem())
+				return mi.isInitializedPointer(p.Elem())
 			},
 		}
 	} else {
@@ -303,19 +303,19 @@ func makeGroupValueCoder(fd pref.FieldDescriptor, ft reflect.Type) ifaceCoderFun
 }
 
 func makeMessageSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
-	if fi, ok := getMessageInfo(ft); ok {
+	if mi := getMessageInfo(ft); mi != nil {
 		return pointerCoderFuncs{
 			size: func(p pointer, tagsize int, opts marshalOptions) int {
-				return sizeMessageSliceInfo(p, fi, tagsize, opts)
+				return sizeMessageSliceInfo(p, mi, tagsize, opts)
 			},
 			marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
-				return appendMessageSliceInfo(b, p, wiretag, fi, opts)
+				return appendMessageSliceInfo(b, p, wiretag, mi, opts)
 			},
 			unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
-				return consumeMessageSliceInfo(b, p, fi, wtyp, opts)
+				return consumeMessageSliceInfo(b, p, mi, wtyp, opts)
 			},
 			isInit: func(p pointer) error {
-				return isInitMessageSliceInfo(p, fi)
+				return isInitMessageSliceInfo(p, mi)
 			},
 		}
 	}
@@ -471,19 +471,19 @@ var coderMessageSliceIface = ifaceCoderFuncs{
 
 func makeGroupSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
 	num := fd.Number()
-	if fi, ok := getMessageInfo(ft); ok {
+	if mi := getMessageInfo(ft); mi != nil {
 		return pointerCoderFuncs{
 			size: func(p pointer, tagsize int, opts marshalOptions) int {
-				return sizeGroupSliceInfo(p, fi, tagsize, opts)
+				return sizeGroupSliceInfo(p, mi, tagsize, opts)
 			},
 			marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
-				return appendGroupSliceInfo(b, p, wiretag, fi, opts)
+				return appendGroupSliceInfo(b, p, wiretag, mi, opts)
 			},
 			unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
-				return consumeGroupSliceInfo(b, p, num, wtyp, fi, opts)
+				return consumeGroupSliceInfo(b, p, num, wtyp, mi, opts)
 			},
 			isInit: func(p pointer) error {
-				return isInitMessageSliceInfo(p, fi)
+				return isInitMessageSliceInfo(p, mi)
 			},
 		}
 	}

+ 9 - 16
internal/impl/message.go

@@ -66,26 +66,19 @@ type exporter func(v interface{}, i int) interface{}
 
 var prefMessageType = reflect.TypeOf((*pref.Message)(nil)).Elem()
 
-// getMessageInfo returns the MessageInfo (if any) for a type.
-//
-// We find the MessageInfo by calling the ProtoReflect method on the type's
-// zero value and looking at the returned type to see if it is a
-// messageReflectWrapper. Note that the MessageInfo may still be uninitialized
-// at this point.
-func getMessageInfo(mt reflect.Type) (mi *MessageInfo, ok bool) {
-	method, ok := mt.MethodByName("ProtoReflect")
+// getMessageInfo returns the MessageInfo for any message type that
+// is generated by our implementation of protoc-gen-go (for v2 and on).
+// If it is unable to obtain a MessageInfo, it returns nil.
+func getMessageInfo(mt reflect.Type) *MessageInfo {
+	m, ok := reflect.Zero(mt).Interface().(pref.ProtoMessage)
 	if !ok {
-		return nil, false
-	}
-	if method.Type.NumIn() != 1 || method.Type.NumOut() != 1 || method.Type.Out(0) != prefMessageType {
-		return nil, false
+		return nil
 	}
-	ret := reflect.Zero(mt).Method(method.Index).Call(nil)
-	m, ok := ret[0].Elem().Interface().(*messageReflectWrapper)
+	mr, ok := m.ProtoReflect().(interface{ ProtoMessageInfo() *MessageInfo })
 	if !ok {
-		return nil, ok
+		return nil
 	}
-	return m.mi, true
+	return mr.ProtoMessageInfo()
 }
 
 func (mi *MessageInfo) init() {