Browse Source

proto: make invalid UTF-8 errors non-fatal (#660)

The current logic currently treats RequredNotSetError as a non-fatal
error such that it continues to proceed with its normal execution, but
returns that non-fatal error at the end.

We now make it such that invalid UTF-8 is also a distinguishable error
that is treated as non-fatal by the execution logic. In the rare event
that both an RequiredNotSet error and InvalidUTF8 error is encountered,
the first one encountered is returned.

In the process of making this change, we also fix a number of cases where
RequiredNotSet was treated as fatal, when it should not have been non-fatal
(notably in the logic for extensions, message sets, and maps).

This change deliberately does not provide an API make it easy to distinguish
invalid UTF-8 errors as this is not a normal behavior we want users to
depend on. We can always expose additional API for this in the future.

Users can test for invalid UTF-8 with:
	re, ok := err.(interface{ InvalidUTF8() bool })
	isInvalidUTF8 := ok && re.InvalidUTF8()
Joe Tsai 7 years ago
parent
commit
f5983d50c8
5 changed files with 169 additions and 94 deletions
  1. 27 6
      proto/all_test.go
  2. 0 15
      proto/encode.go
  3. 60 2
      proto/lib.go
  4. 57 47
      proto/table_marshal.go
  5. 25 24
      proto/table_unmarshal.go

+ 27 - 6
proto/all_test.go

@@ -2256,26 +2256,32 @@ func TestInvalidUTF8(t *testing.T) {
 		label  string
 		label  string
 		proto2 Message
 		proto2 Message
 		proto3 Message
 		proto3 Message
+		want   []byte
 	}{{
 	}{{
 		label:  "Scalar",
 		label:  "Scalar",
 		proto2: &TestUTF8{Scalar: String(invalidUTF8)},
 		proto2: &TestUTF8{Scalar: String(invalidUTF8)},
 		proto3: &pb3.TestUTF8{Scalar: invalidUTF8},
 		proto3: &pb3.TestUTF8{Scalar: invalidUTF8},
+		want:   []byte{0x0a, 0x07, 0xde, 0xad, 0xbe, 0xef, 0x80, 0x00, 0xff},
 	}, {
 	}, {
 		label:  "Vector",
 		label:  "Vector",
 		proto2: &TestUTF8{Vector: []string{invalidUTF8}},
 		proto2: &TestUTF8{Vector: []string{invalidUTF8}},
 		proto3: &pb3.TestUTF8{Vector: []string{invalidUTF8}},
 		proto3: &pb3.TestUTF8{Vector: []string{invalidUTF8}},
+		want:   []byte{0x12, 0x07, 0xde, 0xad, 0xbe, 0xef, 0x80, 0x00, 0xff},
 	}, {
 	}, {
 		label:  "Oneof",
 		label:  "Oneof",
 		proto2: &TestUTF8{Oneof: &TestUTF8_Field{invalidUTF8}},
 		proto2: &TestUTF8{Oneof: &TestUTF8_Field{invalidUTF8}},
 		proto3: &pb3.TestUTF8{Oneof: &pb3.TestUTF8_Field{invalidUTF8}},
 		proto3: &pb3.TestUTF8{Oneof: &pb3.TestUTF8_Field{invalidUTF8}},
+		want:   []byte{0x1a, 0x07, 0xde, 0xad, 0xbe, 0xef, 0x80, 0x00, 0xff},
 	}, {
 	}, {
 		label:  "MapKey",
 		label:  "MapKey",
 		proto2: &TestUTF8{MapKey: map[string]int64{invalidUTF8: 0}},
 		proto2: &TestUTF8{MapKey: map[string]int64{invalidUTF8: 0}},
 		proto3: &pb3.TestUTF8{MapKey: map[string]int64{invalidUTF8: 0}},
 		proto3: &pb3.TestUTF8{MapKey: map[string]int64{invalidUTF8: 0}},
+		want:   []byte{0x22, 0x0b, 0x0a, 0x07, 0xde, 0xad, 0xbe, 0xef, 0x80, 0x00, 0xff, 0x10, 0x00},
 	}, {
 	}, {
 		label:  "MapValue",
 		label:  "MapValue",
 		proto2: &TestUTF8{MapValue: map[int64]string{0: invalidUTF8}},
 		proto2: &TestUTF8{MapValue: map[int64]string{0: invalidUTF8}},
 		proto3: &pb3.TestUTF8{MapValue: map[int64]string{0: invalidUTF8}},
 		proto3: &pb3.TestUTF8{MapValue: map[int64]string{0: invalidUTF8}},
+		want:   []byte{0x2a, 0x0b, 0x08, 0x00, 0x12, 0x07, 0xde, 0xad, 0xbe, 0xef, 0x80, 0x00, 0xff},
 	}}
 	}}
 
 
 	for _, tt := range tests {
 	for _, tt := range tests {
@@ -2284,22 +2290,37 @@ func TestInvalidUTF8(t *testing.T) {
 		if err != nil {
 		if err != nil {
 			t.Errorf("Marshal(proto2.%s) = %v, want nil", tt.label, err)
 			t.Errorf("Marshal(proto2.%s) = %v, want nil", tt.label, err)
 		}
 		}
-		tt.proto2.Reset()
-		err = Unmarshal(b, tt.proto2)
-		if err != nil {
+		if !bytes.Equal(b, tt.want) {
+			t.Errorf("Marshal(proto2.%s) = %x, want %x", tt.label, b, tt.want)
+		}
+
+		m := Clone(tt.proto2)
+		m.Reset()
+		if err = Unmarshal(tt.want, m); err != nil {
 			t.Errorf("Unmarshal(proto2.%s) = %v, want nil", tt.label, err)
 			t.Errorf("Unmarshal(proto2.%s) = %v, want nil", tt.label, err)
 		}
 		}
+		if !Equal(m, tt.proto2) {
+			t.Errorf("proto2.%s: output mismatch:\ngot  %v\nwant %v", tt.label, m, tt.proto2)
+		}
 
 
 		// Proto3 should validate UTF-8.
 		// Proto3 should validate UTF-8.
-		_, err = Marshal(tt.proto3)
+		b, err = Marshal(tt.proto3)
 		if err == nil {
 		if err == nil {
 			t.Errorf("Marshal(proto3.%s) = %v, want non-nil", tt.label, err)
 			t.Errorf("Marshal(proto3.%s) = %v, want non-nil", tt.label, err)
 		}
 		}
-		tt.proto3.Reset()
-		err = Unmarshal(b, tt.proto3)
+		if !bytes.Equal(b, tt.want) {
+			t.Errorf("Marshal(proto3.%s) = %x, want %x", tt.label, b, tt.want)
+		}
+
+		m = Clone(tt.proto3)
+		m.Reset()
+		err = Unmarshal(tt.want, m)
 		if err == nil {
 		if err == nil {
 			t.Errorf("Unmarshal(proto3.%s) = %v, want non-nil", tt.label, err)
 			t.Errorf("Unmarshal(proto3.%s) = %v, want non-nil", tt.label, err)
 		}
 		}
+		if !Equal(m, tt.proto3) {
+			t.Errorf("proto3.%s: output mismatch:\ngot  %v\nwant %v", tt.label, m, tt.proto2)
+		}
 	}
 	}
 }
 }
 
 

