periodlimit.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. package limit
  2. import (
  3. "errors"
  4. "strconv"
  5. "time"
  6. "git.i2edu.net/i2/go-zero/core/stores/redis"
  7. )
  8. const (
  9. // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
  10. periodScript = `local limit = tonumber(ARGV[1])
  11. local window = tonumber(ARGV[2])
  12. local current = redis.call("INCRBY", KEYS[1], 1)
  13. if current == 1 then
  14. redis.call("expire", KEYS[1], window)
  15. return 1
  16. elseif current < limit then
  17. return 1
  18. elseif current == limit then
  19. return 2
  20. else
  21. return 0
  22. end`
  23. zoneDiff = 3600 * 8 // GMT+8 for our services
  24. )
  25. const (
  26. // Unknown means not initialized state.
  27. Unknown = iota
  28. // Allowed means allowed state.
  29. Allowed
  30. // HitQuota means this request exactly hit the quota.
  31. HitQuota
  32. // OverQuota means passed the quota.
  33. OverQuota
  34. internalOverQuota = 0
  35. internalAllowed = 1
  36. internalHitQuota = 2
  37. )
  38. // ErrUnknownCode is an error that represents unknown status code.
  39. var ErrUnknownCode = errors.New("unknown status code")
  40. type (
  41. // PeriodOption defines the method to customize a PeriodLimit.
  42. PeriodOption func(l *PeriodLimit)
  43. // A PeriodLimit is used to limit requests during a period of time.
  44. PeriodLimit struct {
  45. period int
  46. quota int
  47. limitStore *redis.Redis
  48. keyPrefix string
  49. align bool
  50. }
  51. )
  52. // NewPeriodLimit returns a PeriodLimit with given parameters.
  53. func NewPeriodLimit(period, quota int, limitStore *redis.Redis, keyPrefix string,
  54. opts ...PeriodOption) *PeriodLimit {
  55. limiter := &PeriodLimit{
  56. period: period,
  57. quota: quota,
  58. limitStore: limitStore,
  59. keyPrefix: keyPrefix,
  60. }
  61. for _, opt := range opts {
  62. opt(limiter)
  63. }
  64. return limiter
  65. }
  66. // Take requests a permit, it returns the permit state.
  67. func (h *PeriodLimit) Take(key string) (int, error) {
  68. resp, err := h.limitStore.Eval(periodScript, []string{h.keyPrefix + key}, []string{
  69. strconv.Itoa(h.quota),
  70. strconv.Itoa(h.calcExpireSeconds()),
  71. })
  72. if err != nil {
  73. return Unknown, err
  74. }
  75. code, ok := resp.(int64)
  76. if !ok {
  77. return Unknown, ErrUnknownCode
  78. }
  79. switch code {
  80. case internalOverQuota:
  81. return OverQuota, nil
  82. case internalAllowed:
  83. return Allowed, nil
  84. case internalHitQuota:
  85. return HitQuota, nil
  86. default:
  87. return Unknown, ErrUnknownCode
  88. }
  89. }
  90. func (h *PeriodLimit) calcExpireSeconds() int {
  91. if h.align {
  92. unix := time.Now().Unix() + zoneDiff
  93. return h.period - int(unix%int64(h.period))
  94. }
  95. return h.period
  96. }
  97. // Align returns a func to customize a PeriodLimit with alignment.
  98. func Align() PeriodOption {
  99. return func(l *PeriodLimit) {
  100. l.align = true
  101. }
  102. }