contentsecurityhandler_test.go 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. package handler
  2. import (
  3. "bytes"
  4. "crypto/sha256"
  5. "encoding/base64"
  6. "fmt"
  7. "io"
  8. "io/ioutil"
  9. "log"
  10. "net/http"
  11. "net/http/httptest"
  12. "net/url"
  13. "os"
  14. "strconv"
  15. "strings"
  16. "testing"
  17. "time"
  18. "github.com/stretchr/testify/assert"
  19. "github.com/tal-tech/go-zero/core/codec"
  20. "github.com/tal-tech/go-zero/rest/httpx"
  21. )
  22. const timeDiff = time.Hour * 2 * 24
  23. var (
  24. fingerprint = "12345"
  25. pubKey = []byte(`-----BEGIN PUBLIC KEY-----
  26. MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQD7bq4FLG0ctccbEFEsUBuRxkjE
  27. eJ5U+0CAEjJk20V9/u2Fu76i1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVH
  28. miYbRgh5Fy6336KepLCtCmV/r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwR
  29. my47YlhspwszKdRP+wIDAQAB
  30. -----END PUBLIC KEY-----`)
  31. priKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
  32. MIICXAIBAAKBgQD7bq4FLG0ctccbEFEsUBuRxkjEeJ5U+0CAEjJk20V9/u2Fu76i
  33. 1oKoShCs7GXtAFbDb5A/ImIXkPY62nAaxTGK4KVHmiYbRgh5Fy6336KepLCtCmV/
  34. r0PKZeCyJH9uYLs7EuE1z9Hgm5UUjmpHDhJtkAwRmy47YlhspwszKdRP+wIDAQAB
  35. AoGBANs1qf7UtuSbD1ZnKX5K8V5s07CHwPMygw+lzc3k5ndtNUStZQ2vnAaBXHyH
  36. Nm4lJ4AI2mhQ39jQB/1TyP1uAzvpLhT60fRybEq9zgJ/81Gm9bnaEpFJ9bP2bBrY
  37. J0jbaTMfbzL/PJFl3J3RGMR40C76h5yRYSnOpMoMiKWnJqrhAkEA/zCOkR+34Pk0
  38. Yo3sIP4ranY6AAvwacgNaui4ll5xeYwv3iLOQvPlpxIxFHKXEY0klNNyjjXqgYjP
  39. cOenqtt6UwJBAPw7EYuteVHvHvQVuTbKAaYHcOrp4nFeZF3ndFfl0w2dwGhfzcXO
  40. ROyd5dNQCuCWRo8JBpjG6PFyzezayF4KLrkCQCGditoxHG7FRRJKcbVy5dMzWbaR
  41. 3AyDLslLeK1OKZKCVffkC9mj+TeF3PM9mQrV1eDI7ckv7wE7PWA5E8wc90MCQEOV
  42. MCZU3OTvRUPxbicYCUkLRV4sPNhTimD+21WR5vMHCb7trJ0Ln7wmsqXkFIYIve8l
  43. Y/cblN7c/AAyvu0znUECQA318nPldsxR6+H8HTS3uEbkL4UJdjQJHsvTwKxAw5qc
  44. moKExvRlN0zmGGuArKcqS38KG7PXZMrUv3FXPdp6BDQ=
  45. -----END RSA PRIVATE KEY-----`)
  46. key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
  47. )
  48. type requestSettings struct {
  49. method string
  50. url string
  51. body io.Reader
  52. strict bool
  53. crypt bool
  54. requestUri string
  55. timestamp int64
  56. fingerprint string
  57. missHeader bool
  58. signature string
  59. }
  60. func init() {
  61. log.SetOutput(ioutil.Discard)
  62. }
  63. func TestContentSecurityHandler(t *testing.T) {
  64. tests := []struct {
  65. method string
  66. url string
  67. body string
  68. strict bool
  69. crypt bool
  70. requestUri string
  71. timestamp int64
  72. fingerprint string
  73. missHeader bool
  74. signature string
  75. statusCode int
  76. }{
  77. {
  78. method: http.MethodGet,
  79. url: "http://localhost/a/b?c=d&e=f",
  80. strict: true,
  81. crypt: false,
  82. },
  83. {
  84. method: http.MethodPost,
  85. url: "http://localhost/a/b?c=d&e=f",
  86. body: "hello",
  87. strict: true,
  88. crypt: false,
  89. },
  90. {
  91. method: http.MethodGet,
  92. url: "http://localhost/a/b?c=d&e=f",
  93. strict: true,
  94. crypt: true,
  95. },
  96. {
  97. method: http.MethodPost,
  98. url: "http://localhost/a/b?c=d&e=f",
  99. body: "hello",
  100. strict: true,
  101. crypt: true,
  102. },
  103. {
  104. method: http.MethodGet,
  105. url: "http://localhost/a/b?c=d&e=f",
  106. strict: true,
  107. crypt: true,
  108. timestamp: time.Now().Add(timeDiff).Unix(),
  109. statusCode: http.StatusUnauthorized,
  110. },
  111. {
  112. method: http.MethodPost,
  113. url: "http://localhost/a/b?c=d&e=f",
  114. body: "hello",
  115. strict: true,
  116. crypt: true,
  117. timestamp: time.Now().Add(-timeDiff).Unix(),
  118. statusCode: http.StatusUnauthorized,
  119. },
  120. {
  121. method: http.MethodPost,
  122. url: "http://remotehost/",
  123. body: "hello",
  124. strict: true,
  125. crypt: true,
  126. requestUri: "http://localhost/a/b?c=d&e=f",
  127. },
  128. {
  129. method: http.MethodPost,
  130. url: "http://localhost/a/b?c=d&e=f",
  131. body: "hello",
  132. strict: false,
  133. crypt: true,
  134. fingerprint: "badone",
  135. },
  136. {
  137. method: http.MethodPost,
  138. url: "http://localhost/a/b?c=d&e=f",
  139. body: "hello",
  140. strict: true,
  141. crypt: true,
  142. timestamp: time.Now().Add(-timeDiff).Unix(),
  143. fingerprint: "badone",
  144. statusCode: http.StatusUnauthorized,
  145. },
  146. {
  147. method: http.MethodPost,
  148. url: "http://localhost/a/b?c=d&e=f",
  149. body: "hello",
  150. strict: true,
  151. crypt: true,
  152. missHeader: true,
  153. statusCode: http.StatusUnauthorized,
  154. },
  155. {
  156. method: http.MethodHead,
  157. url: "http://localhost/a/b?c=d&e=f",
  158. strict: true,
  159. crypt: false,
  160. },
  161. {
  162. method: http.MethodGet,
  163. url: "http://localhost/a/b?c=d&e=f",
  164. strict: true,
  165. crypt: false,
  166. signature: "badone",
  167. statusCode: http.StatusUnauthorized,
  168. },
  169. }
  170. for _, test := range tests {
  171. t.Run(test.url, func(t *testing.T) {
  172. if test.statusCode == 0 {
  173. test.statusCode = http.StatusOK
  174. }
  175. if len(test.fingerprint) == 0 {
  176. test.fingerprint = fingerprint
  177. }
  178. if test.timestamp == 0 {
  179. test.timestamp = time.Now().Unix()
  180. }
  181. func() {
  182. keyFile, err := createTempFile(priKey)
  183. defer os.Remove(keyFile)
  184. assert.Nil(t, err)
  185. decrypter, err := codec.NewRsaDecrypter(keyFile)
  186. assert.Nil(t, err)
  187. contentSecurityHandler := ContentSecurityHandler(map[string]codec.RsaDecrypter{
  188. fingerprint: decrypter,
  189. }, time.Hour, test.strict)
  190. handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  191. }))
  192. var reader io.Reader
  193. if len(test.body) > 0 {
  194. reader = strings.NewReader(test.body)
  195. }
  196. setting := requestSettings{
  197. method: test.method,
  198. url: test.url,
  199. body: reader,
  200. strict: test.strict,
  201. crypt: test.crypt,
  202. requestUri: test.requestUri,
  203. timestamp: test.timestamp,
  204. fingerprint: test.fingerprint,
  205. missHeader: test.missHeader,
  206. signature: test.signature,
  207. }
  208. req, err := buildRequest(setting)
  209. assert.Nil(t, err)
  210. resp := httptest.NewRecorder()
  211. handler.ServeHTTP(resp, req)
  212. assert.Equal(t, test.statusCode, resp.Code)
  213. }()
  214. })
  215. }
  216. }
  217. func TestContentSecurityHandler_UnsignedCallback(t *testing.T) {
  218. keyFile, err := createTempFile(priKey)
  219. defer os.Remove(keyFile)
  220. assert.Nil(t, err)
  221. decrypter, err := codec.NewRsaDecrypter(keyFile)
  222. assert.Nil(t, err)
  223. contentSecurityHandler := ContentSecurityHandler(
  224. map[string]codec.RsaDecrypter{
  225. fingerprint: decrypter,
  226. },
  227. time.Hour,
  228. true,
  229. func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
  230. w.WriteHeader(http.StatusOK)
  231. })
  232. handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
  233. setting := requestSettings{
  234. method: http.MethodGet,
  235. url: "http://localhost/a/b?c=d&e=f",
  236. signature: "badone",
  237. }
  238. req, err := buildRequest(setting)
  239. assert.Nil(t, err)
  240. resp := httptest.NewRecorder()
  241. handler.ServeHTTP(resp, req)
  242. assert.Equal(t, http.StatusOK, resp.Code)
  243. }
  244. func TestContentSecurityHandler_UnsignedCallback_WrongTime(t *testing.T) {
  245. keyFile, err := createTempFile(priKey)
  246. defer os.Remove(keyFile)
  247. assert.Nil(t, err)
  248. decrypter, err := codec.NewRsaDecrypter(keyFile)
  249. assert.Nil(t, err)
  250. contentSecurityHandler := ContentSecurityHandler(
  251. map[string]codec.RsaDecrypter{
  252. fingerprint: decrypter,
  253. },
  254. time.Hour,
  255. true,
  256. func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) {
  257. assert.Equal(t, httpx.CodeSignatureWrongTime, code)
  258. w.WriteHeader(http.StatusOK)
  259. })
  260. handler := contentSecurityHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
  261. var reader io.Reader
  262. reader = strings.NewReader("hello")
  263. setting := requestSettings{
  264. method: http.MethodPost,
  265. url: "http://localhost/a/b?c=d&e=f",
  266. body: reader,
  267. strict: true,
  268. crypt: true,
  269. timestamp: time.Now().Add(time.Hour * 24 * 365).Unix(),
  270. fingerprint: fingerprint,
  271. }
  272. req, err := buildRequest(setting)
  273. assert.Nil(t, err)
  274. resp := httptest.NewRecorder()
  275. handler.ServeHTTP(resp, req)
  276. assert.Equal(t, http.StatusOK, resp.Code)
  277. }
  278. func buildRequest(rs requestSettings) (*http.Request, error) {
  279. var bodyStr string
  280. var err error
  281. if rs.crypt && rs.body != nil {
  282. var buf bytes.Buffer
  283. io.Copy(&buf, rs.body)
  284. bodyBytes, err := codec.EcbEncrypt(key, buf.Bytes())
  285. if err != nil {
  286. return nil, err
  287. }
  288. bodyStr = base64.StdEncoding.EncodeToString(bodyBytes)
  289. }
  290. r := httptest.NewRequest(rs.method, rs.url, strings.NewReader(bodyStr))
  291. if len(rs.signature) == 0 {
  292. sha := sha256.New()
  293. sha.Write([]byte(bodyStr))
  294. bodySign := fmt.Sprintf("%x", sha.Sum(nil))
  295. var path string
  296. var query string
  297. if len(rs.requestUri) > 0 {
  298. if u, err := url.Parse(rs.requestUri); err != nil {
  299. return nil, err
  300. } else {
  301. path = u.Path
  302. query = u.RawQuery
  303. }
  304. } else {
  305. path = r.URL.Path
  306. query = r.URL.RawQuery
  307. }
  308. contentOfSign := strings.Join([]string{
  309. strconv.FormatInt(rs.timestamp, 10),
  310. rs.method,
  311. path,
  312. query,
  313. bodySign,
  314. }, "\n")
  315. rs.signature = codec.HmacBase64([]byte(key), contentOfSign)
  316. }
  317. var mode string
  318. if rs.crypt {
  319. mode = "1"
  320. } else {
  321. mode = "0"
  322. }
  323. content := strings.Join([]string{
  324. "version=v1",
  325. "type=" + mode,
  326. fmt.Sprintf("key=%s", base64.StdEncoding.EncodeToString(key)),
  327. "time=" + strconv.FormatInt(rs.timestamp, 10),
  328. }, "; ")
  329. encrypter, err := codec.NewRsaEncrypter([]byte(pubKey))
  330. if err != nil {
  331. log.Fatal(err)
  332. }
  333. output, err := encrypter.Encrypt([]byte(content))
  334. if err != nil {
  335. log.Fatal(err)
  336. }
  337. encryptedContent := base64.StdEncoding.EncodeToString(output)
  338. if !rs.missHeader {
  339. r.Header.Set(httpx.ContentSecurity, strings.Join([]string{
  340. fmt.Sprintf("key=%s", rs.fingerprint),
  341. "secret=" + encryptedContent,
  342. "signature=" + rs.signature,
  343. }, "; "))
  344. }
  345. if len(rs.requestUri) > 0 {
  346. r.Header.Set("X-Request-Uri", rs.requestUri)
  347. }
  348. return r, nil
  349. }
  350. func createTempFile(body []byte) (string, error) {
  351. tmpFile, err := ioutil.TempFile(os.TempDir(), "go-unit-*.tmp")
  352. if err != nil {
  353. return "", err
  354. } else {
  355. tmpFile.Close()
  356. }
  357. err = ioutil.WriteFile(tmpFile.Name(), body, os.ModePerm)
  358. if err != nil {
  359. return "", err
  360. }
  361. return tmpFile.Name(), nil
  362. }