浏览代码

goprotobuf: Introduce new proto.Message type.

Every generated protocol buffer type now implements the proto.Message interface,
which means we can add some compile-time type safety throughout the API
as well as drop a bunch of error cases.

R=r, rsc
CC=golang-dev
http://codereview.appspot.com/6298073
David Symonds 13 年之前
父节点
当前提交
9f60f43c7a

+ 9 - 17
proto/all_test.go

@@ -1077,7 +1077,7 @@ func TestTypeMismatch(t *testing.T) {
 	}
 }
 
-func encodeDecode(t *testing.T, in, out interface{}, msg string) {
+func encodeDecode(t *testing.T, in, out Message, msg string) {
 	buf, err := Marshal(in)
 	if err != nil {
 		t.Fatalf("failed marshaling %v: %v", msg, err)
@@ -1183,6 +1183,10 @@ type NNIMessage struct {
 	nni nonNillableInt
 }
 
+func (*NNIMessage) Reset()         {}
+func (*NNIMessage) String() string { return "" }
+func (*NNIMessage) ProtoMessage()  {}
+
 // A type that implements the Marshaler interface and is nillable.
 type nillableMessage struct {
 	x uint64
@@ -1196,6 +1200,10 @@ type NMMessage struct {
 	nm *nillableMessage
 }
 
+func (*NMMessage) Reset()         {}
+func (*NMMessage) String() string { return "" }
+func (*NMMessage) ProtoMessage()  {}
+
 // Verify a type that uses the Marshaler interface, but has a nil pointer.
 func TestNilMarshaler(t *testing.T) {
 	// Try a struct with a Marshaler field that is nil.
@@ -1214,22 +1222,6 @@ func TestNilMarshaler(t *testing.T) {
 	}
 }
 
-// Check that passing things other than pointer to struct to Marshal
-// returns a good error, rather than panicking.
-func TestStructTyping(t *testing.T) {
-	om := &OtherMessage{}
-	bad := [...]interface{}{*om, &om}
-	for _, pb := range bad {
-		_, err := Marshal(pb)
-		if err != ErrNotPtr {
-			t.Errorf("marshaling %T: got %v, expected %v", pb, err, ErrNotPtr)
-		}
-		if err := Unmarshal([]byte{}, pb); err != ErrNotPtr {
-			t.Errorf("unmarshaling %T: got %v, expected %v", pb, err, ErrNotPtr)
-		}
-	}
-}
-
 func TestAllSetDefaults(t *testing.T) {
 	// Exercise SetDefaults with all scalar field types.
 	m := &Defaults{

+ 2 - 3
proto/clone.go

@@ -41,8 +41,7 @@ import (
 )
 
 // Clone returns a deep copy of a protocol buffer.
-// pb must be a pointer to a protocol buffer struct.
-func Clone(pb interface{}) interface{} {
+func Clone(pb Message) Message {
 	in := reflect.ValueOf(pb)
 	if in.Kind() != reflect.Ptr || in.Elem().Kind() != reflect.Struct {
 		return nil
@@ -50,7 +49,7 @@ func Clone(pb interface{}) interface{} {
 
 	out := reflect.New(in.Type().Elem())
 	copyStruct(out.Elem(), in.Elem())
-	return out.Interface()
+	return out.Interface().(Message)
 }
 
 func copyStruct(out, in reflect.Value) {

+ 5 - 3
proto/decode.go

@@ -293,7 +293,7 @@ type Unmarshaler interface {
 // Unmarshal parses the protocol buffer representation in buf and places the
 // decoded result in pb.  If the struct underlying pb does not match
 // the data in buf, the results can be unpredictable.
-func Unmarshal(buf []byte, pb interface{}) error {
+func Unmarshal(buf []byte, pb Message) error {
 	// If the object can unmarshal itself, let it.
 	if u, ok := pb.(Unmarshaler); ok {
 		return u.Unmarshal(buf)
@@ -306,7 +306,7 @@ func Unmarshal(buf []byte, pb interface{}) error {
 // Buffer and places the decoded result in pb.  If the struct
 // underlying pb does not match the data in the buffer, the results can be
 // unpredictable.
-func (p *Buffer) Unmarshal(pb interface{}) error {
+func (p *Buffer) Unmarshal(pb Message) error {
 	// If the object can unmarshal itself, let it.
 	if u, ok := pb.(Unmarshaler); ok {
 		err := u.Unmarshal(p.buf[p.index:])
@@ -321,7 +321,9 @@ func (p *Buffer) Unmarshal(pb interface{}) error {
 
 	err = p.unmarshalType(typ, false, base)
 
-	stats.Decode++
+	if collectStats {
+		stats.Decode++
+	}
 
 	return err
 }

+ 7 - 8
proto/encode.go

@@ -61,9 +61,6 @@ var (
 
 	// ErrNil is the error returned if Marshal is called with nil.
 	ErrNil = errors.New("proto: Marshal called with nil")
-
-	// ErrNotPtr is the error returned if Marshal is called with something other than a pointer to a struct.
-	ErrNotPtr = errors.New("proto: Marshal called with something other than a pointer to a struct")
 )
 
 // The fundamental encoders that put bytes on the wire.
@@ -169,9 +166,9 @@ type Marshaler interface {
 	Marshal() ([]byte, error)
 }
 
-// Marshal takes the protocol buffer struct represented by pb
+// Marshal takes the protocol buffer
 // and encodes it into the wire format, returning the data.
-func Marshal(pb interface{}) ([]byte, error) {
+func Marshal(pb Message) ([]byte, error) {
 	// Can the object marshal itself?
 	if m, ok := pb.(Marshaler); ok {
 		return m.Marshal()
@@ -184,10 +181,10 @@ func Marshal(pb interface{}) ([]byte, error) {
 	return p.buf, err
 }
 
-// Marshal takes the protocol buffer struct represented by pb
+// Marshal takes the protocol buffer
 // and encodes it into the wire format, writing the result to the
 // Buffer.
-func (p *Buffer) Marshal(pb interface{}) error {
+func (p *Buffer) Marshal(pb Message) error {
 	// Can the object marshal itself?
 	if m, ok := pb.(Marshaler); ok {
 		data, err := m.Marshal()
@@ -203,7 +200,9 @@ func (p *Buffer) Marshal(pb interface{}) error {
 		err = p.enc_struct(t.Elem(), b)
 	}
 
-	stats.Encode++
+	if collectStats {
+		stats.Encode++
+	}
 
 	return err
 }

+ 4 - 5
proto/equal.go

@@ -43,8 +43,7 @@ import (
 
 /*
 Equal returns true iff protocol buffers a and b are equal.
-The arguments must both be protocol buffer structs,
-or both be pointers to protocol buffer structs.
+The arguments must both be pointers to protocol buffer structs.
 
 Equality is defined in this way:
   - Two messages are equal iff they are the same type,
@@ -65,7 +64,7 @@ Equality is defined in this way:
 
 The return value is undefined if a and b are not protocol buffers.
 */
-func Equal(a, b interface{}) bool {
+func Equal(a, b Message) bool {
 	v1, v2 := reflect.ValueOf(a), reflect.ValueOf(b)
 	if v1.Type() != v2.Type() {
 		return false
@@ -182,7 +181,7 @@ func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool {
 
 		if m1 != nil && m2 != nil {
 			// Both are unencoded.
-			if !Equal(m1, m2) {
+			if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2)) {
 				return false
 			}
 			continue
@@ -210,7 +209,7 @@ func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool {
 			log.Printf("proto: badly encoded extension %d of %v: %v", extNum, base, err)
 			return false
 		}
-		if !Equal(m1, m2) {
+		if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2)) {
 			return false
 		}
 	}

+ 15 - 3
proto/equal_test.go

@@ -46,6 +46,10 @@ var messageWithExtension1a = &pb.MyMessage{Count: Int32(7)}
 var messageWithExtension1b = &pb.MyMessage{Count: Int32(7)}
 var messageWithExtension2 = &pb.MyMessage{Count: Int32(7)}
 
+// Two messages with non-message extensions.
+var messageWithInt32Extension1 = &pb.MyMessage{Count: Int32(8)}
+var messageWithInt32Extension2 = &pb.MyMessage{Count: Int32(8)}
+
 func init() {
 	ext1 := &pb.Ext{Data: String("Kirk")}
 	ext2 := &pb.Ext{Data: String("Picard")}
@@ -72,16 +76,21 @@ func init() {
 	if err := SetExtension(messageWithExtension2, pb.E_Ext_More, ext2); err != nil {
 		log.Panicf("SetExtension on 2 failed: %v", err)
 	}
+
+	if err := SetExtension(messageWithInt32Extension1, pb.E_Ext_Number, Int32(23)); err != nil {
+		log.Panicf("SetExtension on Int32-1 failed: %v", err)
+	}
+	if err := SetExtension(messageWithInt32Extension1, pb.E_Ext_Number, Int32(24)); err != nil {
+		log.Panicf("SetExtension on Int32-2 failed: %v", err)
+	}
 }
 
 var EqualTests = []struct {
 	desc string
-	a, b interface{}
+	a, b Message
 	exp  bool
 }{
 	{"different types", &pb.GoEnum{}, &pb.GoTestField{}, false},
-	{"one pointer, one value", &pb.GoEnum{}, pb.GoEnum{}, false},
-	{"non-protocol buffers", 7, 7, false},
 	{"equal empty", &pb.GoEnum{}, &pb.GoEnum{}, true},
 
 	{"one set field, one unset field", &pb.GoTestField{Label: String("foo")}, &pb.GoTestField{}, false},
@@ -123,6 +132,9 @@ var EqualTests = []struct {
 	{"extension vs. no extension", messageWithoutExtension, messageWithExtension1a, false},
 	{"extension vs. same extension", messageWithExtension1a, messageWithExtension1b, true},
 	{"extension vs. different extension", messageWithExtension1a, messageWithExtension2, false},
+
+	{"int32 extension vs. itself", messageWithInt32Extension1, messageWithInt32Extension1, true},
+	{"int32 extension vs. a different int32", messageWithInt32Extension1, messageWithInt32Extension2, false},
 }
 
 func TestEqual(t *testing.T) {

+ 4 - 3
proto/extensions.go

@@ -52,6 +52,7 @@ type ExtensionRange struct {
 
 // extendableProto is an interface implemented by any protocol buffer that may be extended.
 type extendableProto interface {
+	Message
 	ExtensionRangeArray() []ExtensionRange
 	ExtensionMap() map[int32]Extension
 }
@@ -59,7 +60,7 @@ type extendableProto interface {
 // ExtensionDesc represents an extension specification.
 // Used in generated code from the protocol compiler.
 type ExtensionDesc struct {
-	ExtendedType  interface{} // nil pointer to the type that is being extended
+	ExtendedType  Message     // nil pointer to the type that is being extended
 	ExtensionType interface{} // nil pointer to the extension type
 	Field         int32       // field number
 	Name          string      // fully-qualified name of extension, for text formatting
@@ -230,7 +231,7 @@ func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
 
 // GetExtensions returns a slice of the extensions present in pb that are also listed in es.
 // The returned slice has the same length as es; missing extensions will appear as nil elements.
-func GetExtensions(pb interface{}, es []*ExtensionDesc) (extensions []interface{}, err error) {
+func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
 	epb, ok := pb.(extendableProto)
 	if !ok {
 		err = errors.New("not an extendable proto")
@@ -282,6 +283,6 @@ func RegisterExtension(desc *ExtensionDesc) {
 // RegisteredExtensions returns a map of the registered extensions of a
 // protocol buffer struct, indexed by the extension number.
 // The argument pb should be a nil pointer to the struct type.
-func RegisteredExtensions(pb interface{}) map[int32]*ExtensionDesc {
+func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc {
 	return extensionMaps[reflect.TypeOf(pb).Elem()]
 }

+ 12 - 6
proto/lib.go

@@ -172,6 +172,13 @@ import (
 	"sync"
 )
 
+// Message is implemented by generated protocol buffer messages.
+type Message interface {
+	Reset()
+	String() string
+	ProtoMessage()
+}
+
 // Stats records allocation details about the protocol buffer encoders
 // and decoders.  Useful for tuning the library itself.
 type Stats struct {
@@ -183,6 +190,9 @@ type Stats struct {
 	Cmiss   uint64 // number of cache misses
 }
 
+// Set to true to enable stats collection.
+const collectStats = false
+
 var stats Stats
 
 // GetStats returns a copy of the global Stats structure.
@@ -545,12 +555,8 @@ out:
 // SetDefaults sets unset protocol buffer fields to their default values.
 // It only modifies fields that are both unset and have defined defaults.
 // It recursively sets default values in any non-nil sub-messages.
-func SetDefaults(pb interface{}) {
-	v := reflect.ValueOf(pb)
-	if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
-		log.Printf("proto: hit non-pointer-to-struct %v", v)
-	}
-	setDefaults(v, true, false)
+func SetDefaults(pb Message) {
+	setDefaults(reflect.ValueOf(pb), true, false)
 }
 
 // v is a pointer to a struct.

+ 11 - 4
proto/message_set.go

@@ -74,13 +74,16 @@ type MessageSet struct {
 	// TODO: caching?
 }
 
+// Make sure MessageSet is a Message.
+var _ Message = (*MessageSet)(nil)
+
 // messageTypeIder is an interface satisfied by a protocol buffer type
 // that may be stored in a MessageSet.
 type messageTypeIder interface {
 	MessageTypeId() int32
 }
 
-func (ms *MessageSet) find(pb interface{}) *_MessageSet_Item {
+func (ms *MessageSet) find(pb Message) *_MessageSet_Item {
 	mti, ok := pb.(messageTypeIder)
 	if !ok {
 		return nil
@@ -94,14 +97,14 @@ func (ms *MessageSet) find(pb interface{}) *_MessageSet_Item {
 	return nil
 }
 
-func (ms *MessageSet) Has(pb interface{}) bool {
+func (ms *MessageSet) Has(pb Message) bool {
 	if ms.find(pb) != nil {
 		return true
 	}
 	return false
 }
 
-func (ms *MessageSet) Unmarshal(pb interface{}) error {
+func (ms *MessageSet) Unmarshal(pb Message) error {
 	if item := ms.find(pb); item != nil {
 		return Unmarshal(item.Message, pb)
 	}
@@ -111,7 +114,7 @@ func (ms *MessageSet) Unmarshal(pb interface{}) error {
 	return nil // TODO: return error instead?
 }
 
-func (ms *MessageSet) Marshal(pb interface{}) error {
+func (ms *MessageSet) Marshal(pb Message) error {
 	msg, err := Marshal(pb)
 	if err != nil {
 		return err
@@ -135,6 +138,10 @@ func (ms *MessageSet) Marshal(pb interface{}) error {
 	return nil
 }
 
+func (ms *MessageSet) Reset()         { *ms = MessageSet{} }
+func (ms *MessageSet) String() string { return CompactTextString(ms) }
+func (*MessageSet) ProtoMessage()     {}
+
 // Support for the message_set_wire_format message option.
 
 func skipVarint(buf []byte) []byte {

+ 7 - 7
proto/properties.go

@@ -435,10 +435,14 @@ func GetProperties(t reflect.Type) *StructProperties {
 	mutex.Lock()
 	if prop, ok := propertiesMap[t]; ok {
 		mutex.Unlock()
-		stats.Chit++
+		if collectStats {
+			stats.Chit++
+		}
 		return prop
 	}
-	stats.Cmiss++
+	if collectStats {
+		stats.Cmiss++
+	}
 
 	prop := new(StructProperties)
 
@@ -514,17 +518,13 @@ func propByIndex(t reflect.Type, x []int) *Properties {
 }
 
 // Get the address and type of a pointer to a struct from an interface.
-func getbase(pb interface{}) (t reflect.Type, b uintptr, err error) {
+func getbase(pb Message) (t reflect.Type, b uintptr, err error) {
 	if pb == nil {
 		err = ErrNil
 		return
 	}
 	// get the reflect type of the pointer to the struct.
 	t = reflect.TypeOf(pb)
-	if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct {
-		err = ErrNotPtr
-		return
-	}
 	// get the address of the struct.
 	value := reflect.ValueOf(pb)
 	b = value.Pointer()

+ 22 - 0
proto/testdata/test.pb.go

@@ -200,6 +200,7 @@ type GoEnum struct {
 
 func (this *GoEnum) Reset()         { *this = GoEnum{} }
 func (this *GoEnum) String() string { return proto.CompactTextString(this) }
+func (*GoEnum) ProtoMessage()       {}
 
 type GoTestField struct {
 	Label            *string `protobuf:"bytes,1,req" json:"Label,omitempty"`
@@ -209,6 +210,7 @@ type GoTestField struct {
 
 func (this *GoTestField) Reset()         { *this = GoTestField{} }
 func (this *GoTestField) String() string { return proto.CompactTextString(this) }
+func (*GoTestField) ProtoMessage()       {}
 
 type GoTest struct {
 	Kind                    *GoTest_KIND            `protobuf:"varint,1,req,enum=testdata.GoTest_KIND" json:"Kind,omitempty"`
@@ -288,6 +290,7 @@ type GoTest struct {
 
 func (this *GoTest) Reset()         { *this = GoTest{} }
 func (this *GoTest) String() string { return proto.CompactTextString(this) }
+func (*GoTest) ProtoMessage()       {}
 
 const Default_GoTest_F_BoolDefaulted bool = true
 const Default_GoTest_F_Int32Defaulted int32 = 32
@@ -312,6 +315,7 @@ type GoTest_RequiredGroup struct {
 
 func (this *GoTest_RequiredGroup) Reset()         { *this = GoTest_RequiredGroup{} }
 func (this *GoTest_RequiredGroup) String() string { return proto.CompactTextString(this) }
+func (*GoTest_RequiredGroup) ProtoMessage()       {}
 
 type GoTest_RepeatedGroup struct {
 	RequiredField    *string `protobuf:"bytes,81,req" json:"RequiredField,omitempty"`
@@ -320,6 +324,7 @@ type GoTest_RepeatedGroup struct {
 
 func (this *GoTest_RepeatedGroup) Reset()         { *this = GoTest_RepeatedGroup{} }
 func (this *GoTest_RepeatedGroup) String() string { return proto.CompactTextString(this) }
+func (*GoTest_RepeatedGroup) ProtoMessage()       {}
 
 type GoTest_OptionalGroup struct {
 	RequiredField    *string `protobuf:"bytes,91,req" json:"RequiredField,omitempty"`
@@ -328,6 +333,7 @@ type GoTest_OptionalGroup struct {
 
 func (this *GoTest_OptionalGroup) Reset()         { *this = GoTest_OptionalGroup{} }
 func (this *GoTest_OptionalGroup) String() string { return proto.CompactTextString(this) }
+func (*GoTest_OptionalGroup) ProtoMessage()       {}
 
 type GoSkipTest struct {
 	SkipInt32        *int32                `protobuf:"varint,11,req,name=skip_int32" json:"skip_int32,omitempty"`
@@ -340,6 +346,7 @@ type GoSkipTest struct {
 
 func (this *GoSkipTest) Reset()         { *this = GoSkipTest{} }
 func (this *GoSkipTest) String() string { return proto.CompactTextString(this) }
+func (*GoSkipTest) ProtoMessage()       {}
 
 type GoSkipTest_SkipGroup struct {
 	GroupInt32       *int32  `protobuf:"varint,16,req,name=group_int32" json:"group_int32,omitempty"`
@@ -349,6 +356,7 @@ type GoSkipTest_SkipGroup struct {
 
 func (this *GoSkipTest_SkipGroup) Reset()         { *this = GoSkipTest_SkipGroup{} }
 func (this *GoSkipTest_SkipGroup) String() string { return proto.CompactTextString(this) }
+func (*GoSkipTest_SkipGroup) ProtoMessage()       {}
 
 type NonPackedTest struct {
 	A                []int32 `protobuf:"varint,1,rep,name=a" json:"a,omitempty"`
@@ -357,6 +365,7 @@ type NonPackedTest struct {
 
 func (this *NonPackedTest) Reset()         { *this = NonPackedTest{} }
 func (this *NonPackedTest) String() string { return proto.CompactTextString(this) }
+func (*NonPackedTest) ProtoMessage()       {}
 
 type PackedTest struct {
 	B                []int32 `protobuf:"varint,1,rep,packed,name=b" json:"b,omitempty"`
@@ -365,6 +374,7 @@ type PackedTest struct {
 
 func (this *PackedTest) Reset()         { *this = PackedTest{} }
 func (this *PackedTest) String() string { return proto.CompactTextString(this) }
+func (*PackedTest) ProtoMessage()       {}
 
 type MaxTag struct {
 	LastField        *string `protobuf:"bytes,536870911,opt,name=last_field" json:"last_field,omitempty"`
@@ -373,6 +383,7 @@ type MaxTag struct {
 
 func (this *MaxTag) Reset()         { *this = MaxTag{} }
 func (this *MaxTag) String() string { return proto.CompactTextString(this) }
+func (*MaxTag) ProtoMessage()       {}
 
 type InnerMessage struct {
 	Host             *string `protobuf:"bytes,1,req,name=host" json:"host,omitempty"`
@@ -383,6 +394,7 @@ type InnerMessage struct {
 
 func (this *InnerMessage) Reset()         { *this = InnerMessage{} }
 func (this *InnerMessage) String() string { return proto.CompactTextString(this) }
+func (*InnerMessage) ProtoMessage()       {}
 
 const Default_InnerMessage_Port int32 = 4000
 
@@ -396,6 +408,7 @@ type OtherMessage struct {
 
 func (this *OtherMessage) Reset()         { *this = OtherMessage{} }
 func (this *OtherMessage) String() string { return proto.CompactTextString(this) }
+func (*OtherMessage) ProtoMessage()       {}
 
 type MyMessage struct {
 	Count            *int32                    `protobuf:"varint,1,req,name=count" json:"count,omitempty"`
@@ -413,6 +426,7 @@ type MyMessage struct {
 
 func (this *MyMessage) Reset()         { *this = MyMessage{} }
 func (this *MyMessage) String() string { return proto.CompactTextString(this) }
+func (*MyMessage) ProtoMessage()       {}
 
 var extRange_MyMessage = []proto.ExtensionRange{
 	{100, 536870911},
@@ -435,6 +449,7 @@ type MyMessage_SomeGroup struct {
 
 func (this *MyMessage_SomeGroup) Reset()         { *this = MyMessage_SomeGroup{} }
 func (this *MyMessage_SomeGroup) String() string { return proto.CompactTextString(this) }
+func (*MyMessage_SomeGroup) ProtoMessage()       {}
 
 type Ext struct {
 	Data             *string `protobuf:"bytes,1,opt,name=data" json:"data,omitempty"`
@@ -443,6 +458,7 @@ type Ext struct {
 
 func (this *Ext) Reset()         { *this = Ext{} }
 func (this *Ext) String() string { return proto.CompactTextString(this) }
+func (*Ext) ProtoMessage()       {}
 
 var E_Ext_More = &proto.ExtensionDesc{
 	ExtendedType:  (*MyMessage)(nil),
@@ -475,6 +491,7 @@ type MessageList struct {
 
 func (this *MessageList) Reset()         { *this = MessageList{} }
 func (this *MessageList) String() string { return proto.CompactTextString(this) }
+func (*MessageList) ProtoMessage()       {}
 
 type MessageList_Message struct {
 	Name             *string `protobuf:"bytes,2,req,name=name" json:"name,omitempty"`
@@ -484,6 +501,7 @@ type MessageList_Message struct {
 
 func (this *MessageList_Message) Reset()         { *this = MessageList_Message{} }
 func (this *MessageList_Message) String() string { return proto.CompactTextString(this) }
+func (*MessageList_Message) ProtoMessage()       {}
 
 type Strings struct {
 	StringField      *string `protobuf:"bytes,1,opt,name=string_field" json:"string_field,omitempty"`
@@ -493,6 +511,7 @@ type Strings struct {
 
 func (this *Strings) Reset()         { *this = Strings{} }
 func (this *Strings) String() string { return proto.CompactTextString(this) }
+func (*Strings) ProtoMessage()       {}
 
 type Defaults struct {
 	F_Bool           *bool           `protobuf:"varint,1,opt,def=1" json:"F_Bool,omitempty"`
@@ -518,6 +537,7 @@ type Defaults struct {
 
 func (this *Defaults) Reset()         { *this = Defaults{} }
 func (this *Defaults) String() string { return proto.CompactTextString(this) }
+func (*Defaults) ProtoMessage()       {}
 
 const Default_Defaults_F_Bool bool = true
 const Default_Defaults_F_Int32 int32 = 32
@@ -547,6 +567,7 @@ type SubDefaults struct {
 
 func (this *SubDefaults) Reset()         { *this = SubDefaults{} }
 func (this *SubDefaults) String() string { return proto.CompactTextString(this) }
+func (*SubDefaults) ProtoMessage()       {}
 
 const Default_SubDefaults_N int64 = 7
 
@@ -557,6 +578,7 @@ type RepeatedEnum struct {
 
 func (this *RepeatedEnum) Reset()         { *this = RepeatedEnum{} }
 func (this *RepeatedEnum) String() string { return proto.CompactTextString(this) }
+func (*RepeatedEnum) ProtoMessage()       {}
 
 var E_Greeting = &proto.ExtensionDesc{
 	ExtendedType:  (*MyMessage)(nil),

+ 8 - 22
proto/text.go

@@ -295,7 +295,7 @@ func writeMessageSet(w *textWriter, ms *MessageSet) {
 			w.indent()
 
 			pb := reflect.New(msd.t.Elem())
-			if err := Unmarshal(item.Message, pb.Interface()); err != nil {
+			if err := Unmarshal(item.Message, pb.Interface().(Message)); err != nil {
 				fmt.Fprintf(w, "/* bad message: %v */\n", err)
 			} else {
 				writeStruct(w, pb.Elem())
@@ -431,7 +431,7 @@ func writeExtension(w *textWriter, name string, pb interface{}) {
 	w.WriteByte('\n')
 }
 
-func marshalText(w io.Writer, pb interface{}, compact bool) {
+func marshalText(w io.Writer, pb Message, compact bool) {
 	if pb == nil {
 		w.Write([]byte("<nil>"))
 		return
@@ -441,40 +441,26 @@ func marshalText(w io.Writer, pb interface{}, compact bool) {
 	aw.complete = true
 	aw.compact = compact
 
-	// Reject non-pointer inputs (it's a bad practice to pass potentially large protos around by value).
-	v := reflect.ValueOf(pb)
-	if v.Kind() != reflect.Ptr {
-		w.Write([]byte("<struct-by-value>"))
-		return
-	}
-
 	// Dereference the received pointer so we don't have outer < and >.
-	v = reflect.Indirect(v)
-
-	if v.Kind() == reflect.Struct {
-		writeStruct(aw, v)
-	} else {
-		writeAny(aw, v, nil)
-	}
+	v := reflect.Indirect(reflect.ValueOf(pb))
+	writeStruct(aw, v)
 }
 
 // MarshalText writes a given protocol buffer in text format.
-// Values that are not protocol buffers can also be written, but their formatting is not guaranteed.
-func MarshalText(w io.Writer, pb interface{}) { marshalText(w, pb, false) }
+func MarshalText(w io.Writer, pb Message) { marshalText(w, pb, false) }
 
 // MarshalTextString is the same as MarshalText, but returns the string directly.
-func MarshalTextString(pb interface{}) string {
+func MarshalTextString(pb Message) string {
 	var buf bytes.Buffer
 	marshalText(&buf, pb, false)
 	return buf.String()
 }
 
 // CompactText writes a given protocl buffer in compact text format (one line).
-// Values that are not protocol buffers can also be written, but their formatting is not guaranteed.
-func CompactText(w io.Writer, pb interface{}) { marshalText(w, pb, true) }
+func CompactText(w io.Writer, pb Message) { marshalText(w, pb, true) }
 
 // CompactTextString is the same as CompactText, but returns the string directly.
-func CompactTextString(pb interface{}) string {
+func CompactTextString(pb Message) string {
 	var buf bytes.Buffer
 	marshalText(&buf, pb, true)
 	return buf.String()

+ 3 - 8
proto/text_parser.go

@@ -324,7 +324,7 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) *ParseError
 			var desc *ExtensionDesc
 			// This could be faster, but it's functional.
 			// TODO: Do something smarter than a linear scan.
-			for _, d := range RegisteredExtensions(reflect.New(st).Interface()) {
+			for _, d := range RegisteredExtensions(reflect.New(st).Interface().(Message)) {
 				if d.Name == tok.value {
 					desc = d
 					break
@@ -519,14 +519,9 @@ func (p *textParser) readAny(v reflect.Value, props *Properties) *ParseError {
 	return p.errorf("invalid %v: %v", v.Type(), tok.value)
 }
 
-var notPtrStruct error = &ParseError{"destination is not a pointer to a struct", 0, 0}
-
-// UnmarshalText reads a protobuffer in Text format.
-func UnmarshalText(s string, pb interface{}) error {
+// UnmarshalText reads a protocol buffer in Text format.
+func UnmarshalText(s string, pb Message) error {
 	v := reflect.ValueOf(pb)
-	if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
-		return notPtrStruct
-	}
 	if pe := newTextParser(s).readStruct(v.Elem(), ""); pe != nil {
 		return pe
 	}

+ 0 - 9
proto/text_test.go

@@ -228,12 +228,3 @@ func TestStringEscaping(t *testing.T) {
 		}
 	}
 }
-
-func TestNonPtrMessage(t *testing.T) {
-	// Ensure we don't panic when we pass a non-pointer to MarshalText.
-	var buf bytes.Buffer
-	proto.MarshalText(&buf, pb.MyMessage{})
-	if s := buf.String(); s != "<struct-by-value>" {
-		t.Errorf("got: %q, want %q", s, "<struct-by-value>")
-	}
-}

+ 18 - 0
protoc-gen-go/descriptor/descriptor.pb.go

@@ -170,6 +170,7 @@ type FileDescriptorSet struct {
 
 func (this *FileDescriptorSet) Reset()         { *this = FileDescriptorSet{} }
 func (this *FileDescriptorSet) String() string { return proto.CompactTextString(this) }
+func (*FileDescriptorSet) ProtoMessage()       {}
 
 type FileDescriptorProto struct {
 	Name             *string                   `protobuf:"bytes,1,opt,name=name"`
@@ -187,6 +188,7 @@ type FileDescriptorProto struct {
 
 func (this *FileDescriptorProto) Reset()         { *this = FileDescriptorProto{} }
 func (this *FileDescriptorProto) String() string { return proto.CompactTextString(this) }
+func (*FileDescriptorProto) ProtoMessage()       {}
 
 type DescriptorProto struct {
 	Name             *string                           `protobuf:"bytes,1,opt,name=name"`
@@ -201,6 +203,7 @@ type DescriptorProto struct {
 
 func (this *DescriptorProto) Reset()         { *this = DescriptorProto{} }
 func (this *DescriptorProto) String() string { return proto.CompactTextString(this) }
+func (*DescriptorProto) ProtoMessage()       {}
 
 type DescriptorProto_ExtensionRange struct {
 	Start            *int32 `protobuf:"varint,1,opt,name=start"`
@@ -210,6 +213,7 @@ type DescriptorProto_ExtensionRange struct {
 
 func (this *DescriptorProto_ExtensionRange) Reset()         { *this = DescriptorProto_ExtensionRange{} }
 func (this *DescriptorProto_ExtensionRange) String() string { return proto.CompactTextString(this) }
+func (*DescriptorProto_ExtensionRange) ProtoMessage()       {}
 
 type FieldDescriptorProto struct {
 	Name             *string                     `protobuf:"bytes,1,opt,name=name"`
@@ -225,6 +229,7 @@ type FieldDescriptorProto struct {
 
 func (this *FieldDescriptorProto) Reset()         { *this = FieldDescriptorProto{} }
 func (this *FieldDescriptorProto) String() string { return proto.CompactTextString(this) }
+func (*FieldDescriptorProto) ProtoMessage()       {}
 
 type EnumDescriptorProto struct {
 	Name             *string                     `protobuf:"bytes,1,opt,name=name"`
@@ -235,6 +240,7 @@ type EnumDescriptorProto struct {
 
 func (this *EnumDescriptorProto) Reset()         { *this = EnumDescriptorProto{} }
 func (this *EnumDescriptorProto) String() string { return proto.CompactTextString(this) }
+func (*EnumDescriptorProto) ProtoMessage()       {}
 
 type EnumValueDescriptorProto struct {
 	Name             *string           `protobuf:"bytes,1,opt,name=name"`
@@ -245,6 +251,7 @@ type EnumValueDescriptorProto struct {
 
 func (this *EnumValueDescriptorProto) Reset()         { *this = EnumValueDescriptorProto{} }
 func (this *EnumValueDescriptorProto) String() string { return proto.CompactTextString(this) }
+func (*EnumValueDescriptorProto) ProtoMessage()       {}
 
 type ServiceDescriptorProto struct {
 	Name             *string                  `protobuf:"bytes,1,opt,name=name"`
@@ -255,6 +262,7 @@ type ServiceDescriptorProto struct {
 
 func (this *ServiceDescriptorProto) Reset()         { *this = ServiceDescriptorProto{} }
 func (this *ServiceDescriptorProto) String() string { return proto.CompactTextString(this) }
+func (*ServiceDescriptorProto) ProtoMessage()       {}
 
 type MethodDescriptorProto struct {
 	Name             *string        `protobuf:"bytes,1,opt,name=name"`
@@ -266,6 +274,7 @@ type MethodDescriptorProto struct {
 
 func (this *MethodDescriptorProto) Reset()         { *this = MethodDescriptorProto{} }
 func (this *MethodDescriptorProto) String() string { return proto.CompactTextString(this) }
+func (*MethodDescriptorProto) ProtoMessage()       {}
 
 type FileOptions struct {
 	JavaPackage         *string                   `protobuf:"bytes,1,opt,name=java_package"`
@@ -282,6 +291,7 @@ type FileOptions struct {
 
 func (this *FileOptions) Reset()         { *this = FileOptions{} }
 func (this *FileOptions) String() string { return proto.CompactTextString(this) }
+func (*FileOptions) ProtoMessage()       {}
 
 var extRange_FileOptions = []proto.ExtensionRange{
 	proto.ExtensionRange{1000, 536870911},
@@ -313,6 +323,7 @@ type MessageOptions struct {
 
 func (this *MessageOptions) Reset()         { *this = MessageOptions{} }
 func (this *MessageOptions) String() string { return proto.CompactTextString(this) }
+func (*MessageOptions) ProtoMessage()       {}
 
 var extRange_MessageOptions = []proto.ExtensionRange{
 	proto.ExtensionRange{1000, 536870911},
@@ -344,6 +355,7 @@ type FieldOptions struct {
 
 func (this *FieldOptions) Reset()         { *this = FieldOptions{} }
 func (this *FieldOptions) String() string { return proto.CompactTextString(this) }
+func (*FieldOptions) ProtoMessage()       {}
 
 var extRange_FieldOptions = []proto.ExtensionRange{
 	proto.ExtensionRange{1000, 536870911},
@@ -372,6 +384,7 @@ type EnumOptions struct {
 
 func (this *EnumOptions) Reset()         { *this = EnumOptions{} }
 func (this *EnumOptions) String() string { return proto.CompactTextString(this) }
+func (*EnumOptions) ProtoMessage()       {}
 
 var extRange_EnumOptions = []proto.ExtensionRange{
 	proto.ExtensionRange{1000, 536870911},
@@ -397,6 +410,7 @@ type EnumValueOptions struct {
 
 func (this *EnumValueOptions) Reset()         { *this = EnumValueOptions{} }
 func (this *EnumValueOptions) String() string { return proto.CompactTextString(this) }
+func (*EnumValueOptions) ProtoMessage()       {}
 
 var extRange_EnumValueOptions = []proto.ExtensionRange{
 	proto.ExtensionRange{1000, 536870911},
@@ -420,6 +434,7 @@ type ServiceOptions struct {
 
 func (this *ServiceOptions) Reset()         { *this = ServiceOptions{} }
 func (this *ServiceOptions) String() string { return proto.CompactTextString(this) }
+func (*ServiceOptions) ProtoMessage()       {}
 
 var extRange_ServiceOptions = []proto.ExtensionRange{
 	proto.ExtensionRange{1000, 536870911},
@@ -443,6 +458,7 @@ type MethodOptions struct {
 
 func (this *MethodOptions) Reset()         { *this = MethodOptions{} }
 func (this *MethodOptions) String() string { return proto.CompactTextString(this) }
+func (*MethodOptions) ProtoMessage()       {}
 
 var extRange_MethodOptions = []proto.ExtensionRange{
 	proto.ExtensionRange{1000, 536870911},
@@ -470,6 +486,7 @@ type UninterpretedOption struct {
 
 func (this *UninterpretedOption) Reset()         { *this = UninterpretedOption{} }
 func (this *UninterpretedOption) String() string { return proto.CompactTextString(this) }
+func (*UninterpretedOption) ProtoMessage()       {}
 
 type UninterpretedOption_NamePart struct {
 	NamePart         *string `protobuf:"bytes,1,req,name=name_part"`
@@ -479,6 +496,7 @@ type UninterpretedOption_NamePart struct {
 
 func (this *UninterpretedOption_NamePart) Reset()         { *this = UninterpretedOption_NamePart{} }
 func (this *UninterpretedOption_NamePart) String() string { return proto.CompactTextString(this) }
+func (*UninterpretedOption_NamePart) ProtoMessage()       {}
 
 func init() {
 	proto.RegisterEnum("google_protobuf.FieldDescriptorProto_Type", FieldDescriptorProto_Type_name, FieldDescriptorProto_Type_value)

+ 4 - 1
protoc-gen-go/generator/generator.go

@@ -268,6 +268,7 @@ func (ms messageSymbol) GenerateAlias(g *Generator, pkg string) {
 	g.P("type ", ms.sym, " ", remoteSym)
 	g.P("func (this *", ms.sym, ") Reset() { (*", remoteSym, ")(this).Reset() }")
 	g.P("func (this *", ms.sym, ") String() string { return (*", remoteSym, ")(this).String() }")
+	g.P("func (*", ms.sym, ") ProtoMessage() {}")
 	if ms.hasExtensions {
 		g.P("func (*", ms.sym, ") ExtensionRangeArray() []", g.ProtoPkg, ".ExtensionRange ",
 			"{ return (*", remoteSym, ")(nil).ExtensionRangeArray() }")
@@ -1139,6 +1140,7 @@ func (g *Generator) RecordTypeUse(t string) {
 var methodNames = [...]string{
 	"Reset",
 	"String",
+	"ProtoMessage",
 	"Marshal",
 	"Unmarshal",
 	"ExtensionRangeArray",
@@ -1181,9 +1183,10 @@ func (g *Generator) generateMessage(message *Descriptor) {
 	g.Out()
 	g.P("}")
 
-	// Reset and String functions
+	// Reset, String and ProtoMessage methods.
 	g.P("func (this *", ccTypeName, ") Reset() { *this = ", ccTypeName, "{} }")
 	g.P("func (this *", ccTypeName, ") String() string { return ", g.ProtoPkg, ".CompactTextString(this) }")
+	g.P("func (*", ccTypeName, ") ProtoMessage() {}")
 
 	// Extension support methods
 	var hasExtensions, isMessageSet bool

+ 3 - 0
protoc-gen-go/plugin/plugin.pb.go

@@ -22,6 +22,7 @@ type CodeGeneratorRequest struct {
 
 func (this *CodeGeneratorRequest) Reset()         { *this = CodeGeneratorRequest{} }
 func (this *CodeGeneratorRequest) String() string { return proto.CompactTextString(this) }
+func (*CodeGeneratorRequest) ProtoMessage()       {}
 
 type CodeGeneratorResponse struct {
 	Error            *string                       `protobuf:"bytes,1,opt,name=error"`
@@ -31,6 +32,7 @@ type CodeGeneratorResponse struct {
 
 func (this *CodeGeneratorResponse) Reset()         { *this = CodeGeneratorResponse{} }
 func (this *CodeGeneratorResponse) String() string { return proto.CompactTextString(this) }
+func (*CodeGeneratorResponse) ProtoMessage()       {}
 
 type CodeGeneratorResponse_File struct {
 	Name             *string `protobuf:"bytes,1,opt,name=name"`
@@ -41,6 +43,7 @@ type CodeGeneratorResponse_File struct {
 
 func (this *CodeGeneratorResponse_File) Reset()         { *this = CodeGeneratorResponse_File{} }
 func (this *CodeGeneratorResponse_File) String() string { return proto.CompactTextString(this) }
+func (*CodeGeneratorResponse_File) ProtoMessage()       {}
 
 func init() {
 }

+ 6 - 0
protoc-gen-go/testdata/my_test/test.pb.go

@@ -150,6 +150,7 @@ type Request struct {
 
 func (this *Request) Reset()         { *this = Request{} }
 func (this *Request) String() string { return proto.CompactTextString(this) }
+func (*Request) ProtoMessage()       {}
 
 const Default_Request_Hat HatType = HatType_FEDORA
 
@@ -162,6 +163,7 @@ type Request_SomeGroup struct {
 
 func (this *Request_SomeGroup) Reset()         { *this = Request_SomeGroup{} }
 func (this *Request_SomeGroup) String() string { return proto.CompactTextString(this) }
+func (*Request_SomeGroup) ProtoMessage()       {}
 
 type Reply struct {
 	Found            []*Reply_Entry            `protobuf:"bytes,1,rep,name=found" json:"found,omitempty"`
@@ -172,6 +174,7 @@ type Reply struct {
 
 func (this *Reply) Reset()         { *this = Reply{} }
 func (this *Reply) String() string { return proto.CompactTextString(this) }
+func (*Reply) ProtoMessage()       {}
 
 var extRange_Reply = []proto.ExtensionRange{
 	{100, 536870911},
@@ -196,6 +199,7 @@ type Reply_Entry struct {
 
 func (this *Reply_Entry) Reset()         { *this = Reply_Entry{} }
 func (this *Reply_Entry) String() string { return proto.CompactTextString(this) }
+func (*Reply_Entry) ProtoMessage()       {}
 
 const Default_Reply_Entry_Value int64 = 7
 
@@ -205,6 +209,7 @@ type ReplyExtensions struct {
 
 func (this *ReplyExtensions) Reset()         { *this = ReplyExtensions{} }
 func (this *ReplyExtensions) String() string { return proto.CompactTextString(this) }
+func (*ReplyExtensions) ProtoMessage()       {}
 
 var E_ReplyExtensions_Time = &proto.ExtensionDesc{
 	ExtendedType:  (*Reply)(nil),
@@ -221,6 +226,7 @@ type OldReply struct {
 
 func (this *OldReply) Reset()         { *this = OldReply{} }
 func (this *OldReply) String() string { return proto.CompactTextString(this) }
+func (*OldReply) ProtoMessage()       {}
 
 func (this *OldReply) Marshal() ([]byte, error) {
 	return proto.MarshalMessageSet(this.ExtensionMap())

+ 6 - 0
protoc-gen-go/testdata/my_test/test.pb.go.golden

@@ -150,6 +150,7 @@ type Request struct {
 
 func (this *Request) Reset()         { *this = Request{} }
 func (this *Request) String() string { return proto.CompactTextString(this) }
+func (*Request) ProtoMessage()       {}
 
 const Default_Request_Hat HatType = HatType_FEDORA
 
@@ -162,6 +163,7 @@ type Request_SomeGroup struct {
 
 func (this *Request_SomeGroup) Reset()         { *this = Request_SomeGroup{} }
 func (this *Request_SomeGroup) String() string { return proto.CompactTextString(this) }
+func (*Request_SomeGroup) ProtoMessage()       {}
 
 type Reply struct {
 	Found            []*Reply_Entry            `protobuf:"bytes,1,rep,name=found" json:"found,omitempty"`
@@ -172,6 +174,7 @@ type Reply struct {
 
 func (this *Reply) Reset()         { *this = Reply{} }
 func (this *Reply) String() string { return proto.CompactTextString(this) }
+func (*Reply) ProtoMessage()       {}
 
 var extRange_Reply = []proto.ExtensionRange{
 	{100, 536870911},
@@ -196,6 +199,7 @@ type Reply_Entry struct {
 
 func (this *Reply_Entry) Reset()         { *this = Reply_Entry{} }
 func (this *Reply_Entry) String() string { return proto.CompactTextString(this) }
+func (*Reply_Entry) ProtoMessage()       {}
 
 const Default_Reply_Entry_Value int64 = 7
 
@@ -205,6 +209,7 @@ type ReplyExtensions struct {
 
 func (this *ReplyExtensions) Reset()         { *this = ReplyExtensions{} }
 func (this *ReplyExtensions) String() string { return proto.CompactTextString(this) }
+func (*ReplyExtensions) ProtoMessage()       {}
 
 var E_ReplyExtensions_Time = &proto.ExtensionDesc{
 	ExtendedType:  (*Reply)(nil),
@@ -221,6 +226,7 @@ type OldReply struct {
 
 func (this *OldReply) Reset()         { *this = OldReply{} }
 func (this *OldReply) String() string { return proto.CompactTextString(this) }
+func (*OldReply) ProtoMessage()       {}
 
 func (this *OldReply) Marshal() ([]byte, error) {
 	return proto.MarshalMessageSet(this.ExtensionMap())