+ 0 - 15
proto/encode.go

@@ -37,24 +37,9 @@ package proto
 
 
 import (
 import (
 	"errors"
 	"errors"
-	"fmt"
 	"reflect"
 	"reflect"
 )
 )
 
 
-// RequiredNotSetError is an error type returned by either Marshal or Unmarshal.
-// Marshal reports this when a required field is not initialized.
-// Unmarshal reports this when a required field is missing from the wire data.
-type RequiredNotSetError struct {
-	field string
-}
-
-func (e *RequiredNotSetError) Error() string {
-	if e.field == "" {
-		return fmt.Sprintf("proto: required field not set")
-	}
-	return fmt.Sprintf("proto: required field %q not set", e.field)
-}
-
 var (
 var (
 	// errRepeatedHasNil is the error returned if Marshal is called with
 	// errRepeatedHasNil is the error returned if Marshal is called with
 	// a struct with a repeated field containing a nil element.
 	// a struct with a repeated field containing a nil element.

+ 60 - 2
proto/lib.go

@@ -265,7 +265,6 @@ package proto
 
 
 import (
 import (
 	"encoding/json"
 	"encoding/json"
-	"errors"
 	"fmt"
 	"fmt"
 	"log"
 	"log"
 	"reflect"
 	"reflect"
@@ -274,7 +273,66 @@ import (
 	"sync"
 	"sync"
 )
 )
 
 
-var errInvalidUTF8 = errors.New("proto: invalid UTF-8 string")
+// RequiredNotSetError is an error type returned by either Marshal or Unmarshal.
+// Marshal reports this when a required field is not initialized.
+// Unmarshal reports this when a required field is missing from the wire data.
+type RequiredNotSetError struct{ field string }
+
+func (e *RequiredNotSetError) Error() string {
+	if e.field == "" {
+		return fmt.Sprintf("proto: required field not set")
+	}
+	return fmt.Sprintf("proto: required field %q not set", e.field)
+}
+func (e *RequiredNotSetError) RequiredNotSet() bool {
+	return true
+}
+
+type invalidUTF8Error struct{ field string }
+
+func (e *invalidUTF8Error) Error() string {
+	if e.field == "" {
+		return "proto: invalid UTF-8 detected"
+	}
+	return fmt.Sprintf("proto: field %q contains invalid UTF-8", e.field)
+}
+func (e *invalidUTF8Error) InvalidUTF8() bool {
+	return true
+}
+
+// errInvalidUTF8 is a sentinel error to identify fields with invalid UTF-8.
+// This error should not be exposed to the external API as such errors should
+// be recreated with the field information.
+var errInvalidUTF8 = &invalidUTF8Error{}
+
+// isNonFatal reports whether the error is either a RequiredNotSet error
+// or a InvalidUTF8 error.
+func isNonFatal(err error) bool {
+	if re, ok := err.(interface{ RequiredNotSet() bool }); ok && re.RequiredNotSet() {
+		return true
+	}
+	if re, ok := err.(interface{ InvalidUTF8() bool }); ok && re.InvalidUTF8() {
+		return true
+	}
+	return false
+}
+
+type nonFatal struct{ E error }
+
+// Merge merges err into nf and reports whether it was successful.
+// Otherwise it returns false for any fatal non-nil errors.
+func (nf *nonFatal) Merge(err error) (ok bool) {
+	if err == nil {
+		return true // not an error
+	}
+	if !isNonFatal(err) {
+		return false // fatal error
+	}
+	if nf.E == nil {
+		nf.E = err // store first instance of non-fatal error
+	}
+	return true
+}
 
 
 // Message is implemented by generated protocol buffer messages.
 // Message is implemented by generated protocol buffer messages.
 type Message interface {
 type Message interface {

+ 57 - 47
proto/table_marshal.go

@@ -231,7 +231,7 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte
 		return b, err
 		return b, err
 	}
 	}
 
 
-	var err, errreq error
+	var err, errLater error
 	// The old marshaler encodes extensions at beginning.
 	// The old marshaler encodes extensions at beginning.
 	if u.extensions.IsValid() {
 	if u.extensions.IsValid() {
 		e := ptr.offset(u.extensions).toExtensions()
 		e := ptr.offset(u.extensions).toExtensions()
@@ -252,11 +252,11 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte
 		}
 		}
 	}
 	}
 	for _, f := range u.fields {
 	for _, f := range u.fields {
-		if f.required && errreq == nil {
+		if f.required && errLater == nil {
 			if ptr.offset(f.field).getPointer().isNil() {
 			if ptr.offset(f.field).getPointer().isNil() {
 				// Required field is not set.
 				// Required field is not set.
 				// We record the error but keep going, to give a complete marshaling.
 				// We record the error but keep going, to give a complete marshaling.
-				errreq = &RequiredNotSetError{f.name}
+				errLater = &RequiredNotSetError{f.name}
 				continue
 				continue
 			}
 			}
 		}
 		}
@@ -269,8 +269,8 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte
 			if err1, ok := err.(*RequiredNotSetError); ok {
 			if err1, ok := err.(*RequiredNotSetError); ok {
 				// Required field in submessage is not set.
 				// Required field in submessage is not set.
 				// We record the error but keep going, to give a complete marshaling.
 				// We record the error but keep going, to give a complete marshaling.
-				if errreq == nil {
-					errreq = &RequiredNotSetError{f.name + "." + err1.field}
+				if errLater == nil {
+					errLater = &RequiredNotSetError{f.name + "." + err1.field}
 				}
 				}
 				continue
 				continue
 			}
 			}
@@ -278,8 +278,11 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte
 				err = errors.New("proto: repeated field " + f.name + " has nil element")
 				err = errors.New("proto: repeated field " + f.name + " has nil element")
 			}
 			}
 			if err == errInvalidUTF8 {
 			if err == errInvalidUTF8 {
-				fullName := revProtoTypes[reflect.PtrTo(u.typ)] + "." + f.name
-				err = fmt.Errorf("proto: string field %q contains invalid UTF-8", fullName)
+				if errLater == nil {
+					fullName := revProtoTypes[reflect.PtrTo(u.typ)] + "." + f.name
+					errLater = &invalidUTF8Error{fullName}
+				}
+				continue
 			}
 			}
 			return b, err
 			return b, err
 		}
 		}
