Browse Source

Refactors binding module

Manu Mtz-Almeida 10 years ago
parent
commit
d4413b6e91
10 changed files with 367 additions and 299 deletions
  1. 30 269
      binding/binding.go
  2. 143 0
      binding/form_mapping.go
  3. 23 0
      binding/get_form.go
  4. 26 0
      binding/json.go
  5. 23 0
      binding/post_form.go
  6. 79 0
      binding/validate.go
  7. 25 0
      binding/xml.go
  8. 5 19
      context.go
  9. 10 0
      deprecated.go
  10. 3 11
      gin.go

+ 30 - 269
binding/binding.go

@@ -4,282 +4,43 @@
 
 package binding
 
-import (
-	"encoding/json"
-	"encoding/xml"
-	"errors"
-	"log"
-	"net/http"
-	"reflect"
-	"strconv"
-	"strings"
+import "net/http"
+
+const (
+	MIMEJSON              = "application/json"
+	MIMEHTML              = "text/html"
+	MIMEXML               = "application/xml"
+	MIMEXML2              = "text/xml"
+	MIMEPlain             = "text/plain"
+	MIMEPOSTForm          = "application/x-www-form-urlencoded"
+	MIMEMultipartPOSTForm = "multipart/form-data"
 )
 
-type (
-	Binding interface {
-		Bind(*http.Request, interface{}) error
-	}
-
-	// JSON binding
-	jsonBinding struct{}
-
-	// XML binding
-	xmlBinding struct{}
-
-	// form binding
-	formBinding struct{}
-
-	// multipart form binding
-	multipartFormBinding struct{}
-)
-
-const MAX_MEMORY = 1 * 1024 * 1024
+type Binding interface {
+	Name() string
+	Bind(*http.Request, interface{}) error
+}
 
 var (
-	JSON          = jsonBinding{}
-	XML           = xmlBinding{}
-	Form          = formBinding{} // todo
-	MultipartForm = multipartFormBinding{}
+	JSON     = jsonBinding{}
+	XML      = xmlBinding{}
+	GETForm  = getFormBinding{}
+	POSTForm = postFormBinding{}
 )
 
-func (_ jsonBinding) Bind(req *http.Request, obj interface{}) error {
-	decoder := json.NewDecoder(req.Body)
-	if err := decoder.Decode(obj); err == nil {
-		return Validate(obj)
+func Default(method, contentType string) Binding {
+	if method == "GET" {
+		return GETForm
 	} else {
-		return err
-	}
-}
-
-func (_ xmlBinding) Bind(req *http.Request, obj interface{}) error {
-	decoder := xml.NewDecoder(req.Body)
-	if err := decoder.Decode(obj); err == nil {
-		return Validate(obj)
-	} else {
-		return err
-	}
-}
-
-func (_ formBinding) Bind(req *http.Request, obj interface{}) error {
-	if err := req.ParseForm(); err != nil {
-		return err
-	}
-	if err := mapForm(obj, req.Form); err != nil {
-		return err
-	}
-	return Validate(obj)
-}
-
-func (_ multipartFormBinding) Bind(req *http.Request, obj interface{}) error {
-	if err := req.ParseMultipartForm(MAX_MEMORY); err != nil {
-		return err
-	}
-	if err := mapForm(obj, req.Form); err != nil {
-		return err
-	}
-	return Validate(obj)
-}
-
-func mapForm(ptr interface{}, form map[string][]string) error {
-	typ := reflect.TypeOf(ptr).Elem()
-	formStruct := reflect.ValueOf(ptr).Elem()
-	for i := 0; i < typ.NumField(); i++ {
-		typeField := typ.Field(i)
-		if inputFieldName := typeField.Tag.Get("form"); inputFieldName != "" {
-			structField := formStruct.Field(i)
-			if !structField.CanSet() {
-				continue
-			}
-
-			inputValue, exists := form[inputFieldName]
-			if !exists {
-				continue
-			}
-			numElems := len(inputValue)
-			if structField.Kind() == reflect.Slice && numElems > 0 {
-				sliceOf := structField.Type().Elem().Kind()
-				slice := reflect.MakeSlice(structField.Type(), numElems, numElems)
-				for i := 0; i < numElems; i++ {
-					if err := setWithProperType(sliceOf, inputValue[i], slice.Index(i)); err != nil {
-						return err
-					}
-				}
-				formStruct.Field(i).Set(slice)
-			} else {
-				if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil {
-					return err
-				}
-			}
-		}
-	}
-	return nil
-}
-
-func setIntField(val string, bitSize int, structField reflect.Value) error {
-	if val == "" {
-		val = "0"
-	}
-
-	intVal, err := strconv.ParseInt(val, 10, bitSize)
-	if err == nil {
-		structField.SetInt(intVal)
-	}
-
-	return err
-}
-
-func setUintField(val string, bitSize int, structField reflect.Value) error {
-	if val == "" {
-		val = "0"
-	}
-
-	uintVal, err := strconv.ParseUint(val, 10, bitSize)
-	if err == nil {
-		structField.SetUint(uintVal)
-	}
-
-	return err
-}
-
-func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error {
-	switch valueKind {
-	case reflect.Int:
-		return setIntField(val, 0, structField)
-	case reflect.Int8:
-		return setIntField(val, 8, structField)
-	case reflect.Int16:
-		return setIntField(val, 16, structField)
-	case reflect.Int32:
-		return setIntField(val, 32, structField)
-	case reflect.Int64:
-		return setIntField(val, 64, structField)
-	case reflect.Uint:
-		return setUintField(val, 0, structField)
-	case reflect.Uint8:
-		return setUintField(val, 8, structField)
-	case reflect.Uint16:
-		return setUintField(val, 16, structField)
-	case reflect.Uint32:
-		return setUintField(val, 32, structField)
-	case reflect.Uint64:
-		return setUintField(val, 64, structField)
-	case reflect.Bool:
-		if val == "" {
-			val = "false"
-		}
-		boolVal, err := strconv.ParseBool(val)
-		if err != nil {
-			return err
-		} else {
-			structField.SetBool(boolVal)
-		}
-	case reflect.Float32:
-		if val == "" {
-			val = "0.0"
-		}
-		floatVal, err := strconv.ParseFloat(val, 32)
-		if err != nil {
-			return err
-		} else {
-			structField.SetFloat(floatVal)
-		}
-	case reflect.Float64:
-		if val == "" {
-			val = "0.0"
-		}
-		floatVal, err := strconv.ParseFloat(val, 64)
-		if err != nil {
-			return err
-		} else {
-			structField.SetFloat(floatVal)
-		}
-	case reflect.String:
-		structField.SetString(val)
-	}
-	return nil
-}
-
-// Don't pass in pointers to bind to. Can lead to bugs. See:
-// https://github.com/codegangsta/martini-contrib/issues/40
-// https://github.com/codegangsta/martini-contrib/pull/34#issuecomment-29683659
-func ensureNotPointer(obj interface{}) {
-	if reflect.TypeOf(obj).Kind() == reflect.Ptr {
-		log.Panic("Pointers are not accepted as binding models")
-	}
-}
-
-func Validate(obj interface{}, parents ...string) error {
-	typ := reflect.TypeOf(obj)
-	val := reflect.ValueOf(obj)
-
-	if typ.Kind() == reflect.Ptr {
-		typ = typ.Elem()
-		val = val.Elem()
-	}
-
-	switch typ.Kind() {
-	case reflect.Struct:
-		for i := 0; i < typ.NumField(); i++ {
-			field := typ.Field(i)
-
-			// Allow ignored and unexported fields in the struct
-			if len(field.PkgPath) > 0 || field.Tag.Get("form") == "-" {
-				continue
-			}
-
-			fieldValue := val.Field(i).Interface()
-			zero := reflect.Zero(field.Type).Interface()
-
-			if strings.Index(field.Tag.Get("binding"), "required") > -1 {
-				fieldType := field.Type.Kind()
-				if fieldType == reflect.Struct {
-					if reflect.DeepEqual(zero, fieldValue) {
-						return errors.New("Required " + field.Name)
-					}
-					err := Validate(fieldValue, field.Name)
-					if err != nil {
-						return err
-					}
-				} else if reflect.DeepEqual(zero, fieldValue) {
-					if len(parents) > 0 {
-						return errors.New("Required " + field.Name + " on " + parents[0])
-					} else {
-						return errors.New("Required " + field.Name)
-					}
-				} else if fieldType == reflect.Slice && field.Type.Elem().Kind() == reflect.Struct {
-					err := Validate(fieldValue)
-					if err != nil {
-						return err
-					}
-				}
-			} else {
-				fieldType := field.Type.Kind()
-				if fieldType == reflect.Struct {
-					if reflect.DeepEqual(zero, fieldValue) {
-						continue
-					}
-					err := Validate(fieldValue, field.Name)
-					if err != nil {
-						return err
-					}
-				} else if fieldType == reflect.Slice && field.Type.Elem().Kind() == reflect.Struct {
-					err := Validate(fieldValue, field.Name)
-					if err != nil {
-						return err
-					}
-				}
-			}
-		}
-	case reflect.Slice:
-		for i := 0; i < val.Len(); i++ {
-			fieldValue := val.Index(i).Interface()
-			err := Validate(fieldValue)
-			if err != nil {
-				return err
-			}
+		switch contentType {
+		case MIMEPOSTForm:
+			return POSTForm
+		case MIMEJSON:
+			return JSON
+		case MIMEXML, MIMEXML2:
+			return XML
+		default:
+			return GETForm
 		}
-	default:
-		return nil
 	}
-	return nil
 }

+ 143 - 0
binding/form_mapping.go

@@ -0,0 +1,143 @@
+// Copyright 2014 Manu Martinez-Almeida.  All rights reserved.
+// Use of this source code is governed by a MIT style
+// license that can be found in the LICENSE file.
+
+package binding
+
+import (
+	"errors"
+	"fmt"
+	"log"
+	"reflect"
+	"strconv"
+)
+
+func mapForm(ptr interface{}, form map[string][]string) error {
+	typ := reflect.TypeOf(ptr).Elem()
+	val := reflect.ValueOf(ptr).Elem()
+	for i := 0; i < typ.NumField(); i++ {
+		typeField := typ.Field(i)
+		structField := val.Field(i)
+		if !structField.CanSet() {
+			continue
+		}
+
+		inputFieldName := typeField.Tag.Get("form")
+		if inputFieldName == "" {
+			inputFieldName = typeField.Name
+		}
+		inputValue, exists := form[inputFieldName]
+		fmt.Println("Field: "+inputFieldName+" Value: ", inputValue)
+
+		if !exists {
+			continue
+		}
+
+		numElems := len(inputValue)
+		if structField.Kind() == reflect.Slice && numElems > 0 {
+			sliceOf := structField.Type().Elem().Kind()
+			slice := reflect.MakeSlice(structField.Type(), numElems, numElems)
+			for i := 0; i < numElems; i++ {
+				if err := setWithProperType(sliceOf, inputValue[i], slice.Index(i)); err != nil {
+					return err
+				}
+			}
+			val.Field(i).Set(slice)
+		} else {
+			if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil {
+				return err
+			}
+		}
+
+	}
+	return nil
+}
+
+func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error {
+	switch valueKind {
+	case reflect.Int:
+		return setIntField(val, 0, structField)
+	case reflect.Int8:
+		return setIntField(val, 8, structField)
+	case reflect.Int16:
+		return setIntField(val, 16, structField)
+	case reflect.Int32:
+		return setIntField(val, 32, structField)
+	case reflect.Int64:
+		return setIntField(val, 64, structField)
+	case reflect.Uint:
+		return setUintField(val, 0, structField)
+	case reflect.Uint8:
+		return setUintField(val, 8, structField)
+	case reflect.Uint16:
+		return setUintField(val, 16, structField)
+	case reflect.Uint32:
+		return setUintField(val, 32, structField)
+	case reflect.Uint64:
+		return setUintField(val, 64, structField)
+	case reflect.Bool:
+		return setBoolField(val, structField)
+	case reflect.Float32:
+		return setFloatField(val, 32, structField)
+	case reflect.Float64:
+		return setFloatField(val, 64, structField)
+	case reflect.String:
+		structField.SetString(val)
+	default:
+		return errors.New("Unknown type")
+	}
+	return nil
+}
+
+func setIntField(val string, bitSize int, field reflect.Value) error {
+	if val == "" {
+		val = "0"
+	}
+	intVal, err := strconv.ParseInt(val, 10, bitSize)
+	if err == nil {
+		field.SetInt(intVal)
+	}
+	return err
+}
+
+func setUintField(val string, bitSize int, field reflect.Value) error {
+	if val == "" {
+		val = "0"
+	}
+	uintVal, err := strconv.ParseUint(val, 10, bitSize)
+	if err == nil {
+		field.SetUint(uintVal)
+	}
+	return err
+}
+
+func setBoolField(val string, field reflect.Value) error {
+	if val == "" {
+		val = "false"
+	}
+	boolVal, err := strconv.ParseBool(val)
+	if err == nil {
+		field.SetBool(boolVal)
+	}
+	return nil
+}
+
+func setFloatField(val string, bitSize int, field reflect.Value) error {
+	if val == "" {
+		val = "0.0"
+	}
+	floatVal, err := strconv.ParseFloat(val, bitSize)
+	if err == nil {
+		field.SetFloat(floatVal)
+	}
+	return err
+}
+
+// Don't pass in pointers to bind to. Can lead to bugs. See:
+// https://github.com/codegangsta/martini-contrib/issues/40
+// https://github.com/codegangsta/martini-contrib/pull/34#issuecomment-29683659
+func ensureNotPointer(obj interface{}) {
+	if reflect.TypeOf(obj).Kind() == reflect.Ptr {
+		log.Panic("Pointers are not accepted as binding models")
+	}
+}

+ 23 - 0
binding/get_form.go

@@ -0,0 +1,23 @@
+// Copyright 2014 Manu Martinez-Almeida.  All rights reserved.
+// Use of this source code is governed by a MIT style
+// license that can be found in the LICENSE file.
+
+package binding
+
+import "net/http"
+
+type getFormBinding struct{}
+
+func (_ getFormBinding) Name() string {
+	return "get_form"
+}
+
+func (_ getFormBinding) Bind(req *http.Request, obj interface{}) error {
+	if err := req.ParseForm(); err != nil {
+		return err
+	}
+	if err := mapForm(obj, req.Form); err != nil {
+		return err
+	}
+	return Validate(obj)
+}

+ 26 - 0
binding/json.go

@@ -0,0 +1,26 @@
+// Copyright 2014 Manu Martinez-Almeida.  All rights reserved.
+// Use of this source code is governed by a MIT style
+// license that can be found in the LICENSE file.
+
+package binding
+
+import (
+	"encoding/json"
+
+	"net/http"
+)
+
+type jsonBinding struct{}
+
+func (_ jsonBinding) Name() string {
+	return "json"
+}
+
+func (_ jsonBinding) Bind(req *http.Request, obj interface{}) error {
+	decoder := json.NewDecoder(req.Body)
+	if err := decoder.Decode(obj); err == nil {
+		return Validate(obj)
+	} else {
+		return err
+	}
+}

+ 23 - 0
binding/post_form.go

@@ -0,0 +1,23 @@
+// Copyright 2014 Manu Martinez-Almeida.  All rights reserved.
+// Use of this source code is governed by a MIT style
+// license that can be found in the LICENSE file.
+
+package binding
+
+import "net/http"
+
+type postFormBinding struct{}
+
+func (_ postFormBinding) Name() string {
+	return "post_form"
+}
+
+func (_ postFormBinding) Bind(req *http.Request, obj interface{}) error {
+	if err := req.ParseForm(); err != nil {
+		return err
+	}
+	if err := mapForm(obj, req.PostForm); err != nil {
+		return err
+	}
+	return Validate(obj)
+}

+ 79 - 0
binding/validate.go

@@ -0,0 +1,79 @@
+// Copyright 2014 Manu Martinez-Almeida.  All rights reserved.
+// Use of this source code is governed by a MIT style
+// license that can be found in the LICENSE file.
+
+package binding
+
+import (
+	"errors"
+	"reflect"
+	"strings"
+)
+
+func Validate(obj interface{}) error {
+	return validate(obj, "{{ROOT}}")
+}
+
+func validate(obj interface{}, parent string) error {
+	typ, val := inspectObject(obj)
+	switch typ.Kind() {
+	case reflect.Struct:
+		return validateStruct(typ, val, parent)
+
+	case reflect.Slice:
+		return validateSlice(typ, val, parent)
+
+	default:
+		return errors.New("The object is not a slice or struct.")
+	}
+}
+
+func inspectObject(obj interface{}) (typ reflect.Type, val reflect.Value) {
+	typ = reflect.TypeOf(obj)
+	val = reflect.ValueOf(obj)
+	if typ.Kind() == reflect.Ptr {
+		typ = typ.Elem()
+		val = val.Elem()
+	}
+	return
+}
+
+func validateSlice(typ reflect.Type, val reflect.Value, parent string) error {
+	if typ.Elem().Kind() == reflect.Struct {
+		for i := 0; i < val.Len(); i++ {
+			itemValue := val.Index(i).Interface()
+			if err := validate(itemValue, parent); err != nil {
+				return err
+			}
+		}
+	}
+	return nil
+}
+
+func validateStruct(typ reflect.Type, val reflect.Value, parent string) error {
+	for i := 0; i < typ.NumField(); i++ {
+		field := typ.Field(i)
+		// Allow ignored and unexported fields in the struct
+		// TODO should include  || field.Tag.Get("form") == "-"
+		if len(field.PkgPath) > 0 {
+			continue
+		}
+
+		fieldValue := val.Field(i).Interface()
+		requiredField := strings.Index(field.Tag.Get("binding"), "required") > -1
+
+		if requiredField {
+			zero := reflect.Zero(field.Type).Interface()
+			if reflect.DeepEqual(zero, fieldValue) {
+				return errors.New("Required " + field.Name + " in " + parent)
+			}
+		}
+		fieldType := field.Type.Kind()
+		if fieldType == reflect.Struct || fieldType == reflect.Slice {
+			if err := validate(fieldValue, field.Name); err != nil {
+				return err
+			}
+		}
+	}
+	return nil
+}

+ 25 - 0
binding/xml.go

@@ -0,0 +1,25 @@
+// Copyright 2014 Manu Martinez-Almeida.  All rights reserved.
+// Use of this source code is governed by a MIT style
+// license that can be found in the LICENSE file.
+
+package binding
+
+import (
+	"encoding/xml"
+	"net/http"
+)
+
+type xmlBinding struct{}
+
+func (_ xmlBinding) Name() string {
+	return "xml"
+}
+
+func (_ xmlBinding) Bind(req *http.Request, obj interface{}) error {
+	decoder := xml.NewDecoder(req.Body)
+	if err := decoder.Decode(obj); err == nil {
+		return Validate(obj)
+	} else {
+		return err
+	}
+}

+ 5 - 19
context.go

@@ -179,21 +179,7 @@ func (c *Context) ContentType() string {
 // else --> returns an error
 // if Parses the request's body as JSON if Content-Type == "application/json"  using JSON or XML  as a JSON input. It decodes the json payload into the struct specified as a pointer.Like ParseBody() but this method also writes a 400 error if the json is not valid.
 func (c *Context) Bind(obj interface{}) bool {
-	var b binding.Binding
-	ctype := filterFlags(c.Request.Header.Get("Content-Type"))
-	switch {
-	case c.Request.Method == "GET" || ctype == MIMEPOSTForm:
-		b = binding.Form
-	case ctype == MIMEMultipartPOSTForm:
-		b = binding.MultipartForm
-	case ctype == MIMEJSON:
-		b = binding.JSON
-	case ctype == MIMEXML || ctype == MIMEXML2:
-		b = binding.XML
-	default:
-		c.Fail(400, errors.New("unknown content-type: "+ctype))
-		return false
-	}
+	b := binding.Default(c.Request.Method, c.ContentType())
 	return c.BindWith(obj, b)
 }
 
@@ -283,18 +269,18 @@ type Negotiate struct {
 
 func (c *Context) Negotiate(code int, config Negotiate) {
 	switch c.NegotiateFormat(config.Offered...) {
-	case MIMEJSON:
+	case binding.MIMEJSON:
 		data := chooseData(config.JSONData, config.Data)
 		c.JSON(code, data)
 
-	case MIMEHTML:
-		data := chooseData(config.HTMLData, config.Data)
+	case binding.MIMEHTML:
 		if len(config.HTMLPath) == 0 {
 			log.Panic("negotiate config is wrong. html path is needed")
 		}
+		data := chooseData(config.HTMLData, config.Data)
 		c.HTML(code, config.HTMLPath, data)
 
-	case MIMEXML:
+	case binding.MIMEXML:
 		data := chooseData(config.XMLData, config.Data)
 		c.XML(code, data)
 

+ 10 - 0
deprecated.go

@@ -13,6 +13,16 @@ import (
 	"github.com/gin-gonic/gin/binding"
 )
 
+const (
+	MIMEJSON              = binding.MIMEJSON
+	MIMEHTML              = binding.MIMEHTML
+	MIMEXML               = binding.MIMEXML
+	MIMEXML2              = binding.MIMEXML2
+	MIMEPlain             = binding.MIMEPlain
+	MIMEPOSTForm          = binding.MIMEPOSTForm
+	MIMEMultipartPOSTForm = binding.MIMEMultipartPOSTForm
+)
+
 // DEPRECATED, use Bind() instead.
 // Like ParseBody() but this method also writes a 400 error if the json is not valid.
 func (c *Context) EnsureBody(item interface{}) bool {

+ 3 - 11
gin.go

@@ -9,19 +9,11 @@ import (
 	"net/http"
 	"sync"
 
+	"github.com/gin-gonic/gin/binding"
 	"github.com/gin-gonic/gin/render"
 	"github.com/julienschmidt/httprouter"
 )
 
-const (
-	MIMEJSON              = "application/json"
-	MIMEHTML              = "text/html"
-	MIMEXML               = "application/xml"
-	MIMEXML2              = "text/xml"
-	MIMEPlain             = "text/plain"
-	MIMEPOSTForm          = "application/x-www-form-urlencoded"
-	MIMEMultipartPOSTForm = "multipart/form-data"
-)
 
 type (
 	HandlerFunc func(*Context)
@@ -147,7 +139,7 @@ func (engine *Engine) handle404(w http.ResponseWriter, req *http.Request) {
 	c.Next()
 	if !c.Writer.Written() {
 		if c.Writer.Status() == 404 {
-			c.Data(-1, MIMEPlain, engine.Default404Body)
+			c.Data(-1, binding.MIMEPlain, engine.Default404Body)
 		} else {
 			c.Writer.WriteHeaderNow()
 		}
@@ -162,7 +154,7 @@ func (engine *Engine) handle405(w http.ResponseWriter, req *http.Request) {
 	c.Next()
 	if !c.Writer.Written() {
 		if c.Writer.Status() == 405 {
-			c.Data(-1, MIMEPlain, engine.Default405Body)
+			c.Data(-1, binding.MIMEPlain, engine.Default405Body)
 		} else {
 			c.Writer.WriteHeaderNow()
 		}