|
|
@@ -37,6 +37,7 @@ package proto
|
|
|
import (
|
|
|
"bytes"
|
|
|
"log"
|
|
|
+ "os"
|
|
|
"reflect"
|
|
|
"strings"
|
|
|
)
|
|
|
@@ -59,7 +60,7 @@ Equality is defined in this way:
|
|
|
- Two unknown field sets are equal if their current
|
|
|
encoded state is equal. (TODO)
|
|
|
- Two extension sets are equal iff they have corresponding
|
|
|
- elements that are pairwise equal. (TODO)
|
|
|
+ elements that are pairwise equal.
|
|
|
- Every other combination of things are not equal.
|
|
|
|
|
|
The return value is undefined if a and b are not protocol buffers.
|
|
|
@@ -101,7 +102,14 @@ func equalStruct(v1, v2 reflect.Value) bool {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // TODO: Deal with XXX_unrecognized and XXX_extensions.
|
|
|
+ if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() {
|
|
|
+ em2 := v2.FieldByName("XXX_extensions")
|
|
|
+ if !equalExtensions(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // TODO: Deal with XXX_unrecognized.
|
|
|
|
|
|
return true
|
|
|
}
|
|
|
@@ -148,3 +156,56 @@ func equalAny(v1, v2 reflect.Value) bool {
|
|
|
log.Printf("proto: don't know how to compare %v", v1)
|
|
|
return false
|
|
|
}
|
|
|
+
|
|
|
+// base is the struct type that the extensions are based on.
|
|
|
+// em1 and em2 are extension maps.
|
|
|
+func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool {
|
|
|
+ if len(em1) != len(em2) {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ for extNum, e1 := range em1 {
|
|
|
+ e2, ok := em2[extNum]
|
|
|
+ if !ok {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ m1, m2 := e1.value, e2.value
|
|
|
+
|
|
|
+ if m1 != nil && m2 != nil {
|
|
|
+ // Both are unencoded.
|
|
|
+ if !Equal(m1, m2) {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ // At least one is encoded. To do a semantically correct comparison
|
|
|
+ // we need to unmarshal them first.
|
|
|
+ var desc *ExtensionDesc
|
|
|
+ if m := extensionMaps[base]; m != nil {
|
|
|
+ desc = m[extNum]
|
|
|
+ }
|
|
|
+ if desc == nil {
|
|
|
+ log.Printf("proto: don't know how to compare extension %d of %v", extNum, base)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ var err os.Error
|
|
|
+ if m1 == nil {
|
|
|
+ m1, err = decodeExtension(e1.enc, desc)
|
|
|
+ }
|
|
|
+ if m2 == nil && err == nil {
|
|
|
+ m2, err = decodeExtension(e2.enc, desc)
|
|
|
+ }
|
|
|
+ if err != nil {
|
|
|
+ // The encoded form is invalid.
|
|
|
+ log.Printf("proto: badly encoded extension %d of %v: %v", extNum, base, err)
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ if !Equal(m1, m2) {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return true
|
|
|
+}
|