@@ -288,7 +291,7 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte
 		s := *ptr.offset(u.unrecognized).toBytes()
 		s := *ptr.offset(u.unrecognized).toBytes()
 		b = append(b, s...)
 		b = append(b, s...)
 	}
 	}
-	return b, errreq
+	return b, errLater
 }
 }
 
 
 // computeMarshalInfo initializes the marshal info.
 // computeMarshalInfo initializes the marshal info.
@@ -2038,52 +2041,68 @@ func appendStringSlice(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, e
 	return b, nil
 	return b, nil
 }
 }
 func appendUTF8StringValue(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
 func appendUTF8StringValue(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
+	var invalidUTF8 bool
 	v := *ptr.toString()
 	v := *ptr.toString()
 	if !utf8.ValidString(v) {
 	if !utf8.ValidString(v) {
-		return nil, errInvalidUTF8
+		invalidUTF8 = true
 	}
 	}
 	b = appendVarint(b, wiretag)
 	b = appendVarint(b, wiretag)
 	b = appendVarint(b, uint64(len(v)))
 	b = appendVarint(b, uint64(len(v)))
 	b = append(b, v...)
 	b = append(b, v...)
+	if invalidUTF8 {
+		return b, errInvalidUTF8
+	}
 	return b, nil
 	return b, nil
 }
 }
 func appendUTF8StringValueNoZero(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
 func appendUTF8StringValueNoZero(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
+	var invalidUTF8 bool
 	v := *ptr.toString()
 	v := *ptr.toString()
 	if v == "" {
 	if v == "" {
 		return b, nil
 		return b, nil
 	}
 	}
 	if !utf8.ValidString(v) {
 	if !utf8.ValidString(v) {
-		return nil, errInvalidUTF8
+		invalidUTF8 = true
 	}
 	}
 	b = appendVarint(b, wiretag)
 	b = appendVarint(b, wiretag)
 	b = appendVarint(b, uint64(len(v)))
 	b = appendVarint(b, uint64(len(v)))
 	b = append(b, v...)
 	b = append(b, v...)
+	if invalidUTF8 {
+		return b, errInvalidUTF8
+	}
 	return b, nil
 	return b, nil
 }
 }
 func appendUTF8StringPtr(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
 func appendUTF8StringPtr(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
+	var invalidUTF8 bool
 	p := *ptr.toStringPtr()
 	p := *ptr.toStringPtr()
 	if p == nil {
 	if p == nil {
 		return b, nil
 		return b, nil
 	}
 	}
 	v := *p
 	v := *p
 	if !utf8.ValidString(v) {
 	if !utf8.ValidString(v) {
-		return nil, errInvalidUTF8
+		invalidUTF8 = true
 	}
 	}
 	b = appendVarint(b, wiretag)
 	b = appendVarint(b, wiretag)
 	b = appendVarint(b, uint64(len(v)))
 	b = appendVarint(b, uint64(len(v)))
 	b = append(b, v...)
 	b = append(b, v...)
+	if invalidUTF8 {
+		return b, errInvalidUTF8
+	}
 	return b, nil
 	return b, nil
 }
 }
 func appendUTF8StringSlice(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
 func appendUTF8StringSlice(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
+	var invalidUTF8 bool
 	s := *ptr.toStringSlice()
 	s := *ptr.toStringSlice()
 	for _, v := range s {
 	for _, v := range s {
 		if !utf8.ValidString(v) {
 		if !utf8.ValidString(v) {
-			return nil, errInvalidUTF8
+			invalidUTF8 = true
 		}
 		}
 		b = appendVarint(b, wiretag)
 		b = appendVarint(b, wiretag)
 		b = appendVarint(b, uint64(len(v)))
 		b = appendVarint(b, uint64(len(v)))
 		b = append(b, v...)
 		b = append(b, v...)
 	}
 	}
+	if invalidUTF8 {
+		return b, errInvalidUTF8
+	}
 	return b, nil
 	return b, nil
 }
 }
 func appendBytes(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
 func appendBytes(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
@@ -2162,7 +2181,8 @@ func makeGroupSliceMarshaler(u *marshalInfo) (sizer, marshaler) {
 		},
 		},
 		func(b []byte, ptr pointer, wiretag uint64, deterministic bool) ([]byte, error) {
 		func(b []byte, ptr pointer, wiretag uint64, deterministic bool) ([]byte, error) {
 			s := ptr.getPointerSlice()
 			s := ptr.getPointerSlice()
-			var err, errreq error
+			var err error
+			var nerr nonFatal
 			for _, v := range s {
 			for _, v := range s {
 				if v.isNil() {
 				if v.isNil() {
 					return b, errRepeatedHasNil
 					return b, errRepeatedHasNil
@@ -2170,22 +2190,14 @@ func makeGroupSliceMarshaler(u *marshalInfo) (sizer, marshaler) {
 				b = appendVarint(b, wiretag) // start group
 				b = appendVarint(b, wiretag) // start group
 				b, err = u.marshal(b, v, deterministic)
 				b, err = u.marshal(b, v, deterministic)
 				b = appendVarint(b, wiretag+(WireEndGroup-WireStartGroup)) // end group
 				b = appendVarint(b, wiretag+(WireEndGroup-WireStartGroup)) // end group
-				if err != nil {
-					if _, ok := err.(*RequiredNotSetError); ok {
-						// Required field in submessage is not set.
-						// We record the error but keep going, to give a complete marshaling.
-						if errreq == nil {
-							errreq = err
-						}
-						continue
-					}
+				if !nerr.Merge(err) {
 					if err == ErrNil {
 					if err == ErrNil {
 						err = errRepeatedHasNil
 						err = errRepeatedHasNil
 					}
 					}
 					return b, err
 					return b, err
 				}
 				}
 			}
 			}
-			return b, errreq
+			return b, nerr.E
 		}
 		}
 }
 }
 
 
@@ -2229,7 +2241,8 @@ func makeMessageSliceMarshaler(u *marshalInfo) (sizer, marshaler) {
 		},
 		},
 		func(b []byte, ptr pointer, wiretag uint64, deterministic bool) ([]byte, error) {
 		func(b []byte, ptr pointer, wiretag uint64, deterministic bool) ([]byte, error) {
 			s := ptr.getPointerSlice()
 			s := ptr.getPointerSlice()
-			var err, errreq error
+			var err error
+			var nerr nonFatal
 			for _, v := range s {
 			for _, v := range s {
 				if v.isNil() {
 				if v.isNil() {
 					return b, errRepeatedHasNil
 					return b, errRepeatedHasNil
@@ -2239,22 +2252,14 @@ func makeMessageSliceMarshaler(u *marshalInfo) (sizer, marshaler) {
 				b = appendVarint(b, uint64(siz))
 				b = appendVarint(b, uint64(siz))
 				b, err = u.marshal(b, v, deterministic)
 				b, err = u.marshal(b, v, deterministic)
 
 
-				if err != nil {
-					if _, ok := err.(*RequiredNotSetError); ok {
-						// Required field in submessage is not set.
-						// We record the error but keep going, to give a complete marshaling.
-						if errreq == nil {
-							errreq = err
-						}
-						continue
-					}
+				if !nerr.Merge(err) {
 					if err == ErrNil {
 					if err == ErrNil {
 						err = errRepeatedHasNil
 						err = errRepeatedHasNil
 					}
 					}
 					return b, err
 					return b, err
 				}
 				}
 			}
 			}
-			return b, errreq
+			return b, nerr.E
 		}
 		}
 }
 }
 
 
@@ -2317,6 +2322,8 @@ func makeMapMarshaler(f *reflect.StructField) (sizer, marshaler) {
 			if len(keys) > 1 && deterministic {
 			if len(keys) > 1 && deterministic {
 				sort.Sort(mapKeys(keys))
 				sort.Sort(mapKeys(keys))
 			}
 			}
+
+			var nerr nonFatal
 			for _, k := range keys {
 			for _, k := range keys {
 				ki := k.Interface()
 				ki := k.Interface()
 				vi := m.MapIndex(k).Interface()
 				vi := m.MapIndex(k).Interface()
@@ -2326,15 +2333,15 @@ func makeMapMarshaler(f *reflect.StructField) (sizer, marshaler) {
 				siz := keySizer(kaddr, 1) + valCachedSizer(vaddr, 1) // tag of key = 1 (size=1), tag of val = 2 (size=1)
 				siz := keySizer(kaddr, 1) + valCachedSizer(vaddr, 1) // tag of key = 1 (size=1), tag of val = 2 (size=1)
 				b = appendVarint(b, uint64(siz))
 				b = appendVarint(b, uint64(siz))
 				b, err = keyMarshaler(b, kaddr, keyWireTag, deterministic)
 				b, err = keyMarshaler(b, kaddr, keyWireTag, deterministic)
-				if err != nil {
+				if !nerr.Merge(err) {
 					return b, err
 					return b, err
 				}
 				}
 				b, err = valMarshaler(b, vaddr, valWireTag, deterministic)
 				b, err = valMarshaler(b, vaddr, valWireTag, deterministic)
-				if err != nil && err != ErrNil { // allow nil value in map
+				if err != ErrNil && !nerr.Merge(err) { // allow nil value in map
 					return b, err
 					return b, err
 				}
 				}
 			}
 			}
-			return b, nil
+			return b, nerr.E
 		}
 		}
 }
 }
 
 
