equal.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package proto
  5. import (
  6. "bytes"
  7. "reflect"
  8. "google.golang.org/protobuf/internal/encoding/wire"
  9. pref "google.golang.org/protobuf/reflect/protoreflect"
  10. )
  11. // Equal reports whether two messages are equal.
  12. //
  13. // Two messages are equal if they belong to the same message descriptor,
  14. // have the same set of populated known and extension field values,
  15. // and the same set of unknown fields values.
  16. //
  17. // Scalar values are compared with the equivalent of the == operator in Go,
  18. // except bytes values which are compared using bytes.Equal.
  19. // Note that this means that floating point NaNs are considered inequal.
  20. // Message values are compared by recursively calling Equal.
  21. // Lists are equal if each element value is also equal.
  22. // Maps are equal if they have the same set of keys, where the pair of values
  23. // for each key is also equal.
  24. func Equal(x, y Message) bool {
  25. return equalMessage(x.ProtoReflect(), y.ProtoReflect())
  26. }
  27. // equalMessage compares two messages.
  28. func equalMessage(mx, my pref.Message) bool {
  29. if mx.Descriptor() != my.Descriptor() {
  30. return false
  31. }
  32. nx := 0
  33. equal := true
  34. mx.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool {
  35. nx++
  36. vy := my.Get(fd)
  37. equal = my.Has(fd) && equalField(fd, vx, vy)
  38. return equal
  39. })
  40. if !equal {
  41. return false
  42. }
  43. ny := 0
  44. my.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool {
  45. ny++
  46. return true
  47. })
  48. if nx != ny {
  49. return false
  50. }
  51. return equalUnknown(mx.GetUnknown(), my.GetUnknown())
  52. }
  53. // equalField compares two fields.
  54. func equalField(fd pref.FieldDescriptor, x, y pref.Value) bool {
  55. switch {
  56. case fd.IsList():
  57. return equalList(fd, x.List(), y.List())
  58. case fd.IsMap():
  59. return equalMap(fd, x.Map(), y.Map())
  60. default:
  61. return equalValue(fd, x, y)
  62. }
  63. }
  64. // equalMap compares two maps.
  65. func equalMap(fd pref.FieldDescriptor, x, y pref.Map) bool {
  66. if x.Len() != y.Len() {
  67. return false
  68. }
  69. equal := true
  70. x.Range(func(k pref.MapKey, vx pref.Value) bool {
  71. vy := y.Get(k)
  72. equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy)
  73. return equal
  74. })
  75. return equal
  76. }
  77. // equalList compares two lists.
  78. func equalList(fd pref.FieldDescriptor, x, y pref.List) bool {
  79. if x.Len() != y.Len() {
  80. return false
  81. }
  82. for i := x.Len() - 1; i >= 0; i-- {
  83. if !equalValue(fd, x.Get(i), y.Get(i)) {
  84. return false
  85. }
  86. }
  87. return true
  88. }
  89. // equalValue compares two singular values.
  90. func equalValue(fd pref.FieldDescriptor, x, y pref.Value) bool {
  91. switch {
  92. case fd.Message() != nil:
  93. return equalMessage(x.Message(), y.Message())
  94. case fd.Kind() == pref.BytesKind:
  95. return bytes.Equal(x.Bytes(), y.Bytes())
  96. default:
  97. return x.Interface() == y.Interface()
  98. }
  99. }
  100. // equalUnknown compares unknown fields by direct comparison on the raw bytes
  101. // of each individual field number.
  102. func equalUnknown(x, y pref.RawFields) bool {
  103. if len(x) != len(y) {
  104. return false
  105. }
  106. if bytes.Equal([]byte(x), []byte(y)) {
  107. return true
  108. }
  109. mx := make(map[pref.FieldNumber]pref.RawFields)
  110. my := make(map[pref.FieldNumber]pref.RawFields)
  111. for len(x) > 0 {
  112. fnum, _, n := wire.ConsumeField(x)
  113. mx[fnum] = append(mx[fnum], x[:n]...)
  114. x = x[n:]
  115. }
  116. for len(y) > 0 {
  117. fnum, _, n := wire.ConsumeField(y)
  118. my[fnum] = append(my[fnum], y[:n]...)
  119. y = y[n:]
  120. }
  121. return reflect.DeepEqual(mx, my)
  122. }