tokenparser.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. package token
  2. import (
  3. "net/http"
  4. "sync"
  5. "sync/atomic"
  6. "time"
  7. "github.com/dgrijalva/jwt-go"
  8. "github.com/dgrijalva/jwt-go/request"
  9. "github.com/tal-tech/go-zero/core/timex"
  10. )
  11. const claimHistoryResetDuration = time.Hour * 24
  12. type (
  13. // ParseOption defines the method to customize a TokenParser.
  14. ParseOption func(parser *TokenParser)
  15. // A TokenParser is used to parse tokens.
  16. TokenParser struct {
  17. resetTime time.Duration
  18. resetDuration time.Duration
  19. history sync.Map
  20. }
  21. )
  22. // NewTokenParser returns a TokenParser.
  23. func NewTokenParser(opts ...ParseOption) *TokenParser {
  24. parser := &TokenParser{
  25. resetTime: timex.Now(),
  26. resetDuration: claimHistoryResetDuration,
  27. }
  28. for _, opt := range opts {
  29. opt(parser)
  30. }
  31. return parser
  32. }
  33. // ParseToken parses token from given r, with passed in secret and prevSecret.
  34. func (tp *TokenParser) ParseToken(r *http.Request, secret, prevSecret string) (*jwt.Token, error) {
  35. var token *jwt.Token
  36. var err error
  37. if len(prevSecret) > 0 {
  38. count := tp.loadCount(secret)
  39. prevCount := tp.loadCount(prevSecret)
  40. var first, second string
  41. if count > prevCount {
  42. first = secret
  43. second = prevSecret
  44. } else {
  45. first = prevSecret
  46. second = secret
  47. }
  48. token, err = tp.doParseToken(r, first)
  49. if err != nil {
  50. token, err = tp.doParseToken(r, second)
  51. if err != nil {
  52. return nil, err
  53. }
  54. tp.incrementCount(second)
  55. } else {
  56. tp.incrementCount(first)
  57. }
  58. } else {
  59. token, err = tp.doParseToken(r, secret)
  60. if err != nil {
  61. return nil, err
  62. }
  63. }
  64. return token, nil
  65. }
  66. func (tp *TokenParser) doParseToken(r *http.Request, secret string) (*jwt.Token, error) {
  67. return request.ParseFromRequest(r, request.AuthorizationHeaderExtractor,
  68. func(token *jwt.Token) (interface{}, error) {
  69. return []byte(secret), nil
  70. }, request.WithParser(newParser()))
  71. }
  72. func (tp *TokenParser) incrementCount(secret string) {
  73. now := timex.Now()
  74. if tp.resetTime+tp.resetDuration < now {
  75. tp.history.Range(func(key, value interface{}) bool {
  76. tp.history.Delete(key)
  77. return true
  78. })
  79. }
  80. value, ok := tp.history.Load(secret)
  81. if ok {
  82. atomic.AddUint64(value.(*uint64), 1)
  83. } else {
  84. var count uint64 = 1
  85. tp.history.Store(secret, &count)
  86. }
  87. }
  88. func (tp *TokenParser) loadCount(secret string) uint64 {
  89. value, ok := tp.history.Load(secret)
  90. if ok {
  91. return *value.(*uint64)
  92. }
  93. return 0
  94. }
  95. // WithResetDuration returns a func to customize a TokenParser with reset duration.
  96. func WithResetDuration(duration time.Duration) ParseOption {
  97. return func(parser *TokenParser) {
  98. parser.resetDuration = duration
  99. }
  100. }
  101. func newParser() *jwt.Parser {
  102. return &jwt.Parser{
  103. UseJSONNumber: true,
  104. }
  105. }