Browse Source

internal/proto: implement DiscardUnknown with reflection

Change-Id: Ida73c810353bde868bcd7cf2ea40098695695458
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/167738
Reviewed-by: Herbie Ong <herbie@google.com>
Joe Tsai 7 years ago
parent
commit
2d5f4c3ffd
3 changed files with 126 additions and 11 deletions
  1. 76 0
      internal/proto/discard.go
  2. 22 0
      proto/all_test.go
  3. 28 11
      proto/discard_test.go

+ 76 - 0
internal/proto/discard.go

@@ -0,0 +1,76 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package proto
+
+import (
+	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
+	"github.com/golang/protobuf/v2/runtime/protoimpl"
+)
+
+// DiscardUnknown recursively discards all unknown fields from this message
+// and all embedded messages.
+//
+// When unmarshaling a message with unrecognized fields, the tags and values
+// of such fields are preserved in the Message. This allows a later call to
+// marshal to be able to produce a message that continues to have those
+// unrecognized fields. To avoid this, DiscardUnknown is used to
+// explicitly clear the unknown fields after unmarshaling.
+//
+// For proto2 messages, the unknown fields of message extensions are only
+// discarded from messages that have been accessed via GetExtension.
+func DiscardUnknown(m Message) {
+	if m == nil {
+		return
+	}
+	discardUnknown(protoimpl.X.MessageOf(m))
+}
+
+func discardUnknown(m pref.Message) {
+	fieldTypes := m.Type().Fields()
+	knownFields := m.KnownFields()
+	knownFields.Range(func(num pref.FieldNumber, val pref.Value) bool {
+		fd := fieldTypes.ByNumber(num)
+		if fd == nil {
+			fd = knownFields.ExtensionTypes().ByNumber(num)
+		}
+		switch {
+		// Handle singular message.
+		case fd.Cardinality() != pref.Repeated:
+			if k := fd.Kind(); k == pref.MessageKind || k == pref.GroupKind {
+				discardUnknown(knownFields.Get(num).Message())
+			}
+		// Handle list of messages.
+		case !fd.IsMap():
+			if k := fd.Kind(); k == pref.MessageKind || k == pref.GroupKind {
+				ls := knownFields.Get(num).List()
+				for i := 0; i < ls.Len(); i++ {
+					discardUnknown(ls.Get(i).Message())
+				}
+			}
+		// Handle map of messages.
+		default:
+			k := fd.MessageType().Fields().ByNumber(2).Kind()
+			if k == pref.MessageKind || k == pref.GroupKind {
+				ms := knownFields.Get(num).Map()
+				ms.Range(func(_ pref.MapKey, v pref.Value) bool {
+					discardUnknown(v.Message())
+					return true
+				})
+			}
+		}
+		return true
+	})
+
+	extRanges := m.Type().ExtensionRanges()
+	unknownFields := m.UnknownFields()
+	unknownFields.Range(func(num pref.FieldNumber, _ pref.RawFields) bool {
+		// NOTE: Historically, this function did not discard unknown fields
+		// that were within the extension field ranges.
+		if !extRanges.Has(num) {
+			unknownFields.Set(num, nil)
+		}
+		return true
+	})
+}

+ 22 - 0
proto/all_test.go

@@ -1170,6 +1170,28 @@ func TestBadWireTypeUnknown(t *testing.T) {
 	}
 }
 
