Procházet zdrojové kódy

Fix unmarshaling code to properly handle multiple instances of the same
extension appearing in the wire format. Prior to this change, multiple
custom options would result in all but the first being discarded when
processed using the proto.GetExtension() facilities.

Signed-off-by: David Symonds <dsymonds@golang.org>

Erik McClenney před 10 roky
rodič
revize
7c1e7ed8fe

+ 1 - 2
proto/extensions.go

@@ -301,7 +301,6 @@ func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
 	o := NewBuffer(b)
 
 	t := reflect.TypeOf(extension.ExtensionType)
-	rep := extension.repeated()
 
 	props := extensionProperties(extension)
 
@@ -323,7 +322,7 @@ func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
 			return nil, err
 		}
 
-		if !rep || o.index >= len(o.buf) {
+		if o.index >= len(o.buf) {
 			break
 		}
 	}

+ 138 - 0
proto/extensions_test.go

@@ -32,6 +32,7 @@
 package proto_test
 
 import (
+	"bytes"
 	"fmt"
 	"reflect"
 	"testing"
@@ -290,3 +291,140 @@ func TestNilExtension(t *testing.T) {
 	// Note: if the behavior of Marshal is ever changed to ignore nil extensions, update
 	// this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal.
 }
+
+func TestMarshalUnmarshalRepeatedExtension(t *testing.T) {
+	// Add a repeated extension to the result.
+	tests := []struct {
+		name string
+		ext  []*pb.ComplexExtension
+	}{
+		{
+			"two fields",
+			[]*pb.ComplexExtension{
+				{First: proto.Int32(7)},
+				{Second: proto.Int32(11)},
+			},
+		},
+		{
+			"repeated field",
+			[]*pb.ComplexExtension{
+				{Third: []int32{1000}},
+				{Third: []int32{2000}},
+			},
+		},
+		{
+			"two fields and repeated field",
+			[]*pb.ComplexExtension{
+				{Third: []int32{1000}},
+				{First: proto.Int32(9)},
+				{Second: proto.Int32(21)},
+				{Third: []int32{2000}},
+			},
+		},
+	}
+	for _, test := range tests {
+		// Marshal message with a repeated extension.
+		msg1 := new(pb.OtherMessage)
+		err := proto.SetExtension(msg1, pb.E_RComplex, test.ext)
+		if err != nil {
+			t.Fatalf("[%s] Error setting extension: %v", test.name, err)
+		}
+		b, err := proto.Marshal(msg1)
+		if err != nil {
+			t.Fatalf("[%s] Error marshaling message: %v", test.name, err)
+		}
+
+		// Unmarshal and read the merged proto.
+		msg2 := new(pb.OtherMessage)
+		err = proto.Unmarshal(b, msg2)
+		if err != nil {
+			t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
+		}
+		e, err := proto.GetExtension(msg2, pb.E_RComplex)
+		if err != nil {
+			t.Fatalf("[%s] Error getting extension: %v", test.name, err)
+		}
+		ext := e.([]*pb.ComplexExtension)
+		if ext == nil {
+			t.Fatalf("[%s] Invalid extension", test.name)
+		}
+		if !reflect.DeepEqual(ext, test.ext) {
+			t.Errorf("[%s] Wrong value for ComplexExtension: got: %v want: %v\n", test.name, ext, test.ext)
+		}
+	}
+}
+
+func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) {
+	// We may see multiple instances of the same extension in the wire
+	// format. For example, the proto compiler may encode custom options in
+	// this way. Here, we verify that we merge the extensions together.
+	tests := []struct {
+		name string
+		ext  []*pb.ComplexExtension
+	}{
+		{
+			"two fields",
+			[]*pb.ComplexExtension{
+				{First: proto.Int32(7)},
+				{Second: proto.Int32(11)},
+			},
+		},
+		{
+			"repeated field",
+			[]*pb.ComplexExtension{
+				{Third: []int32{1000}},
+				{Third: []int32{2000}},
+			},
+		},
+		{
+			"two fields and repeated field",
+			[]*pb.ComplexExtension{
+				{Third: []int32{1000}},
+				{First: proto.Int32(9)},
+				{Second: proto.Int32(21)},
+				{Third: []int32{2000}},
+			},
+		},
+	}
+	for _, test := range tests {
+		var buf bytes.Buffer
+		var want pb.ComplexExtension
+
+		// Generate a serialized representation of a repeated extension
+		// by catenating bytes together.
+		for i, e := range test.ext {
+			// Merge to create the wanted proto.
+			proto.Merge(&want, e)
+
+			// serialize the message
+			msg := new(pb.OtherMessage)
+			err := proto.SetExtension(msg, pb.E_Complex, e)
+			if err != nil {
+				t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err)
+			}
+			b, err := proto.Marshal(msg)
+			if err != nil {
+				t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err)
+			}
+			buf.Write(b)
+		}
+
+		// Unmarshal and read the merged proto.
+		msg2 := new(pb.OtherMessage)
+		err := proto.Unmarshal(buf.Bytes(), msg2)
+		if err != nil {
+			t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
+		}
+		e, err := proto.GetExtension(msg2, pb.E_Complex)
+		if err != nil {
+			t.Fatalf("[%s] Error getting extension: %v", test.name, err)
+		}
+		ext := e.(*pb.ComplexExtension)
+		if ext == nil {
+			t.Fatalf("[%s] Invalid extension", test.name)
+		}
+		if !reflect.DeepEqual(*ext, want) {
+			t.Errorf("[%s] Wrong value for ComplexExtension: got: %s want: %s\n", test.name, ext, want)
+		}
+	}
+}

