Selaa lähdekoodia

Fix proto.Equal handling of proto3 bytes fields.

proto3 specifies that non-message scalar fields don't have a "has" bit,
and so []byte{} and []byte(nil) are considered equivalent.
David Symonds 10 vuotta sitten
vanhempi
commit
4df416cb47
2 muutettua tiedostoa jossa 22 lisäystä ja 9 poistoa
  1. 19 9
      proto/equal.go
  2. 3 0
      proto/equal_test.go

+ 19 - 9
proto/equal.go

@@ -50,7 +50,9 @@ Equality is defined in this way:
     are equal, and extensions sets are equal.
     are equal, and extensions sets are equal.
   - Two set scalar fields are equal iff their values are equal.
   - Two set scalar fields are equal iff their values are equal.
     If the fields are of a floating-point type, remember that
     If the fields are of a floating-point type, remember that
-    NaN != x for all x, including NaN.
+    NaN != x for all x, including NaN. If the message is defined
+    in a proto3 .proto file, fields are not "set"; specifically,
+    zero length proto3 "bytes" fields are equal (nil == {}).
   - Two repeated fields are equal iff their lengths are the same,
   - Two repeated fields are equal iff their lengths are the same,
     and their corresponding elements are equal (a "bytes" field,
     and their corresponding elements are equal (a "bytes" field,
     although represented by []byte, is not a repeated field)
     although represented by []byte, is not a repeated field)
@@ -88,6 +90,7 @@ func Equal(a, b Message) bool {
 
 
 // v1 and v2 are known to have the same type.
 // v1 and v2 are known to have the same type.
 func equalStruct(v1, v2 reflect.Value) bool {
 func equalStruct(v1, v2 reflect.Value) bool {
+	sprop := GetProperties(v1.Type())
 	for i := 0; i < v1.NumField(); i++ {
 	for i := 0; i < v1.NumField(); i++ {
 		f := v1.Type().Field(i)
 		f := v1.Type().Field(i)
 		if strings.HasPrefix(f.Name, "XXX_") {
 		if strings.HasPrefix(f.Name, "XXX_") {
@@ -113,7 +116,7 @@ func equalStruct(v1, v2 reflect.Value) bool {
 			}
 			}
 			f1, f2 = f1.Elem(), f2.Elem()
 			f1, f2 = f1.Elem(), f2.Elem()
 		}
 		}
-		if !equalAny(f1, f2) {
+		if !equalAny(f1, f2, sprop.Prop[i]) {
 			return false
 			return false
 		}
 		}
 	}
 	}
@@ -140,7 +143,8 @@ func equalStruct(v1, v2 reflect.Value) bool {
 }
 }
 
 
 // v1 and v2 are known to have the same type.
 // v1 and v2 are known to have the same type.
-func equalAny(v1, v2 reflect.Value) bool {
+// prop may be nil.
+func equalAny(v1, v2 reflect.Value, prop *Properties) bool {
 	if v1.Type() == protoMessageType {
 	if v1.Type() == protoMessageType {
 		m1, _ := v1.Interface().(Message)
 		m1, _ := v1.Interface().(Message)
 		m2, _ := v2.Interface().(Message)
 		m2, _ := v2.Interface().(Message)
@@ -163,7 +167,7 @@ func equalAny(v1, v2 reflect.Value) bool {
 		if e1.Type() != e2.Type() {
 		if e1.Type() != e2.Type() {
 			return false
 			return false
 		}
 		}
-		return equalAny(e1, e2)
+		return equalAny(e1, e2, nil)
 	case reflect.Map:
 	case reflect.Map:
 		if v1.Len() != v2.Len() {
 		if v1.Len() != v2.Len() {
 			return false
 			return false
@@ -174,16 +178,22 @@ func equalAny(v1, v2 reflect.Value) bool {
 				// This key was not found in the second map.
 				// This key was not found in the second map.
 				return false
 				return false
 			}
 			}
-			if !equalAny(v1.MapIndex(key), val2) {
+			if !equalAny(v1.MapIndex(key), val2, nil) {
 				return false
 				return false
 			}
 			}
 		}
 		}
 		return true
 		return true
 	case reflect.Ptr:
 	case reflect.Ptr:
-		return equalAny(v1.Elem(), v2.Elem())
+		return equalAny(v1.Elem(), v2.Elem(), prop)
 	case reflect.Slice:
 	case reflect.Slice:
 		if v1.Type().Elem().Kind() == reflect.Uint8 {
 		if v1.Type().Elem().Kind() == reflect.Uint8 {
 			// short circuit: []byte
 			// short circuit: []byte
+
+			// Edge case: if this is in a proto3 message, a zero length
+			// bytes field is considered the zero value.
+			if prop != nil && prop.proto3 && v1.Len() == 0 && v2.Len() == 0 {
+				return true
+			}
 			if v1.IsNil() != v2.IsNil() {
 			if v1.IsNil() != v2.IsNil() {
 				return false
 				return false
 			}
 			}
@@ -194,7 +204,7 @@ func equalAny(v1, v2 reflect.Value) bool {
 			return false
 			return false
 		}
 		}
 		for i := 0; i < v1.Len(); i++ {
 		for i := 0; i < v1.Len(); i++ {
-			if !equalAny(v1.Index(i), v2.Index(i)) {
+			if !equalAny(v1.Index(i), v2.Index(i), prop) {
 				return false
 				return false
 			}
 			}
 		}
 		}
@@ -229,7 +239,7 @@ func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool {
 
 
 		if m1 != nil && m2 != nil {
 		if m1 != nil && m2 != nil {
 			// Both are unencoded.
 			// Both are unencoded.
-			if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2)) {
+			if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2), nil) {
 				return false
 				return false
 			}
 			}
 			continue
 			continue
@@ -257,7 +267,7 @@ func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool {
 			log.Printf("proto: badly encoded extension %d of %v: %v", extNum, base, err)
 			log.Printf("proto: badly encoded extension %d of %v: %v", extNum, base, err)
 			return false
 			return false
 		}
 		}
-		if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2)) {
+		if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2), nil) {
 			return false
 			return false
 		}
 		}
 	}
 	}

+ 3 - 0
proto/equal_test.go

@@ -35,6 +35,7 @@ import (
 	"testing"
 	"testing"
 
 
 	. "github.com/golang/protobuf/proto"
 	. "github.com/golang/protobuf/proto"
+	proto3pb "github.com/golang/protobuf/proto/proto3_proto"
 	pb "github.com/golang/protobuf/proto/testdata"
 	pb "github.com/golang/protobuf/proto/testdata"
 )
 )
 
 
@@ -131,6 +132,8 @@ var EqualTests = []struct {
 		&pb.MyMessage{RepBytes: [][]byte{[]byte("sham"), []byte("wow")}},
 		&pb.MyMessage{RepBytes: [][]byte{[]byte("sham"), []byte("wow")}},
 		true,
 		true,
 	},
 	},
+	// In proto3, []byte{} and []byte(nil) are equal.
+	{"proto3 bytes, empty vs nil", &proto3pb.Message{Data: []byte{}}, &proto3pb.Message{Data: nil}, true},
 
 
 	{"extension vs. no extension", messageWithoutExtension, messageWithExtension1a, false},
 	{"extension vs. no extension", messageWithoutExtension, messageWithExtension1a, false},
 	{"extension vs. same extension", messageWithExtension1a, messageWithExtension1b, true},
 	{"extension vs. same extension", messageWithExtension1a, messageWithExtension1b, true},