equal.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package proto
  5. import (
  6. "bytes"
  7. pref "github.com/golang/protobuf/v2/reflect/protoreflect"
  8. )
  9. // Equal returns true of two messages are equal.
  10. //
  11. // Two messages are equal if they have identical types and registered extension fields,
  12. // marshal to the same bytes under deterministic serialization,
  13. // and contain no floating point NaNs.
  14. func Equal(a, b Message) bool {
  15. return equalMessage(a.ProtoReflect(), b.ProtoReflect())
  16. }
  17. // equalMessage compares two messages.
  18. func equalMessage(a, b pref.Message) bool {
  19. mda, mdb := a.Type(), b.Type()
  20. if mda != mdb && mda.FullName() != mdb.FullName() {
  21. return false
  22. }
  23. // TODO: The v1 says that a nil message is not equal to an empty one.
  24. // Decide what to do about this when v1 wraps v2.
  25. knowna, knownb := a.KnownFields(), b.KnownFields()
  26. fields := mda.Fields()
  27. for i, flen := 0, fields.Len(); i < flen; i++ {
  28. fd := fields.Get(i)
  29. num := fd.Number()
  30. hasa, hasb := knowna.Has(num), knownb.Has(num)
  31. if !hasa && !hasb {
  32. continue
  33. }
  34. if hasa != hasb || !equalFields(fd, knowna.Get(num), knownb.Get(num)) {
  35. return false
  36. }
  37. }
  38. equal := true
  39. unknowna, unknownb := a.UnknownFields(), b.UnknownFields()
  40. ulen := unknowna.Len()
  41. if ulen != unknownb.Len() {
  42. return false
  43. }
  44. unknowna.Range(func(num pref.FieldNumber, ra pref.RawFields) bool {
  45. rb := unknownb.Get(num)
  46. if !bytes.Equal([]byte(ra), []byte(rb)) {
  47. equal = false
  48. return false
  49. }
  50. return true
  51. })
  52. if !equal {
  53. return false
  54. }
  55. // If the set of extension types is not identical for both messages, we report
  56. // a inequality.
  57. //
  58. // This requirement is stringent. Registering an extension type for a message
  59. // without setting a value for the extension will cause that message to compare
  60. // as inequal to the same message without the registration.
  61. //
  62. // TODO: Revisit this behavior after eager decoding of extensions is implemented.
  63. xtypesa, xtypesb := knowna.ExtensionTypes(), knownb.ExtensionTypes()
  64. if la, lb := xtypesa.Len(), xtypesb.Len(); la != lb {
  65. return false
  66. } else if la == 0 {
  67. return true
  68. }
  69. xtypesa.Range(func(xt pref.ExtensionType) bool {
  70. num := xt.Number()
  71. if xtypesb.ByNumber(num) != xt {
  72. equal = false
  73. return false
  74. }
  75. hasa, hasb := knowna.Has(num), knownb.Has(num)
  76. if !hasa && !hasb {
  77. return true
  78. }
  79. if hasa != hasb || !equalFields(xt, knowna.Get(num), knownb.Get(num)) {
  80. equal = false
  81. return false
  82. }
  83. return true
  84. })
  85. return equal
  86. }
  87. // equalFields compares two fields.
  88. func equalFields(fd pref.FieldDescriptor, a, b pref.Value) bool {
  89. switch {
  90. case fd.IsMap():
  91. return equalMap(fd, a.Map(), b.Map())
  92. case fd.Cardinality() == pref.Repeated:
  93. return equalList(fd, a.List(), b.List())
  94. default:
  95. return equalValue(fd, a, b)
  96. }
  97. }
  98. // equalMap compares a map field.
  99. func equalMap(fd pref.FieldDescriptor, a, b pref.Map) bool {
  100. fdv := fd.Message().Fields().ByNumber(2)
  101. alen := a.Len()
  102. if alen != b.Len() {
  103. return false
  104. }
  105. equal := true
  106. a.Range(func(k pref.MapKey, va pref.Value) bool {
  107. vb := b.Get(k)
  108. if !vb.IsValid() || !equalValue(fdv, va, vb) {
  109. equal = false
  110. return false
  111. }
  112. return true
  113. })
  114. return equal
  115. }
  116. // equalList compares a non-map repeated field.
  117. func equalList(fd pref.FieldDescriptor, a, b pref.List) bool {
  118. alen := a.Len()
  119. if alen != b.Len() {
  120. return false
  121. }
  122. for i := 0; i < alen; i++ {
  123. if !equalValue(fd, a.Get(i), b.Get(i)) {
  124. return false
  125. }
  126. }
  127. return true
  128. }
  129. // equalValue compares the scalar value type of a field.
  130. func equalValue(fd pref.FieldDescriptor, a, b pref.Value) bool {
  131. switch {
  132. case fd.Message() != nil:
  133. return equalMessage(a.Message(), b.Message())
  134. case fd.Kind() == pref.BytesKind:
  135. return bytes.Equal(a.Bytes(), b.Bytes())
  136. default:
  137. return a.Interface() == b.Interface()
  138. }
  139. }