@@ -2407,6 +2414,7 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de
 	defer mu.Unlock()
 	defer mu.Unlock()
 
 
 	var err error
 	var err error
+	var nerr nonFatal
 
 
 	// Fast-path for common cases: zero or one extensions.
 	// Fast-path for common cases: zero or one extensions.
 	// Don't bother sorting the keys.
 	// Don't bother sorting the keys.
@@ -2426,11 +2434,11 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de
 			v := e.value
 			v := e.value
 			p := toAddrPointer(&v, ei.isptr)
 			p := toAddrPointer(&v, ei.isptr)
 			b, err = ei.marshaler(b, p, ei.wiretag, deterministic)
 			b, err = ei.marshaler(b, p, ei.wiretag, deterministic)
-			if err != nil {
+			if !nerr.Merge(err) {
 				return b, err
 				return b, err
 			}
 			}
 		}
 		}
-		return b, nil
+		return b, nerr.E
 	}
 	}
 
 
 	// Sort the keys to provide a deterministic encoding.
 	// Sort the keys to provide a deterministic encoding.
@@ -2457,11 +2465,11 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de
 		v := e.value
 		v := e.value
 		p := toAddrPointer(&v, ei.isptr)
 		p := toAddrPointer(&v, ei.isptr)
 		b, err = ei.marshaler(b, p, ei.wiretag, deterministic)
 		b, err = ei.marshaler(b, p, ei.wiretag, deterministic)
