tokenparser.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. package token
  2. import (
  3. "net/http"
  4. "sync"
  5. "sync/atomic"
  6. "time"
  7. "github.com/dgrijalva/jwt-go"
  8. "github.com/dgrijalva/jwt-go/request"
  9. "github.com/tal-tech/go-zero/core/timex"
  10. )
  11. const claimHistoryResetDuration = time.Hour * 24
  12. type (
  13. ParseOption func(parser *TokenParser)
  14. TokenParser struct {
  15. resetTime time.Duration
  16. resetDuration time.Duration
  17. history sync.Map
  18. }
  19. )
  20. func NewTokenParser(opts ...ParseOption) *TokenParser {
  21. parser := &TokenParser{
  22. resetTime: timex.Now(),
  23. resetDuration: claimHistoryResetDuration,
  24. }
  25. for _, opt := range opts {
  26. opt(parser)
  27. }
  28. return parser
  29. }
  30. func (tp *TokenParser) ParseToken(r *http.Request, secret, prevSecret string) (*jwt.Token, error) {
  31. var token *jwt.Token
  32. var err error
  33. if len(prevSecret) > 0 {
  34. count := tp.loadCount(secret)
  35. prevCount := tp.loadCount(prevSecret)
  36. var first, second string
  37. if count > prevCount {
  38. first = secret
  39. second = prevSecret
  40. } else {
  41. first = prevSecret
  42. second = secret
  43. }
  44. token, err = tp.doParseToken(r, first)
  45. if err != nil {
  46. token, err = tp.doParseToken(r, second)
  47. if err != nil {
  48. return nil, err
  49. }
  50. tp.incrementCount(second)
  51. } else {
  52. tp.incrementCount(first)
  53. }
  54. } else {
  55. token, err = tp.doParseToken(r, secret)
  56. if err != nil {
  57. return nil, err
  58. }
  59. }
  60. return token, nil
  61. }
  62. func (tp *TokenParser) doParseToken(r *http.Request, secret string) (*jwt.Token, error) {
  63. return request.ParseFromRequest(r, request.AuthorizationHeaderExtractor,
  64. func(token *jwt.Token) (interface{}, error) {
  65. return []byte(secret), nil
  66. }, request.WithParser(newParser()))
  67. }
  68. func (tp *TokenParser) incrementCount(secret string) {
  69. now := timex.Now()
  70. if tp.resetTime+tp.resetDuration < now {
  71. tp.history.Range(func(key, value interface{}) bool {
  72. tp.history.Delete(key)
  73. return true
  74. })
  75. }
  76. value, ok := tp.history.Load(secret)
  77. if ok {
  78. atomic.AddUint64(value.(*uint64), 1)
  79. } else {
  80. var count uint64 = 1
  81. tp.history.Store(secret, &count)
  82. }
  83. }
  84. func (tp *TokenParser) loadCount(secret string) uint64 {
  85. value, ok := tp.history.Load(secret)
  86. if ok {
  87. return *value.(*uint64)
  88. }
  89. return 0
  90. }
  91. func WithResetDuration(duration time.Duration) ParseOption {
  92. return func(parser *TokenParser) {
  93. parser.resetDuration = duration
  94. }
  95. }
  96. func newParser() *jwt.Parser {
  97. return &jwt.Parser{
  98. UseJSONNumber: true,
  99. }
  100. }