cryptionhandler_test.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. package handler
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "io/ioutil"
  6. "log"
  7. "net/http"
  8. "net/http/httptest"
  9. "testing"
  10. "github.com/stretchr/testify/assert"
  11. "github.com/tal-tech/go-zero/core/codec"
  12. )
  13. const (
  14. reqText = "ping"
  15. respText = "pong"
  16. )
  17. var aesKey = []byte(`PdSgVkYp3s6v9y$B&E)H+MbQeThWmZq4`)
  18. func init() {
  19. log.SetOutput(ioutil.Discard)
  20. }
  21. func TestCryptionHandlerGet(t *testing.T) {
  22. req := httptest.NewRequest(http.MethodGet, "/any", nil)
  23. handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  24. _, err := w.Write([]byte(respText))
  25. w.Header().Set("X-Test", "test")
  26. assert.Nil(t, err)
  27. }))
  28. recorder := httptest.NewRecorder()
  29. handler.ServeHTTP(recorder, req)
  30. expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
  31. assert.Nil(t, err)
  32. assert.Equal(t, http.StatusOK, recorder.Code)
  33. assert.Equal(t, "test", recorder.Header().Get("X-Test"))
  34. assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
  35. }
  36. func TestCryptionHandlerPost(t *testing.T) {
  37. var buf bytes.Buffer
  38. enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
  39. assert.Nil(t, err)
  40. buf.WriteString(base64.StdEncoding.EncodeToString(enc))
  41. req := httptest.NewRequest(http.MethodPost, "/any", &buf)
  42. handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  43. body, err := ioutil.ReadAll(r.Body)
  44. assert.Nil(t, err)
  45. assert.Equal(t, reqText, string(body))
  46. w.Write([]byte(respText))
  47. }))
  48. recorder := httptest.NewRecorder()
  49. handler.ServeHTTP(recorder, req)
  50. expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
  51. assert.Nil(t, err)
  52. assert.Equal(t, http.StatusOK, recorder.Code)
  53. assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
  54. }
  55. func TestCryptionHandlerPostBadEncryption(t *testing.T) {
  56. var buf bytes.Buffer
  57. enc, err := codec.EcbEncrypt(aesKey, []byte(reqText))
  58. assert.Nil(t, err)
  59. buf.Write(enc)
  60. req := httptest.NewRequest(http.MethodPost, "/any", &buf)
  61. handler := CryptionHandler(aesKey)(nil)
  62. recorder := httptest.NewRecorder()
  63. handler.ServeHTTP(recorder, req)
  64. assert.Equal(t, http.StatusBadRequest, recorder.Code)
  65. }
  66. func TestCryptionHandlerWriteHeader(t *testing.T) {
  67. req := httptest.NewRequest(http.MethodGet, "/any", nil)
  68. handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  69. w.WriteHeader(http.StatusServiceUnavailable)
  70. }))
  71. recorder := httptest.NewRecorder()
  72. handler.ServeHTTP(recorder, req)
  73. assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
  74. }
  75. func TestCryptionHandlerFlush(t *testing.T) {
  76. req := httptest.NewRequest(http.MethodGet, "/any", nil)
  77. handler := CryptionHandler(aesKey)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  78. w.Write([]byte(respText))
  79. flusher, ok := w.(http.Flusher)
  80. assert.True(t, ok)
  81. flusher.Flush()
  82. }))
  83. recorder := httptest.NewRecorder()
  84. handler.ServeHTTP(recorder, req)
  85. expect, err := codec.EcbEncrypt(aesKey, []byte(respText))
  86. assert.Nil(t, err)
  87. assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
  88. }
  89. func TestCryptionHandler_Hijack(t *testing.T) {
  90. resp := httptest.NewRecorder()
  91. writer := newCryptionResponseWriter(resp)
  92. assert.NotPanics(t, func() {
  93. writer.Hijack()
  94. })
  95. writer = newCryptionResponseWriter(mockedHijackable{resp})
  96. assert.NotPanics(t, func() {
  97. writer.Hijack()
  98. })
  99. }