瀏覽代碼

goprotobuf: Make several changes to ErrRequiredNotSet:
- Report the full or partial path to the first missing required field (where possible) instead of the message name.
- Make it ignorable. Unmarshal and Marshal will continue to decode/encode the full proto.

R=r
CC=golang-dev
https://codereview.appspot.com/13248047

David Symonds 12 年之前
父節點
當前提交
4646c37073
共有 3 個文件被更改,包括 135 次插入27 次删除
  1. 67 3
      proto/all_test.go
  2. 12 2
      proto/decode.go
  3. 56 22
      proto/encode.go

+ 67 - 3
proto/all_test.go

@@ -431,7 +431,7 @@ func TestRequiredBit(t *testing.T) {
 	err := o.Marshal(pb)
 	if err == nil {
 		t.Error("did not catch missing required fields")
-	} else if strings.Index(err.Error(), "GoTest") < 0 {
+	} else if strings.Index(err.Error(), "Kind") < 0 {
 		t.Error("wrong error type:", err)
 	}
 }
@@ -1205,7 +1205,7 @@ func TestRequiredFieldEnforcement(t *testing.T) {
 	_, err := Marshal(pb)
 	if err == nil {
 		t.Error("marshal: expected error, got nil")
-	} else if strings.Index(err.Error(), "GoTestField") < 0 {
+	} else if strings.Index(err.Error(), "Label") < 0 {
 		t.Errorf("marshal: bad error type: %v", err)
 	}
 
@@ -1216,7 +1216,7 @@ func TestRequiredFieldEnforcement(t *testing.T) {
 	err = Unmarshal(buf, pb)
 	if err == nil {
 		t.Error("unmarshal: expected error, got nil")
-	} else if strings.Index(err.Error(), "GoTestField") < 0 {
+	} else if strings.Index(err.Error(), "{Unknown}") < 0 {
 		t.Errorf("unmarshal: bad error type: %v", err)
 	}
 }
@@ -1670,6 +1670,70 @@ func TestEncodingSizes(t *testing.T) {
 	}
 }
 
