equal.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package proto
  5. import (
  6. "bytes"
  7. "math"
  8. "reflect"
  9. "google.golang.org/protobuf/internal/encoding/wire"
  10. pref "google.golang.org/protobuf/reflect/protoreflect"
  11. )
  12. // Equal reports whether two messages are equal.
  13. // If two messages marshal to the same bytes under deterministic serialization,
  14. // then Equal is guaranteed to report true.
  15. //
  16. // Two messages are equal if they belong to the same message descriptor,
  17. // have the same set of populated known and extension field values,
  18. // and the same set of unknown fields values.
  19. //
  20. // Scalar values are compared with the equivalent of the == operator in Go,
  21. // except bytes values which are compared using bytes.Equal and
  22. // floating point values which specially treat NaNs as equal.
  23. // Message values are compared by recursively calling Equal.
  24. // Lists are equal if each element value is also equal.
  25. // Maps are equal if they have the same set of keys, where the pair of values
  26. // for each key is also equal.
  27. func Equal(x, y Message) bool {
  28. if x == nil || y == nil {
  29. return x == nil && y == nil
  30. }
  31. return equalMessage(x.ProtoReflect(), y.ProtoReflect())
  32. }
  33. // equalMessage compares two messages.
  34. func equalMessage(mx, my pref.Message) bool {
  35. if mx.Descriptor() != my.Descriptor() {
  36. return false
  37. }
  38. nx := 0
  39. equal := true
  40. mx.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool {
  41. nx++
  42. vy := my.Get(fd)
  43. equal = my.Has(fd) && equalField(fd, vx, vy)
  44. return equal
  45. })
  46. if !equal {
  47. return false
  48. }
  49. ny := 0
  50. my.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool {
  51. ny++
  52. return true
  53. })
  54. if nx != ny {
  55. return false
  56. }
  57. return equalUnknown(mx.GetUnknown(), my.GetUnknown())
  58. }
  59. // equalField compares two fields.
  60. func equalField(fd pref.FieldDescriptor, x, y pref.Value) bool {
  61. switch {
  62. case fd.IsList():
  63. return equalList(fd, x.List(), y.List())
  64. case fd.IsMap():
  65. return equalMap(fd, x.Map(), y.Map())
  66. default:
  67. return equalValue(fd, x, y)
  68. }
  69. }
  70. // equalMap compares two maps.
  71. func equalMap(fd pref.FieldDescriptor, x, y pref.Map) bool {
  72. if x.Len() != y.Len() {
  73. return false
  74. }
  75. equal := true
  76. x.Range(func(k pref.MapKey, vx pref.Value) bool {
  77. vy := y.Get(k)
  78. equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy)
  79. return equal
  80. })
  81. return equal
  82. }
  83. // equalList compares two lists.
  84. func equalList(fd pref.FieldDescriptor, x, y pref.List) bool {
  85. if x.Len() != y.Len() {
  86. return false
  87. }
  88. for i := x.Len() - 1; i >= 0; i-- {
  89. if !equalValue(fd, x.Get(i), y.Get(i)) {
  90. return false
  91. }
  92. }
  93. return true
  94. }
  95. // equalValue compares two singular values.
  96. func equalValue(fd pref.FieldDescriptor, x, y pref.Value) bool {
  97. switch {
  98. case fd.Message() != nil:
  99. return equalMessage(x.Message(), y.Message())
  100. case fd.Kind() == pref.BytesKind:
  101. return bytes.Equal(x.Bytes(), y.Bytes())
  102. case fd.Kind() == pref.FloatKind, fd.Kind() == pref.DoubleKind:
  103. fx := x.Float()
  104. fy := y.Float()
  105. if math.IsNaN(fx) || math.IsNaN(fy) {
  106. return math.IsNaN(fx) && math.IsNaN(fy)
  107. }
  108. return fx == fy
  109. default:
  110. return x.Interface() == y.Interface()
  111. }
  112. }
  113. // equalUnknown compares unknown fields by direct comparison on the raw bytes
  114. // of each individual field number.
  115. func equalUnknown(x, y pref.RawFields) bool {
  116. if len(x) != len(y) {
  117. return false
  118. }
  119. if bytes.Equal([]byte(x), []byte(y)) {
  120. return true
  121. }
  122. mx := make(map[pref.FieldNumber]pref.RawFields)
  123. my := make(map[pref.FieldNumber]pref.RawFields)
  124. for len(x) > 0 {
  125. fnum, _, n := wire.ConsumeField(x)
  126. mx[fnum] = append(mx[fnum], x[:n]...)
  127. x = x[n:]
  128. }
  129. for len(y) > 0 {
  130. fnum, _, n := wire.ConsumeField(y)
  131. my[fnum] = append(my[fnum], y[:n]...)
  132. y = y[n:]
  133. }
  134. return reflect.DeepEqual(mx, my)
  135. }