Переглянути джерело

Pass MaxMultipartMemory when FormFile is called (#1600)

When `gin.Context.FormFile("...")` is called the `engine.MaxMultipartMemory` is never used. This PR makes sure that the `MaxMultipartMemory` is passed and removes 2 calls to `http.Request.ParseForm` since they are called from `http.Request.ParseMultipartForm`
Ismail Gjevori 7 роки тому
батько
коміт
dbc330b804
2 змінених файлів з 18 додано та 2 видалено
  1. 5 2
      context.go
  2. 13 0
      context_test.go

+ 5 - 2
context.go

@@ -414,7 +414,6 @@ func (c *Context) PostFormArray(key string) []string {
 // a boolean value whether at least one value exists for the given key.
 func (c *Context) GetPostFormArray(key string) ([]string, bool) {
 	req := c.Request
-	req.ParseForm()
 	req.ParseMultipartForm(c.engine.MaxMultipartMemory)
 	if values := req.PostForm[key]; len(values) > 0 {
 		return values, true
@@ -437,7 +436,6 @@ func (c *Context) PostFormMap(key string) map[string]string {
 // whether at least one value exists for the given key.
 func (c *Context) GetPostFormMap(key string) (map[string]string, bool) {
 	req := c.Request
-	req.ParseForm()
 	req.ParseMultipartForm(c.engine.MaxMultipartMemory)
 	dicts, exist := c.get(req.PostForm, key)
 
@@ -465,6 +463,11 @@ func (c *Context) get(m map[string][]string, key string) (map[string]string, boo
 
 // FormFile returns the first file for the provided form key.
 func (c *Context) FormFile(name string) (*multipart.FileHeader, error) {
+	if c.Request.MultipartForm == nil {
+		if err := c.Request.ParseMultipartForm(c.engine.MaxMultipartMemory); err != nil {
+			return nil, err
+		}
+	}
 	_, fh, err := c.Request.FormFile(name)
 	return fh, err
 }

+ 13 - 0
context_test.go

@@ -84,6 +84,19 @@ func TestContextFormFile(t *testing.T) {
 	assert.NoError(t, c.SaveUploadedFile(f, "test"))
 }
 
+func TestContextFormFileFailed(t *testing.T) {
+	buf := new(bytes.Buffer)
+	mw := multipart.NewWriter(buf)
+	mw.Close()
+	c, _ := CreateTestContext(httptest.NewRecorder())
+	c.Request, _ = http.NewRequest("POST", "/", nil)
+	c.Request.Header.Set("Content-Type", mw.FormDataContentType())
+	c.engine.MaxMultipartMemory = 8 << 20
+	f, err := c.FormFile("file")
+	assert.Error(t, err)
+	assert.Nil(t, f)
+}
+
 func TestContextMultipartForm(t *testing.T) {
 	buf := new(bytes.Buffer)
 	mw := multipart.NewWriter(buf)