소스 검색

Updates tree.go + fixes + unit tests

Manu Mtz-Almeida 10 년 전
부모
커밋
f212ae7728
5개의 변경된 파일113개의 추가작업 그리고 25개의 파일을 삭제
  1. 2 0
      gin.go
  2. 3 0
      gin_test.go
  3. 80 2
      routes_test.go
  4. 23 18
      tree.go
  5. 5 5
      tree_test.go

+ 2 - 0
gin.go

@@ -233,6 +233,7 @@ func (engine *Engine) serveAutoRedirect(c *Context, root *node, tsr bool) bool {
 		}
 		debugPrint("redirecting request %d: %s --> %s", code, path, req.URL.String())
 		http.Redirect(c.Writer, req, req.URL.String(), code)
+		c.writermem.WriteHeaderNow()
 		return true
 	}
 
@@ -246,6 +247,7 @@ func (engine *Engine) serveAutoRedirect(c *Context, root *node, tsr bool) bool {
 			req.URL.Path = string(fixedPath)
 			debugPrint("redirecting request %d: %s --> %s", code, path, req.URL.String())
 			http.Redirect(c.Writer, req, req.URL.String(), code)
+			c.writermem.WriteHeaderNow()
 			return true
 		}
 	}

+ 3 - 0
gin_test.go

