patrouter.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. package router
  2. import (
  3. "errors"
  4. "net/http"
  5. "path"
  6. "strings"
  7. "github.com/tal-tech/go-zero/core/search"
  8. "github.com/tal-tech/go-zero/rest/httpx"
  9. "github.com/tal-tech/go-zero/rest/internal/context"
  10. )
  11. const (
  12. allowHeader = "Allow"
  13. allowMethodSeparator = ", "
  14. )
  15. var (
  16. ErrInvalidMethod = errors.New("not a valid http method")
  17. ErrInvalidPath = errors.New("path must begin with '/'")
  18. )
  19. type patRouter struct {
  20. trees map[string]*search.Tree
  21. notFound http.Handler
  22. notAllowed http.Handler
  23. }
  24. func NewRouter() httpx.Router {
  25. return &patRouter{
  26. trees: make(map[string]*search.Tree),
  27. }
  28. }
  29. func (pr *patRouter) Handle(method, reqPath string, handler http.Handler) error {
  30. if !validMethod(method) {
  31. return ErrInvalidMethod
  32. }
  33. if len(reqPath) == 0 || reqPath[0] != '/' {
  34. return ErrInvalidPath
  35. }
  36. cleanPath := path.Clean(reqPath)
  37. tree, ok := pr.trees[method]
  38. if ok {
  39. return tree.Add(cleanPath, handler)
  40. }
  41. tree = search.NewTree()
  42. pr.trees[method] = tree
  43. return tree.Add(cleanPath, handler)
  44. }
  45. func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  46. reqPath := path.Clean(r.URL.Path)
  47. if tree, ok := pr.trees[r.Method]; ok {
  48. if result, ok := tree.Search(reqPath); ok {
  49. if len(result.Params) > 0 {
  50. r = context.WithPathVars(r, result.Params)
  51. }
  52. result.Item.(http.Handler).ServeHTTP(w, r)
  53. return
  54. }
  55. }
  56. allows, ok := pr.methodsAllowed(r.Method, reqPath)
  57. if !ok {
  58. pr.handleNotFound(w, r)
  59. return
  60. }
  61. if pr.notAllowed != nil {
  62. pr.notAllowed.ServeHTTP(w, r)
  63. } else {
  64. w.Header().Set(allowHeader, allows)
  65. w.WriteHeader(http.StatusMethodNotAllowed)
  66. }
  67. }
  68. func (pr *patRouter) SetNotFoundHandler(handler http.Handler) {
  69. pr.notFound = handler
  70. }
  71. func (pr *patRouter) SetNotAllowedHandler(handler http.Handler) {
  72. pr.notAllowed = handler
  73. }
  74. func (pr *patRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
  75. if pr.notFound != nil {
  76. pr.notFound.ServeHTTP(w, r)
  77. } else {
  78. http.NotFound(w, r)
  79. }
  80. }
  81. func (pr *patRouter) methodsAllowed(method, path string) (string, bool) {
  82. var allows []string
  83. for treeMethod, tree := range pr.trees {
  84. if treeMethod == method {
  85. continue
  86. }
  87. _, ok := tree.Search(path)
  88. if ok {
  89. allows = append(allows, treeMethod)
  90. }
  91. }
  92. if len(allows) > 0 {
  93. return strings.Join(allows, allowMethodSeparator), true
  94. }
  95. return "", false
  96. }
  97. func validMethod(method string) bool {
  98. return method == http.MethodDelete || method == http.MethodGet ||
  99. method == http.MethodHead || method == http.MethodOptions ||
  100. method == http.MethodPatch || method == http.MethodPost ||
  101. method == http.MethodPut
  102. }