+ 71 - 5
proto/testdata/test.pb.go

@@ -22,6 +22,7 @@ It has these top-level messages:
 	OtherMessage
 	MyMessage
 	Ext
+	ComplexExtension
 	DefaultsMessage
 	MyMessageSet
 	Empty
@@ -1235,17 +1236,32 @@ func (m *InnerMessage) GetConnected() bool {
 }
 
 type OtherMessage struct {
-	Key              *int64        `protobuf:"varint,1,opt,name=key" json:"key,omitempty"`
-	Value            []byte        `protobuf:"bytes,2,opt,name=value" json:"value,omitempty"`
-	Weight           *float32      `protobuf:"fixed32,3,opt,name=weight" json:"weight,omitempty"`
-	Inner            *InnerMessage `protobuf:"bytes,4,opt,name=inner" json:"inner,omitempty"`
-	XXX_unrecognized []byte        `json:"-"`
+	Key              *int64                    `protobuf:"varint,1,opt,name=key" json:"key,omitempty"`
+	Value            []byte                    `protobuf:"bytes,2,opt,name=value" json:"value,omitempty"`
+	Weight           *float32                  `protobuf:"fixed32,3,opt,name=weight" json:"weight,omitempty"`
+	Inner            *InnerMessage             `protobuf:"bytes,4,opt,name=inner" json:"inner,omitempty"`
+	XXX_extensions   map[int32]proto.Extension `json:"-"`
+	XXX_unrecognized []byte                    `json:"-"`
 }
 
 func (m *OtherMessage) Reset()         { *m = OtherMessage{} }
 func (m *OtherMessage) String() string { return proto.CompactTextString(m) }
 func (*OtherMessage) ProtoMessage()    {}
 
