|
|
@@ -12,6 +12,7 @@ import (
|
|
|
protoV1 "github.com/golang/protobuf/proto"
|
|
|
"github.com/golang/protobuf/v2/encoding/textpb"
|
|
|
"github.com/golang/protobuf/v2/internal/encoding/pack"
|
|
|
+ "github.com/golang/protobuf/v2/internal/errors"
|
|
|
"github.com/golang/protobuf/v2/internal/scalar"
|
|
|
"github.com/golang/protobuf/v2/proto"
|
|
|
pref "github.com/golang/protobuf/v2/reflect/protoreflect"
|
|
|
@@ -80,6 +81,23 @@ func TestDecodeRequiredFieldChecks(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+func TestDecodeInvalidUTF8(t *testing.T) {
|
|
|
+ for _, test := range invalidUTF8TestProtos {
|
|
|
+ for _, want := range test.decodeTo {
|
|
|
+ t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
|
|
|
+ got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
|
|
|
+ err := proto.Unmarshal(test.wire, got)
|
|
|
+ if !isErrInvalidUTF8(err) {
|
|
|
+ t.Errorf("Unmarshal did not return expected error for invalid UTF8: %v\nMessage:\n%v", err, marshalText(want))
|
|
|
+ }
|
|
|
+ if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
|
|
|
+ t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
var testProtos = []testProto{
|
|
|
{
|
|
|
desc: "basic scalar types",
|
|
|
@@ -1158,6 +1176,69 @@ var testProtos = []testProto{
|
|
|
},
|
|
|
}
|
|
|
|
|
|
+var invalidUTF8TestProtos = []testProto{
|
|
|
+ {
|
|
|
+ desc: "invalid UTF-8 in optional string field",
|
|
|
+ decodeTo: []proto.Message{&test3pb.TestAllTypes{
|
|
|
+ OptionalString: "abc\xff",
|
|
|
+ }},
|
|
|
+ wire: pack.Message{
|
|
|
+ pack.Tag{14, pack.BytesType}, pack.String("abc\xff"),
|
|
|
+ }.Marshal(),
|
|
|
+ },
|
|
|
+ {
|
|
|
+ desc: "invalid UTF-8 in repeated string field",
|
|
|
+ decodeTo: []proto.Message{&test3pb.TestAllTypes{
|
|
|
+ RepeatedString: []string{"foo", "abc\xff"},
|
|
|
+ }},
|
|
|
+ wire: pack.Message{
|
|
|
+ pack.Tag{44, pack.BytesType}, pack.String("foo"),
|
|
|
+ pack.Tag{44, pack.BytesType}, pack.String("abc\xff"),
|
|
|
+ }.Marshal(),
|
|
|
+ },
|
|
|
+ {
|
|
|
+ desc: "invalid UTF-8 in nested message",
|
|
|
+ decodeTo: []proto.Message{&test3pb.TestAllTypes{
|
|
|
+ OptionalNestedMessage: &test3pb.TestAllTypes_NestedMessage{
|
|
|
+ Corecursive: &test3pb.TestAllTypes{
|
|
|
+ OptionalString: "abc\xff",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }},
|
|
|
+ wire: pack.Message{
|
|
|
+ pack.Tag{18, pack.BytesType}, pack.LengthPrefix(pack.Message{
|
|
|
+ pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
|
|
|
+ pack.Tag{14, pack.BytesType}, pack.String("abc\xff"),
|
|
|
+ }),
|
|
|
+ }),
|
|
|
+ }.Marshal(),
|
|
|
+ },
|
|
|
+ {
|
|
|
+ desc: "invalid UTF-8 in map key",
|
|
|
+ decodeTo: []proto.Message{&test3pb.TestAllTypes{
|
|
|
+ MapStringString: map[string]string{"key\xff": "val"},
|
|
|
+ }},
|
|
|
+ wire: pack.Message{
|
|
|
+ pack.Tag{69, pack.BytesType}, pack.LengthPrefix(pack.Message{
|
|
|
+ pack.Tag{1, pack.BytesType}, pack.String("key\xff"),
|
|
|
+ pack.Tag{2, pack.BytesType}, pack.String("val"),
|
|
|
+ }),
|
|
|
+ }.Marshal(),
|
|
|
+ },
|
|
|
+ {
|
|
|
+ desc: "invalid UTF-8 in map value",
|
|
|
+ decodeTo: []proto.Message{&test3pb.TestAllTypes{
|
|
|
+ MapStringString: map[string]string{"key": "val\xff"},
|
|
|
+ }},
|
|
|
+ wire: pack.Message{
|
|
|
+ pack.Tag{69, pack.BytesType}, pack.LengthPrefix(pack.Message{
|
|
|
+ pack.Tag{1, pack.BytesType}, pack.String("key"),
|
|
|
+ pack.Tag{2, pack.BytesType}, pack.String("val\xff"),
|
|
|
+ }),
|
|
|
+ }.Marshal(),
|
|
|
+ },
|
|
|
+}
|
|
|
+
|
|
|
func build(m proto.Message, opts ...buildOpt) proto.Message {
|
|
|
for _, opt := range opts {
|
|
|
opt(m)
|
|
|
@@ -1185,3 +1266,17 @@ func marshalText(m proto.Message) string {
|
|
|
b, _ := textpb.Marshal(m)
|
|
|
return string(b)
|
|
|
}
|
|
|
+
|
|
|
+func isErrInvalidUTF8(err error) bool {
|
|
|
+ nerr, ok := err.(errors.NonFatalErrors)
|
|
|
+ if !ok || len(nerr) == 0 {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ for _, err := range nerr {
|
|
|
+ if e, ok := err.(interface{ InvalidUTF8() bool }); ok && e.InvalidUTF8() {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ return true
|
|
|
+}
|