@@ -25,6 +25,9 @@ func TestCreateEngine(t *testing.T) {
 	assert.Equal(t, "/", router.absolutePath)
 	assert.Equal(t, router.engine, router)
 	assert.Empty(t, router.Handlers)
+	assert.True(t, router.RedirectTrailingSlash)
+	assert.True(t, router.RedirectFixedPath)
+	assert.True(t, router.HandleMethodNotAllowed)
 
 	assert.Panics(t, func() { router.handle("", "/", []HandlerFunc{func(_ *Context) {}}) })
 	assert.Panics(t, func() { router.handle("GET", "", []HandlerFunc{func(_ *Context) {}}) })

+ 80 - 2
routes_test.go

@@ -5,6 +5,7 @@
 package gin
 
 import (
+	"fmt"
 	"io/ioutil"
 	"net/http"
 	"net/http/httptest"
@@ -110,18 +111,28 @@ func TestRouteNotOK2(t *testing.T) {
 func TestRouteParamsByName(t *testing.T) {
 	name := ""
 	lastName := ""
+	wild := ""
 	router := New()
-	router.GET("/test/:name/:last_name", func(c *Context) {
+	router.GET("/test/:name/:last_name/*wild", func(c *Context) {
 		name = c.Params.ByName("name")
 		lastName = c.Params.ByName("last_name")
+		wild = c.Params.ByName("wild")
+
+		assert.Equal(t, name, c.ParamValue("name"))
+		assert.Equal(t, lastName, c.ParamValue("last_name"))
+
+		assert.Equal(t, name, c.DefaultParamValue("name", "nothing"))
+		assert.Equal(t, lastName, c.DefaultParamValue("last_name", "nothing"))
+		assert.Equal(t, c.DefaultParamValue("noKey", "default"), "default")
 	})
 	// RUN
-	w := performRequest(router, "GET", "/test/john/smith")
+	w := performRequest(router, "GET", "/test/john/smith/is/super/great")
 
 	// TEST
 	assert.Equal(t, w.Code, 200)
 	assert.Equal(t, name, "john")
 	assert.Equal(t, lastName, "smith")
+	assert.Equal(t, wild, "/is/super/great")
 }
 
 // TestHandleStaticFile - ensure the static file handles properly
@@ -183,3 +194,70 @@ func TestRouteHeadToDir(t *testing.T) {
 	assert.Contains(t, bodyAsString, "gin.go")
 	assert.Equal(t, w.HeaderMap.Get("Content-Type"), "text/html; charset=utf-8")
 }
+
+func TestRouteNotAllowed(t *testing.T) {
+	router := New()
+
+	router.POST("/path", func(c *Context) {})
+	w := performRequest(router, "GET", "/path")
+	assert.Equal(t, w.Code, http.StatusMethodNotAllowed)
+
+	router.NoMethod(func(c *Context) {
+		c.String(http.StatusTeapot, "responseText")
+	})
+	w = performRequest(router, "GET", "/path")
+	assert.Equal(t, w.Body.String(), "responseText")
+	assert.Equal(t, w.Code, http.StatusTeapot)
+}
+
+func TestRouterNotFound(t *testing.T) {
+	router := New()
+	router.GET("/path", func(c *Context) {})
+	router.GET("/dir/", func(c *Context) {})
+	router.GET("/", func(c *Context) {})
+
+	testRoutes := []struct {
+		route  string
+		code   int
+		header string
+	}{
+		{"/path/", 301, "map[Location:[/path]]"},   // TSR -/
+		{"/dir", 301, "map[Location:[/dir/]]"},     // TSR +/
+		{"", 301, "map[Location:[/]]"},             // TSR +/
+		{"/PATH", 301, "map[Location:[/path]]"},    // Fixed Case
+		{"/DIR/", 301, "map[Location:[/dir/]]"},    // Fixed Case
+		{"/PATH/", 301, "map[Location:[/path]]"},   // Fixed Case -/
+		{"/DIR", 301, "map[Location:[/dir/]]"},     // Fixed Case +/
+		{"/../path", 301, "map[Location:[/path]]"}, // CleanPath
+		{"/nope", 404, ""},                         // NotFound
+	}
+	for _, tr := range testRoutes {
+		w := performRequest(router, "GET", tr.route)
+		assert.Equal(t, w.Code, tr.code)
+		if w.Code != 404 {
+			assert.Equal(t, fmt.Sprint(w.Header()), tr.header)
+		}
+	}
+
+	// Test custom not found handler
+	var notFound bool
+	router.NoRoute(func(c *Context) {
+		c.AbortWithStatus(404)
+		notFound = true
+	})
+	w := performRequest(router, "GET", "/nope")
+	assert.Equal(t, w.Code, 404)
+	assert.True(t, notFound)
+
+	// Test other method than GET (want 307 instead of 301)
+	router.PATCH("/path", func(c *Context) {})
+	w = performRequest(router, "PATCH", "/path/")
+	assert.Equal(t, w.Code, 307)
+	assert.Equal(t, fmt.Sprint(w.Header()), "map[Location:[/path]]")
+
+	// Test special case where no node for the prefix "/" exists
+	router = New()
+	router.GET("/a", func(c *Context) {})
+	w = performRequest(router, "GET", "/")
+	assert.Equal(t, w.Code, 404)
+}

+ 23 - 18
tree.go

@@ -78,6 +78,7 @@ func (n *node) incrementChildPrio(pos int) int {
 // addRoute adds a node with the given handle to the path.
 // Not concurrency-safe!
 func (n *node) addRoute(path string, handlers []HandlerFunc) {
+	fullPath := path
 	n.priority++
 	numParams := countParams(path)
 
@@ -147,7 +148,9 @@ func (n *node) addRoute(path string, handlers []HandlerFunc) {
 						}
 					}
 
-					panic("conflict with wildcard route")
+					panic("path segment '" + path +
+						"' conflicts with existing wildcard '" + n.path +
+						"' in path '" + fullPath + "'")
 				}
 
 				c := path[0]
@@ -179,23 +182,23 @@ func (n *node) addRoute(path string, handlers []HandlerFunc) {
 					n.incrementChildPrio(len(n.indices) - 1)
 					n = child
 				}
-				n.insertChild(numParams, path, handlers)
+				n.insertChild(numParams, path, fullPath, handlers)
 				return
 
 			} else if i == len(path) { // Make node a (in-path) leaf
 				if n.handlers != nil {
-					panic("a Handle is already registered for this path")
+					panic("handlers are already registered for path ''" + fullPath + "'")
 				}
 				n.handlers = handlers
 			}
 			return
 		}
 	} else { // Empty tree
-		n.insertChild(numParams, path, handlers)
+		n.insertChild(numParams, path, fullPath, handlers)
 	}
 }
 
-func (n *node) insertChild(numParams uint8, path string, handlers []HandlerFunc) {
+func (n *node) insertChild(numParams uint8, path string, fullPath string, handlers []HandlerFunc) {
 	var offset int // already handled bytes of the path
 
 	// find prefix until first wildcard (beginning with ':'' or '*'')
@@ -205,27 +208,29 @@ func (n *node) insertChild(numParams uint8, path string, handlers []HandlerFunc)
 			continue
 		}
 
-		// check if this Node existing children which would be
-		// unreachable if we insert the wildcard here
-		if len(n.children) > 0 {
-			panic("wildcard route conflicts with existing children")
-		}
-
 		// find wildcard end (either '/' or path end)
 		end := i + 1
 		for end < max && path[end] != '/' {
 			switch path[end] {
 			// the wildcard name must not contain ':' and '*'
 			case ':', '*':
-				panic("only one wildcard per path segment is allowed")
+				panic("only one wildcard per path segment is allowed, has: '" +
+					path[i:] + "' in path '" + fullPath + "'")
 			default:
 				end++
 			}
 		}
 
+		// check if this Node existing children which would be
+		// unreachable if we insert the wildcard here
+		if len(n.children) > 0 {
+			panic("wildcard route '" + path[i:end] +
+				"' conflicts with existing children in path '" + fullPath + "'")
+		}
+
 		// check if the wildcard has a name
 		if end-i < 2 {
-			panic("wildcards must be named with a non-empty name")
+			panic("wildcards must be named with a non-empty name in path '" + fullPath + "'")
 		}
 
 		if c == ':' { // param
@@ -261,17 +266,17 @@ func (n *node) insertChild(numParams uint8, path string, handlers []HandlerFunc)
 
 		} else { // catchAll
 			if end != max || numParams > 1 {
-				panic("catch-all routes are only allowed at the end of the path")
+				panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'")
 			}
 
 			if len(n.path) > 0 && n.path[len(n.path)-1] == '/' {
-				panic("catch-all conflicts with existing handle for the path segment root")
+				panic("catch-all conflicts with existing handle for the path segment root in path '" + fullPath + "'")
 			}
 
 			// currently fixed width 1 for '/'
 			i--
 			if path[i] != '/' {
-				panic("no / before catch-all")
+				panic("no / before catch-all in path '" + fullPath + "'")
 			}
 
 			n.path = path[offset:i]
@@ -394,7 +399,7 @@ walk: // Outer loop for walking the tree
 					return
 
 				default:
-					panic("Invalid node type")
+					panic("invalid node type")
 				}
 			}
 		} else if path == n.path {
@@ -505,7 +510,7 @@ func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) (ciPa
 				return append(ciPath, path...), true
 
 			default:
-				panic("Invalid node type")
+				panic("invalid node type")
 			}
 		} else {
 			// We should have reached the node containing the handle.

+ 5 - 5
tree_test.go

@@ -357,7 +357,7 @@ func TestTreeDoubleWildcard(t *testing.T) {
 			tree.addRoute(route, nil)
 		})
 