-		if err != nil {
+		if !nerr.Merge(err) {
 			return b, err
 			return b, err
 		}
 		}
 	}
 	}
-	return b, nil
+	return b, nerr.E
 }
 }
 
 
 // message set format is:
 // message set format is:
@@ -2518,6 +2526,7 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de
 	defer mu.Unlock()
 	defer mu.Unlock()
 
 
 	var err error
 	var err error
+	var nerr nonFatal
 
 
 	// Fast-path for common cases: zero or one extensions.
 	// Fast-path for common cases: zero or one extensions.
 	// Don't bother sorting the keys.
 	// Don't bother sorting the keys.
@@ -2544,12 +2553,12 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de
 			v := e.value
 			v := e.value
 			p := toAddrPointer(&v, ei.isptr)
 			p := toAddrPointer(&v, ei.isptr)
 			b, err = ei.marshaler(b, p, 3<<3|WireBytes, deterministic)
 			b, err = ei.marshaler(b, p, 3<<3|WireBytes, deterministic)
-			if err != nil {
+			if !nerr.Merge(err) {
 				return b, err
 				return b, err
 			}
 			}
 			b = append(b, 1<<3|WireEndGroup)
 			b = append(b, 1<<3|WireEndGroup)
 		}
 		}
-		return b, nil
+		return b, nerr.E
 	}
 	}
 
 
 	// Sort the keys to provide a deterministic encoding.
 	// Sort the keys to provide a deterministic encoding.
