sheddinghandler.go 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. package handler
  2. import (
  3. "net/http"
  4. "sync"
  5. "github.com/tal-tech/go-zero/core/load"
  6. "github.com/tal-tech/go-zero/core/logx"
  7. "github.com/tal-tech/go-zero/core/stat"
  8. "github.com/tal-tech/go-zero/rest/httpx"
  9. "github.com/tal-tech/go-zero/rest/internal/security"
  10. )
  11. const serviceType = "api"
  12. var (
  13. sheddingStat *load.SheddingStat
  14. lock sync.Mutex
  15. )
  16. // SheddingHandler returns a middleware that does load shedding.
  17. func SheddingHandler(shedder load.Shedder, metrics *stat.Metrics) func(http.Handler) http.Handler {
  18. if shedder == nil {
  19. return func(next http.Handler) http.Handler {
  20. return next
  21. }
  22. }
  23. ensureSheddingStat()
  24. return func(next http.Handler) http.Handler {
  25. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  26. sheddingStat.IncrementTotal()
  27. promise, err := shedder.Allow()
  28. if err != nil {
  29. metrics.AddDrop()
  30. sheddingStat.IncrementDrop()
  31. logx.Errorf("[http] dropped, %s - %s - %s",
  32. r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent())
  33. w.WriteHeader(http.StatusServiceUnavailable)
  34. return
  35. }
  36. cw := &security.WithCodeResponseWriter{Writer: w}
  37. defer func() {
  38. if cw.Code == http.StatusServiceUnavailable {
  39. promise.Fail()
  40. } else {
  41. sheddingStat.IncrementPass()
  42. promise.Pass()
  43. }
  44. }()
  45. next.ServeHTTP(cw, r)
  46. })
  47. }
  48. }
  49. func ensureSheddingStat() {
  50. lock.Lock()
  51. if sheddingStat == nil {
  52. sheddingStat = load.NewSheddingStat(serviceType)
  53. }
  54. lock.Unlock()
  55. }