contentsecurity.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. package security
  2. import (
  3. "crypto/sha256"
  4. "encoding/base64"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "net/url"
  10. "strconv"
  11. "strings"
  12. "time"
  13. "github.com/tal-tech/go-zero/core/codec"
  14. "github.com/tal-tech/go-zero/core/iox"
  15. "github.com/tal-tech/go-zero/core/logx"
  16. "github.com/tal-tech/go-zero/rest/httpx"
  17. )
  18. const (
  19. requestUriHeader = "X-Request-Uri"
  20. signatureField = "signature"
  21. timeField = "time"
  22. )
  23. var (
  24. ErrInvalidContentType = errors.New("invalid content type")
  25. ErrInvalidHeader = errors.New("invalid X-Content-Security header")
  26. ErrInvalidKey = errors.New("invalid key")
  27. ErrInvalidPublicKey = errors.New("invalid public key")
  28. ErrInvalidSecret = errors.New("invalid secret")
  29. )
  30. type ContentSecurityHeader struct {
  31. Key []byte
  32. Timestamp string
  33. ContentType int
  34. Signature string
  35. }
  36. func (h *ContentSecurityHeader) Encrypted() bool {
  37. return h.ContentType == httpx.CryptionType
  38. }
  39. func ParseContentSecurity(decrypters map[string]codec.RsaDecrypter, r *http.Request) (
  40. *ContentSecurityHeader, error) {
  41. contentSecurity := r.Header.Get(httpx.ContentSecurity)
  42. attrs := httpx.ParseHeader(contentSecurity)
  43. fingerprint := attrs[httpx.KeyField]
  44. secret := attrs[httpx.SecretField]
  45. signature := attrs[signatureField]
  46. if len(fingerprint) == 0 || len(secret) == 0 || len(signature) == 0 {
  47. return nil, ErrInvalidHeader
  48. }
  49. decrypter, ok := decrypters[fingerprint]
  50. if !ok {
  51. return nil, ErrInvalidPublicKey
  52. }
  53. decryptedSecret, err := decrypter.DecryptBase64(secret)
  54. if err != nil {
  55. return nil, ErrInvalidSecret
  56. }
  57. attrs = httpx.ParseHeader(string(decryptedSecret))
  58. base64Key := attrs[httpx.KeyField]
  59. timestamp := attrs[timeField]
  60. contentType := attrs[httpx.TypeField]
  61. key, err := base64.StdEncoding.DecodeString(base64Key)
  62. if err != nil {
  63. return nil, ErrInvalidKey
  64. }
  65. cType, err := strconv.Atoi(contentType)
  66. if err != nil {
  67. return nil, ErrInvalidContentType
  68. }
  69. return &ContentSecurityHeader{
  70. Key: key,
  71. Timestamp: timestamp,
  72. ContentType: cType,
  73. Signature: signature,
  74. }, nil
  75. }
  76. func VerifySignature(r *http.Request, securityHeader *ContentSecurityHeader, tolerance time.Duration) int {
  77. seconds, err := strconv.ParseInt(securityHeader.Timestamp, 10, 64)
  78. if err != nil {
  79. return httpx.CodeSignatureInvalidHeader
  80. }
  81. now := time.Now().Unix()
  82. toleranceSeconds := int64(tolerance.Seconds())
  83. if seconds+toleranceSeconds < now || now+toleranceSeconds < seconds {
  84. return httpx.CodeSignatureWrongTime
  85. }
  86. reqPath, reqQuery := getPathQuery(r)
  87. signContent := strings.Join([]string{
  88. securityHeader.Timestamp,
  89. r.Method,
  90. reqPath,
  91. reqQuery,
  92. computeBodySignature(r),
  93. }, "\n")
  94. actualSignature := codec.HmacBase64(securityHeader.Key, signContent)
  95. passed := securityHeader.Signature == actualSignature
  96. if !passed {
  97. logx.Infof("signature different, expect: %s, actual: %s",
  98. securityHeader.Signature, actualSignature)
  99. }
  100. if passed {
  101. return httpx.CodeSignaturePass
  102. }
  103. return httpx.CodeSignatureInvalidToken
  104. }
  105. func computeBodySignature(r *http.Request) string {
  106. var dup io.ReadCloser
  107. r.Body, dup = iox.DupReadCloser(r.Body)
  108. sha := sha256.New()
  109. io.Copy(sha, r.Body)
  110. r.Body = dup
  111. return fmt.Sprintf("%x", sha.Sum(nil))
  112. }
  113. func getPathQuery(r *http.Request) (string, string) {
  114. requestUri := r.Header.Get(requestUriHeader)
  115. if len(requestUri) == 0 {
  116. return r.URL.Path, r.URL.RawQuery
  117. }
  118. uri, err := url.Parse(requestUri)
  119. if err != nil {
  120. return r.URL.Path, r.URL.RawQuery
  121. }
  122. return uri.Path, uri.RawQuery
  123. }