responses_test.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package httpx
  2. import (
  3. "errors"
  4. "net/http"
  5. "strings"
  6. "testing"
  7. "github.com/stretchr/testify/assert"
  8. "github.com/tal-tech/go-zero/core/logx"
  9. )
  10. type message struct {
  11. Name string `json:"name"`
  12. }
  13. func init() {
  14. logx.Disable()
  15. }
  16. func TestError(t *testing.T) {
  17. const (
  18. body = "foo"
  19. wrappedBody = `"foo"`
  20. )
  21. tests := []struct {
  22. name string
  23. input string
  24. errorHandler func(error) (int, interface{})
  25. expectBody string
  26. expectCode int
  27. }{
  28. {
  29. name: "default error handler",
  30. input: body,
  31. expectBody: body,
  32. expectCode: http.StatusBadRequest,
  33. },
  34. {
  35. name: "customized error handler return string",
  36. input: body,
  37. errorHandler: func(err error) (int, interface{}) {
  38. return http.StatusForbidden, err.Error()
  39. },
  40. expectBody: wrappedBody,
  41. expectCode: http.StatusForbidden,
  42. },
  43. {
  44. name: "customized error handler return error",
  45. input: body,
  46. errorHandler: func(err error) (int, interface{}) {
  47. return http.StatusForbidden, err
  48. },
  49. expectBody: body,
  50. expectCode: http.StatusForbidden,
  51. },
  52. }
  53. for _, test := range tests {
  54. t.Run(test.name, func(t *testing.T) {
  55. w := tracedResponseWriter{
  56. headers: make(map[string][]string),
  57. }
  58. if test.errorHandler != nil {
  59. lock.RLock()
  60. prev := errorHandler
  61. lock.RUnlock()
  62. SetErrorHandler(test.errorHandler)
  63. defer func() {
  64. lock.Lock()
  65. errorHandler = prev
  66. lock.Unlock()
  67. }()
  68. }
  69. Error(&w, errors.New(test.input))
  70. assert.Equal(t, test.expectCode, w.code)
  71. assert.Equal(t, test.expectBody, strings.TrimSpace(w.builder.String()))
  72. })
  73. }
  74. }
  75. func TestOk(t *testing.T) {
  76. w := tracedResponseWriter{
  77. headers: make(map[string][]string),
  78. }
  79. Ok(&w)
  80. assert.Equal(t, http.StatusOK, w.code)
  81. }
  82. func TestOkJson(t *testing.T) {
  83. w := tracedResponseWriter{
  84. headers: make(map[string][]string),
  85. }
  86. msg := message{Name: "anyone"}
  87. OkJson(&w, msg)
  88. assert.Equal(t, http.StatusOK, w.code)
  89. assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String())
  90. }
  91. func TestWriteJsonTimeout(t *testing.T) {
  92. // only log it and ignore
  93. w := tracedResponseWriter{
  94. headers: make(map[string][]string),
  95. timeout: true,
  96. }
  97. msg := message{Name: "anyone"}
  98. WriteJson(&w, http.StatusOK, msg)
  99. assert.Equal(t, http.StatusOK, w.code)
  100. }
  101. func TestWriteJsonLessWritten(t *testing.T) {
  102. w := tracedResponseWriter{
  103. headers: make(map[string][]string),
  104. lessWritten: true,
  105. }
  106. msg := message{Name: "anyone"}
  107. WriteJson(&w, http.StatusOK, msg)
  108. assert.Equal(t, http.StatusOK, w.code)
  109. }
  110. type tracedResponseWriter struct {
  111. headers map[string][]string
  112. builder strings.Builder
  113. code int
  114. lessWritten bool
  115. timeout bool
  116. }
  117. func (w *tracedResponseWriter) Header() http.Header {
  118. return w.headers
  119. }
  120. func (w *tracedResponseWriter) Write(bytes []byte) (n int, err error) {
  121. if w.timeout {
  122. return 0, http.ErrHandlerTimeout
  123. }
  124. n, err = w.builder.Write(bytes)
  125. if w.lessWritten {
  126. n -= 1
  127. }
  128. return
  129. }
  130. func (w *tracedResponseWriter) WriteHeader(code int) {
  131. w.code = code
  132. }