@@ -2583,11 +2592,11 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de
 		p := toAddrPointer(&v, ei.isptr)
 		p := toAddrPointer(&v, ei.isptr)
 		b, err = ei.marshaler(b, p, 3<<3|WireBytes, deterministic)
 		b, err = ei.marshaler(b, p, 3<<3|WireBytes, deterministic)
 		b = append(b, 1<<3|WireEndGroup)
 		b = append(b, 1<<3|WireEndGroup)
-		if err != nil {
+		if nerr.Merge(err) {
 			return b, err
 			return b, err
 		}
 		}
 	}
 	}
-	return b, nil
+	return b, nerr.E
 }
 }
 
 
 // sizeV1Extensions computes the size of encoded data for a V1-API extension field.
 // sizeV1Extensions computes the size of encoded data for a V1-API extension field.
@@ -2630,6 +2639,7 @@ func (u *marshalInfo) appendV1Extensions(b []byte, m map[int32]Extension, determ
 	sort.Ints(keys)
 	sort.Ints(keys)
 
 
 	var err error
 	var err error
+	var nerr nonFatal
 	for _, k := range keys {
 	for _, k := range keys {
 		e := m[int32(k)]
 		e := m[int32(k)]
 		if e.value == nil || e.desc == nil {
 		if e.value == nil || e.desc == nil {
@@ -2646,11 +2656,11 @@ func (u *marshalInfo) appendV1Extensions(b []byte, m map[int32]Extension, determ
 		v := e.value
 		v := e.value
 		p := toAddrPointer(&v, ei.isptr)
 		p := toAddrPointer(&v, ei.isptr)
 		b, err = ei.marshaler(b, p, ei.wiretag, deterministic)
 		b, err = ei.marshaler(b, p, ei.wiretag, deterministic)
-		if err != nil {
+		if !nerr.Merge(err) {
 			return b, err
 			return b, err
 		}
 		}
 	}
 	}
-	return b, nil
+	return b, nerr.E
 }
 }
 
 
 // newMarshaler is the interface representing objects that can marshal themselves.
 // newMarshaler is the interface representing objects that can marshal themselves.

+ 25 - 24
proto/table_unmarshal.go

@@ -138,8 +138,8 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error {
 	if u.isMessageSet {
 	if u.isMessageSet {
 		return UnmarshalMessageSet(b, m.offset(u.extensions).toExtensions())
 		return UnmarshalMessageSet(b, m.offset(u.extensions).toExtensions())
 	}
 	}
-	var reqMask uint64            // bitmask of required fields we've seen.
-	var rnse *RequiredNotSetError // an instance of a RequiredNotSetError returned by a submessage.
+	var reqMask uint64 // bitmask of required fields we've seen.
+	var errLater error
 	for len(b) > 0 {
 	for len(b) > 0 {
 		// Read tag and wire type.
 		// Read tag and wire type.
 		// Special case 1 and 2 byte varints.
 		// Special case 1 and 2 byte varints.
@@ -175,17 +175,20 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error {
 				reqMask |= f.reqMask
 				reqMask |= f.reqMask
 				continue
 				continue
 			}
 			}
-			if r, ok := err.(*RequiredNotSetError); ok {
+			if r, ok := err.(*RequiredNotSetError); ok && errLater == nil {
 				// Remember this error, but keep parsing. We need to produce
 				// Remember this error, but keep parsing. We need to produce
 				// a full parse even if a required field is missing.
 				// a full parse even if a required field is missing.
-				rnse = r
+				errLater = r
 				reqMask |= f.reqMask
 				reqMask |= f.reqMask
 				continue
 				continue
 			}
 			}
 			if err != errInternalBadWireType {
 			if err != errInternalBadWireType {
 				if err == errInvalidUTF8 {
 				if err == errInvalidUTF8 {
-					fullName := revProtoTypes[reflect.PtrTo(u.typ)] + "." + f.name
-					err = fmt.Errorf("proto: string field %q contains invalid UTF-8", fullName)
+					if errLater == nil {
+						fullName := revProtoTypes[reflect.PtrTo(u.typ)] + "." + f.name
+						errLater = &invalidUTF8Error{fullName}
+					}
+					continue
 				}
 				}
 				return err
 				return err
 			}
 			}
@@ -245,20 +248,16 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error {
 			emap[int32(tag)] = e
 			emap[int32(tag)] = e
 		}
 		}
 	}
 	}
-	if rnse != nil {
-		// A required field of a submessage/group is missing. Return that error.
-		return rnse
-	}
-	if reqMask != u.reqMask {
+	if reqMask != u.reqMask && errLater == nil {
 		// A required field of this message is missing.
 		// A required field of this message is missing.
 		for _, n := range u.reqFields {
 		for _, n := range u.reqFields {
 			if reqMask&1 == 0 {
 			if reqMask&1 == 0 {
-				return &RequiredNotSetError{n}
+				errLater = &RequiredNotSetError{n}
 			}
 			}
 			reqMask >>= 1
 			reqMask >>= 1
 		}
 		}
 	}
 	}
-	return nil
+	return errLater
 }
 }
 
 
 // computeUnmarshalInfo fills in u with information for use
 // computeUnmarshalInfo fills in u with information for use