+func TestErrRequiredNotSet(t *testing.T) {
+	pb := initGoTest(false)
+	pb.RequiredField.Label = nil
+	pb.F_Int32Required = nil
+	pb.F_Int64Required = nil
+
+	expected := "0807" + // field 1, encoding 0, value 7
+		"2206" + "120474797065" + // field 4, encoding 2 (GoTestField)
+		"5001" + // field 10, encoding 0, value 1
+		"6d20000000" + // field 13, encoding 5, value 0x20
+		"714000000000000000" + // field 14, encoding 1, value 0x40
+		"78a019" + // field 15, encoding 0, value 0xca0 = 3232
+		"8001c032" + // field 16, encoding 0, value 0x1940 = 6464
+		"8d0100004a45" + // field 17, encoding 5, value 3232.0
+		"9101000000000040b940" + // field 18, encoding 1, value 6464.0
+		"9a0106" + "737472696e67" + // field 19, encoding 2, string "string"
+		"b304" + // field 70, encoding 3, start group
+		"ba0408" + "7265717569726564" + // field 71, encoding 2, string "required"
+		"b404" + // field 70, encoding 4, end group
+		"aa0605" + "6279746573" + // field 101, encoding 2, string "bytes"
+		"b0063f" + // field 102, encoding 0, 0x3f zigzag32
+		"b8067f" // field 103, encoding 0, 0x7f zigzag64
+
+	o := old()
+	bytes, err := Marshal(pb)
+	if _, ok := err.(*ErrRequiredNotSet); !ok {
+		fmt.Printf("marshal-1 err = %v, want *ErrRequiredNotSet", err)
+		o.DebugPrint("", bytes)
+		t.Fatalf("expected = %s", expected)
+	}
+	if strings.Index(err.Error(), "RequiredField.Label") < 0 {
+		t.Errorf("marshal-1 wrong err msg: %v", err)
+	}
+	if !equal(bytes, expected, t) {
+		o.DebugPrint("neq 1", bytes)
+		t.Fatalf("expected = %s", expected)
+	}
+
+	// Now test Unmarshal by recreating the original buffer.
+	pbd := new(GoTest)
+	err = Unmarshal(bytes, pbd)
+	if _, ok := err.(*ErrRequiredNotSet); !ok {
+		t.Fatalf("unmarshal err = %v, want *ErrRequiredNotSet", err)
+		o.DebugPrint("", bytes)
+		t.Fatalf("string = %s", expected)
+	}
+	if strings.Index(err.Error(), "RequiredField.{Unknown}") < 0 {
+		t.Errorf("unmarshal wrong err msg: %v", err)
+	}
+	bytes, err = Marshal(pbd)
+	if _, ok := err.(*ErrRequiredNotSet); !ok {
+		t.Errorf("marshal-2 err = %v, want *ErrRequiredNotSet", err)
+		o.DebugPrint("", bytes)
+		t.Fatalf("string = %s", expected)
+	}
+	if strings.Index(err.Error(), "RequiredField.Label") < 0 {
+		t.Errorf("marshal-2 wrong err msg: %v", err)
+	}
+	if !equal(bytes, expected, t) {
+		o.DebugPrint("neq 2", bytes)
+		t.Fatalf("string = %s", expected)
+	}
+}
+
 func fuzzUnmarshal(t *testing.T, data []byte) {
 	defer func() {
 		if e := recover(); e != nil {

+ 12 - 2
proto/decode.go

@@ -353,6 +353,7 @@ func (p *Buffer) Unmarshal(pb Message) error {
 
 // unmarshalType does the work of unmarshaling a structure.
 func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group bool, base structPointer) error {
+	var state errorState
 	required, reqFields := prop.reqCount, uint64(0)
 
 	var err error
@@ -406,7 +407,10 @@ func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group
 				continue
 			}
 		}
-		err = dec(o, p, base)
+		decErr := dec(o, p, base)
+		if decErr != nil && !state.shouldContinue(decErr, p) {
+			err = decErr
+		}
 		if err == nil && p.Required {
 			// Successfully decoded a required field.
 			if tag <= 64 {
@@ -430,8 +434,14 @@ func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group
 		if is_group {
 			return io.ErrUnexpectedEOF
 		}
+		if state.err != nil {
+			return state.err
+		}
 		if required > 0 {
-			return &ErrRequiredNotSet{st}
+			// Not enough information to determine the exact field. If we use extra
+			// CPU, we could determine the field only if the missing required field
+			// has a tag <= 64 and we check reqFields.
+			return &ErrRequiredNotSet{"{Unknown}"}
 		}
 	}
 	return err

+ 56 - 22
proto/encode.go

@@ -37,6 +37,7 @@ package proto
 
 import (
 	"errors"
+	"fmt"
 	"reflect"
 	"sort"
 )
@@ -46,12 +47,16 @@ import (
 // all been initialized. It is also the error returned if Unmarshal is
 // called with an encoded protocol buffer that does not include all the
 // required fields.
+//
+// When printed, ErrRequiredNotSet reports the first unset required field in a
+// message. If the field cannot be precisely determined, it is reported as
+// "{Unknown}".
 type ErrRequiredNotSet struct {
-	t reflect.Type
+	field string
 }
 
 func (e *ErrRequiredNotSet) Error() string {
-	return "proto: required fields not set in " + e.t.String()
+	return fmt.Sprintf("proto: required field %q not set", e.field)
 }
 
 var (
@@ -175,7 +180,8 @@ func Marshal(pb Message) ([]byte, error) {
 	}
 	p := NewBuffer(nil)
 	err := p.Marshal(pb)
-	if err != nil {
+	var state errorState
+	if err != nil && !state.shouldContinue(err, nil) {
 		return nil, err
 	}
 	return p.buf, err
@@ -274,6 +280,7 @@ func isNil(v reflect.Value) bool {
 
 // Encode a message struct.
 func (o *Buffer) enc_struct_message(p *Properties, base structPointer) error {
+	var state errorState
 	structp := structPointer_GetStructPointer(base, p.field)
 	if structPointer_IsNil(structp) {
 		return ErrNil
@@ -283,7 +290,7 @@ func (o *Buffer) enc_struct_message(p *Properties, base structPointer) error {
 	if p.isMarshaler {
 		m := structPointer_Interface(structp, p.stype).(Marshaler)
 		data, err := m.Marshal()
-		if err != nil {
+		if err != nil && !state.shouldContinue(err, nil) {
 			return err
 		}
 		o.buf = append(o.buf, p.tagcode...)
@@ -300,18 +307,19 @@ func (o *Buffer) enc_struct_message(p *Properties, base structPointer) error {
 
 	nbuf := o.buf
 	o.buf = obuf
-	if err != nil {
+	if err != nil && !state.shouldContinue(err, nil) {
 		o.buffree(nbuf)
 		return err
 	}
 	o.buf = append(o.buf, p.tagcode...)
 	o.EncodeRawBytes(nbuf)
 	o.buffree(nbuf)
-	return nil
+	return state.err
 }
 
 // Encode a group struct.
 func (o *Buffer) enc_struct_group(p *Properties, base structPointer) error {
+	var state errorState
 	b := structPointer_GetStructPointer(base, p.field)
 	if structPointer_IsNil(b) {
 		return ErrNil
@@ -319,11 +327,11 @@ func (o *Buffer) enc_struct_group(p *Properties, base structPointer) error {
 
 	o.EncodeVarint(uint64((p.Tag << 3) | WireStartGroup))
 	err := o.enc_struct(p.stype, p.sprop, b)
-	if err != nil {
+	if err != nil && !state.shouldContinue(err, nil) {
 		return err
 	}
 	o.EncodeVarint(uint64((p.Tag << 3) | WireEndGroup))
-	return nil
+	return state.err
 }
 
 // Encode a slice of bools ([]bool).
@@ -470,6 +478,7 @@ func (o *Buffer) enc_slice_string(p *Properties, base structPointer) error {
 
 // Encode a slice of message structs ([]*struct).
 func (o *Buffer) enc_slice_struct_message(p *Properties, base structPointer) error {
+	var state errorState
 	s := structPointer_StructPointerSlice(base, p.field)
 	l := s.Len()
 
@@ -483,7 +492,7 @@ func (o *Buffer) enc_slice_struct_message(p *Properties, base structPointer) err
 		if p.isMarshaler {
 			m := structPointer_Interface(structp, p.stype).(Marshaler)
 			data, err := m.Marshal()
-			if err != nil {
+			if err != nil && !state.shouldContinue(err, nil) {
 				return err
 			}
 			o.buf = append(o.buf, p.tagcode...)
@@ -498,7 +507,7 @@ func (o *Buffer) enc_slice_struct_message(p *Properties, base structPointer) err
 
 		nbuf := o.buf
 		o.buf = obuf
-		if err != nil {
+		if err != nil && !state.shouldContinue(err, nil) {
 			o.buffree(nbuf)
 			if err == ErrNil {
 				return ErrRepeatedHasNil
@@ -510,11 +519,12 @@ func (o *Buffer) enc_slice_struct_message(p *Properties, base structPointer) err
 
 		o.buffree(nbuf)
 	}
-	return nil
+	return state.err
 }
 
 // Encode a slice of group structs ([]*struct).
 func (o *Buffer) enc_slice_struct_group(p *Properties, base structPointer) error {
+	var state errorState
 	s := structPointer_StructPointerSlice(base, p.field)
 	l := s.Len()
 
@@ -528,7 +538,7 @@ func (o *Buffer) enc_slice_struct_group(p *Properties, base structPointer) error
 
 		err := o.enc_struct(p.stype, p.sprop, b)
 
-		if err != nil {
+		if err != nil && !state.shouldContinue(err, nil) {
 			if err == ErrNil {
 				return ErrRepeatedHasNil
 			}
@@ -537,7 +547,7 @@ func (o *Buffer) enc_slice_struct_group(p *Properties, base structPointer) error
 
 		o.EncodeVarint(uint64((p.Tag << 3) | WireEndGroup))
 	}
-	return nil
+	return state.err
 }
 
 // Encode an extension map.
@@ -569,7 +579,7 @@ func (o *Buffer) enc_map(p *Properties, base structPointer) error {
 
 // Encode a struct.
 func (o *Buffer) enc_struct(t reflect.Type, prop *StructProperties, base structPointer) error {
-	required := prop.reqCount
+	var state errorState
 	// Encode fields in tag order so that decoders may use optimizations
 	// that depend on the ordering.
 	// http://code.google.com/apis/protocolbuffers/docs/encoding.html#order
@@ -577,19 +587,15 @@ func (o *Buffer) enc_struct(t reflect.Type, prop *StructProperties, base structP
 		p := prop.Prop[i]
 		if p.enc != nil {
 			err := p.enc(o, p, base)
-			if err != nil {
+			if err != nil && !state.shouldContinue(err, p) {
 				if err != ErrNil {
 					return err
+				} else if p.Required && state.err == nil {
+					state.err = &ErrRequiredNotSet{p.Name}
 				}
-			} else if p.Required {
-				required--
 			}
 		}
 	}
-	// See if we encoded all required fields.
-	if required > 0 {
-		return &ErrRequiredNotSet{t}
-	}
 
 	// Add unrecognized fields at the end.
 	if prop.unrecField.IsValid() {
@@ -599,5 +605,33 @@ func (o *Buffer) enc_struct(t reflect.Type, prop *StructProperties, base structP
 		}
 	}
 
-	return nil
+	return state.err
+}
+
+// errorState maintains the first error that occurs and updates that error
+// with additional context.
+type errorState struct {
+	err error
+}
+
+// shouldContinue reports whether encoding should continue upon encountering the
+// given error. If the error is ErrRequiredNotSet, shouldContinue returns true
+// and, if this is the first appearance of that error, remembers it for future
+// reporting.
+//
+// If prop is not nil, it may update any error with additional context about the
+// field with the error.
+func (s *errorState) shouldContinue(err error, prop *Properties) bool {
+	// Ignore unset required fields.
+	reqNotSet, ok := err.(*ErrRequiredNotSet)
+	if !ok {
+		return false
+	}
+	if s.err == nil {
+		if prop != nil {
+			err = &ErrRequiredNotSet{prop.Name + "." + reqNotSet.field}
+		}
+		s.err = err
+	}
+	return true
 }