contentsecurity.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. package security
  2. import (
  3. "crypto/sha256"
  4. "encoding/base64"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "net/url"
  10. "strconv"
  11. "strings"
  12. "time"
  13. "github.com/tal-tech/go-zero/core/codec"
  14. "github.com/tal-tech/go-zero/core/iox"
  15. "github.com/tal-tech/go-zero/core/logx"
  16. "github.com/tal-tech/go-zero/rest/httpx"
  17. )
  18. const (
  19. requestUriHeader = "X-Request-Uri"
  20. signatureField = "signature"
  21. timeField = "time"
  22. )
  23. var (
  24. // ErrInvalidContentType is an error that indicates invalid content type.
  25. ErrInvalidContentType = errors.New("invalid content type")
  26. // ErrInvalidHeader is an error that indicates invalid X-Content-Security header.
  27. ErrInvalidHeader = errors.New("invalid X-Content-Security header")
  28. // ErrInvalidKey is an error that indicates invalid key.
  29. ErrInvalidKey = errors.New("invalid key")
  30. // ErrInvalidPublicKey is an error that indicates invalid public key.
  31. ErrInvalidPublicKey = errors.New("invalid public key")
  32. // ErrInvalidSecret is an error that indicates invalid secret.
  33. ErrInvalidSecret = errors.New("invalid secret")
  34. )
  35. // A ContentSecurityHeader is a content security header.
  36. type ContentSecurityHeader struct {
  37. Key []byte
  38. Timestamp string
  39. ContentType int
  40. Signature string
  41. }
  42. // Encrypted checks if it's a crypted request.
  43. func (h *ContentSecurityHeader) Encrypted() bool {
  44. return h.ContentType == httpx.CryptionType
  45. }
  46. // ParseContentSecurity parses content security settings in give r.
  47. func ParseContentSecurity(decrypters map[string]codec.RsaDecrypter, r *http.Request) (
  48. *ContentSecurityHeader, error) {
  49. contentSecurity := r.Header.Get(httpx.ContentSecurity)
  50. attrs := httpx.ParseHeader(contentSecurity)
  51. fingerprint := attrs[httpx.KeyField]
  52. secret := attrs[httpx.SecretField]
  53. signature := attrs[signatureField]
  54. if len(fingerprint) == 0 || len(secret) == 0 || len(signature) == 0 {
  55. return nil, ErrInvalidHeader
  56. }
  57. decrypter, ok := decrypters[fingerprint]
  58. if !ok {
  59. return nil, ErrInvalidPublicKey
  60. }
  61. decryptedSecret, err := decrypter.DecryptBase64(secret)
  62. if err != nil {
  63. return nil, ErrInvalidSecret
  64. }
  65. attrs = httpx.ParseHeader(string(decryptedSecret))
  66. base64Key := attrs[httpx.KeyField]
  67. timestamp := attrs[timeField]
  68. contentType := attrs[httpx.TypeField]
  69. key, err := base64.StdEncoding.DecodeString(base64Key)
  70. if err != nil {
  71. return nil, ErrInvalidKey
  72. }
  73. cType, err := strconv.Atoi(contentType)
  74. if err != nil {
  75. return nil, ErrInvalidContentType
  76. }
  77. return &ContentSecurityHeader{
  78. Key: key,
  79. Timestamp: timestamp,
  80. ContentType: cType,
  81. Signature: signature,
  82. }, nil
  83. }
  84. // VerifySignature verifies the signature in given r.
  85. func VerifySignature(r *http.Request, securityHeader *ContentSecurityHeader, tolerance time.Duration) int {
  86. seconds, err := strconv.ParseInt(securityHeader.Timestamp, 10, 64)
  87. if err != nil {
  88. return httpx.CodeSignatureInvalidHeader
  89. }
  90. now := time.Now().Unix()
  91. toleranceSeconds := int64(tolerance.Seconds())
  92. if seconds+toleranceSeconds < now || now+toleranceSeconds < seconds {
  93. return httpx.CodeSignatureWrongTime
  94. }
  95. reqPath, reqQuery := getPathQuery(r)
  96. signContent := strings.Join([]string{
  97. securityHeader.Timestamp,
  98. r.Method,
  99. reqPath,
  100. reqQuery,
  101. computeBodySignature(r),
  102. }, "\n")
  103. actualSignature := codec.HmacBase64(securityHeader.Key, signContent)
  104. passed := securityHeader.Signature == actualSignature
  105. if !passed {
  106. logx.Infof("signature different, expect: %s, actual: %s",
  107. securityHeader.Signature, actualSignature)
  108. }
  109. if passed {
  110. return httpx.CodeSignaturePass
  111. }
  112. return httpx.CodeSignatureInvalidToken
  113. }
  114. func computeBodySignature(r *http.Request) string {
  115. var dup io.ReadCloser
  116. r.Body, dup = iox.DupReadCloser(r.Body)
  117. sha := sha256.New()
  118. io.Copy(sha, r.Body)
  119. r.Body = dup
  120. return fmt.Sprintf("%x", sha.Sum(nil))
  121. }
  122. func getPathQuery(r *http.Request) (string, string) {
  123. requestUri := r.Header.Get(requestUriHeader)
  124. if len(requestUri) == 0 {
  125. return r.URL.Path, r.URL.RawQuery
  126. }
  127. uri, err := url.Parse(requestUri)
  128. if err != nil {
  129. return r.URL.Path, r.URL.RawQuery
  130. }
  131. return uri.Path, uri.RawQuery
  132. }