Explorar el Código

Adds option for listing directory files + better unit tests

Manu Mtz-Almeida hace 10 años
padre
commit
cac77e04e3
Se han modificado 4 ficheros con 150 adiciones y 39 borrados
  1. 85 12
      middleware_test.go
  2. 14 5
      routergroup.go
  3. 11 0
      routergroup_test.go
  4. 40 22
      routes_test.go

+ 85 - 12
middleware_test.go

@@ -9,6 +9,7 @@ import (
 
 
 	"testing"
 	"testing"
 
 
+	"github.com/manucorporat/sse"
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
@@ -27,10 +28,10 @@ func TestMiddlewareGeneralCase(t *testing.T) {
 		signature += "D"
 		signature += "D"
 	})
 	})
 	router.NoRoute(func(c *Context) {
 	router.NoRoute(func(c *Context) {
-		signature += "X"
+		signature += " X "
 	})
 	})
 	router.NoMethod(func(c *Context) {
 	router.NoMethod(func(c *Context) {
-		signature += "X"
+		signature += " XX "
 	})
 	})
 	// RUN
 	// RUN
 	w := performRequest(router, "GET", "/")
 	w := performRequest(router, "GET", "/")
@@ -40,8 +41,7 @@ func TestMiddlewareGeneralCase(t *testing.T) {
 	assert.Equal(t, signature, "ACDB")
 	assert.Equal(t, signature, "ACDB")
 }
 }
 
 
