123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- package security
- import (
- "crypto/sha256"
- "encoding/base64"
- "errors"
- "fmt"
- "io"
- "net/http"
- "net/url"
- "strconv"
- "strings"
- "time"
- "github.com/tal-tech/go-zero/core/codec"
- "github.com/tal-tech/go-zero/core/iox"
- "github.com/tal-tech/go-zero/core/logx"
- "github.com/tal-tech/go-zero/rest/httpx"
- )
- const (
- requestUriHeader = "X-Request-Uri"
- signatureField = "signature"
- timeField = "time"
- )
- var (
- // ErrInvalidContentType is an error that indicates invalid content type.
- ErrInvalidContentType = errors.New("invalid content type")
- // ErrInvalidHeader is an error that indicates invalid X-Content-Security header.
- ErrInvalidHeader = errors.New("invalid X-Content-Security header")
- // ErrInvalidKey is an error that indicates invalid key.
- ErrInvalidKey = errors.New("invalid key")
- // ErrInvalidPublicKey is an error that indicates invalid public key.
- ErrInvalidPublicKey = errors.New("invalid public key")
- // ErrInvalidSecret is an error that indicates invalid secret.
- ErrInvalidSecret = errors.New("invalid secret")
- )
- // A ContentSecurityHeader is a content security header.
- type ContentSecurityHeader struct {
- Key []byte
- Timestamp string
- ContentType int
- Signature string
- }
- // Encrypted checks if it's a crypted request.
- func (h *ContentSecurityHeader) Encrypted() bool {
- return h.ContentType == httpx.CryptionType
- }
- // ParseContentSecurity parses content security settings in give r.
- func ParseContentSecurity(decrypters map[string]codec.RsaDecrypter, r *http.Request) (
- *ContentSecurityHeader, error) {
- contentSecurity := r.Header.Get(httpx.ContentSecurity)
- attrs := httpx.ParseHeader(contentSecurity)
- fingerprint := attrs[httpx.KeyField]
- secret := attrs[httpx.SecretField]
- signature := attrs[signatureField]
- if len(fingerprint) == 0 || len(secret) == 0 || len(signature) == 0 {
- return nil, ErrInvalidHeader
- }
- decrypter, ok := decrypters[fingerprint]
- if !ok {
- return nil, ErrInvalidPublicKey
- }
- decryptedSecret, err := decrypter.DecryptBase64(secret)
- if err != nil {
- return nil, ErrInvalidSecret
- }
- attrs = httpx.ParseHeader(string(decryptedSecret))
- base64Key := attrs[httpx.KeyField]
- timestamp := attrs[timeField]
- contentType := attrs[httpx.TypeField]
- key, err := base64.StdEncoding.DecodeString(base64Key)
- if err != nil {
- return nil, ErrInvalidKey
- }
- cType, err := strconv.Atoi(contentType)
- if err != nil {
- return nil, ErrInvalidContentType
- }
- return &ContentSecurityHeader{
- Key: key,
- Timestamp: timestamp,
- ContentType: cType,
- Signature: signature,
- }, nil
- }
- // VerifySignature verifies the signature in given r.
- func VerifySignature(r *http.Request, securityHeader *ContentSecurityHeader, tolerance time.Duration) int {
- seconds, err := strconv.ParseInt(securityHeader.Timestamp, 10, 64)
- if err != nil {
- return httpx.CodeSignatureInvalidHeader
- }
- now := time.Now().Unix()
- toleranceSeconds := int64(tolerance.Seconds())
- if seconds+toleranceSeconds < now || now+toleranceSeconds < seconds {
- return httpx.CodeSignatureWrongTime
- }
- reqPath, reqQuery := getPathQuery(r)
- signContent := strings.Join([]string{
- securityHeader.Timestamp,
- r.Method,
- reqPath,
- reqQuery,
- computeBodySignature(r),
- }, "\n")
- actualSignature := codec.HmacBase64(securityHeader.Key, signContent)
- passed := securityHeader.Signature == actualSignature
- if !passed {
- logx.Infof("signature different, expect: %s, actual: %s",
- securityHeader.Signature, actualSignature)
- }
- if passed {
- return httpx.CodeSignaturePass
- }
- return httpx.CodeSignatureInvalidToken
- }
- func computeBodySignature(r *http.Request) string {
- var dup io.ReadCloser
- r.Body, dup = iox.DupReadCloser(r.Body)
- sha := sha256.New()
- io.Copy(sha, r.Body)
- r.Body = dup
- return fmt.Sprintf("%x", sha.Sum(nil))
- }
- func getPathQuery(r *http.Request) (string, string) {
- requestUri := r.Header.Get(requestUriHeader)
- if len(requestUri) == 0 {
- return r.URL.Path, r.URL.RawQuery
- }
- uri, err := url.Parse(requestUri)
- if err != nil {
- return r.URL.Path, r.URL.RawQuery
- }
- return uri.Path, uri.RawQuery
- }
|