Browse Source

proto: fix equality to work with V1 generated format

When the new V2 generated extension format was introduced, we
mistakenly dropped support for comparing V1 generated extensions
for equality. Add that back.
matloob@google.com 9 years ago
parent
commit
9e6977f30c
4 changed files with 60 additions and 18 deletions
  1. 44 4
      proto/all_test.go
  2. 1 6
      proto/encode.go
  3. 11 0
      proto/equal.go
  4. 4 8
      proto/properties.go

+ 44 - 4
proto/all_test.go

@@ -1956,14 +1956,54 @@ func TestMapFieldRoundTrips(t *testing.T) {
 }
 
 func TestMapFieldWithNil(t *testing.T) {
-	m := &MessageWithMap{
+	m1 := &MessageWithMap{
 		MsgMapping: map[int64]*FloatingPoint{
 			1: nil,
 		},
 	}
-	b, err := Marshal(m)
-	if err == nil {
-		t.Fatalf("Marshal of bad map should have failed, got these bytes: %v", b)
+	b, err := Marshal(m1)
+	if err != nil {
+		t.Fatalf("Marshal: %v", err)
+	}
+	m2 := new(MessageWithMap)
+	if err := Unmarshal(b, m2); err != nil {
+		t.Fatalf("Unmarshal: %v, got these bytes: %v", err, b)
+	}
+	if v, ok := m2.MsgMapping[1]; !ok {
+		t.Error("msg_mapping[1] not present")
+	} else if v != nil {
+		t.Errorf("msg_mapping[1] not nil: %v", v)
+	}
+}
+
+func TestMapFieldWithNilBytes(t *testing.T) {
+	m1 := &MessageWithMap{
+		ByteMapping: map[bool][]byte{
+			false: []byte{},
+			true:  nil,
+		},
+	}
+	n := Size(m1)
+	b, err := Marshal(m1)
+	if err != nil {
+		t.Fatalf("Marshal: %v", err)
+	}
+	if n != len(b) {
+		t.Errorf("Size(m1) = %d; want len(Marshal(m1)) = %d", n, len(b))
+	}
+	m2 := new(MessageWithMap)
+	if err := Unmarshal(b, m2); err != nil {
+		t.Fatalf("Unmarshal: %v, got these bytes: %v", err, b)
+	}
+	if v, ok := m2.ByteMapping[false]; !ok {
+		t.Error("byte_mapping[false] not present")
+	} else if len(v) != 0 {
+		t.Errorf("byte_mapping[false] not empty: %#v", v)
+	}
+	if v, ok := m2.ByteMapping[true]; !ok {
+		t.Error("byte_mapping[true] not present")
+	} else if len(v) != 0 {
+		t.Errorf("byte_mapping[true] not empty: %#v", v)
 	}
 }
 

+ 1 - 6
proto/encode.go

@@ -1149,7 +1149,7 @@ func (o *Buffer) enc_new_map(p *Properties, base structPointer) error {
 		if err := p.mkeyprop.enc(o, p.mkeyprop, keybase); err != nil {
 			return err
 		}
-		if err := p.mvalprop.enc(o, p.mvalprop, valbase); err != nil {
+		if err := p.mvalprop.enc(o, p.mvalprop, valbase); err != nil && err != ErrNil {
 			return err
 		}
 		return nil
@@ -1159,11 +1159,6 @@ func (o *Buffer) enc_new_map(p *Properties, base structPointer) error {
 	for _, key := range v.MapKeys() {
 		val := v.MapIndex(key)
 
-		// The only illegal map entry values are nil message pointers.
-		if val.Kind() == reflect.Ptr && val.IsNil() {
-			return errors.New("proto: map has nil element")
-		}
-
 		keycopy.Set(key)
 		valcopy.Set(val)
 

+ 11 - 0
proto/equal.go

@@ -128,6 +128,13 @@ func equalStruct(v1, v2 reflect.Value) bool {
 		}
 	}
 
+	if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() {
+		em2 := v2.FieldByName("XXX_extensions")
+		if !equalExtMap(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) {
+			return false
+		}
+	}
+
 	uf := v1.FieldByName("XXX_unrecognized")
 	if !uf.IsValid() {
 		return true
@@ -227,6 +234,10 @@ func equalAny(v1, v2 reflect.Value, prop *Properties) bool {
 func equalExtensions(base reflect.Type, x1, x2 XXX_InternalExtensions) bool {
 	em1, _ := x1.extensionsRead()
 	em2, _ := x2.extensionsRead()
+	return equalExtMap(base, em1, em2)
+}
+
+func equalExtMap(base reflect.Type, em1, em2 map[int32]Extension) bool {
 	if len(em1) != len(em2) {
 		return false
 	}

+ 4 - 8
proto/properties.go

@@ -473,17 +473,13 @@ func (p *Properties) setEncAndDec(typ reflect.Type, f *reflect.StructField, lock
 			p.dec = (*Buffer).dec_slice_int64
 			p.packedDec = (*Buffer).dec_slice_packed_int64
 		case reflect.Uint8:
-			p.enc = (*Buffer).enc_slice_byte
 			p.dec = (*Buffer).dec_slice_byte
-			p.size = size_slice_byte
-			// This is a []byte, which is either a bytes field,
-			// or the value of a map field. In the latter case,
-			// we always encode an empty []byte, so we should not
-			// use the proto3 enc/size funcs.
-			// f == nil iff this is the key/value of a map field.
-			if p.proto3 && f != nil {
+			if p.proto3 {
 				p.enc = (*Buffer).enc_proto3_slice_byte
 				p.size = size_proto3_slice_byte
+			} else {
+				p.enc = (*Buffer).enc_slice_byte
+				p.size = size_slice_byte
 			}
 		case reflect.Float32, reflect.Float64:
 			switch t2.Bits() {