periodlimit.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. package limit
  2. import (
  3. "errors"
  4. "strconv"
  5. "time"
  6. "github.com/tal-tech/go-zero/core/stores/redis"
  7. )
  8. const (
  9. // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
  10. periodScript = `local limit = tonumber(ARGV[1])
  11. local window = tonumber(ARGV[2])
  12. local current = redis.call("INCRBY", KEYS[1], 1)
  13. if current == 1 then
  14. redis.call("expire", KEYS[1], window)
  15. return 1
  16. elseif current < limit then
  17. return 1
  18. elseif current == limit then
  19. return 2
  20. else
  21. return 0
  22. end`
  23. zoneDiff = 3600 * 8 // GMT+8 for our services
  24. )
  25. const (
  26. Unknown = iota
  27. Allowed
  28. HitQuota
  29. OverQuota
  30. internalOverQuota = 0
  31. internalAllowed = 1
  32. internalHitQuota = 2
  33. )
  34. var ErrUnknownCode = errors.New("unknown status code")
  35. type (
  36. LimitOption func(l *PeriodLimit)
  37. PeriodLimit struct {
  38. period int
  39. quota int
  40. limitStore *redis.Redis
  41. keyPrefix string
  42. align bool
  43. }
  44. )
  45. func NewPeriodLimit(period, quota int, limitStore *redis.Redis, keyPrefix string,
  46. opts ...LimitOption) *PeriodLimit {
  47. limiter := &PeriodLimit{
  48. period: period,
  49. quota: quota,
  50. limitStore: limitStore,
  51. keyPrefix: keyPrefix,
  52. }
  53. for _, opt := range opts {
  54. opt(limiter)
  55. }
  56. return limiter
  57. }
  58. func (h *PeriodLimit) Take(key string) (int, error) {
  59. resp, err := h.limitStore.Eval(periodScript, []string{h.keyPrefix + key}, []string{
  60. strconv.Itoa(h.quota),
  61. strconv.Itoa(h.calcExpireSeconds()),
  62. })
  63. if err != nil {
  64. return Unknown, err
  65. }
  66. code, ok := resp.(int64)
  67. if !ok {
  68. return Unknown, ErrUnknownCode
  69. }
  70. switch code {
  71. case internalOverQuota:
  72. return OverQuota, nil
  73. case internalAllowed:
  74. return Allowed, nil
  75. case internalHitQuota:
  76. return HitQuota, nil
  77. default:
  78. return Unknown, ErrUnknownCode
  79. }
  80. }
  81. func (h *PeriodLimit) calcExpireSeconds() int {
  82. if h.align {
  83. unix := time.Now().Unix() + zoneDiff
  84. return h.period - int(unix%int64(h.period))
  85. } else {
  86. return h.period
  87. }
  88. }
  89. func Align() LimitOption {
  90. return func(l *PeriodLimit) {
  91. l.align = true
  92. }
  93. }