patrouter.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. package router
  2. import (
  3. "errors"
  4. "net/http"
  5. "path"
  6. "strings"
  7. "github.com/tal-tech/go-zero/core/search"
  8. "github.com/tal-tech/go-zero/rest/httpx"
  9. "github.com/tal-tech/go-zero/rest/internal/context"
  10. )
  11. const (
  12. allowHeader = "Allow"
  13. allowMethodSeparator = ", "
  14. )
  15. var (
  16. ErrInvalidMethod = errors.New("not a valid http method")
  17. ErrInvalidPath = errors.New("path must begin with '/'")
  18. )
  19. type PatRouter struct {
  20. trees map[string]*search.Tree
  21. notFound http.Handler
  22. }
  23. func NewPatRouter() httpx.Router {
  24. return &PatRouter{
  25. trees: make(map[string]*search.Tree),
  26. }
  27. }
  28. func (pr *PatRouter) Handle(method, reqPath string, handler http.Handler) error {
  29. if !validMethod(method) {
  30. return ErrInvalidMethod
  31. }
  32. if len(reqPath) == 0 || reqPath[0] != '/' {
  33. return ErrInvalidPath
  34. }
  35. cleanPath := path.Clean(reqPath)
  36. if tree, ok := pr.trees[method]; ok {
  37. return tree.Add(cleanPath, handler)
  38. } else {
  39. tree = search.NewTree()
  40. pr.trees[method] = tree
  41. return tree.Add(cleanPath, handler)
  42. }
  43. }
  44. func (pr *PatRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  45. reqPath := path.Clean(r.URL.Path)
  46. if tree, ok := pr.trees[r.Method]; ok {
  47. if result, ok := tree.Search(reqPath); ok {
  48. if len(result.Params) > 0 {
  49. r = context.WithPathVars(r, result.Params)
  50. }
  51. result.Item.(http.Handler).ServeHTTP(w, r)
  52. return
  53. }
  54. }
  55. if allow, ok := pr.methodNotAllowed(r.Method, reqPath); ok {
  56. w.Header().Set(allowHeader, allow)
  57. w.WriteHeader(http.StatusMethodNotAllowed)
  58. } else {
  59. pr.handleNotFound(w, r)
  60. }
  61. }
  62. func (pr *PatRouter) SetNotFoundHandler(handler http.Handler) {
  63. pr.notFound = handler
  64. }
  65. func (pr *PatRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
  66. if pr.notFound != nil {
  67. pr.notFound.ServeHTTP(w, r)
  68. } else {
  69. http.NotFound(w, r)
  70. }
  71. }
  72. func (pr *PatRouter) methodNotAllowed(method, path string) (string, bool) {
  73. var allows []string
  74. for treeMethod, tree := range pr.trees {
  75. if treeMethod == method {
  76. continue
  77. }
  78. _, ok := tree.Search(path)
  79. if ok {
  80. allows = append(allows, treeMethod)
  81. }
  82. }
  83. if len(allows) > 0 {
  84. return strings.Join(allows, allowMethodSeparator), true
  85. } else {
  86. return "", false
  87. }
  88. }
  89. func validMethod(method string) bool {
  90. return method == http.MethodDelete || method == http.MethodGet ||
  91. method == http.MethodHead || method == http.MethodOptions ||
  92. method == http.MethodPatch || method == http.MethodPost ||
  93. method == http.MethodPut
  94. }