Browse Source

goprotobuf: Support extensions in proto.Equal.

R=r
CC=golang-dev
http://codereview.appspot.com/4950066
David Symonds 14 years ago
parent
commit
a4b61c0f37
3 changed files with 121 additions and 10 deletions
  1. 63 2
      proto/equal.go
  2. 40 0
      proto/equal_test.go
  3. 18 8
      proto/extensions.go

+ 63 - 2
proto/equal.go

@@ -37,6 +37,7 @@ package proto
 import (
 import (
 	"bytes"
 	"bytes"
 	"log"
 	"log"
+	"os"
 	"reflect"
 	"reflect"
 	"strings"
 	"strings"
 )
 )
@@ -59,7 +60,7 @@ Equality is defined in this way:
   - Two unknown field sets are equal if their current
   - Two unknown field sets are equal if their current
     encoded state is equal. (TODO)
     encoded state is equal. (TODO)
   - Two extension sets are equal iff they have corresponding
   - Two extension sets are equal iff they have corresponding
-    elements that are pairwise equal. (TODO)
+    elements that are pairwise equal.
   - Every other combination of things are not equal.
   - Every other combination of things are not equal.
 
 
 The return value is undefined if a and b are not protocol buffers.
 The return value is undefined if a and b are not protocol buffers.
@@ -101,7 +102,14 @@ func equalStruct(v1, v2 reflect.Value) bool {
 		}
 		}
 	}
 	}
 
 
-	// TODO: Deal with XXX_unrecognized and XXX_extensions.
+	if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() {
+		em2 := v2.FieldByName("XXX_extensions")
+		if !equalExtensions(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) {
+			return false
+		}
+	}
+
+	// TODO: Deal with XXX_unrecognized.
 
 
 	return true
 	return true
 }
 }
@@ -148,3 +156,56 @@ func equalAny(v1, v2 reflect.Value) bool {
 	log.Printf("proto: don't know how to compare %v", v1)
 	log.Printf("proto: don't know how to compare %v", v1)
 	return false
 	return false
 }
 }
+
+// base is the struct type that the extensions are based on.
+// em1 and em2 are extension maps.
+func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool {
+	if len(em1) != len(em2) {
+		return false
+	}
+
+	for extNum, e1 := range em1 {
+		e2, ok := em2[extNum]
+		if !ok {
+			return false
+		}
+
+		m1, m2 := e1.value, e2.value
+
+		if m1 != nil && m2 != nil {
+			// Both are unencoded.
+			if !Equal(m1, m2) {
+				return false
+			}
+			continue
+		}
+
+		// At least one is encoded. To do a semantically correct comparison
+		// we need to unmarshal them first.
+		var desc *ExtensionDesc
+		if m := extensionMaps[base]; m != nil {
+			desc = m[extNum]
+		}
+		if desc == nil {
+			log.Printf("proto: don't know how to compare extension %d of %v", extNum, base)
+			continue
+		}
+		var err os.Error
+		if m1 == nil {
+			m1, err = decodeExtension(e1.enc, desc)
+		}
+		if m2 == nil && err == nil {
+			m2, err = decodeExtension(e2.enc, desc)
+		}
+		if err != nil {
+			// The encoded form is invalid.
+			log.Printf("proto: badly encoded extension %d of %v: %v", extNum, base, err)
+			return false
+		}
+		if !Equal(m1, m2) {
+			return false
+		}
+	}
+
+	return true
+}

+ 40 - 0
proto/equal_test.go

@@ -32,12 +32,48 @@
 package proto_test
 package proto_test
 
 
 import (
 import (
+	"log"
 	"testing"
 	"testing"
 
 
 	. "goprotobuf.googlecode.com/hg/proto"
 	. "goprotobuf.googlecode.com/hg/proto"
 	pb "./testdata/_obj/test_proto"
 	pb "./testdata/_obj/test_proto"
 )
 )
 
 