+var extRange_OtherMessage = []proto.ExtensionRange{
+	{100, 536870911},
+}
+
+func (*OtherMessage) ExtensionRangeArray() []proto.ExtensionRange {
+	return extRange_OtherMessage
+}
+func (m *OtherMessage) ExtensionMap() map[int32]proto.Extension {
+	if m.XXX_extensions == nil {
+		m.XXX_extensions = make(map[int32]proto.Extension)
+	}
+	return m.XXX_extensions
+}
+
 func (m *OtherMessage) GetKey() int64 {
 	if m != nil && m.Key != nil {
 		return *m.Key
@@ -1442,6 +1458,38 @@ var E_Ext_Number = &proto.ExtensionDesc{
 	Tag:           "varint,105,opt,name=number",
 }
 
+type ComplexExtension struct {
+	First            *int32  `protobuf:"varint,1,opt,name=first" json:"first,omitempty"`
+	Second           *int32  `protobuf:"varint,2,opt,name=second" json:"second,omitempty"`
+	Third            []int32 `protobuf:"varint,3,rep,name=third" json:"third,omitempty"`
+	XXX_unrecognized []byte  `json:"-"`
+}
+
+func (m *ComplexExtension) Reset()         { *m = ComplexExtension{} }
+func (m *ComplexExtension) String() string { return proto.CompactTextString(m) }
+func (*ComplexExtension) ProtoMessage()    {}
+
+func (m *ComplexExtension) GetFirst() int32 {
+	if m != nil && m.First != nil {
+		return *m.First
+	}
+	return 0
+}
+
+func (m *ComplexExtension) GetSecond() int32 {
+	if m != nil && m.Second != nil {
+		return *m.Second
+	}
+	return 0
+}
+
+func (m *ComplexExtension) GetThird() []int32 {
+	if m != nil {
+		return m.Third
+	}
+	return nil
+}
+
 type DefaultsMessage struct {
 	XXX_extensions   map[int32]proto.Extension `json:"-"`
 	XXX_unrecognized []byte                    `json:"-"`
@@ -2196,6 +2244,22 @@ var E_Greeting = &proto.ExtensionDesc{
 	Tag:           "bytes,106,rep,name=greeting",
 }
 
+var E_Complex = &proto.ExtensionDesc{
+	ExtendedType:  (*OtherMessage)(nil),
+	ExtensionType: (*ComplexExtension)(nil),
+	Field:         200,
+	Name:          "testdata.complex",
+	Tag:           "bytes,200,opt,name=complex",
+}
+
+var E_RComplex = &proto.ExtensionDesc{
+	ExtendedType:  (*OtherMessage)(nil),
+	ExtensionType: ([]*ComplexExtension)(nil),
+	Field:         201,
+	Name:          "testdata.r_complex",
+	Tag:           "bytes,201,rep,name=r_complex",
+}
+
 var E_NoDefaultDouble = &proto.ExtensionDesc{
 	ExtendedType:  (*DefaultsMessage)(nil),
 	ExtensionType: (*float64)(nil),
@@ -2863,6 +2927,8 @@ func init() {
 	proto.RegisterExtension(E_Ext_Text)
 	proto.RegisterExtension(E_Ext_Number)
 	proto.RegisterExtension(E_Greeting)
+	proto.RegisterExtension(E_Complex)
+	proto.RegisterExtension(E_RComplex)
 	proto.RegisterExtension(E_NoDefaultDouble)
 	proto.RegisterExtension(E_NoDefaultFloat)
 	proto.RegisterExtension(E_NoDefaultInt32)

+ 13 - 0
proto/testdata/test.proto

@@ -233,6 +233,8 @@ message OtherMessage {
   optional bytes value = 2;
   optional float weight = 3;
   optional InnerMessage inner = 4;
+
+  extensions 100 to max;
 }
 
 message MyMessage {
@@ -277,6 +279,17 @@ extend MyMessage {
   repeated string greeting = 106;
 }
 
+message ComplexExtension {
+  optional int32 first = 1;
+  optional int32 second = 2;
+  repeated int32 third = 3;
+}
+
+extend OtherMessage {
+  optional ComplexExtension complex = 200;
+  repeated ComplexExtension r_complex = 201;
+}
+
 message DefaultsMessage {
   enum DefaultsEnum {
     ZERO = 0;