123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- package limit
- import (
- "fmt"
- "strconv"
- "sync"
- "sync/atomic"
- "time"
- "github.com/tal-tech/go-zero/core/logx"
- "github.com/tal-tech/go-zero/core/stores/redis"
- xrate "golang.org/x/time/rate"
- )
- const (
- // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
- // KEYS[1] as tokens_key
- // KEYS[2] as timestamp_key
- script = `local rate = tonumber(ARGV[1])
- local capacity = tonumber(ARGV[2])
- local now = tonumber(ARGV[3])
- local requested = tonumber(ARGV[4])
- local fill_time = capacity/rate
- local ttl = math.floor(fill_time*2)
- local last_tokens = tonumber(redis.call("get", KEYS[1]))
- if last_tokens == nil then
- last_tokens = capacity
- end
- local last_refreshed = tonumber(redis.call("get", KEYS[2]))
- if last_refreshed == nil then
- last_refreshed = 0
- end
- local delta = math.max(0, now-last_refreshed)
- local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
- local allowed = filled_tokens >= requested
- local new_tokens = filled_tokens
- if allowed then
- new_tokens = filled_tokens - requested
- end
- redis.call("setex", KEYS[1], ttl, new_tokens)
- redis.call("setex", KEYS[2], ttl, now)
- return allowed`
- tokenFormat = "{%s}.tokens"
- timestampFormat = "{%s}.ts"
- pingInterval = time.Millisecond * 100
- )
- // A TokenLimiter controls how frequently events are allowed to happen with in one second.
- type TokenLimiter struct {
- rate int
- burst int
- store *redis.Redis
- tokenKey string
- timestampKey string
- rescueLock sync.Mutex
- redisAlive uint32
- rescueLimiter *xrate.Limiter
- monitorStarted bool
- }
- // NewTokenLimiter returns a new TokenLimiter that allows events up to rate and permits
- // bursts of at most burst tokens.
- func NewTokenLimiter(rate, burst int, store *redis.Redis, key string) *TokenLimiter {
- tokenKey := fmt.Sprintf(tokenFormat, key)
- timestampKey := fmt.Sprintf(timestampFormat, key)
- return &TokenLimiter{
- rate: rate,
- burst: burst,
- store: store,
- tokenKey: tokenKey,
- timestampKey: timestampKey,
- redisAlive: 1,
- rescueLimiter: xrate.NewLimiter(xrate.Every(time.Second/time.Duration(rate)), burst),
- }
- }
- // Allow is shorthand for AllowN(time.Now(), 1).
- func (lim *TokenLimiter) Allow() bool {
- return lim.AllowN(time.Now(), 1)
- }
- // AllowN reports whether n events may happen at time now.
- // Use this method if you intend to drop / skip events that exceed the rate rate.
- // Otherwise use Reserve or Wait.
- func (lim *TokenLimiter) AllowN(now time.Time, n int) bool {
- return lim.reserveN(now, n)
- }
- func (lim *TokenLimiter) reserveN(now time.Time, n int) bool {
- if atomic.LoadUint32(&lim.redisAlive) == 0 {
- return lim.rescueLimiter.AllowN(now, n)
- }
- resp, err := lim.store.Eval(
- script,
- []string{
- lim.tokenKey,
- lim.timestampKey,
- },
- []string{
- strconv.Itoa(lim.rate),
- strconv.Itoa(lim.burst),
- strconv.FormatInt(now.Unix(), 10),
- strconv.Itoa(n),
- })
- // redis allowed == false
- // Lua boolean false -> r Nil bulk reply
- if err == redis.Nil {
- return false
- } else if err != nil {
- logx.Errorf("fail to use rate limiter: %s, use in-process limiter for rescue", err)
- lim.startMonitor()
- return lim.rescueLimiter.AllowN(now, n)
- }
- code, ok := resp.(int64)
- if !ok {
- logx.Errorf("fail to eval redis script: %v, use in-process limiter for rescue", resp)
- lim.startMonitor()
- return lim.rescueLimiter.AllowN(now, n)
- }
- // redis allowed == true
- // Lua boolean true -> r integer reply with value of 1
- return code == 1
- }
- func (lim *TokenLimiter) startMonitor() {
- lim.rescueLock.Lock()
- defer lim.rescueLock.Unlock()
- if lim.monitorStarted {
- return
- }
- lim.monitorStarted = true
- atomic.StoreUint32(&lim.redisAlive, 0)
- go lim.waitForRedis()
- }
- func (lim *TokenLimiter) waitForRedis() {
- ticker := time.NewTicker(pingInterval)
- defer func() {
- ticker.Stop()
- lim.rescueLock.Lock()
- lim.monitorStarted = false
- lim.rescueLock.Unlock()
- }()
- for range ticker.C {
- if lim.store.Ping() {
- atomic.StoreUint32(&lim.redisAlive, 1)
- return
- }
- }
- }
|