consistenthash.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. package hash
  2. import (
  3. "fmt"
  4. "sort"
  5. "strconv"
  6. "sync"
  7. "github.com/tal-tech/go-zero/core/lang"
  8. "github.com/tal-tech/go-zero/core/mapping"
  9. )
  10. const (
  11. TopWeight = 100
  12. minReplicas = 100
  13. prime = 16777619
  14. )
  15. type (
  16. HashFunc func(data []byte) uint64
  17. ConsistentHash struct {
  18. hashFunc HashFunc
  19. replicas int
  20. keys []uint64
  21. ring map[uint64][]interface{}
  22. nodes map[string]lang.PlaceholderType
  23. lock sync.RWMutex
  24. }
  25. )
  26. func NewConsistentHash() *ConsistentHash {
  27. return NewCustomConsistentHash(minReplicas, Hash)
  28. }
  29. func NewCustomConsistentHash(replicas int, fn HashFunc) *ConsistentHash {
  30. if replicas < minReplicas {
  31. replicas = minReplicas
  32. }
  33. if fn == nil {
  34. fn = Hash
  35. }
  36. return &ConsistentHash{
  37. hashFunc: fn,
  38. replicas: replicas,
  39. ring: make(map[uint64][]interface{}),
  40. nodes: make(map[string]lang.PlaceholderType),
  41. }
  42. }
  43. // Add adds the node with the number of h.replicas,
  44. // the later call will overwrite the replicas of the former calls.
  45. func (h *ConsistentHash) Add(node interface{}) {
  46. h.AddWithReplicas(node, h.replicas)
  47. }
  48. // AddWithReplicas adds the node with the number of replicas,
  49. // replicas will be truncated to h.replicas if it's larger than h.replicas,
  50. // the later call will overwrite the replicas of the former calls.
  51. func (h *ConsistentHash) AddWithReplicas(node interface{}, replicas int) {
  52. h.Remove(node)
  53. if replicas > h.replicas {
  54. replicas = h.replicas
  55. }
  56. nodeRepr := repr(node)
  57. h.lock.Lock()
  58. defer h.lock.Unlock()
  59. h.addNode(nodeRepr)
  60. for i := 0; i < replicas; i++ {
  61. hash := h.hashFunc([]byte(nodeRepr + strconv.Itoa(i)))
  62. h.keys = append(h.keys, hash)
  63. h.ring[hash] = append(h.ring[hash], node)
  64. }
  65. sort.Slice(h.keys, func(i int, j int) bool {
  66. return h.keys[i] < h.keys[j]
  67. })
  68. }
  69. // AddWithWeight adds the node with weight, the weight can be 1 to 100, indicates the percent,
  70. // the later call will overwrite the replicas of the former calls.
  71. func (h *ConsistentHash) AddWithWeight(node interface{}, weight int) {
  72. // don't need to make sure weight not larger than TopWeight,
  73. // because AddWithReplicas makes sure replicas cannot be larger than h.replicas
  74. replicas := h.replicas * weight / TopWeight
  75. h.AddWithReplicas(node, replicas)
  76. }
  77. func (h *ConsistentHash) Get(v interface{}) (interface{}, bool) {
  78. h.lock.RLock()
  79. defer h.lock.RUnlock()
  80. if len(h.ring) == 0 {
  81. return nil, false
  82. }
  83. hash := h.hashFunc([]byte(repr(v)))
  84. index := sort.Search(len(h.keys), func(i int) bool {
  85. return h.keys[i] >= hash
  86. }) % len(h.keys)
  87. nodes := h.ring[h.keys[index]]
  88. switch len(nodes) {
  89. case 0:
  90. return nil, false
  91. case 1:
  92. return nodes[0], true
  93. default:
  94. innerIndex := h.hashFunc([]byte(innerRepr(v)))
  95. pos := int(innerIndex % uint64(len(nodes)))
  96. return nodes[pos], true
  97. }
  98. }
  99. func (h *ConsistentHash) Remove(node interface{}) {
  100. nodeRepr := repr(node)
  101. h.lock.Lock()
  102. defer h.lock.Unlock()
  103. if !h.containsNode(nodeRepr) {
  104. return
  105. }
  106. for i := 0; i < h.replicas; i++ {
  107. hash := h.hashFunc([]byte(nodeRepr + strconv.Itoa(i)))
  108. index := sort.Search(len(h.keys), func(i int) bool {
  109. return h.keys[i] >= hash
  110. })
  111. if index < len(h.keys) {
  112. h.keys = append(h.keys[:index], h.keys[index+1:]...)
  113. }
  114. h.removeRingNode(hash, nodeRepr)
  115. }
  116. h.removeNode(nodeRepr)
  117. }
  118. func (h *ConsistentHash) removeRingNode(hash uint64, nodeRepr string) {
  119. if nodes, ok := h.ring[hash]; ok {
  120. newNodes := nodes[:0]
  121. for _, x := range nodes {
  122. if repr(x) != nodeRepr {
  123. newNodes = append(newNodes, x)
  124. }
  125. }
  126. if len(newNodes) > 0 {
  127. h.ring[hash] = newNodes
  128. } else {
  129. delete(h.ring, hash)
  130. }
  131. }
  132. }
  133. func (h *ConsistentHash) addNode(nodeRepr string) {
  134. h.nodes[nodeRepr] = lang.Placeholder
  135. }
  136. func (h *ConsistentHash) containsNode(nodeRepr string) bool {
  137. _, ok := h.nodes[nodeRepr]
  138. return ok
  139. }
  140. func (h *ConsistentHash) removeNode(nodeRepr string) {
  141. delete(h.nodes, nodeRepr)
  142. }
  143. func innerRepr(node interface{}) string {
  144. return fmt.Sprintf("%d:%v", prime, node)
  145. }
  146. func repr(node interface{}) string {
  147. return mapping.Repr(node)
  148. }