patrouter.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. package router
  2. import (
  3. "errors"
  4. "net/http"
  5. "path"
  6. "strings"
  7. "git.i2edu.net/i2/go-zero/core/search"
  8. "git.i2edu.net/i2/go-zero/rest/httpx"
  9. "git.i2edu.net/i2/go-zero/rest/internal/context"
  10. )
  11. const (
  12. allowHeader = "Allow"
  13. allowMethodSeparator = ", "
  14. )
  15. var (
  16. // ErrInvalidMethod is an error that indicates not a valid http method.
  17. ErrInvalidMethod = errors.New("not a valid http method")
  18. // ErrInvalidPath is an error that indicates path is not start with /.
  19. ErrInvalidPath = errors.New("path must begin with '/'")
  20. )
  21. type patRouter struct {
  22. trees map[string]*search.Tree
  23. notFound http.Handler
  24. notAllowed http.Handler
  25. }
  26. // NewRouter returns a httpx.Router.
  27. func NewRouter() httpx.Router {
  28. return &patRouter{
  29. trees: make(map[string]*search.Tree),
  30. }
  31. }
  32. func (pr *patRouter) Handle(method, reqPath string, handler http.Handler) error {
  33. if !validMethod(method) {
  34. return ErrInvalidMethod
  35. }
  36. if len(reqPath) == 0 || reqPath[0] != '/' {
  37. return ErrInvalidPath
  38. }
  39. cleanPath := path.Clean(reqPath)
  40. tree, ok := pr.trees[method]
  41. if ok {
  42. return tree.Add(cleanPath, handler)
  43. }
  44. tree = search.NewTree()
  45. pr.trees[method] = tree
  46. return tree.Add(cleanPath, handler)
  47. }
  48. func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  49. reqPath := path.Clean(r.URL.Path)
  50. if tree, ok := pr.trees[r.Method]; ok {
  51. if result, ok := tree.Search(reqPath); ok {
  52. if len(result.Params) > 0 {
  53. r = context.WithPathVars(r, result.Params)
  54. }
  55. result.Item.(http.Handler).ServeHTTP(w, r)
  56. return
  57. }
  58. }
  59. allows, ok := pr.methodsAllowed(r.Method, reqPath)
  60. if !ok {
  61. pr.handleNotFound(w, r)
  62. return
  63. }
  64. if pr.notAllowed != nil {
  65. pr.notAllowed.ServeHTTP(w, r)
  66. } else {
  67. w.Header().Set(allowHeader, allows)
  68. w.WriteHeader(http.StatusMethodNotAllowed)
  69. }
  70. }
  71. func (pr *patRouter) SetNotFoundHandler(handler http.Handler) {
  72. pr.notFound = handler
  73. }
  74. func (pr *patRouter) SetNotAllowedHandler(handler http.Handler) {
  75. pr.notAllowed = handler
  76. }
  77. func (pr *patRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
  78. if pr.notFound != nil {
  79. pr.notFound.ServeHTTP(w, r)
  80. } else {
  81. http.NotFound(w, r)
  82. }
  83. }
  84. func (pr *patRouter) methodsAllowed(method, path string) (string, bool) {
  85. var allows []string
  86. for treeMethod, tree := range pr.trees {
  87. if treeMethod == method {
  88. continue
  89. }
  90. _, ok := tree.Search(path)
  91. if ok {
  92. allows = append(allows, treeMethod)
  93. }
  94. }
  95. if len(allows) > 0 {
  96. return strings.Join(allows, allowMethodSeparator), true
  97. }
  98. return "", false
  99. }
  100. func validMethod(method string) bool {
  101. return method == http.MethodDelete || method == http.MethodGet ||
  102. method == http.MethodHead || method == http.MethodOptions ||
  103. method == http.MethodPatch || method == http.MethodPost ||
  104. method == http.MethodPut
  105. }