maxconnshandler_test.go 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. package handler
  2. import (
  3. "io/ioutil"
  4. "log"
  5. "net/http"
  6. "net/http/httptest"
  7. "sync"
  8. "testing"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/tal-tech/go-zero/core/lang"
  11. )
  12. const conns = 4
  13. func init() {
  14. log.SetOutput(ioutil.Discard)
  15. }
  16. func TestMaxConnsHandler(t *testing.T) {
  17. var waitGroup sync.WaitGroup
  18. waitGroup.Add(conns)
  19. done := make(chan lang.PlaceholderType)
  20. defer close(done)
  21. maxConns := MaxConns(conns)
  22. handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  23. waitGroup.Done()
  24. <-done
  25. }))
  26. for i := 0; i < conns; i++ {
  27. go func() {
  28. req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
  29. handler.ServeHTTP(httptest.NewRecorder(), req)
  30. }()
  31. }
  32. waitGroup.Wait()
  33. req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
  34. resp := httptest.NewRecorder()
  35. handler.ServeHTTP(resp, req)
  36. assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
  37. }
  38. func TestWithoutMaxConnsHandler(t *testing.T) {
  39. const (
  40. key = "block"
  41. value = "1"
  42. )
  43. var waitGroup sync.WaitGroup
  44. waitGroup.Add(conns)
  45. done := make(chan lang.PlaceholderType)
  46. defer close(done)
  47. maxConns := MaxConns(0)
  48. handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  49. val := r.Header.Get(key)
  50. if val == value {
  51. waitGroup.Done()
  52. <-done
  53. }
  54. }))
  55. for i := 0; i < conns; i++ {
  56. go func() {
  57. req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
  58. req.Header.Set(key, value)
  59. handler.ServeHTTP(httptest.NewRecorder(), req)
  60. }()
  61. }
  62. waitGroup.Wait()
  63. req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
  64. resp := httptest.NewRecorder()
  65. handler.ServeHTTP(resp, req)
  66. assert.Equal(t, http.StatusOK, resp.Code)
  67. }