12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- package handler
- import (
- "io/ioutil"
- "log"
- "net/http"
- "net/http/httptest"
- "sync"
- "testing"
- "github.com/stretchr/testify/assert"
- "github.com/tal-tech/go-zero/core/lang"
- )
- const conns = 4
- func init() {
- log.SetOutput(ioutil.Discard)
- }
- func TestMaxConnsHandler(t *testing.T) {
- var waitGroup sync.WaitGroup
- waitGroup.Add(conns)
- done := make(chan lang.PlaceholderType)
- defer close(done)
- maxConns := MaxConns(conns)
- handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- waitGroup.Done()
- <-done
- }))
- for i := 0; i < conns; i++ {
- go func() {
- req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
- handler.ServeHTTP(httptest.NewRecorder(), req)
- }()
- }
- waitGroup.Wait()
- req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
- resp := httptest.NewRecorder()
- handler.ServeHTTP(resp, req)
- assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
- }
- func TestWithoutMaxConnsHandler(t *testing.T) {
- const (
- key = "block"
- value = "1"
- )
- var waitGroup sync.WaitGroup
- waitGroup.Add(conns)
- done := make(chan lang.PlaceholderType)
- defer close(done)
- maxConns := MaxConns(0)
- handler := maxConns(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- val := r.Header.Get(key)
- if val == value {
- waitGroup.Done()
- <-done
- }
- }))
- for i := 0; i < conns; i++ {
- go func() {
- req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
- req.Header.Set(key, value)
- handler.ServeHTTP(httptest.NewRecorder(), req)
- }()
- }
- waitGroup.Wait()
- req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
- resp := httptest.NewRecorder()
- handler.ServeHTTP(resp, req)
- assert.Equal(t, http.StatusOK, resp.Code)
- }
|