123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390 |
- package handler
- import (
- "bytes"
- "crypto/sha256"
- "encoding/base64"
- "fmt"
- "io"
- "io/ioutil"
- "log"
- "net/http"
- "net/http/httptest"
- "net/url"
- "os"
- "strconv"
- "strings"
- "testing"
- "time"
- "github.com/stretchr/testify/assert"
- "github.com/tal-tech/go-zero/core/codec"
- "github.com/tal-tech/go-zero/rest/httpx"
- )
- const timeDiff = time.Hour * 2 * 24
- var (
- fingerprint = "12345"
- pubKey = []byte(`-----BEGIN PUBLIC KEY-----
- MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQD7bq4FLG0ctccbEFEsUBuRxkjE
- eJ5U+0CAEjJk20V9/u2Fu76i1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVH
- miYbRgh5Fy6336KepLCtCmV/r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwR
- my47YlhspwszKdRP+wIDAQAB
- -----END PUBLIC KEY-----`)
- priKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
- MIICXAIBAAKBgQD7bq4FLG0ctccbEFEsUBuRxkjEeJ5U+0CAEjJk20V9/u2Fu76i
- 1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVHmiYbRgh5Fy6336KepLCtCmV/
- r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwRmy47YlhspwszKdRP+wIDAQAB
- AoGBANs1qf7UtuSbD1ZnKX5K8V5s07CHwPMygw+lzc3k5ndtNUStZQ2vnAaBXHyH
- Nm4lJ4AI2mhQ39jQB/1TyP1uAzvpLhT60fRybEq9zgJ/81Gm9bnaEpFJ9bP2bBrY
- J0jbaTMfbzL/PJFl3J3RGMR40C76h5yRYSnOpMoMiKWnJqrhAkEA/zCOkR+34Pk0
- Yo3sIP4ranY6AAvwacgNaui4ll5xeYwv3iLOQvPlpxIxFHKXEY0klNNyjjXqgYjP
- cOenqtt6UwJBAPw7EYuteVHvHvQVuTbKAaYHcOrp4nFeZF3ndFfl0w2dwGhfzcXO
- ROyd5dNQCuCWRo8JBpjG6PFyzezayF4KLrkCQCGditoxHG7FRRJKcbVy5dMzWbaR
- 3AyDLslLeK1OKZKCVffkC9mj+TeF3PM9mQrV1eDI7ckv7wE7PWA5E8wc90MCQEOV
- MCZU3OTvRUPxbicYCUkLRV4sPNhTimD+21WR5vMHCb7trJ0Ln7wmsqXkFIYIve8l
- Y/cblN7c/AAyvu0znUECQA318nPldsxR6+H8HTS3uEbkL4UJdjQJHsvTwKxAw5qc
- moKExvRlN0zmGGuArKcqS38KG7PXZMrUv3FXPdp6BDQ=
- -----END RSA PRIVATE KEY-----`)
- key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
- )
- type requestSettings struct {
- method string
- url string
- body io.Reader
- strict bool
- crypt bool
- requestUri string
- timestamp int64
- fingerprint string
- missHeader bool
- signature string
- }
- func init() {
- log.SetOutput(ioutil.Discard)
- }
- func TestContentSecurityHandler(t *testing.T) {
- tests := []struct {
- method string
- url string
- body string
- strict bool
- crypt bool
- requestUri string
- timestamp int64
- fingerprint string
- missHeader bool
- signature string
- statusCode int
- }{
- {
- method: http.MethodGet,
- url: "http://localhost/a/b?c=d&e=f",
- strict: true,
- crypt: false,
- },
- {
- method: http.MethodPost,
- url: "http://localhost/a/b?c=d&e=f",
- body: "hello",
- strict: true,
- crypt: false,
- },
- {
- method: http.MethodGet,
- url: "http://localhost/a/b?c=d&e=f",
- strict: true,
- crypt: true,
- },
- {
- method: http.MethodPost,
- url: "http://localhost/a/b?c=d&e=f",
- body: "hello",
- strict: true,
- crypt: true,
- },
- {
- method: http.MethodGet,
- url: "http://localhost/a/b?c=d&e=f",
- strict: true,
- crypt: true,
- timestamp: time.Now().Add(timeDiff).Unix(),
- statusCode: http.StatusUnauthorized,
- },
- {
- method: http.MethodPost,
- url: "http://localhost/a/b?c=d&e=f",
- body: "hello",
- strict: true,
- crypt: true,
- timestamp: time.Now().Add(-timeDiff).Unix(),
- statusCode: http.StatusUnauthorized,
- },
- {
- method: http.MethodPost,
- url: "http://remotehost/",
- body: "hello",
- strict: true,
- crypt: true,
- requestUri: "http://localhost/a/b?c=d&e=f",
- },
- {
- method: http.MethodPost,
- url: "http://localhost/a/b?c=d&e=f",
- body: "hello",
- strict: false,
- crypt: true,
- fingerprint: "badone",
- },
- {
- method: http.MethodPost,
- url: "http://localhost/a/b?c=d&e=f",
- body: "hello",
- strict: true,
- crypt: true,
- timestamp: time.Now().Add(-timeDiff).Unix(),
- fingerprint: "badone",
- statusCode: http.StatusUnauthorized,
- },
- {
- method: http.MethodPost,
- url: "http://localhost/a/b?c=d&e=f",
- body: "hello",
- strict: true,
- crypt: true,
- missHeader: true,
- statusCode: http.StatusUnauthorized,
- },
- {
- method: http.MethodHead,
- url: "http://localhost/a/b?c=d&e=f",
- strict: true,
- crypt: false,
- },
- {
- method: http.MethodGet,
- url: "http://localhost/a/b?c=d&e=f",
- strict: true,
- crypt: false,
- signature: "badone",
- statusCode: http.StatusUnauthorized,
- },
- }
- for _, test := range tests {
- t.Run(test.url, func(t *testing.T) {
- if test.statusCode == 0 {
- test.statusCode = http.StatusOK
- }
- if len(test.fingerprint) == 0 {
- test.fingerprint = fingerprint
- }
- if test.timestamp == 0 {
- test.timestamp = time.Now().Unix()
- }
- func() {
- keyFile, err := createTempFile(priKey)
- defer os.Remove(keyFile)
- assert.Nil(t, err)
- decrypter, err := codec.NewRsaDecrypter(keyFile)
- assert.Nil(t, err)
- contentSecurityHandler := ContentSecurityHandler(map[string]codec.RsaDecrypter{
- fingerprint: decrypter,
- }, time.Hour, test.strict)
- handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- }))
- var reader io.Reader
- if len(test.body) > 0 {
- reader = strings.NewReader(test.body)
- }
- setting := requestSettings{
- method: test.method,
- url: test.url,
- body: reader,
- strict: test.strict,
- crypt: test.crypt,
- requestUri: test.requestUri,
- timestamp: test.timestamp,
- fingerprint: test.fingerprint,
- missHeader: test.missHeader,
- signature: test.signature,
- }
- req, err := buildRequest(setting)
- assert.Nil(t, err)
- resp := httptest.NewRecorder()
- handler.ServeHTTP(resp, req)
- assert.Equal(t, test.statusCode, resp.Code)
- }()
- })
- }
- }
- func TestContentSecurityHandler_UnsignedCallback(t *testing.T) {
- keyFile, err := createTempFile(priKey)
- defer os.Remove(keyFile)
- assert.Nil(t, err)
- decrypter, err := codec.NewRsaDecrypter(keyFile)
- assert.Nil(t, err)
- contentSecurityHandler := ContentSecurityHandler(
- map[string]codec.RsaDecrypter{
- fingerprint: decrypter,
- },
- time.Hour,
- true,
- func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
- w.WriteHeader(http.StatusOK)
- })
- handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
- setting := requestSettings{
- method: http.MethodGet,
- url: "http://localhost/a/b?c=d&e=f",
- signature: "badone",
- }
- req, err := buildRequest(setting)
- assert.Nil(t, err)
- resp := httptest.NewRecorder()
- handler.ServeHTTP(resp, req)
- assert.Equal(t, http.StatusOK, resp.Code)
- }
- func TestContentSecurityHandler_UnsignedCallback_WrongTime(t *testing.T) {
- keyFile, err := createTempFile(priKey)
- defer os.Remove(keyFile)
- assert.Nil(t, err)
- decrypter, err := codec.NewRsaDecrypter(keyFile)
- assert.Nil(t, err)
- contentSecurityHandler := ContentSecurityHandler(
- map[string]codec.RsaDecrypter{
- fingerprint: decrypter,
- },
- time.Hour,
- true,
- func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
- assert.Equal(t, httpx.CodeSignatureWrongTime, code)
- w.WriteHeader(http.StatusOK)
- })
- handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
- var reader io.Reader
- reader = strings.NewReader("hello")
- setting := requestSettings{
- method: http.MethodPost,
- url: "http://localhost/a/b?c=d&e=f",
- body: reader,
- strict: true,
- crypt: true,
- timestamp: time.Now().Add(time.Hour * 24 * 365).Unix(),
- fingerprint: fingerprint,
- }
- req, err := buildRequest(setting)
- assert.Nil(t, err)
- resp := httptest.NewRecorder()
- handler.ServeHTTP(resp, req)
- assert.Equal(t, http.StatusOK, resp.Code)
- }
- func buildRequest(rs requestSettings) (*http.Request, error) {
- var bodyStr string
- var err error
- if rs.crypt && rs.body != nil {
- var buf bytes.Buffer
- io.Copy(&buf, rs.body)
- bodyBytes, err := codec.EcbEncrypt(key, buf.Bytes())
- if err != nil {
- return nil, err
- }
- bodyStr = base64.StdEncoding.EncodeToString(bodyBytes)
- }
- r := httptest.NewRequest(rs.method, rs.url, strings.NewReader(bodyStr))
- if len(rs.signature) == 0 {
- sha := sha256.New()
- sha.Write([]byte(bodyStr))
- bodySign := fmt.Sprintf("%x", sha.Sum(nil))
- var path string
- var query string
- if len(rs.requestUri) > 0 {
- if u, err := url.Parse(rs.requestUri); err != nil {
- return nil, err
- } else {
- path = u.Path
- query = u.RawQuery
- }
- } else {
- path = r.URL.Path
- query = r.URL.RawQuery
- }
- contentOfSign := strings.Join([]string{
- strconv.FormatInt(rs.timestamp, 10),
- rs.method,
- path,
- query,
- bodySign,
- }, "\n")
- rs.signature = codec.HmacBase64([]byte(key), contentOfSign)
- }
- var mode string
- if rs.crypt {
- mode = "1"
- } else {
- mode = "0"
- }
- content := strings.Join([]string{
- "version=v1",
- "type=" + mode,
- fmt.Sprintf("key=%s", base64.StdEncoding.EncodeToString(key)),
- "time=" + strconv.FormatInt(rs.timestamp, 10),
- }, "; ")
- encrypter, err := codec.NewRsaEncrypter([]byte(pubKey))
- if err != nil {
- log.Fatal(err)
- }
- output, err := encrypter.Encrypt([]byte(content))
- if err != nil {
- log.Fatal(err)
- }
- encryptedContent := base64.StdEncoding.EncodeToString(output)
- if !rs.missHeader {
- r.Header.Set(httpx.ContentSecurity, strings.Join([]string{
- fmt.Sprintf("key=%s", rs.fingerprint),
- "secret=" + encryptedContent,
- "signature=" + rs.signature,
- }, "; "))
- }
- if len(rs.requestUri) > 0 {
- r.Header.Set("X-Request-Uri", rs.requestUri)
- }
- return r, nil
- }
- func createTempFile(body []byte) (string, error) {
- tmpFile, err := ioutil.TempFile(os.TempDir(), "go-unit-*.tmp")
- if err != nil {
- return "", err
- } else {
- tmpFile.Close()
- }
- err = ioutil.WriteFile(tmpFile.Name(), body, os.ModePerm)
- if err != nil {
- return "", err
- }
- return tmpFile.Name(), nil
- }
|