responses_test.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. package httpx
  2. import (
  3. "errors"
  4. "net/http"
  5. "strings"
  6. "testing"
  7. "github.com/stretchr/testify/assert"
  8. "github.com/tal-tech/go-zero/core/logx"
  9. )
  10. type message struct {
  11. Name string `json:"name"`
  12. }
  13. func init() {
  14. logx.Disable()
  15. }
  16. func TestError(t *testing.T) {
  17. const body = "foo"
  18. w := tracedResponseWriter{
  19. headers: make(map[string][]string),
  20. }
  21. Error(&w, errors.New(body))
  22. assert.Equal(t, http.StatusBadRequest, w.code)
  23. assert.Equal(t, body, strings.TrimSpace(w.builder.String()))
  24. }
  25. func TestOk(t *testing.T) {
  26. w := tracedResponseWriter{
  27. headers: make(map[string][]string),
  28. }
  29. Ok(&w)
  30. assert.Equal(t, http.StatusOK, w.code)
  31. }
  32. func TestOkJson(t *testing.T) {
  33. w := tracedResponseWriter{
  34. headers: make(map[string][]string),
  35. }
  36. msg := message{Name: "anyone"}
  37. OkJson(&w, msg)
  38. assert.Equal(t, http.StatusOK, w.code)
  39. assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String())
  40. }
  41. func TestWriteJsonTimeout(t *testing.T) {
  42. // only log it and ignore
  43. w := tracedResponseWriter{
  44. headers: make(map[string][]string),
  45. timeout: true,
  46. }
  47. msg := message{Name: "anyone"}
  48. WriteJson(&w, http.StatusOK, msg)
  49. assert.Equal(t, http.StatusOK, w.code)
  50. }
  51. func TestWriteJsonLessWritten(t *testing.T) {
  52. w := tracedResponseWriter{
  53. headers: make(map[string][]string),
  54. lessWritten: true,
  55. }
  56. msg := message{Name: "anyone"}
  57. WriteJson(&w, http.StatusOK, msg)
  58. assert.Equal(t, http.StatusOK, w.code)
  59. }
  60. type tracedResponseWriter struct {
  61. headers map[string][]string
  62. builder strings.Builder
  63. code int
  64. lessWritten bool
  65. timeout bool
  66. }
  67. func (w *tracedResponseWriter) Header() http.Header {
  68. return w.headers
  69. }
  70. func (w *tracedResponseWriter) Write(bytes []byte) (n int, err error) {
  71. if w.timeout {
  72. return 0, http.ErrHandlerTimeout
  73. }
  74. n, err = w.builder.Write(bytes)
  75. if w.lessWritten {
  76. n -= 1
  77. }
  78. return
  79. }
  80. func (w *tracedResponseWriter) WriteHeader(code int) {
  81. w.code = code
  82. }