-		if rs, ok := recv.(string); !ok || rs != panicMsg {
+		if rs, ok := recv.(string); !ok || !strings.HasPrefix(rs, panicMsg) {
 			t.Fatalf(`"Expected panic "%s" for route '%s', got "%v"`, panicMsg, route, recv)
 		}
 	}
@@ -594,15 +594,15 @@ func TestTreeInvalidNodeType(t *testing.T) {
 	recv := catchPanic(func() {
 		tree.getValue("/test", nil)
 	})
-	if rs, ok := recv.(string); !ok || rs != "Invalid node type" {
-		t.Fatalf(`Expected panic "Invalid node type", got "%v"`, recv)
+	if rs, ok := recv.(string); !ok || rs != "invalid node type" {
+		t.Fatalf(`Expected panic "invalid node type", got "%v"`, recv)
 	}
 
 	// case-insensitive lookup
 	recv = catchPanic(func() {
 		tree.findCaseInsensitivePath("/test", true)
 	})
-	if rs, ok := recv.(string); !ok || rs != "Invalid node type" {
-		t.Fatalf(`Expected panic "Invalid node type", got "%v"`, recv)
+	if rs, ok := recv.(string); !ok || rs != "invalid node type" {
+		t.Fatalf(`Expected panic "invalid node type", got "%v"`, recv)
 	}
 }