@@ -1529,10 +1528,10 @@ func unmarshalUTF8StringValue(b []byte, f pointer, w int) ([]byte, error) {
 		return nil, io.ErrUnexpectedEOF
 		return nil, io.ErrUnexpectedEOF
 	}
 	}
 	v := string(b[:x])
 	v := string(b[:x])
+	*f.toString() = v
 	if !utf8.ValidString(v) {
 	if !utf8.ValidString(v) {
-		return nil, errInvalidUTF8
+		return b[x:], errInvalidUTF8
 	}
 	}
-	*f.toString() = v
 	return b[x:], nil
 	return b[x:], nil
 }
 }
 
 
@@ -1549,10 +1548,10 @@ func unmarshalUTF8StringPtr(b []byte, f pointer, w int) ([]byte, error) {
 		return nil, io.ErrUnexpectedEOF
 		return nil, io.ErrUnexpectedEOF
 	}
 	}
 	v := string(b[:x])
 	v := string(b[:x])
+	*f.toStringPtr() = &v
 	if !utf8.ValidString(v) {
 	if !utf8.ValidString(v) {
-		return nil, errInvalidUTF8
+		return b[x:], errInvalidUTF8
 	}
 	}
-	*f.toStringPtr() = &v
 	return b[x:], nil
 	return b[x:], nil
 }
 }
 
 
