breakerhandler_test.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. package handler
  2. import (
  3. "fmt"
  4. "net/http"
  5. "net/http/httptest"
  6. "testing"
  7. "github.com/stretchr/testify/assert"
  8. "github.com/tal-tech/go-zero/core/logx"
  9. "github.com/tal-tech/go-zero/core/stat"
  10. )
  11. func init() {
  12. logx.Disable()
  13. stat.SetReporter(nil)
  14. }
  15. func TestBreakerHandlerAccept(t *testing.T) {
  16. metrics := stat.NewMetrics("unit-test")
  17. breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
  18. handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  19. w.Header().Set("X-Test", "test")
  20. _, err := w.Write([]byte("content"))
  21. assert.Nil(t, err)
  22. }))
  23. req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
  24. req.Header.Set("X-Test", "test")
  25. resp := httptest.NewRecorder()
  26. handler.ServeHTTP(resp, req)
  27. assert.Equal(t, http.StatusOK, resp.Code)
  28. assert.Equal(t, "test", resp.Header().Get("X-Test"))
  29. assert.Equal(t, "content", resp.Body.String())
  30. }
  31. func TestBreakerHandlerFail(t *testing.T) {
  32. metrics := stat.NewMetrics("unit-test")
  33. breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
  34. handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  35. w.WriteHeader(http.StatusBadGateway)
  36. }))
  37. req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
  38. resp := httptest.NewRecorder()
  39. handler.ServeHTTP(resp, req)
  40. assert.Equal(t, http.StatusBadGateway, resp.Code)
  41. }
  42. func TestBreakerHandler_4XX(t *testing.T) {
  43. metrics := stat.NewMetrics("unit-test")
  44. breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
  45. handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  46. w.WriteHeader(http.StatusBadRequest)
  47. }))
  48. for i := 0; i < 1000; i++ {
  49. req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
  50. resp := httptest.NewRecorder()
  51. handler.ServeHTTP(resp, req)
  52. }
  53. const tries = 100
  54. var pass int
  55. for i := 0; i < tries; i++ {
  56. req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
  57. resp := httptest.NewRecorder()
  58. handler.ServeHTTP(resp, req)
  59. if resp.Code == http.StatusBadRequest {
  60. pass++
  61. }
  62. }
  63. assert.Equal(t, tries, pass)
  64. }
  65. func TestBreakerHandlerReject(t *testing.T) {
  66. metrics := stat.NewMetrics("unit-test")
  67. breakerHandler := BreakerHandler(http.MethodGet, "/", metrics)
  68. handler := breakerHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  69. w.WriteHeader(http.StatusInternalServerError)
  70. }))
  71. for i := 0; i < 1000; i++ {
  72. req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
  73. resp := httptest.NewRecorder()
  74. handler.ServeHTTP(resp, req)
  75. }
  76. var drops int
  77. for i := 0; i < 100; i++ {
  78. req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
  79. resp := httptest.NewRecorder()
  80. handler.ServeHTTP(resp, req)
  81. if resp.Code == http.StatusServiceUnavailable {
  82. drops++
  83. }
  84. }
  85. assert.True(t, drops >= 80, fmt.Sprintf("expected to be greater than 80, but got %d", drops))
  86. }