+// Four identical base messages.
+// The init function adds extensions to some of them.
+var messageWithoutExtension = &pb.MyMessage{Count: Int32(7)}
+var messageWithExtension1a = &pb.MyMessage{Count: Int32(7)}
+var messageWithExtension1b = &pb.MyMessage{Count: Int32(7)}
+var messageWithExtension2 = &pb.MyMessage{Count: Int32(7)}
+
+func init() {
+	ext1 := &pb.Ext{Data: String("Kirk")}
+	ext2 := &pb.Ext{Data: String("Picard")}
+
+	// messageWithExtension1a has ext1, but never marshals it.
+	if err := SetExtension(messageWithExtension1a, pb.E_Ext_More, ext1); err != nil {
+		log.Panicf("SetExtension on 1a failed: %v", err)
+	}
+
+	// messageWithExtension1b is the unmarshaled form of messageWithExtension1a.
+	if err := SetExtension(messageWithExtension1b, pb.E_Ext_More, ext1); err != nil {
+		log.Panicf("SetExtension on 1b failed: %v", err)
+	}
+	buf, err := Marshal(messageWithExtension1b)
+	if err != nil {
+		log.Panicf("Marshal of 1b failed: %v", err)
+	}
+	messageWithExtension1b.Reset()
+	if err := Unmarshal(buf, messageWithExtension1b); err != nil {
+		log.Panicf("Unmarshal of 1b failed: %v", err)
+	}
+
+	// messageWithExtension2 has ext2.
+	if err := SetExtension(messageWithExtension2, pb.E_Ext_More, ext2); err != nil {
+		log.Panicf("SetExtension on 2 failed: %v", err)
+	}
+}
+
 var EqualTests = []struct {
 var EqualTests = []struct {
 	desc string
 	desc string
 	a, b interface{}
 	a, b interface{}
@@ -70,6 +106,10 @@ var EqualTests = []struct {
 		&pb.GoTest{RequiredField: &pb.GoTestField{Label: String("wow")}},
 		&pb.GoTest{RequiredField: &pb.GoTestField{Label: String("wow")}},
 		true,
 		true,
 	},
 	},
+
+	{"extension vs. no extension", messageWithoutExtension, messageWithExtension1a, false},
+	{"extension vs. same extension", messageWithExtension1a, messageWithExtension1b, true},
+	{"extension vs. different extension", messageWithExtension1a, messageWithExtension2, false},
 }
 }
 
 
 func TestEqual(t *testing.T) {
 func TestEqual(t *testing.T) {

+ 18 - 8
proto/extensions.go

@@ -176,9 +176,24 @@ func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, os
 		return e.value, nil
 		return e.value, nil
 	}
 	}
 
 
+	v, err := decodeExtension(e.enc, extension)
+	if err != nil {
+		return nil, err
+	}
+
+	// Remember the decoded version and drop the encoded version.
+	// That way it is safe to mutate what we return.
+	e.value = v
+	e.desc = extension
+	e.enc = nil
+	return e.value, nil
+}
+
+// decodeExtension decodes an extension encoded in b.
+func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, os.Error) {
 	// Discard wire type and field number varint. It isn't needed.
 	// Discard wire type and field number varint. It isn't needed.
-	_, n := DecodeVarint(e.enc)
-	o := NewBuffer(e.enc[n:])
+	_, n := DecodeVarint(b)
+	o := NewBuffer(b[n:])
 
 
 	t := reflect.TypeOf(extension.ExtensionType)
 	t := reflect.TypeOf(extension.ExtensionType)
 	props := &Properties{}
 	props := &Properties{}
@@ -195,12 +210,7 @@ func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, os
 	if err := props.dec(o, props, uintptr(base), sbase); err != nil {
 	if err := props.dec(o, props, uintptr(base), sbase); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	// Remember the decoded version and drop the encoded version.
-	// That way it is safe to mutate what we return.
-	e.value = unsafe.Unreflect(t, base)
-	e.desc = extension
-	e.enc = nil
-	return e.value, nil
+	return unsafe.Unreflect(t, base), nil
 }
 }
 
 
 // GetExtensions returns a slice of the extensions present in pb that are also listed in es.
 // GetExtensions returns a slice of the extensions present in pb that are also listed in es.