Selaa lähdekoodia

proto: reject invalid UTF-8 in strings (#499)

The proto2 and proto3 specifications explicitly say that strings must be
composed of valid UTF-8 characters. 

The C++ implementation prints an error anytime strings with invalid UTF-8
is present during parsing or serializing. For proto3, it explicitly
errors out when parsing an invalid string.

Thus, we cause marshal/unmarshal to error if any string fields are invalid.
The error returned is fail-fast and indistinguishable. This means that the
we stop the unmarshal process the moment we detect an invalid string,
as opposed to finishing the unmarshal. An indistinguishable error means that
we provide no API for the user to programmatically check specifically for
invalid UTF-8 errors so that they can ignore it.

In conversations with the protobuf team, they felt strongly that there should be
no ability for users to ignore the UTF-8 validation error. However, if this change is
overly problematic for users, we can consider a workaround for users already
depending on strings containing invalid UTF-8.
Joe Tsai 8 vuotta sitten
vanhempi
commit
35253352f9
4 muutettua tiedostoa jossa 41 lisäystä ja 0 poistoa
  1. 15 0
      proto/all_test.go
  2. 3 0
      proto/lib.go
  3. 13 0
      proto/table_marshal.go
  4. 10 0
      proto/table_unmarshal.go

+ 15 - 0
proto/all_test.go

@@ -2184,6 +2184,21 @@ func TestConcurrentMarshal(t *testing.T) {
 	}
 }
 
+func TestInvalidUTF8(t *testing.T) {
+	const wire = "\x12\x04\xde\xea\xca\xfe"
+
+	var m GoTest
+	if err := Unmarshal([]byte(wire), &m); err == nil {
+		t.Errorf("Unmarshal error: got nil, want non-nil")
+	}
+
+	m.Reset()
+	m.Table = String(wire[2:])
+	if _, err := Marshal(&m); err == nil {
+		t.Errorf("Marshal error: got nil, want non-nil")
+	}
+}
+
 // Benchmarks
 
 func testMsg() *GoTest {

+ 3 - 0
proto/lib.go

@@ -265,6 +265,7 @@ package proto
 
 import (
 	"encoding/json"
+	"errors"
 	"fmt"
 	"log"
 	"reflect"
@@ -279,6 +280,8 @@ import (
 // This variable is temporary and will go away soon. Do not rely on it.
 var Proto3UnknownFields = false
 
+var errInvalidUTF8 = errors.New("proto: invalid UTF-8 string")
+
 // Message is implemented by generated protocol buffer messages.
 type Message interface {
 	Reset()

+ 13 - 0
proto/table_marshal.go

@@ -41,6 +41,7 @@ import (
 	"strings"
 	"sync"
 	"sync/atomic"
+	"unicode/utf8"
 )
 
 // a sizer takes a pointer to a field and the size of its tag, computes the size of
@@ -1983,6 +1984,9 @@ func appendBoolPackedSlice(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byt
 }
 func appendStringValue(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
 	v := *ptr.toString()
+	if !utf8.ValidString(v) {
+		return nil, errInvalidUTF8
+	}
 	b = appendVarint(b, wiretag)
 	b = appendVarint(b, uint64(len(v)))
 	b = append(b, v...)
@@ -1993,6 +1997,9 @@ func appendStringValueNoZero(b []byte, ptr pointer, wiretag uint64, _ bool) ([]b
 	if v == "" {
 		return b, nil
 	}
+	if !utf8.ValidString(v) {
+		return nil, errInvalidUTF8
+	}
 	b = appendVarint(b, wiretag)
 	b = appendVarint(b, uint64(len(v)))
 	b = append(b, v...)
@@ -2004,6 +2011,9 @@ func appendStringPtr(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, err
 		return b, nil
 	}
 	v := *p
+	if !utf8.ValidString(v) {
+		return nil, errInvalidUTF8
+	}
 	b = appendVarint(b, wiretag)
 	b = appendVarint(b, uint64(len(v)))
 	b = append(b, v...)
@@ -2012,6 +2022,9 @@ func appendStringPtr(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, err
 func appendStringSlice(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
 	s := *ptr.toStringSlice()
 	for _, v := range s {
+		if !utf8.ValidString(v) {
+			return nil, errInvalidUTF8
+		}
 		b = appendVarint(b, wiretag)
 		b = appendVarint(b, uint64(len(v)))
 		b = append(b, v...)

+ 10 - 0
proto/table_unmarshal.go

@@ -41,6 +41,7 @@ import (
 	"strings"
 	"sync"
 	"sync/atomic"
+	"unicode/utf8"
 )
 
 // Unmarshal is the entry point from the generated .pb.go files.
@@ -1514,6 +1515,9 @@ func unmarshalStringValue(b []byte, f pointer, w int) ([]byte, error) {
 		return nil, io.ErrUnexpectedEOF
 	}
 	v := string(b[:x])
+	if !utf8.ValidString(v) {
+		return nil, errInvalidUTF8
+	}
 	*f.toString() = v
 	return b[x:], nil
 }
@@ -1531,6 +1535,9 @@ func unmarshalStringPtr(b []byte, f pointer, w int) ([]byte, error) {
 		return nil, io.ErrUnexpectedEOF
 	}
 	v := string(b[:x])
+	if !utf8.ValidString(v) {
+		return nil, errInvalidUTF8
+	}
 	*f.toStringPtr() = &v
 	return b[x:], nil
 }
@@ -1548,6 +1555,9 @@ func unmarshalStringSlice(b []byte, f pointer, w int) ([]byte, error) {
 		return nil, io.ErrUnexpectedEOF
 	}
 	v := string(b[:x])
+	if !utf8.ValidString(v) {
+		return nil, errInvalidUTF8
+	}
 	s := f.toStringSlice()
 	*s = append(*s, v)
 	return b[x:], nil