+func TestBadWireTypeUnknown2(t *testing.T) {
+	var b []byte
+	fmt.Sscanf("0a01780d00000000080b101612036161611521000000202c220362626225370000002203636363214200000000000000584d5a036464645900000000000056405d63000000", "%x", &b)
+
+	m := new(MyMessage)
+	if err := Unmarshal(b, m); err != nil {
+		t.Errorf("unexpected Unmarshal error: %v", err)
+	}
+
+	var unknown []byte
+	fmt.Sscanf("0a01780d0000000010161521000000202c2537000000214200000000000000584d5a036464645d63000000", "%x", &unknown)
+	if !bytes.Equal(m.XXX_unrecognized, unknown) {
+		t.Errorf("unknown bytes mismatch:\ngot  %x\nwant %x", m.XXX_unrecognized, unknown)
+	}
+	protoV1a.DiscardUnknown(m)
+
+	want := &MyMessage{Count: Int32(11), Name: String("aaa"), Pet: []string{"bbb", "ccc"}, Bigfloat: Float64(88)}
+	if !Equal(m, want) {
+		t.Errorf("message mismatch:\ngot  %v\nwant %v", m, want)
+	}
+}
+
 func encodeDecode(t *testing.T, in, out Message, msg string) {
 	buf, err := Marshal(in)
 	if err != nil {

+ 28 - 11
proto/discard_test.go

@@ -7,12 +7,15 @@ package proto_test
 import (
 	"testing"
 
+	protoV1a "github.com/golang/protobuf/internal/proto"
 	"github.com/golang/protobuf/proto"
 
 	proto3pb "github.com/golang/protobuf/proto/proto3_proto"
 	pb "github.com/golang/protobuf/proto/test_proto"
 )
 
+const rawFields = "\x2d\xc3\xd2\xe1\xf0"
+
 func TestDiscardUnknown(t *testing.T) {
 	tests := []struct {
 		desc     string
@@ -27,8 +30,8 @@ func TestDiscardUnknown(t *testing.T) {
 		desc: "Nested",
 		in: &proto3pb.Message{
 			Name:             "Aaron",
-			Nested:           &proto3pb.Nested{Cute: true, XXX_unrecognized: []byte("blah")},
-			XXX_unrecognized: []byte("blah"),
+			Nested:           &proto3pb.Nested{Cute: true, XXX_unrecognized: []byte(rawFields)},
+			XXX_unrecognized: []byte(rawFields),
 		},
 		want: &proto3pb.Message{
 			Name:   "Aaron",
@@ -39,10 +42,10 @@ func TestDiscardUnknown(t *testing.T) {
 		in: &proto3pb.Message{
 			Name: "Aaron",
 			Children: []*proto3pb.Message{
-				{Name: "Sarah", XXX_unrecognized: []byte("blah")},
-				{Name: "Abraham", XXX_unrecognized: []byte("blah")},
+				{Name: "Sarah", XXX_unrecognized: []byte(rawFields)},
+				{Name: "Abraham", XXX_unrecognized: []byte(rawFields)},
 			},
-			XXX_unrecognized: []byte("blah"),
+			XXX_unrecognized: []byte(rawFields),
 		},
 		want: &proto3pb.Message{
 			Name: "Aaron",
@@ -56,9 +59,9 @@ func TestDiscardUnknown(t *testing.T) {
 		in: &pb.Communique{
 			Union: &pb.Communique_Msg{&pb.Strings{
 				StringField:      proto.String("123"),
-				XXX_unrecognized: []byte("blah"),
+				XXX_unrecognized: []byte(rawFields),
 			}},
-			XXX_unrecognized: []byte("blah"),
+			XXX_unrecognized: []byte(rawFields),
 		},
 		want: &pb.Communique{
 			Union: &pb.Communique_Msg{&pb.Strings{StringField: proto.String("123")}},
@@ -68,7 +71,7 @@ func TestDiscardUnknown(t *testing.T) {
 		in: &pb.MessageWithMap{MsgMapping: map[int64]*pb.FloatingPoint{
 			0x4002: &pb.FloatingPoint{
 				Exact:            proto.Bool(true),
-				XXX_unrecognized: []byte("blah"),
+				XXX_unrecognized: []byte(rawFields),
 			},
 		}},
 		want: &pb.MessageWithMap{MsgMapping: map[int64]*pb.FloatingPoint{
@@ -81,13 +84,13 @@ func TestDiscardUnknown(t *testing.T) {
 				Count: proto.Int32(42),
 				Somegroup: &pb.MyMessage_SomeGroup{
 					GroupField:       proto.Int32(6),
-					XXX_unrecognized: []byte("blah"),
+					XXX_unrecognized: []byte(rawFields),
 				},
-				XXX_unrecognized: []byte("blah"),
+				XXX_unrecognized: []byte(rawFields),
 			}
 			proto.SetExtension(m, pb.E_Ext_More, &pb.Ext{
 				Data:             proto.String("extension"),
-				XXX_unrecognized: []byte("blah"),
+				XXX_unrecognized: []byte(rawFields),
 			})
 			return m
 		}(),
@@ -101,6 +104,20 @@ func TestDiscardUnknown(t *testing.T) {
 		}(),
 	}}
 
+	// Test the reflection code path.
+	for _, tt := range tests {
+		// Clone the input so that we don't alter the original.
+		in := tt.in
+		if in != nil {
+			in = proto.Clone(tt.in)
+		}
+
+		protoV1a.DiscardUnknown(tt.in)
+		if !proto.Equal(tt.in, tt.want) {
+			t.Errorf("test %s, expected unknown fields to be discarded\ngot  %v\nwant %v", tt.desc, tt.in, tt.want)
+		}
+	}
+
 	// Test the legacy code path.
 	for _, tt := range tests {
 		// Clone the input so that we don't alter the original.