equal.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package proto
  5. import (
  6. "bytes"
  7. "math"
  8. "reflect"
  9. "google.golang.org/protobuf/internal/encoding/wire"
  10. pref "google.golang.org/protobuf/reflect/protoreflect"
  11. )
  12. // Equal reports whether two messages are equal.
  13. // If two messages marshal to the same bytes under deterministic serialization,
  14. // then Equal is guaranteed to report true.
  15. //
  16. // Two messages are equal if they belong to the same message descriptor,
  17. // have the same set of populated known and extension field values,
  18. // and the same set of unknown fields values.
  19. //
  20. // Scalar values are compared with the equivalent of the == operator in Go,
  21. // except bytes values which are compared using bytes.Equal and
  22. // floating point values which specially treat NaNs as equal.
  23. // Message values are compared by recursively calling Equal.
  24. // Lists are equal if each element value is also equal.
  25. // Maps are equal if they have the same set of keys, where the pair of values
  26. // for each key is also equal.
  27. func Equal(x, y Message) bool {
  28. return equalMessage(x.ProtoReflect(), y.ProtoReflect())
  29. }
  30. // equalMessage compares two messages.
  31. func equalMessage(mx, my pref.Message) bool {
  32. if mx.Descriptor() != my.Descriptor() {
  33. return false
  34. }
  35. nx := 0
  36. equal := true
  37. mx.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool {
  38. nx++
  39. vy := my.Get(fd)
  40. equal = my.Has(fd) && equalField(fd, vx, vy)
  41. return equal
  42. })
  43. if !equal {
  44. return false
  45. }
  46. ny := 0
  47. my.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool {
  48. ny++
  49. return true
  50. })
  51. if nx != ny {
  52. return false
  53. }
  54. return equalUnknown(mx.GetUnknown(), my.GetUnknown())
  55. }
  56. // equalField compares two fields.
  57. func equalField(fd pref.FieldDescriptor, x, y pref.Value) bool {
  58. switch {
  59. case fd.IsList():
  60. return equalList(fd, x.List(), y.List())
  61. case fd.IsMap():
  62. return equalMap(fd, x.Map(), y.Map())
  63. default:
  64. return equalValue(fd, x, y)
  65. }
  66. }
  67. // equalMap compares two maps.
  68. func equalMap(fd pref.FieldDescriptor, x, y pref.Map) bool {
  69. if x.Len() != y.Len() {
  70. return false
  71. }
  72. equal := true
  73. x.Range(func(k pref.MapKey, vx pref.Value) bool {
  74. vy := y.Get(k)
  75. equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy)
  76. return equal
  77. })
  78. return equal
  79. }
  80. // equalList compares two lists.
  81. func equalList(fd pref.FieldDescriptor, x, y pref.List) bool {
  82. if x.Len() != y.Len() {
  83. return false
  84. }
  85. for i := x.Len() - 1; i >= 0; i-- {
  86. if !equalValue(fd, x.Get(i), y.Get(i)) {
  87. return false
  88. }
  89. }
  90. return true
  91. }
  92. // equalValue compares two singular values.
  93. func equalValue(fd pref.FieldDescriptor, x, y pref.Value) bool {
  94. switch {
  95. case fd.Message() != nil:
  96. return equalMessage(x.Message(), y.Message())
  97. case fd.Kind() == pref.BytesKind:
  98. return bytes.Equal(x.Bytes(), y.Bytes())
  99. case fd.Kind() == pref.FloatKind, fd.Kind() == pref.DoubleKind:
  100. fx := x.Float()
  101. fy := y.Float()
  102. if math.IsNaN(fx) || math.IsNaN(fy) {
  103. return math.IsNaN(fx) && math.IsNaN(fy)
  104. }
  105. return fx == fy
  106. default:
  107. return x.Interface() == y.Interface()
  108. }
  109. }
  110. // equalUnknown compares unknown fields by direct comparison on the raw bytes
  111. // of each individual field number.
  112. func equalUnknown(x, y pref.RawFields) bool {
  113. if len(x) != len(y) {
  114. return false
  115. }
  116. if bytes.Equal([]byte(x), []byte(y)) {
  117. return true
  118. }
  119. mx := make(map[pref.FieldNumber]pref.RawFields)
  120. my := make(map[pref.FieldNumber]pref.RawFields)
  121. for len(x) > 0 {
  122. fnum, _, n := wire.ConsumeField(x)
  123. mx[fnum] = append(mx[fnum], x[:n]...)
  124. x = x[n:]
  125. }
  126. for len(y) > 0 {
  127. fnum, _, n := wire.ConsumeField(y)
  128. my[fnum] = append(my[fnum], y[:n]...)
  129. y = y[n:]
  130. }
  131. return reflect.DeepEqual(mx, my)
  132. }