Просмотр исходного кода

proto: cleanup invalid extension exception

The invalidExtensions flag no longer seems necessary. Tests pass without it.

Change-Id: Ieb35e26912b047718ccbfcdc926625aec1cd8c87
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/185937
Reviewed-by: Damien Neil <dneil@google.com>
Joe Tsai 6 лет назад
Родитель
Сommit
09cef32bce
2 измененных файлов с 8 добавлено и 30 удалено
  1. 8 20
      proto/decode_test.go
  2. 0 10
      proto/encode_test.go

+ 8 - 20
proto/decode_test.go

@@ -23,11 +23,10 @@ import (
 )
 )
 
 
 type testProto struct {
 type testProto struct {
-	desc              string
-	decodeTo          []proto.Message
-	wire              []byte
-	partial           bool
-	invalidExtensions bool
+	desc     string
+	decodeTo []proto.Message
+	wire     []byte
+	partial  bool
 }
 }
 
 
 func TestDecode(t *testing.T) {
 func TestDecode(t *testing.T) {
@@ -49,11 +48,6 @@ func TestDecode(t *testing.T) {
 				for i := range wire {
 				for i := range wire {
 					wire[i] = 0
 					wire[i] = 0
 				}
 				}
-
-				if test.invalidExtensions {
-					// Equal doesn't work on messages containing invalid extension data.
-					return
-				}
 				if !proto.Equal(got, want) {
 				if !proto.Equal(got, want) {
 					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
 					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
 				}
 				}
@@ -67,10 +61,6 @@ func TestDecodeRequiredFieldChecks(t *testing.T) {
 		if !test.partial {
 		if !test.partial {
 			continue
 			continue
 		}
 		}
-		if test.invalidExtensions {
-			// Missing required fields in extensions just end up in the unknown fields.
-			continue
-		}
 		for _, m := range test.decodeTo {
 		for _, m := range test.decodeTo {
 			t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) {
 			t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) {
 				got := reflect.New(reflect.TypeOf(m).Elem()).Interface().(proto.Message)
 				got := reflect.New(reflect.TypeOf(m).Elem()).Interface().(proto.Message)
@@ -1255,9 +1245,8 @@ var testProtos = []testProto{
 		})}.Marshal(),
 		})}.Marshal(),
 	},
 	},
 	{
 	{
-		desc:              "required field in extension message unset",
-		partial:           true,
-		invalidExtensions: true,
+		desc:    "required field in extension message unset",
+		partial: true,
 		decodeTo: []proto.Message{build(
 		decodeTo: []proto.Message{build(
 			&testpb.TestAllExtensions{},
 			&testpb.TestAllExtensions{},
 			extend(testpb.E_TestRequired_Single, &testpb.TestRequired{}),
 			extend(testpb.E_TestRequired_Single, &testpb.TestRequired{}),
@@ -1281,9 +1270,8 @@ var testProtos = []testProto{
 		}.Marshal(),
 		}.Marshal(),
 	},
 	},
 	{
 	{
-		desc:              "required field in repeated extension message unset",
-		partial:           true,
-		invalidExtensions: true,
+		desc:    "required field in repeated extension message unset",
+		partial: true,
 		decodeTo: []proto.Message{build(
 		decodeTo: []proto.Message{build(
 			&testpb.TestAllExtensions{},
 			&testpb.TestAllExtensions{},
 			extend(testpb.E_TestRequired_Multi, []*testpb.TestRequired{
 			extend(testpb.E_TestRequired_Multi, []*testpb.TestRequired{

+ 0 - 10
proto/encode_test.go

@@ -40,11 +40,6 @@ func TestEncode(t *testing.T) {
 					t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
 					t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
 					return
 					return
 				}
 				}
-
-				if test.invalidExtensions {
-					// Equal doesn't work on messages containing invalid extension data.
-					return
-				}
 				if !proto.Equal(got, want) {
 				if !proto.Equal(got, want) {
 					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
 					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
 				}
 				}
@@ -81,11 +76,6 @@ func TestEncodeDeterministic(t *testing.T) {
 					t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
 					t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
 					return
 					return
 				}
 				}
-
-				if test.invalidExtensions {
-					// Equal doesn't work on messages containing invalid extension data.
-					return
-				}
 				if !proto.Equal(got, want) {
 				if !proto.Equal(got, want) {
 					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
 					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
 				}
 				}