tokenlimit.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. package limit
  2. import (
  3. "fmt"
  4. "strconv"
  5. "sync"
  6. "sync/atomic"
  7. "time"
  8. "github.com/tal-tech/go-zero/core/logx"
  9. "github.com/tal-tech/go-zero/core/stores/redis"
  10. xrate "golang.org/x/time/rate"
  11. )
  12. const (
  13. // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
  14. // KEYS[1] as tokens_key
  15. // KEYS[2] as timestamp_key
  16. script = `local rate = tonumber(ARGV[1])
  17. local capacity = tonumber(ARGV[2])
  18. local now = tonumber(ARGV[3])
  19. local requested = tonumber(ARGV[4])
  20. local fill_time = capacity/rate
  21. local ttl = math.floor(fill_time*2)
  22. local last_tokens = tonumber(redis.call("get", KEYS[1]))
  23. if last_tokens == nil then
  24. last_tokens = capacity
  25. end
  26. local last_refreshed = tonumber(redis.call("get", KEYS[2]))
  27. if last_refreshed == nil then
  28. last_refreshed = 0
  29. end
  30. local delta = math.max(0, now-last_refreshed)
  31. local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
  32. local allowed = filled_tokens >= requested
  33. local new_tokens = filled_tokens
  34. if allowed then
  35. new_tokens = filled_tokens - requested
  36. end
  37. redis.call("setex", KEYS[1], ttl, new_tokens)
  38. redis.call("setex", KEYS[2], ttl, now)
  39. return allowed`
  40. tokenFormat = "{%s}.tokens"
  41. timestampFormat = "{%s}.ts"
  42. pingInterval = time.Millisecond * 100
  43. )
  44. // A TokenLimiter controls how frequently events are allowed to happen with in one second.
  45. type TokenLimiter struct {
  46. rate int
  47. burst int
  48. store *redis.Redis
  49. tokenKey string
  50. timestampKey string
  51. rescueLock sync.Mutex
  52. redisAlive uint32
  53. rescueLimiter *xrate.Limiter
  54. monitorStarted bool
  55. }
  56. // NewTokenLimiter returns a new TokenLimiter that allows events up to rate and permits
  57. // bursts of at most burst tokens.
  58. func NewTokenLimiter(rate, burst int, store *redis.Redis, key string) *TokenLimiter {
  59. tokenKey := fmt.Sprintf(tokenFormat, key)
  60. timestampKey := fmt.Sprintf(timestampFormat, key)
  61. return &TokenLimiter{
  62. rate: rate,
  63. burst: burst,
  64. store: store,
  65. tokenKey: tokenKey,
  66. timestampKey: timestampKey,
  67. redisAlive: 1,
  68. rescueLimiter: xrate.NewLimiter(xrate.Every(time.Second/time.Duration(rate)), burst),
  69. }
  70. }
  71. // Allow is shorthand for AllowN(time.Now(), 1).
  72. func (lim *TokenLimiter) Allow() bool {
  73. return lim.AllowN(time.Now(), 1)
  74. }
  75. // AllowN reports whether n events may happen at time now.
  76. // Use this method if you intend to drop / skip events that exceed the rate rate.
  77. // Otherwise use Reserve or Wait.
  78. func (lim *TokenLimiter) AllowN(now time.Time, n int) bool {
  79. return lim.reserveN(now, n)
  80. }
  81. func (lim *TokenLimiter) reserveN(now time.Time, n int) bool {
  82. if atomic.LoadUint32(&lim.redisAlive) == 0 {
  83. return lim.rescueLimiter.AllowN(now, n)
  84. }
  85. resp, err := lim.store.Eval(
  86. script,
  87. []string{
  88. lim.tokenKey,
  89. lim.timestampKey,
  90. },
  91. []string{
  92. strconv.Itoa(lim.rate),
  93. strconv.Itoa(lim.burst),
  94. strconv.FormatInt(now.Unix(), 10),
  95. strconv.Itoa(n),
  96. })
  97. // redis allowed == false
  98. // Lua boolean false -> r Nil bulk reply
  99. if err == redis.Nil {
  100. return false
  101. } else if err != nil {
  102. logx.Errorf("fail to use rate limiter: %s, use in-process limiter for rescue", err)
  103. lim.startMonitor()
  104. return lim.rescueLimiter.AllowN(now, n)
  105. }
  106. code, ok := resp.(int64)
  107. if !ok {
  108. logx.Errorf("fail to eval redis script: %v, use in-process limiter for rescue", resp)
  109. lim.startMonitor()
  110. return lim.rescueLimiter.AllowN(now, n)
  111. }
  112. // redis allowed == true
  113. // Lua boolean true -> r integer reply with value of 1
  114. return code == 1
  115. }
  116. func (lim *TokenLimiter) startMonitor() {
  117. lim.rescueLock.Lock()
  118. defer lim.rescueLock.Unlock()
  119. if lim.monitorStarted {
  120. return
  121. }
  122. lim.monitorStarted = true
  123. atomic.StoreUint32(&lim.redisAlive, 0)
  124. go lim.waitForRedis()
  125. }
  126. func (lim *TokenLimiter) waitForRedis() {
  127. ticker := time.NewTicker(pingInterval)
  128. defer func() {
  129. ticker.Stop()
  130. lim.rescueLock.Lock()
  131. lim.monitorStarted = false
  132. lim.rescueLock.Unlock()
  133. }()
  134. for range ticker.C {
  135. if lim.store.Ping() {
  136. atomic.StoreUint32(&lim.redisAlive, 1)
  137. return
  138. }
  139. }
  140. }