timingwheel.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. package collection
  2. import (
  3. "container/list"
  4. "fmt"
  5. "time"
  6. "github.com/tal-tech/go-zero/core/lang"
  7. "github.com/tal-tech/go-zero/core/threading"
  8. "github.com/tal-tech/go-zero/core/timex"
  9. )
  10. const drainWorkers = 8
  11. type (
  12. Execute func(key, value interface{})
  13. TimingWheel struct {
  14. interval time.Duration
  15. ticker timex.Ticker
  16. slots []*list.List
  17. timers *SafeMap
  18. tickedPos int
  19. numSlots int
  20. execute Execute
  21. setChannel chan timingEntry
  22. moveChannel chan baseEntry
  23. removeChannel chan interface{}
  24. drainChannel chan func(key, value interface{})
  25. stopChannel chan lang.PlaceholderType
  26. }
  27. timingEntry struct {
  28. baseEntry
  29. value interface{}
  30. circle int
  31. diff int
  32. removed bool
  33. }
  34. baseEntry struct {
  35. delay time.Duration
  36. key interface{}
  37. }
  38. positionEntry struct {
  39. pos int
  40. item *timingEntry
  41. }
  42. timingTask struct {
  43. key interface{}
  44. value interface{}
  45. }
  46. )
  47. func NewTimingWheel(interval time.Duration, numSlots int, execute Execute) (*TimingWheel, error) {
  48. if interval <= 0 || numSlots <= 0 || execute == nil {
  49. return nil, fmt.Errorf("interval: %v, slots: %d, execute: %p", interval, numSlots, execute)
  50. }
  51. return newTimingWheelWithClock(interval, numSlots, execute, timex.NewTicker(interval))
  52. }
  53. func newTimingWheelWithClock(interval time.Duration, numSlots int, execute Execute, ticker timex.Ticker) (
  54. *TimingWheel, error) {
  55. tw := &TimingWheel{
  56. interval: interval,
  57. ticker: ticker,
  58. slots: make([]*list.List, numSlots),
  59. timers: NewSafeMap(),
  60. tickedPos: numSlots - 1, // at previous virtual circle
  61. execute: execute,
  62. numSlots: numSlots,
  63. setChannel: make(chan timingEntry),
  64. moveChannel: make(chan baseEntry),
  65. removeChannel: make(chan interface{}),
  66. drainChannel: make(chan func(key, value interface{})),
  67. stopChannel: make(chan lang.PlaceholderType),
  68. }
  69. tw.initSlots()
  70. go tw.run()
  71. return tw, nil
  72. }
  73. func (tw *TimingWheel) Drain(fn func(key, value interface{})) {
  74. tw.drainChannel <- fn
  75. }
  76. func (tw *TimingWheel) MoveTimer(key interface{}, delay time.Duration) {
  77. if delay <= 0 || key == nil {
  78. return
  79. }
  80. tw.moveChannel <- baseEntry{
  81. delay: delay,
  82. key: key,
  83. }
  84. }
  85. func (tw *TimingWheel) RemoveTimer(key interface{}) {
  86. if key == nil {
  87. return
  88. }
  89. tw.removeChannel <- key
  90. }
  91. func (tw *TimingWheel) SetTimer(key, value interface{}, delay time.Duration) {
  92. if delay <= 0 || key == nil {
  93. return
  94. }
  95. tw.setChannel <- timingEntry{
  96. baseEntry: baseEntry{
  97. delay: delay,
  98. key: key,
  99. },
  100. value: value,
  101. }
  102. }
  103. func (tw *TimingWheel) Stop() {
  104. close(tw.stopChannel)
  105. }
  106. func (tw *TimingWheel) drainAll(fn func(key, value interface{})) {
  107. runner := threading.NewTaskRunner(drainWorkers)
  108. for _, slot := range tw.slots {
  109. for e := slot.Front(); e != nil; {
  110. task := e.Value.(*timingEntry)
  111. next := e.Next()
  112. slot.Remove(e)
  113. e = next
  114. if !task.removed {
  115. runner.Schedule(func() {
  116. fn(task.key, task.value)
  117. })
  118. }
  119. }
  120. }
  121. }
  122. func (tw *TimingWheel) getPositionAndCircle(d time.Duration) (pos int, circle int) {
  123. steps := int(d / tw.interval)
  124. pos = (tw.tickedPos + steps) % tw.numSlots
  125. circle = (steps - 1) / tw.numSlots
  126. return
  127. }
  128. func (tw *TimingWheel) initSlots() {
  129. for i := 0; i < tw.numSlots; i++ {
  130. tw.slots[i] = list.New()
  131. }
  132. }
  133. func (tw *TimingWheel) moveTask(task baseEntry) {
  134. val, ok := tw.timers.Get(task.key)
  135. if !ok {
  136. return
  137. }
  138. timer := val.(*positionEntry)
  139. if task.delay < tw.interval {
  140. threading.GoSafe(func() {
  141. tw.execute(timer.item.key, timer.item.value)
  142. })
  143. return
  144. }
  145. pos, circle := tw.getPositionAndCircle(task.delay)
  146. if pos >= timer.pos {
  147. timer.item.circle = circle
  148. timer.item.diff = pos - timer.pos
  149. } else if circle > 0 {
  150. circle--
  151. timer.item.circle = circle
  152. timer.item.diff = tw.numSlots + pos - timer.pos
  153. } else {
  154. timer.item.removed = true
  155. newItem := &timingEntry{
  156. baseEntry: task,
  157. value: timer.item.value,
  158. }
  159. tw.slots[pos].PushBack(newItem)
  160. tw.setTimerPosition(pos, newItem)
  161. }
  162. }
  163. func (tw *TimingWheel) onTick() {
  164. tw.tickedPos = (tw.tickedPos + 1) % tw.numSlots
  165. l := tw.slots[tw.tickedPos]
  166. tw.scanAndRunTasks(l)
  167. }
  168. func (tw *TimingWheel) removeTask(key interface{}) {
  169. val, ok := tw.timers.Get(key)
  170. if !ok {
  171. return
  172. }
  173. timer := val.(*positionEntry)
  174. timer.item.removed = true
  175. }
  176. func (tw *TimingWheel) run() {
  177. for {
  178. select {
  179. case <-tw.ticker.Chan():
  180. tw.onTick()
  181. case task := <-tw.setChannel:
  182. tw.setTask(&task)
  183. case key := <-tw.removeChannel:
  184. tw.removeTask(key)
  185. case task := <-tw.moveChannel:
  186. tw.moveTask(task)
  187. case fn := <-tw.drainChannel:
  188. tw.drainAll(fn)
  189. case <-tw.stopChannel:
  190. tw.ticker.Stop()
  191. return
  192. }
  193. }
  194. }
  195. func (tw *TimingWheel) runTasks(tasks []timingTask) {
  196. if len(tasks) == 0 {
  197. return
  198. }
  199. go func() {
  200. for i := range tasks {
  201. threading.RunSafe(func() {
  202. tw.execute(tasks[i].key, tasks[i].value)
  203. })
  204. }
  205. }()
  206. }
  207. func (tw *TimingWheel) scanAndRunTasks(l *list.List) {
  208. var tasks []timingTask
  209. for e := l.Front(); e != nil; {
  210. task := e.Value.(*timingEntry)
  211. if task.removed {
  212. next := e.Next()
  213. l.Remove(e)
  214. tw.timers.Del(task.key)
  215. e = next
  216. continue
  217. } else if task.circle > 0 {
  218. task.circle--
  219. e = e.Next()
  220. continue
  221. } else if task.diff > 0 {
  222. next := e.Next()
  223. l.Remove(e)
  224. // (tw.tickedPos+task.diff)%tw.numSlots
  225. // cannot be the same value of tw.tickedPos
  226. pos := (tw.tickedPos + task.diff) % tw.numSlots
  227. tw.slots[pos].PushBack(task)
  228. tw.setTimerPosition(pos, task)
  229. task.diff = 0
  230. e = next
  231. continue
  232. }
  233. tasks = append(tasks, timingTask{
  234. key: task.key,
  235. value: task.value,
  236. })
  237. next := e.Next()
  238. l.Remove(e)
  239. tw.timers.Del(task.key)
  240. e = next
  241. }
  242. tw.runTasks(tasks)
  243. }
  244. func (tw *TimingWheel) setTask(task *timingEntry) {
  245. if task.delay < tw.interval {
  246. task.delay = tw.interval
  247. }
  248. if val, ok := tw.timers.Get(task.key); ok {
  249. entry := val.(*positionEntry)
  250. entry.item.value = task.value
  251. tw.moveTask(task.baseEntry)
  252. } else {
  253. pos, circle := tw.getPositionAndCircle(task.delay)
  254. task.circle = circle
  255. tw.slots[pos].PushBack(task)
  256. tw.setTimerPosition(pos, task)
  257. }
  258. }
  259. func (tw *TimingWheel) setTimerPosition(pos int, task *timingEntry) {
  260. if val, ok := tw.timers.Get(task.key); ok {
  261. timer := val.(*positionEntry)
  262. timer.pos = pos
  263. } else {
  264. tw.timers.Set(task.key, &positionEntry{
  265. pos: pos,
  266. item: task,
  267. })
  268. }
  269. }