Kaynağa Gözat

proto: mention field name in error message (#616)

Instead of simply saying there was a UTF-8 validation error, specify
which field in which message had an issue.
Joe Tsai 7 yıl önce
ebeveyn
işleme
3a3da3a4e2
2 değiştirilmiş dosya ile 26 ekleme ve 8 silme
  1. 4 0
      proto/table_marshal.go
  2. 22 8
      proto/table_unmarshal.go

+ 4 - 0
proto/table_marshal.go

@@ -277,6 +277,10 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte
 			if err == errRepeatedHasNil {
 			if err == errRepeatedHasNil {
 				err = errors.New("proto: repeated field " + f.name + " has nil element")
 				err = errors.New("proto: repeated field " + f.name + " has nil element")
 			}
 			}
+			if err == errInvalidUTF8 {
+				fullName := revProtoTypes[reflect.PtrTo(u.typ)] + "." + f.name
+				err = fmt.Errorf("proto: string field %q contains invalid UTF-8", fullName)
+			}
 			return b, err
 			return b, err
 		}
 		}
 	}
 	}

+ 22 - 8
proto/table_unmarshal.go

@@ -97,6 +97,8 @@ type unmarshalFieldInfo struct {
 
 
 	// if a required field, contains a single set bit at this field's index in the required field list.
 	// if a required field, contains a single set bit at this field's index in the required field list.
 	reqMask uint64
 	reqMask uint64
+
+	name string // name of the field, for error reporting
 }
 }
 
 
 var (
 var (
@@ -181,6 +183,10 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error {
 				continue
 				continue
 			}
 			}
 			if err != errInternalBadWireType {
 			if err != errInternalBadWireType {
+				if err == errInvalidUTF8 {
+					fullName := revProtoTypes[reflect.PtrTo(u.typ)] + "." + f.name
+					err = fmt.Errorf("proto: string field %q contains invalid UTF-8", fullName)
+				}
 				return err
 				return err
 			}
 			}
 			// Fragments with bad wire type are treated as unknown fields.
 			// Fragments with bad wire type are treated as unknown fields.
@@ -351,7 +357,7 @@ func (u *unmarshalInfo) computeUnmarshalInfo() {
 		}
 		}
 
 
 		// Store the info in the correct slot in the message.
 		// Store the info in the correct slot in the message.
-		u.setTag(tag, toField(&f), unmarshal, reqMask)
+		u.setTag(tag, toField(&f), unmarshal, reqMask, name)
 	}
 	}
 
 
 	// Find any types associated with oneof fields.
 	// Find any types associated with oneof fields.
@@ -366,10 +372,17 @@ func (u *unmarshalInfo) computeUnmarshalInfo() {
 
 
 			f := typ.Field(0) // oneof implementers have one field
 			f := typ.Field(0) // oneof implementers have one field
 			baseUnmarshal := fieldUnmarshaler(&f)
 			baseUnmarshal := fieldUnmarshaler(&f)
-			tagstr := strings.Split(f.Tag.Get("protobuf"), ",")[1]
-			tag, err := strconv.Atoi(tagstr)
+			tags := strings.Split(f.Tag.Get("protobuf"), ",")
+			fieldNum, err := strconv.Atoi(tags[1])
 			if err != nil {
 			if err != nil {
-				panic("protobuf tag field not an integer: " + tagstr)
+				panic("protobuf tag field not an integer: " + tags[1])
+			}
+			var name string
+			for _, tag := range tags {
+				if strings.HasPrefix(tag, "name=") {
+					name = strings.TrimPrefix(tag, "name=")
+					break
+				}
 			}
 			}
 
 
 			// Find the oneof field that this struct implements.
 			// Find the oneof field that this struct implements.
@@ -380,7 +393,7 @@ func (u *unmarshalInfo) computeUnmarshalInfo() {
 					// That lets us know where this struct should be stored
 					// That lets us know where this struct should be stored
 					// when we encounter it during unmarshaling.
 					// when we encounter it during unmarshaling.
 					unmarshal := makeUnmarshalOneof(typ, of.ityp, baseUnmarshal)
 					unmarshal := makeUnmarshalOneof(typ, of.ityp, baseUnmarshal)
-					u.setTag(tag, of.field, unmarshal, 0)
+					u.setTag(fieldNum, of.field, unmarshal, 0, name)
 				}
 				}
 			}
 			}
 		}
 		}
@@ -401,7 +414,7 @@ func (u *unmarshalInfo) computeUnmarshalInfo() {
 	// [0 0] is [tag=0/wiretype=varint varint-encoded-0].
 	// [0 0] is [tag=0/wiretype=varint varint-encoded-0].
 	u.setTag(0, zeroField, func(b []byte, f pointer, w int) ([]byte, error) {
 	u.setTag(0, zeroField, func(b []byte, f pointer, w int) ([]byte, error) {
 		return nil, fmt.Errorf("proto: %s: illegal tag 0 (wire type %d)", t, w)
 		return nil, fmt.Errorf("proto: %s: illegal tag 0 (wire type %d)", t, w)
-	}, 0)
+	}, 0, "")
 
 
 	// Set mask for required field check.
 	// Set mask for required field check.
 	u.reqMask = uint64(1)<<uint(len(u.reqFields)) - 1
 	u.reqMask = uint64(1)<<uint(len(u.reqFields)) - 1
@@ -413,8 +426,9 @@ func (u *unmarshalInfo) computeUnmarshalInfo() {
 // tag = tag # for field
 // tag = tag # for field
 // field/unmarshal = unmarshal info for that field.
 // field/unmarshal = unmarshal info for that field.
 // reqMask = if required, bitmask for field position in required field list. 0 otherwise.
 // reqMask = if required, bitmask for field position in required field list. 0 otherwise.
-func (u *unmarshalInfo) setTag(tag int, field field, unmarshal unmarshaler, reqMask uint64) {
-	i := unmarshalFieldInfo{field: field, unmarshal: unmarshal, reqMask: reqMask}
+// name = short name of the field.
+func (u *unmarshalInfo) setTag(tag int, field field, unmarshal unmarshaler, reqMask uint64, name string) {
+	i := unmarshalFieldInfo{field: field, unmarshal: unmarshal, reqMask: reqMask, name: name}
 	n := u.typ.NumField()
 	n := u.typ.NumField()
 	if tag >= 0 && (tag < 16 || tag < 2*n) { // TODO: what are the right numbers here?
 	if tag >= 0 && (tag < 16 || tag < 2*n) { // TODO: what are the right numbers here?
 		for len(u.dense) <= tag {
 		for len(u.dense) <= tag {