-// TestBadAbortHandlersChain - ensure that Abort after switch context will not interrupt pending handlers
-func TestMiddlewareNextOrder(t *testing.T) {
+func TestMiddlewareNoRoute(t *testing.T) {
 	signature := ""
 	signature := ""
 	router := New()
 	router := New()
 	router.Use(func(c *Context) {
 	router.Use(func(c *Context) {
@@ -52,6 +52,9 @@ func TestMiddlewareNextOrder(t *testing.T) {
 	router.Use(func(c *Context) {
 	router.Use(func(c *Context) {
 		signature += "C"
 		signature += "C"
 		c.Next()
 		c.Next()
+		c.Next()
+		c.Next()
+		c.Next()
 		signature += "D"
 		signature += "D"
 	})
 	})
 	router.NoRoute(func(c *Context) {
 	router.NoRoute(func(c *Context) {
@@ -63,6 +66,9 @@ func TestMiddlewareNextOrder(t *testing.T) {
 		c.Next()
 		c.Next()
 		signature += "H"
 		signature += "H"
 	})
 	})
+	router.NoMethod(func(c *Context) {
+		signature += " X "
+	})
 	// RUN
 	// RUN
 	w := performRequest(router, "GET", "/")
 	w := performRequest(router, "GET", "/")
 
 
@@ -71,30 +77,65 @@ func TestMiddlewareNextOrder(t *testing.T) {
 	assert.Equal(t, signature, "ACEGHFDB")
 	assert.Equal(t, signature, "ACEGHFDB")
 }
 }
 
 
-// TestAbortHandlersChain - ensure that Abort interrupt used middlewares in fifo order
-func TestMiddlewareAbortHandlersChain(t *testing.T) {
+func TestMiddlewareNoMethod(t *testing.T) {
 	signature := ""
 	signature := ""
 	router := New()
 	router := New()
 	router.Use(func(c *Context) {
 	router.Use(func(c *Context) {
 		signature += "A"
 		signature += "A"
+		c.Next()
+		signature += "B"
 	})
 	})
 	router.Use(func(c *Context) {
 	router.Use(func(c *Context) {
 		signature += "C"
 		signature += "C"
-		c.AbortWithStatus(409)
 		c.Next()
 		c.Next()
 		signature += "D"
 		signature += "D"
 	})
 	})
-	router.GET("/", func(c *Context) {
+	router.NoMethod(func(c *Context) {
+		signature += "E"
+		c.Next()
+		signature += "F"
+	}, func(c *Context) {
+		signature += "G"
+		c.Next()
+		signature += "H"
+	})
+	router.NoRoute(func(c *Context) {
+		signature += " X "
+	})
+	router.POST("/", func(c *Context) {
+		signature += " XX "
+	})
+	// RUN
+	w := performRequest(router, "GET", "/")
+
+	// TEST
+	assert.Equal(t, w.Code, 405)
+	assert.Equal(t, signature, "ACEGHFDB")
+}
+
+func TestMiddlewareAbort(t *testing.T) {
+	signature := ""
+	router := New()
+	router.Use(func(c *Context) {
+		signature += "A"
+	})
+	router.Use(func(c *Context) {
+		signature += "C"
+		c.AbortWithStatus(401)
+		c.Next()
 		signature += "D"
 		signature += "D"
+	})
+	router.GET("/", func(c *Context) {
+		signature += " X "
 		c.Next()
 		c.Next()
-		signature += "E"
+		signature += " XX "
 	})
 	})
 
 
 	// RUN
 	// RUN
 	w := performRequest(router, "GET", "/")
 	w := performRequest(router, "GET", "/")
 
 
 	// TEST
 	// TEST
-	assert.Equal(t, w.Code, 409)
+	assert.Equal(t, w.Code, 401)
 	assert.Equal(t, signature, "ACD")
 	assert.Equal(t, signature, "ACD")
 }
 }
 
 
@@ -103,8 +144,8 @@ func TestMiddlewareAbortHandlersChainAndNext(t *testing.T) {
 	router := New()
 	router := New()
 	router.Use(func(c *Context) {
 	router.Use(func(c *Context) {
 		signature += "A"
 		signature += "A"
-		c.AbortWithStatus(410)
 		c.Next()
 		c.Next()
+		c.AbortWithStatus(410)
 		signature += "B"
 		signature += "B"
 
 
 	})
 	})
@@ -117,7 +158,7 @@ func TestMiddlewareAbortHandlersChainAndNext(t *testing.T) {
 
 
 	// TEST
 	// TEST
 	assert.Equal(t, w.Code, 410)
 	assert.Equal(t, w.Code, 410)
-	assert.Equal(t, signature, "AB")
+	assert.Equal(t, signature, "ACB")
 }
 }
 
 
 // TestFailHandlersChain - ensure that Fail interrupt used middlewares in fifo order as
 // TestFailHandlersChain - ensure that Fail interrupt used middlewares in fifo order as
@@ -142,3 +183,35 @@ func TestMiddlewareFailHandlersChain(t *testing.T) {
 	assert.Equal(t, w.Code, 500)
 	assert.Equal(t, w.Code, 500)
 	assert.Equal(t, signature, "A")
 	assert.Equal(t, signature, "A")
 }
 }
+
+func TestMiddlewareWrite(t *testing.T) {
+	router := New()
+	router.Use(func(c *Context) {
+		c.String(400, "hola\n")
+	})
+	router.Use(func(c *Context) {
+		c.XML(400, H{"foo": "bar"})
+	})
+	router.Use(func(c *Context) {
+		c.JSON(400, H{"foo": "bar"})
+	})
+	router.GET("/", func(c *Context) {
+		c.JSON(400, H{"foo": "bar"})
+	}, func(c *Context) {
+		c.Render(400, sse.Event{
+			Event: "test",
+			Data:  "message",
+		})
+	})
+
+	w := performRequest(router, "GET", "/")
+
+	assert.Equal(t, w.Code, 400)
+	assert.Equal(t, w.Body.String(), `hola
+<map><foo>bar</foo></map>{"foo":"bar"}
+{"foo":"bar"}
+event: test
+data: message
+
+`)
+}

+ 14 - 5
routergroup.go

@@ -119,11 +119,14 @@ func (group *RouterGroup) StaticFile(relativePath, filepath string) {
 // use :
 // use :
 //     router.Static("/static", "/var/www")
 //     router.Static("/static", "/var/www")
 func (group *RouterGroup) Static(relativePath, root string) {
 func (group *RouterGroup) Static(relativePath, root string) {
-	group.StaticFS(relativePath, http.Dir(root))
+	group.StaticFS(relativePath, http.Dir(root), false)
 }
 }
 
 
-func (group *RouterGroup) StaticFS(relativePath string, fs http.FileSystem) {
-	handler := group.createStaticHandler(relativePath, fs)
+func (group *RouterGroup) StaticFS(relativePath string, fs http.FileSystem, listDirectory bool) {
+	if strings.Contains(relativePath, ":") || strings.Contains(relativePath, "*") {
+		panic("URL parameters can not be used when serving a static folder")
+	}
+	handler := group.createStaticHandler(relativePath, fs, listDirectory)
 	relativePath = path.Join(relativePath, "/*filepath")
 	relativePath = path.Join(relativePath, "/*filepath")
 
 
 	// Register GET and HEAD handlers
 	// Register GET and HEAD handlers
@@ -131,10 +134,16 @@ func (group *RouterGroup) StaticFS(relativePath string, fs http.FileSystem) {
 	group.HEAD(relativePath, handler)
 	group.HEAD(relativePath, handler)
 }
 }
 
 
-func (group *RouterGroup) createStaticHandler(relativePath string, fs http.FileSystem) func(*Context) {
+func (group *RouterGroup) createStaticHandler(relativePath string, fs http.FileSystem, listDirectory bool) HandlerFunc {
 	absolutePath := group.calculateAbsolutePath(relativePath)
 	absolutePath := group.calculateAbsolutePath(relativePath)
 	fileServer := http.StripPrefix(absolutePath, http.FileServer(fs))
 	fileServer := http.StripPrefix(absolutePath, http.FileServer(fs))
-	return WrapH(fileServer)
+	return func(c *Context) {
+		if !listDirectory && lastChar(c.Request.URL.Path) == '/' {
+			http.NotFound(c.Writer, c.Request)
+			return
+		}
+		fileServer.ServeHTTP(c.Writer, c.Request)
+	}
 }
 }
 
 
 func (group *RouterGroup) combineHandlers(handlers HandlersChain) HandlersChain {
 func (group *RouterGroup) combineHandlers(handlers HandlersChain) HandlersChain {

+ 11 - 0
routergroup_test.go

@@ -88,6 +88,17 @@ func performRequestInGroup(t *testing.T, method string) {
 	assert.Equal(t, w.Body.String(), "the method was "+method+" and index 1")
 	assert.Equal(t, w.Body.String(), "the method was "+method+" and index 1")
 }
 }
 
 
+func TestRouterGroupInvalidStatic(t *testing.T) {
+	router := New()
+	assert.Panics(t, func() {
+		router.Static("/path/:param", "/")
+	})
+
+	assert.Panics(t, func() {
+		router.Static("/path/*param", "/")
+	})
+}
+
 func TestRouterGroupInvalidStaticFile(t *testing.T) {
 func TestRouterGroupInvalidStaticFile(t *testing.T) {
 	router := New()
 	router := New()
 	assert.Panics(t, func() {
 	assert.Panics(t, func() {

+ 40 - 22
routes_test.go

@@ -78,32 +78,41 @@ func testRouteNotOK2(method string, t *testing.T) {
 }
 }
 
 
 func TestRouterGroupRouteOK(t *testing.T) {
 func TestRouterGroupRouteOK(t *testing.T) {
+	testRouteOK("GET", t)
 	testRouteOK("POST", t)
 	testRouteOK("POST", t)
-	testRouteOK("DELETE", t)
-	testRouteOK("PATCH", t)
 	testRouteOK("PUT", t)
 	testRouteOK("PUT", t)
-	testRouteOK("OPTIONS", t)
+	testRouteOK("PATCH", t)
 	testRouteOK("HEAD", t)
 	testRouteOK("HEAD", t)
+	testRouteOK("OPTIONS", t)
+	testRouteOK("DELETE", t)
+	testRouteOK("CONNECT", t)
+	testRouteOK("TRACE", t)
 }
 }
 
 
 // TestSingleRouteOK tests that POST route is correctly invoked.
 // TestSingleRouteOK tests that POST route is correctly invoked.
 func TestRouteNotOK(t *testing.T) {
 func TestRouteNotOK(t *testing.T) {
+	testRouteNotOK("GET", t)
 	testRouteNotOK("POST", t)
 	testRouteNotOK("POST", t)
-	testRouteNotOK("DELETE", t)
-	testRouteNotOK("PATCH", t)
 	testRouteNotOK("PUT", t)
 	testRouteNotOK("PUT", t)
-	testRouteNotOK("OPTIONS", t)
+	testRouteNotOK("PATCH", t)
 	testRouteNotOK("HEAD", t)
 	testRouteNotOK("HEAD", t)
+	testRouteNotOK("OPTIONS", t)
+	testRouteNotOK("DELETE", t)
+	testRouteNotOK("CONNECT", t)
+	testRouteNotOK("TRACE", t)
 }
 }
 
 
 // TestSingleRouteOK tests that POST route is correctly invoked.
 // TestSingleRouteOK tests that POST route is correctly invoked.
 func TestRouteNotOK2(t *testing.T) {
 func TestRouteNotOK2(t *testing.T) {
+	testRouteNotOK2("GET", t)
 	testRouteNotOK2("POST", t)
 	testRouteNotOK2("POST", t)
-	testRouteNotOK2("DELETE", t)
-	testRouteNotOK2("PATCH", t)
 	testRouteNotOK2("PUT", t)
 	testRouteNotOK2("PUT", t)
-	testRouteNotOK2("OPTIONS", t)
+	testRouteNotOK2("PATCH", t)
 	testRouteNotOK2("HEAD", t)
 	testRouteNotOK2("HEAD", t)
+	testRouteNotOK2("OPTIONS", t)
+	testRouteNotOK2("DELETE", t)
+	testRouteNotOK2("CONNECT", t)
+	testRouteNotOK2("TRACE", t)
 }
 }
 
 
 // TestContextParamsGet tests that a parameter can be parsed from the URL.
 // TestContextParamsGet tests that a parameter can be parsed from the URL.
@@ -142,25 +151,35 @@ func TestRouteStaticFile(t *testing.T) {
 		t.Error(err)
 		t.Error(err)
 	}
 	}
 	defer os.Remove(f.Name())
 	defer os.Remove(f.Name())
-	filePath := path.Join("/", path.Base(f.Name()))
 	f.WriteString("Gin Web Framework")
 	f.WriteString("Gin Web Framework")
 	f.Close()
 	f.Close()
 
 
+	dir, filename := path.Split(f.Name())
+
 	// SETUP gin
 	// SETUP gin
 	router := New()
 	router := New()
-	router.Static("./", testRoot)
+	router.Static("/using_static", dir)
+	router.StaticFile("/result", f.Name())
 
 
-	w := performRequest(router, "GET", filePath)
+	w := performRequest(router, "GET", "/using_static/"+filename)
+	w2 := performRequest(router, "GET", "/result")
 
 
+	assert.Equal(t, w, w2)
 	assert.Equal(t, w.Code, 200)
 	assert.Equal(t, w.Code, 200)
 	assert.Equal(t, w.Body.String(), "Gin Web Framework")
 	assert.Equal(t, w.Body.String(), "Gin Web Framework")
 	assert.Equal(t, w.HeaderMap.Get("Content-Type"), "text/plain; charset=utf-8")
 	assert.Equal(t, w.HeaderMap.Get("Content-Type"), "text/plain; charset=utf-8")
+
+	w3 := performRequest(router, "HEAD", "/using_static/"+filename)
+	w4 := performRequest(router, "HEAD", "/result")
+
+	assert.Equal(t, w3, w4)
+	assert.Equal(t, w3.Code, 200)
 }
 }
 
 
 // TestHandleStaticDir - ensure the root/sub dir handles properly
 // TestHandleStaticDir - ensure the root/sub dir handles properly
-func TestRouteStaticDir(t *testing.T) {
+func TestRouteStaticListingDir(t *testing.T) {
 	router := New()
 	router := New()
-	router.Static("/", "./")
+	router.StaticFS("/", http.Dir("./"), true)
 
 
 	w := performRequest(router, "GET", "/")
 	w := performRequest(router, "GET", "/")
 
 
@@ -170,15 +189,14 @@ func TestRouteStaticDir(t *testing.T) {
 }
 }
 
 
 // TestHandleHeadToDir - ensure the root/sub dir handles properly
 // TestHandleHeadToDir - ensure the root/sub dir handles properly
-func TestRouteHeadToDir(t *testing.T) {
+func TestRouteStaticNoListing(t *testing.T) {
 	router := New()
 	router := New()
 	router.Static("/", "./")
 	router.Static("/", "./")
 
 
-	w := performRequest(router, "HEAD", "/")
+	w := performRequest(router, "GET", "/")
 
 
-	assert.Equal(t, w.Code, 200)
-	assert.Contains(t, w.Body.String(), "gin.go")
-	assert.Equal(t, w.HeaderMap.Get("Content-Type"), "text/html; charset=utf-8")
+	assert.Equal(t, w.Code, 404)
+	assert.NotContains(t, w.Body.String(), "gin.go")
 }
 }
 
 
 func TestRouterMiddlewareAndStatic(t *testing.T) {
 func TestRouterMiddlewareAndStatic(t *testing.T) {
@@ -190,11 +208,11 @@ func TestRouterMiddlewareAndStatic(t *testing.T) {
 	})
 	})
 	static.Static("/", "./")
 	static.Static("/", "./")
 
 
-	w := performRequest(router, "GET", "/")
+	w := performRequest(router, "GET", "/gin.go")
 
 
 	assert.Equal(t, w.Code, 200)
 	assert.Equal(t, w.Code, 200)
-	assert.Contains(t, w.Body.String(), "gin.go")
-	assert.Equal(t, w.HeaderMap.Get("Content-Type"), "text/html; charset=utf-8")
+	assert.Contains(t, w.Body.String(), "package gin")
+	assert.Equal(t, w.HeaderMap.Get("Content-Type"), "text/plain; charset=utf-8")
 	assert.NotEqual(t, w.HeaderMap.Get("Last-Modified"), "Mon, 02 Jan 2006 15:04:05 MST")
 	assert.NotEqual(t, w.HeaderMap.Get("Last-Modified"), "Mon, 02 Jan 2006 15:04:05 MST")
 	assert.Equal(t, w.HeaderMap.Get("Expires"), "Mon, 02 Jan 2006 15:04:05 MST")
 	assert.Equal(t, w.HeaderMap.Get("Expires"), "Mon, 02 Jan 2006 15:04:05 MST")
 	assert.Equal(t, w.HeaderMap.Get("x-GIN"), "Gin Framework")
 	assert.Equal(t, w.HeaderMap.Get("x-GIN"), "Gin Framework")