@@ -1569,11 +1568,11 @@ func unmarshalUTF8StringSlice(b []byte, f pointer, w int) ([]byte, error) {
 		return nil, io.ErrUnexpectedEOF
 		return nil, io.ErrUnexpectedEOF
 	}
 	}
 	v := string(b[:x])
 	v := string(b[:x])
-	if !utf8.ValidString(v) {
-		return nil, errInvalidUTF8
-	}
 	s := f.toStringSlice()
 	s := f.toStringSlice()
 	*s = append(*s, v)
 	*s = append(*s, v)
+	if !utf8.ValidString(v) {
+		return b[x:], errInvalidUTF8
+	}
 	return b[x:], nil
 	return b[x:], nil
 }
 }
 
 
@@ -1755,6 +1754,7 @@ func makeUnmarshalMap(f *reflect.StructField) unmarshaler {
 		// Maps will be somewhat slow. Oh well.
 		// Maps will be somewhat slow. Oh well.
 
 
 		// Read key and value from data.
 		// Read key and value from data.
+		var nerr nonFatal
 		k := reflect.New(kt)
 		k := reflect.New(kt)
 		v := reflect.New(vt)
 		v := reflect.New(vt)
 		for len(b) > 0 {
 		for len(b) > 0 {
@@ -1775,7 +1775,7 @@ func makeUnmarshalMap(f *reflect.StructField) unmarshaler {
 				err = errInternalBadWireType // skip unknown tag
 				err = errInternalBadWireType // skip unknown tag
 			}
 			}
 
 
-			if err == nil {
+			if nerr.Merge(err) {
 				continue
 				continue
 			}
 			}
 			if err != errInternalBadWireType {
 			if err != errInternalBadWireType {
@@ -1798,7 +1798,7 @@ func makeUnmarshalMap(f *reflect.StructField) unmarshaler {
 		// Insert into map.
 		// Insert into map.
 		m.SetMapIndex(k.Elem(), v.Elem())
 		m.SetMapIndex(k.Elem(), v.Elem())
 
 
-		return r, nil
+		return r, nerr.E
 	}
 	}
 }
 }
 
 
@@ -1824,15 +1824,16 @@ func makeUnmarshalOneof(typ, ityp reflect.Type, unmarshal unmarshaler) unmarshal
 		// Unmarshal data into holder.
 		// Unmarshal data into holder.
 		// We unmarshal into the first field of the holder object.
 		// We unmarshal into the first field of the holder object.
 		var err error
 		var err error
+		var nerr nonFatal
 		b, err = unmarshal(b, valToPointer(v).offset(field0), w)
 		b, err = unmarshal(b, valToPointer(v).offset(field0), w)
-		if err != nil {
+		if !nerr.Merge(err) {
 			return nil, err
 			return nil, err
 		}
 		}
 
 
 		// Write pointer to holder into target field.
 		// Write pointer to holder into target field.
 		f.asPointerTo(ityp).Elem().Set(v)
 		f.asPointerTo(ityp).Elem().Set(v)
 
 
-		return b, nil
+		return b, nerr.E
 	}
 	}
 }
 }