engine.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. package rest
  2. import (
  3. "errors"
  4. "fmt"
  5. "net/http"
  6. "time"
  7. "github.com/justinas/alice"
  8. "github.com/tal-tech/go-zero/core/codec"
  9. "github.com/tal-tech/go-zero/core/load"
  10. "github.com/tal-tech/go-zero/core/stat"
  11. "github.com/tal-tech/go-zero/rest/handler"
  12. "github.com/tal-tech/go-zero/rest/httpx"
  13. "github.com/tal-tech/go-zero/rest/internal"
  14. "github.com/tal-tech/go-zero/rest/router"
  15. )
  16. // use 1000m to represent 100%
  17. const topCpuUsage = 1000
  18. var ErrSignatureConfig = errors.New("bad config for Signature")
  19. type engine struct {
  20. conf RestConf
  21. routes []featuredRoutes
  22. unauthorizedCallback handler.UnauthorizedCallback
  23. unsignedCallback handler.UnsignedCallback
  24. middlewares []Middleware
  25. shedder load.Shedder
  26. priorityShedder load.Shedder
  27. }
  28. func newEngine(c RestConf) *engine {
  29. srv := &engine{
  30. conf: c,
  31. }
  32. if c.CpuThreshold > 0 {
  33. srv.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold))
  34. srv.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold(
  35. (c.CpuThreshold + topCpuUsage) >> 1))
  36. }
  37. return srv
  38. }
  39. func (s *engine) AddRoutes(r featuredRoutes) {
  40. s.routes = append(s.routes, r)
  41. }
  42. func (s *engine) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) {
  43. s.unauthorizedCallback = callback
  44. }
  45. func (s *engine) SetUnsignedCallback(callback handler.UnsignedCallback) {
  46. s.unsignedCallback = callback
  47. }
  48. func (s *engine) Start() error {
  49. return s.StartWithRouter(router.NewRouter())
  50. }
  51. func (s *engine) StartWithRouter(router httpx.Router) error {
  52. if err := s.bindRoutes(router); err != nil {
  53. return err
  54. }
  55. if len(s.conf.CertFile) == 0 && len(s.conf.KeyFile) == 0 {
  56. return internal.StartHttp(s.conf.Host, s.conf.Port, router)
  57. }
  58. return internal.StartHttps(s.conf.Host, s.conf.Port, s.conf.CertFile, s.conf.KeyFile, router)
  59. }
  60. func (s *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
  61. verifier func(alice.Chain) alice.Chain) alice.Chain {
  62. if fr.jwt.enabled {
  63. if len(fr.jwt.prevSecret) == 0 {
  64. chain = chain.Append(handler.Authorize(fr.jwt.secret,
  65. handler.WithUnauthorizedCallback(s.unauthorizedCallback)))
  66. } else {
  67. chain = chain.Append(handler.Authorize(fr.jwt.secret,
  68. handler.WithPrevSecret(fr.jwt.prevSecret),
  69. handler.WithUnauthorizedCallback(s.unauthorizedCallback)))
  70. }
  71. }
  72. return verifier(chain)
  73. }
  74. func (s *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
  75. verifier, err := s.signatureVerifier(fr.signature)
  76. if err != nil {
  77. return err
  78. }
  79. for _, route := range fr.routes {
  80. if err := s.bindRoute(fr, router, metrics, route, verifier); err != nil {
  81. return err
  82. }
  83. }
  84. return nil
  85. }
  86. func (s *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
  87. route Route, verifier func(chain alice.Chain) alice.Chain) error {
  88. chain := alice.New(
  89. handler.TracingHandler,
  90. s.getLogHandler(),
  91. handler.MaxConns(s.conf.MaxConns),
  92. handler.BreakerHandler(route.Method, route.Path, metrics),
  93. handler.SheddingHandler(s.getShedder(fr.priority), metrics),
  94. handler.TimeoutHandler(time.Duration(s.conf.Timeout)*time.Millisecond),
  95. handler.RecoverHandler,
  96. handler.MetricHandler(metrics),
  97. handler.PromethousHandler(route.Path),
  98. handler.MaxBytesHandler(s.conf.MaxBytes),
  99. handler.GunzipHandler,
  100. )
  101. chain = s.appendAuthHandler(fr, chain, verifier)
  102. for _, middleware := range s.middlewares {
  103. chain = chain.Append(convertMiddleware(middleware))
  104. }
  105. handle := chain.ThenFunc(route.Handler)
  106. return router.Handle(route.Method, route.Path, handle)
  107. }
  108. func (s *engine) bindRoutes(router httpx.Router) error {
  109. metrics := s.createMetrics()
  110. for _, fr := range s.routes {
  111. if err := s.bindFeaturedRoutes(router, fr, metrics); err != nil {
  112. return err
  113. }
  114. }
  115. return nil
  116. }
  117. func (s *engine) createMetrics() *stat.Metrics {
  118. var metrics *stat.Metrics
  119. if len(s.conf.Name) > 0 {
  120. metrics = stat.NewMetrics(s.conf.Name)
  121. } else {
  122. metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", s.conf.Host, s.conf.Port))
  123. }
  124. return metrics
  125. }
  126. func (s *engine) getLogHandler() func(http.Handler) http.Handler {
  127. if s.conf.Verbose {
  128. return handler.DetailedLogHandler
  129. }
  130. return handler.LogHandler
  131. }
  132. func (s *engine) getShedder(priority bool) load.Shedder {
  133. if priority && s.priorityShedder != nil {
  134. return s.priorityShedder
  135. }
  136. return s.shedder
  137. }
  138. func (s *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
  139. if !signature.enabled {
  140. return func(chain alice.Chain) alice.Chain {
  141. return chain
  142. }, nil
  143. }
  144. if len(signature.PrivateKeys) == 0 {
  145. if signature.Strict {
  146. return nil, ErrSignatureConfig
  147. }
  148. return func(chain alice.Chain) alice.Chain {
  149. return chain
  150. }, nil
  151. }
  152. decrypters := make(map[string]codec.RsaDecrypter)
  153. for _, key := range signature.PrivateKeys {
  154. fingerprint := key.Fingerprint
  155. file := key.KeyFile
  156. decrypter, err := codec.NewRsaDecrypter(file)
  157. if err != nil {
  158. return nil, err
  159. }
  160. decrypters[fingerprint] = decrypter
  161. }
  162. return func(chain alice.Chain) alice.Chain {
  163. if s.unsignedCallback != nil {
  164. return chain.Append(handler.ContentSecurityHandler(
  165. decrypters, signature.Expiry, signature.Strict, s.unsignedCallback))
  166. }
  167. return chain.Append(handler.ContentSecurityHandler(
  168. decrypters, signature.Expiry, signature.Strict))
  169. }, nil
  170. }
  171. func (s *engine) use(middleware Middleware) {
  172. s.middlewares = append(s.middlewares, middleware)
  173. }
  174. func convertMiddleware(ware Middleware) func(http.Handler) http.Handler {
  175. return func(next http.Handler) http.Handler {
  176. return ware(next.ServeHTTP)
  177. }
  178. }