Просмотр исходного кода

proto: In Size, don't double-count the tagcode for structs that implement Marshaler.

In (*Buffer).Marshal, don't ignore output returned alongside RequiredNotSetError.

Remove a redundant shouldContinue check in Marshal (we're going to return the error we were given no matter what, and there's no field name to add here).

Update the test to check for consistency among Marshal, (*Buffer).Marshal, and Size.

This resolves https://github.com/golang/protobuf/issues/236.
PiperOrigin-RevId: 134444296
bcmills 9 лет назад
Родитель
Сommit
df1d3ca07d
2 измененных файлов с 26 добавлено и 25 удалено
  1. 23 14
      proto/all_test.go
  2. 3 11
      proto/encode.go

+ 23 - 14
proto/all_test.go

@@ -420,7 +420,7 @@ func TestMarshalerEncoding(t *testing.T) {
 		name    string
 		m       Message
 		want    []byte
-		wantErr error
+		errType reflect.Type
 	}{
 		{
 			name: "Marshaler that fails",
@@ -428,9 +428,11 @@ func TestMarshalerEncoding(t *testing.T) {
 				err: errors.New("some marshal err"),
 				b:   []byte{5, 6, 7},
 			},
-			// Since there's an error, nothing should be written to buffer.
-			want:    nil,
-			wantErr: errors.New("some marshal err"),
+			// Since the Marshal method returned bytes, they should be written to the
+			// buffer.  (For efficiency, we assume that Marshal implementations are
+			// always correct w.r.t. RequiredNotSetError and output.)
+			want:    []byte{5, 6, 7},
+			errType: reflect.TypeOf(errors.New("some marshal err")),
 		},
 		{
 			name: "Marshaler that fails with RequiredNotSetError",
@@ -446,30 +448,37 @@ func TestMarshalerEncoding(t *testing.T) {
 				10, 3, // for &msgWithFakeMarshaler
 				5, 6, 7, // for &fakeMarshaler
 			},
-			wantErr: &RequiredNotSetError{},
+			errType: reflect.TypeOf(&RequiredNotSetError{}),
 		},
 		{
 			name: "Marshaler that succeeds",
 			m: &fakeMarshaler{
 				b: []byte{0, 1, 2, 3, 4, 127, 255},
 			},
-			want:    []byte{0, 1, 2, 3, 4, 127, 255},
-			wantErr: nil,
+			want: []byte{0, 1, 2, 3, 4, 127, 255},
 		},
 	}
 	for _, test := range tests {
 		b := NewBuffer(nil)
 		err := b.Marshal(test.m)
-		if _, ok := err.(*RequiredNotSetError); ok {
-			// We're not in package proto, so we can only assert the type in this case.
-			err = &RequiredNotSetError{}
-		}
-		if !reflect.DeepEqual(test.wantErr, err) {
-			t.Errorf("%s: got err %v wanted %v", test.name, err, test.wantErr)
+		if reflect.TypeOf(err) != test.errType {
+			t.Errorf("%s: got err %T(%v) wanted %T", test.name, err, err, test.errType)
 		}
 		if !reflect.DeepEqual(test.want, b.Bytes()) {
 			t.Errorf("%s: got bytes %v wanted %v", test.name, b.Bytes(), test.want)
 		}
+		if size := Size(test.m); size != len(b.Bytes()) {
+			t.Errorf("%s: Size(_) = %v, but marshaled to %v bytes", test.name, size, len(b.Bytes()))
+		}
+
+		m, mErr := Marshal(test.m)
+		if !bytes.Equal(b.Bytes(), m) {
+			t.Errorf("%s: Marshal returned %v, but (*Buffer).Marshal wrote %v", test.name, m, b.Bytes())
+		}
+		if !reflect.DeepEqual(err, mErr) {
+			t.Errorf("%s: Marshal err = %q, but (*Buffer).Marshal returned %q",
+				test.name, fmt.Sprint(mErr), fmt.Sprint(err))
+		}
 	}
 }
 
@@ -1302,7 +1311,7 @@ func TestEnum(t *testing.T) {
 // We don't care what the value actually is, just as long as it doesn't crash.
 func TestPrintingNilEnumFields(t *testing.T) {
 	pb := new(GoEnum)
-	fmt.Sprintf("%+v", pb)
+	_ = fmt.Sprintf("%+v", pb)
 }
 
 // Verify that absent required fields cause Marshal/Unmarshal to return errors.

+ 3 - 11
proto/encode.go

@@ -234,10 +234,6 @@ func Marshal(pb Message) ([]byte, error) {
 	}
 	p := NewBuffer(nil)
 	err := p.Marshal(pb)
-	var state errorState
-	if err != nil && !state.shouldContinue(err, nil) {
-		return nil, err
-	}
 	if p.buf == nil && err == nil {
 		// Return a non-nil slice on success.
 		return []byte{}, nil
@@ -266,11 +262,8 @@ func (p *Buffer) Marshal(pb Message) error {
 	// Can the object marshal itself?
 	if m, ok := pb.(Marshaler); ok {
 		data, err := m.Marshal()
-		if err != nil {
-			return err
-		}
 		p.buf = append(p.buf, data...)
-		return nil
+		return err
 	}
 
 	t, base, err := getbase(pb)
@@ -282,7 +275,7 @@ func (p *Buffer) Marshal(pb Message) error {
 	}
 
 	if collectStats {
-		stats.Encode++
+		(stats).Encode++ // Parens are to work around a goimports bug.
 	}
 
 	if len(p.buf) > maxMarshalSize {
@@ -309,7 +302,7 @@ func Size(pb Message) (n int) {
 	}
 
 	if collectStats {
-		stats.Size++
+		(stats).Size++ // Parens are to work around a goimports bug.
 	}
 
 	return
@@ -1014,7 +1007,6 @@ func size_slice_struct_message(p *Properties, base structPointer) (n int) {
 		if p.isMarshaler {
 			m := structPointer_Interface(structp, p.stype).(Marshaler)
 			data, _ := m.Marshal()
-			n += len(p.tagcode)
 			n += sizeRawBytes(data)
 			continue
 		}