Browse Source

proto: add logic to handle legacy message (#496)

Not all proto.Message implementations will be updated to be
using the most recent protoc-gen-go. Thus, they will lack an
XXX_DiscardUnknown method. Add logic to handle older protobufs.
Joe Tsai 8 years ago
parent
commit
10c2d9d3cc
2 changed files with 157 additions and 15 deletions
  1. 106 14
      proto/discard.go
  2. 51 1
      proto/discard_test.go

+ 106 - 14
proto/discard.go

@@ -59,16 +59,10 @@ func DiscardUnknown(m Message) {
 		m.XXX_DiscardUnknown()
 		return
 	}
-	if m == nil {
-		return
-	}
-
-	// The Message interface really needs to provide some form of reflection
-	// API that can be used to implement the fallback if someone is using
-	// a custom protobuf message implementation.
-	//
-	// See https://github.com/golang/protobuf/issues/364
-	panic(fmt.Sprintf("cannot discard unknown fields on %T", m))
+	// TODO: Dynamically populate a InternalMessageInfo for legacy messages,
+	// but the master branch has no implementation for InternalMessageInfo,
+	// so it would be more work to replicate that approach.
+	discardLegacy(m)
 }
 
 // DiscardUnknown recursively discards all unknown fields.
@@ -172,14 +166,14 @@ func (di *discardInfo) computeDiscardInfo() {
 			tf = tf.Elem()
 		}
 		if isPointer && isSlice && tf.Kind() != reflect.Struct {
-			panic("both pointer and slice for basic type in " + tf.Name())
+			panic(fmt.Sprintf("%v.%s cannot be a slice of pointers to primitive types", t, f.Name))
 		}
 
 		switch tf.Kind() {
 		case reflect.Struct:
 			switch {
 			case !isPointer:
-				panic(fmt.Sprintf("message field %s without pointer", tf))
+				panic(fmt.Sprintf("%v.%s cannot be a direct struct value", t, f.Name))
 			case isSlice: // E.g., []*pb.T
 				di := getDiscardInfo(tf)
 				dfi.discard = func(src pointer) {
@@ -202,7 +196,7 @@ func (di *discardInfo) computeDiscardInfo() {
 		case reflect.Map:
 			switch {
 			case isPointer || isSlice:
-				panic("bad pointer or slice in map case in " + tf.Name())
+				panic(fmt.Sprintf("%v.%s cannot be a pointer to a map or a slice of map values", t, f.Name))
 			default: // E.g., map[K]V
 				if tf.Elem().Kind() == reflect.Ptr { // Proto struct (e.g., *T)
 					dfi.discard = func(src pointer) {
@@ -223,7 +217,7 @@ func (di *discardInfo) computeDiscardInfo() {
 			// Must be oneof field.
 			switch {
 			case isPointer || isSlice:
-				panic("bad pointer or slice in interface case in " + tf.Name())
+				panic(fmt.Sprintf("%v.%s cannot be a pointer to a interface or a slice of interface values", t, f.Name))
 			default: // E.g., interface{}
 				// TODO: Make this faster?
 				dfi.discard = func(src pointer) {
@@ -256,3 +250,101 @@ func (di *discardInfo) computeDiscardInfo() {
 
 	atomic.StoreInt32(&di.initialized, 1)
 }
+
+func discardLegacy(m Message) {
+	v := reflect.ValueOf(m)
+	if v.Kind() != reflect.Ptr || v.IsNil() {
+		return
+	}
+	v = v.Elem()
+	if v.Kind() != reflect.Struct {
+		return
+	}
+	t := v.Type()
+
+	for i := 0; i < v.NumField(); i++ {
+		f := t.Field(i)
+		if strings.HasPrefix(f.Name, "XXX_") {
+			continue
+		}
+		vf := v.Field(i)
+		tf := f.Type
+
+		// Unwrap tf to get its most basic type.
+		var isPointer, isSlice bool
+		if tf.Kind() == reflect.Slice && tf.Elem().Kind() != reflect.Uint8 {
+			isSlice = true
+			tf = tf.Elem()
+		}
+		if tf.Kind() == reflect.Ptr {
+			isPointer = true
+			tf = tf.Elem()
+		}
+		if isPointer && isSlice && tf.Kind() != reflect.Struct {
+			panic(fmt.Sprintf("%T.%s cannot be a slice of pointers to primitive types", m, f.Name))
+		}
+
+		switch tf.Kind() {
+		case reflect.Struct:
+			switch {
+			case !isPointer:
+				panic(fmt.Sprintf("%T.%s cannot be a direct struct value", m, f.Name))
+			case isSlice: // E.g., []*pb.T
+				for j := 0; j < vf.Len(); j++ {
+					discardLegacy(vf.Index(j).Interface().(Message))
+				}
+			default: // E.g., *pb.T
+				discardLegacy(vf.Interface().(Message))
+			}
+		case reflect.Map:
+			switch {
+			case isPointer || isSlice:
+				panic(fmt.Sprintf("%T.%s cannot be a pointer to a map or a slice of map values", m, f.Name))
+			default: // E.g., map[K]V
+				tv := vf.Type().Elem()
+				if tv.Kind() == reflect.Ptr && tv.Implements(protoMessageType) { // Proto struct (e.g., *T)
+					for _, key := range vf.MapKeys() {
+						val := vf.MapIndex(key)
+						discardLegacy(val.Interface().(Message))
+					}
+				}
+			}
+		case reflect.Interface:
+			// Must be oneof field.
+			switch {
+			case isPointer || isSlice:
+				panic(fmt.Sprintf("%T.%s cannot be a pointer to a interface or a slice of interface values", m, f.Name))
+			default: // E.g., test_proto.isCommunique_Union interface
+				if !vf.IsNil() && f.Tag.Get("protobuf_oneof") != "" {
+					vf = vf.Elem() // E.g., *test_proto.Communique_Msg
+					if !vf.IsNil() {
+						vf = vf.Elem()   // E.g., test_proto.Communique_Msg
+						vf = vf.Field(0) // E.g., Proto struct (e.g., *T) or primitive value
+						if vf.Kind() == reflect.Ptr {
+							discardLegacy(vf.Interface().(Message))
+						}
+					}
+				}
+			}
+		}
+	}
+
+	if vf := v.FieldByName("XXX_unrecognized"); vf.IsValid() {
+		if vf.Type() != reflect.TypeOf([]byte{}) {
+			panic("expected XXX_unrecognized to be of type []byte")
+		}
+		vf.Set(reflect.ValueOf([]byte(nil)))
+	}
+
+	// For proto2 messages, only discard unknown fields in message extensions
+	// that have been accessed via GetExtension.
+	if em, err := extendable(m); err == nil {
+		// Ignore lock since discardLegacy is not concurrency safe.
+		emm, _ := em.extensionsRead()
+		for _, mx := range emm {
+			if m, ok := mx.value.(Message); ok {
+				discardLegacy(m)
+			}
+		}
+	}
+}

+ 51 - 1
proto/discard_test.go

@@ -61,6 +61,23 @@ func TestDiscardUnknown(t *testing.T) {
 			Name:   "Aaron",
 			Nested: &proto3pb.Nested{Cute: true},
 		},
+	}, {
+		desc: "Slice",
+		in: &proto3pb.Message{
+			Name: "Aaron",
+			Children: []*proto3pb.Message{
+				{Name: "Sarah", XXX_unrecognized: []byte("blah")},
+				{Name: "Abraham", XXX_unrecognized: []byte("blah")},
+			},
+			XXX_unrecognized: []byte("blah"),
+		},
+		want: &proto3pb.Message{
+			Name: "Aaron",
+			Children: []*proto3pb.Message{
+				{Name: "Sarah"},
+				{Name: "Abraham"},
+			},
+		},
 	}, {
 		desc: "OneOf",
 		in: &pb.Communique{
@@ -111,10 +128,43 @@ func TestDiscardUnknown(t *testing.T) {
 		}(),
 	}}
 
+	// Test the legacy code path.
+	for _, tt := range tests {
+		// Clone the input so that we don't alter the original.
+		in := tt.in
+		if in != nil {
+			in = proto.Clone(tt.in)
+		}
+
+		var m LegacyMessage
+		m.Message, _ = in.(*proto3pb.Message)
+		m.Communique, _ = in.(*pb.Communique)
+		m.MessageWithMap, _ = in.(*pb.MessageWithMap)
+		m.MyMessage, _ = in.(*pb.MyMessage)
+		proto.DiscardUnknown(&m)
+		if !proto.Equal(in, tt.want) {
+			t.Errorf("test %s/Legacy, expected unknown fields to be discarded\ngot  %v\nwant %v", tt.desc, in, tt.want)
+		}
+	}
+
 	for _, tt := range tests {
 		proto.DiscardUnknown(tt.in)
 		if !proto.Equal(tt.in, tt.want) {
-			t.Errorf("test %s, expected unknown fields to be discarde\ngot  %v\nwant %v", tt.desc, tt.in, tt.want)
+			t.Errorf("test %s, expected unknown fields to be discarded\ngot  %v\nwant %v", tt.desc, tt.in, tt.want)
 		}
 	}
 }
+
+// LegacyMessage is a proto.Message that has several nested messages.
+// This does not have the XXX_DiscardUnknown method and so forces DiscardUnknown
+// to use the legacy fallback logic.
+type LegacyMessage struct {
+	Message        *proto3pb.Message
+	Communique     *pb.Communique
+	MessageWithMap *pb.MessageWithMap
+	MyMessage      *pb.MyMessage
+}
+
+func (m *LegacyMessage) Reset()         { *m = LegacyMessage{} }
+func (m *LegacyMessage) String() string { return proto.CompactTextString(m) }
+func (*LegacyMessage) ProtoMessage()    {}