Просмотр исходного кода

Fix Merge (and Clone) for proto3.

Scalar fields should not be merged if they have the zero value.
For proto3, that means 0, "", bool, or a zero length []byte.
David Symonds 10 лет назад
Родитель
Сommit
ab974be44d
4 измененных файлов с 58 добавлено и 27 удалено
  1. 21 6
      proto/clone.go
  2. 18 0
      proto/clone_test.go
  3. 17 0
      proto/lib.go
  4. 2 21
      proto/text.go

+ 21 - 6
proto/clone.go

@@ -75,12 +75,13 @@ func Merge(dst, src Message) {
 }
 
 func mergeStruct(out, in reflect.Value) {
+	sprop := GetProperties(in.Type())
 	for i := 0; i < in.NumField(); i++ {
 		f := in.Type().Field(i)
 		if strings.HasPrefix(f.Name, "XXX_") {
 			continue
 		}
-		mergeAny(out.Field(i), in.Field(i))
+		mergeAny(out.Field(i), in.Field(i), false, sprop.Prop[i])
 	}
 
 	if emIn, ok := in.Addr().Interface().(extendableProto); ok {
@@ -98,7 +99,10 @@ func mergeStruct(out, in reflect.Value) {
 	}
 }
 
-func mergeAny(out, in reflect.Value) {
+// mergeAny performs a merge between two values of the same type.
+// viaPtr indicates whether the values were indirected through a pointer (implying proto2).
+// prop is set if this is a struct field (it may be nil).
+func mergeAny(out, in reflect.Value, viaPtr bool, prop *Properties) {
 	if in.Type() == protoMessageType {
 		if !in.IsNil() {
 			if out.IsNil() {
@@ -112,6 +116,9 @@ func mergeAny(out, in reflect.Value) {
 	switch in.Kind() {
 	case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64,
 		reflect.String, reflect.Uint32, reflect.Uint64:
+		if !viaPtr && isProto3Zero(in) {
+			return
+		}
 		out.Set(in)
 	case reflect.Map:
 		if in.Len() == 0 {
@@ -127,7 +134,7 @@ func mergeAny(out, in reflect.Value) {
 			switch elemKind {
 			case reflect.Ptr:
 				val = reflect.New(in.Type().Elem().Elem())
-				mergeAny(val, in.MapIndex(key))
+				mergeAny(val, in.MapIndex(key), false, nil)
 			case reflect.Slice:
 				val = in.MapIndex(key)
 				val = reflect.ValueOf(append([]byte{}, val.Bytes()...))
@@ -143,13 +150,21 @@ func mergeAny(out, in reflect.Value) {
 		if out.IsNil() {
 			out.Set(reflect.New(in.Elem().Type()))
 		}
-		mergeAny(out.Elem(), in.Elem())
+		mergeAny(out.Elem(), in.Elem(), true, nil)
 	case reflect.Slice:
 		if in.IsNil() {
 			return
 		}
 		if in.Type().Elem().Kind() == reflect.Uint8 {
 			// []byte is a scalar bytes field, not a repeated field.
+
+			// Edge case: if this is in a proto3 message, a zero length
+			// bytes field is considered the zero value, and should not
+			// be merged.
+			if prop != nil && prop.proto3 && in.Len() == 0 {
+				return
+			}
+
 			// Make a deep copy.
 			// Append to []byte{} instead of []byte(nil) so that we never end up
 			// with a nil result.
@@ -167,7 +182,7 @@ func mergeAny(out, in reflect.Value) {
 		default:
 			for i := 0; i < n; i++ {
 				x := reflect.Indirect(reflect.New(in.Type().Elem()))
-				mergeAny(x, in.Index(i))
+				mergeAny(x, in.Index(i), false, nil)
 				out.Set(reflect.Append(out, x))
 			}
 		}
@@ -184,7 +199,7 @@ func mergeExtension(out, in map[int32]Extension) {
 		eOut := Extension{desc: eIn.desc}
 		if eIn.value != nil {
 			v := reflect.New(reflect.TypeOf(eIn.value)).Elem()
-			mergeAny(v, reflect.ValueOf(eIn.value))
+			mergeAny(v, reflect.ValueOf(eIn.value), false, nil)
 			eOut.value = v.Interface()
 		}
 		if eIn.enc != nil {

+ 18 - 0
proto/clone_test.go

@@ -36,6 +36,7 @@ import (
 
 	"github.com/golang/protobuf/proto"
 
+	proto3pb "github.com/golang/protobuf/proto/proto3_proto"
 	pb "github.com/golang/protobuf/proto/testdata"
 )
 
@@ -214,6 +215,23 @@ var mergeTests = []struct {
 			ByteMapping: map[bool][]byte{true: []byte("wowsa")},
 		},
 	},
+	// proto3 shouldn't merge zero values,
+	// in the same way that proto2 shouldn't merge nils.
+	{
+		src: &proto3pb.Message{
+			Name: "Aaron",
+			Data: []byte(""), // zero value, but not nil
+		},
+		dst: &proto3pb.Message{
+			HeightInCm: 176,
+			Data:       []byte("texas!"),
+		},
+		want: &proto3pb.Message{
+			Name:       "Aaron",
+			HeightInCm: 176,
+			Data:       []byte("texas!"),
+		},
+	},
 }
 
 func TestMerge(t *testing.T) {

+ 17 - 0
proto/lib.go

@@ -794,3 +794,20 @@ func (s mapKeys) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
 func (s mapKeys) Less(i, j int) bool {
 	return fmt.Sprint(s[i].Interface()) < fmt.Sprint(s[j].Interface())
 }
+
+// isProto3Zero reports whether v is a zero proto3 value.
+func isProto3Zero(v reflect.Value) bool {
+	switch v.Kind() {
+	case reflect.Bool:
+		return !v.Bool()
+	case reflect.Int32, reflect.Int64:
+		return v.Int() == 0
+	case reflect.Uint32, reflect.Uint64:
+		return v.Uint() == 0
+	case reflect.Float32, reflect.Float64:
+		return v.Float() == 0
+	case reflect.String:
+		return v.String() == ""
+	}
+	return false
+}

+ 2 - 21
proto/text.go

@@ -317,27 +317,8 @@ func writeStruct(w *textWriter, sv reflect.Value) error {
 		}
 		if fv.Kind() != reflect.Ptr && fv.Kind() != reflect.Slice {
 			// proto3 non-repeated scalar field; skip if zero value
-			switch fv.Kind() {
-			case reflect.Bool:
-				if !fv.Bool() {
-					continue
-				}
-			case reflect.Int32, reflect.Int64:
-				if fv.Int() == 0 {
-					continue
-				}
-			case reflect.Uint32, reflect.Uint64:
-				if fv.Uint() == 0 {
-					continue
-				}
-			case reflect.Float32, reflect.Float64:
-				if fv.Float() == 0 {
-					continue
-				}
-			case reflect.String:
-				if fv.String() == "" {
-					continue
-				}
+			if isProto3Zero(fv) {
+				continue
 			}
 		}