Pārlūkot izejas kodu

General refactoring. Part 2.

Manu Mtz-Almeida 11 gadi atpakaļ
vecāks
revīzija
aa7b00a083
7 mainītis faili ar 161 papildinājumiem un 145 dzēšanām
  1. 52 50
      auth.go
  2. 11 0
      context.go
  3. 22 20
      gin.go
  4. 55 55
      logger.go
  5. 14 7
      mode.go
  6. 4 3
      routergroup.go
  7. 3 10
      utils.go

+ 52 - 50
auth.go

@@ -16,77 +16,79 @@ const (
 )
 
 type (
-	BasicAuthPair struct {
-		Code string
-		User string
-	}
 	Accounts map[string]string
-	Pairs    []BasicAuthPair
+	authPair struct {
+		Value string
+		User  string
+	}
+	authPairs []authPair
 )
 
-func (a Pairs) Len() int           { return len(a) }
-func (a Pairs) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }
-func (a Pairs) Less(i, j int) bool { return a[i].Code < a[j].Code }
+func (a authPairs) Len() int           { return len(a) }
+func (a authPairs) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }
+func (a authPairs) Less(i, j int) bool { return a[i].Value < a[j].Value }
+
+// Implements a basic Basic HTTP Authorization. It takes as argument a map[string]string where
+// the key is the user name and the value is the password.
+func BasicAuth(accounts Accounts) HandlerFunc {
+	pairs, err := processAccounts(accounts)
+	if err != nil {
+		panic(err)
+	}
+	return func(c *Context) {
+		// Search user in the slice of allowed credentials
+		user, ok := searchCredential(pairs, c.Request.Header.Get("Authorization"))
+		if !ok {
+			// Credentials doesn't match, we return 401 Unauthorized and abort request.
+			c.Writer.Header().Set("WWW-Authenticate", "Basic realm=\"Authorization Required\"")
+			c.Fail(401, errors.New("Unauthorized"))
+		} else {
+			// user is allowed, set UserId to key "user" in this context, the userId can be read later using
+			// c.Get(gin.AuthUserKey)
+			c.Set(AuthUserKey, user)
+		}
+	}
+}
 
-func processCredentials(accounts Accounts) (Pairs, error) {
+func processAccounts(accounts Accounts) (authPairs, error) {
 	if len(accounts) == 0 {
-		return nil, errors.New("Empty list of authorized credentials.")
+		return nil, errors.New("Empty list of authorized credentials")
 	}
-	pairs := make(Pairs, 0, len(accounts))
+	pairs := make(authPairs, 0, len(accounts))
 	for user, password := range accounts {
-		if len(user) == 0 || len(password) == 0 {
-			return nil, errors.New("User or password is empty")
+		if len(user) == 0 {
+			return nil, errors.New("User can not be empty")
 		}
 		base := user + ":" + password
-		code := "Basic " + base64.StdEncoding.EncodeToString([]byte(base))
-		pairs = append(pairs, BasicAuthPair{code, user})
+		value := "Basic " + base64.StdEncoding.EncodeToString([]byte(base))
+		pairs = append(pairs, authPair{
+			Value: value,
+			User:  user,
+		})
 	}
 	// We have to sort the credentials in order to use bsearch later.
 	sort.Sort(pairs)
 	return pairs, nil
 }
 
-func secureCompare(given, actual string) bool {
-	if subtle.ConstantTimeEq(int32(len(given)), int32(len(actual))) == 1 {
-		return subtle.ConstantTimeCompare([]byte(given), []byte(actual)) == 1
-	} else {
-		/* Securely compare actual to itself to keep constant time, but always return false */
-		return subtle.ConstantTimeCompare([]byte(actual), []byte(actual)) == 1 && false
-	}
-}
-
-func searchCredential(pairs Pairs, auth string) string {
+func searchCredential(pairs authPairs, auth string) (string, bool) {
 	if len(auth) == 0 {
-		return ""
+		return "", false
 	}
 	// Search user in the slice of allowed credentials
-	r := sort.Search(len(pairs), func(i int) bool { return pairs[i].Code >= auth })
-	if r < len(pairs) && secureCompare(pairs[r].Code, auth) {
-		return pairs[r].User
+	r := sort.Search(len(pairs), func(i int) bool { return pairs[i].Value >= auth })
+	if r < len(pairs) && secureCompare(pairs[r].Value, auth) {
+		return pairs[r].User, true
 	} else {
-		return ""
+		return "", false
 	}
 }
 
-// Implements a basic Basic HTTP Authorization. It takes as argument a map[string]string where
-// the key is the user name and the value is the password.
-func BasicAuth(accounts Accounts) HandlerFunc {
-
-	pairs, err := processCredentials(accounts)
-	if err != nil {
-		panic(err)
-	}
-	return func(c *Context) {
-		// Search user in the slice of allowed credentials
-		user := searchCredential(pairs, c.Request.Header.Get("Authorization"))
-		if len(user) == 0 {
-			// Credentials doesn't match, we return 401 Unauthorized and abort request.
-			c.Writer.Header().Set("WWW-Authenticate", "Basic realm=\"Authorization Required\"")
-			c.Fail(401, errors.New("Unauthorized"))
-		} else {
-			// user is allowed, set UserId to key "user" in this context, the userId can be read later using
-			// c.Get(gin.AuthUserKey)
-			c.Set(AuthUserKey, user)
-		}
+func secureCompare(given, actual string) bool {
+	if subtle.ConstantTimeEq(int32(len(given)), int32(len(actual))) == 1 {
+		return subtle.ConstantTimeCompare([]byte(given), []byte(actual)) == 1
+	} else {
+		/* Securely compare actual to itself to keep constant time, but always return false */
+		return subtle.ConstantTimeCompare([]byte(actual), []byte(actual)) == 1 && false
 	}
 }

+ 11 - 0
context.go

@@ -197,6 +197,17 @@ func (c *Context) MustGet(key string) interface{} {
 	return value
 }
 
+func (c *Context) ClientIP() string {
+	clientIP := c.Request.Header.Get("X-Real-IP")
+	if len(clientIP) == 0 {
+		clientIP = c.Request.Header.Get("X-Forwarded-For")
+	}
+	if len(clientIP) == 0 {
+		clientIP = c.Request.RemoteAddr
+	}
+	return clientIP
+}
+
 /************************************/
 /********* PARSING REQUEST **********/
 /************************************/

+ 22 - 20
gin.go

@@ -29,29 +29,15 @@ type (
 	// Represents the web framework, it wraps the blazing fast httprouter multiplexer and a list of global middlewares.
 	Engine struct {
 		*RouterGroup
-		HTMLRender render.Render
-		pool       sync.Pool
-		allNoRoute []HandlerFunc
-		noRoute    []HandlerFunc
-		router     *httprouter.Router
+		HTMLRender     render.Render
+		Default404Body []byte
+		pool           sync.Pool
+		allNoRoute     []HandlerFunc
+		noRoute        []HandlerFunc
+		router         *httprouter.Router
 	}
 )
 
-func (engine *Engine) handle404(w http.ResponseWriter, req *http.Request) {
-	c := engine.createContext(w, req, nil, engine.allNoRoute)
-	// set 404 by default, useful for logging
-	c.Writer.WriteHeader(404)
-	c.Next()
-	if !c.Writer.Written() {
-		if c.Writer.Status() == 404 {
-			c.Data(-1, MIMEPlain, []byte("404 page not found"))
-		} else {
-			c.Writer.WriteHeaderNow()
-		}
-	}
-	engine.reuseContext(c)
-}
-
 // Returns a new blank Engine instance without any middleware attached.
 // The most basic configuration
 func New() *Engine {
@@ -62,6 +48,7 @@ func New() *Engine {
 		engine:       engine,
 	}
 	engine.router = httprouter.New()
+	engine.Default404Body = []byte("404 page not found")
 	engine.router.NotFound = engine.handle404
 	engine.pool.New = func() interface{} {
 		c := &Context{Engine: engine}
@@ -119,6 +106,21 @@ func (engine *Engine) rebuild404Handlers() {
 	engine.allNoRoute = engine.combineHandlers(engine.noRoute)
 }
 
+func (engine *Engine) handle404(w http.ResponseWriter, req *http.Request) {
+	c := engine.createContext(w, req, nil, engine.allNoRoute)
+	// set 404 by default, useful for logging
+	c.Writer.WriteHeader(404)
+	c.Next()
+	if !c.Writer.Written() {
+		if c.Writer.Status() == 404 {
+			c.Data(-1, MIMEPlain, engine.Default404Body)
+		} else {
+			c.Writer.WriteHeaderNow()
+		}
+	}
+	engine.reuseContext(c)
+}
+
 // ServeHTTP makes the router implement the http.Handler interface.
 func (engine *Engine) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
 	engine.router.ServeHTTP(writer, request)

+ 55 - 55
logger.go

@@ -10,6 +10,17 @@ import (
 	"time"
 )
 
+var (
+	green   = string([]byte{27, 91, 57, 55, 59, 52, 50, 109})
+	white   = string([]byte{27, 91, 57, 48, 59, 52, 55, 109})
+	yellow  = string([]byte{27, 91, 57, 55, 59, 52, 51, 109})
+	red     = string([]byte{27, 91, 57, 55, 59, 52, 49, 109})
+	blue    = string([]byte{27, 91, 57, 55, 59, 52, 52, 109})
+	magenta = string([]byte{27, 91, 57, 55, 59, 52, 53, 109})
+	cyan    = string([]byte{27, 91, 57, 55, 59, 52, 54, 109})
+	reset   = string([]byte{27, 91, 48, 109})
+)
+
 func ErrorLogger() HandlerFunc {
 	return ErrorLoggerT(ErrorTypeAll)
 }
@@ -26,17 +37,6 @@ func ErrorLoggerT(typ uint32) HandlerFunc {
 	}
 }
 
-var (
-	green   = string([]byte{27, 91, 57, 55, 59, 52, 50, 109})
-	white   = string([]byte{27, 91, 57, 48, 59, 52, 55, 109})
-	yellow  = string([]byte{27, 91, 57, 55, 59, 52, 51, 109})
-	red     = string([]byte{27, 91, 57, 55, 59, 52, 49, 109})
-	blue    = string([]byte{27, 91, 57, 55, 59, 52, 52, 109})
-	magenta = string([]byte{27, 91, 57, 55, 59, 52, 53, 109})
-	cyan    = string([]byte{27, 91, 57, 55, 59, 52, 54, 109})
-	reset   = string([]byte{27, 91, 48, 109})
-)
-
 func Logger() HandlerFunc {
 	stdlogger := log.New(os.Stdout, "", 0)
 	//errlogger := log.New(os.Stderr, "", 0)
@@ -48,58 +48,58 @@ func Logger() HandlerFunc {
 		// Process request
 		c.Next()
 
-		// save the IP of the requester
-		requester := c.Request.Header.Get("X-Real-IP")
-		// if the requester-header is empty, check the forwarded-header
-		if len(requester) == 0 {
-			requester = c.Request.Header.Get("X-Forwarded-For")
-		}
-		// if the requester is still empty, use the hard-coded address from the socket
-		if len(requester) == 0 {
-			requester = c.Request.RemoteAddr
-		}
-
-		var color string
-		code := c.Writer.Status()
-		switch {
-		case code >= 200 && code <= 299:
-			color = green
-		case code >= 300 && code <= 399:
-			color = white
-		case code >= 400 && code <= 499:
-			color = yellow
-		default:
-			color = red
-		}
-
-		var methodColor string
-		method := c.Request.Method
-		switch {
-		case method == "GET":
-			methodColor = blue
-		case method == "POST":
-			methodColor = cyan
-		case method == "PUT":
-			methodColor = yellow
-		case method == "DELETE":
-			methodColor = red
-		case method == "PATCH":
-			methodColor = green
-		case method == "HEAD":
-			methodColor = magenta
-		case method == "OPTIONS":
-			methodColor = white
-		}
+		// Stop timer
 		end := time.Now()
 		latency := end.Sub(start)
+
+		clientIP := c.ClientIP()
+		method := c.Request.Method
+		statusCode := c.Writer.Status()
+		statusColor := colorForStatus(statusCode)
+		methodColor := colorForMethod(method)
+
 		stdlogger.Printf("[GIN] %v |%s %3d %s| %12v | %s |%s  %s %-7s %s\n%s",
 			end.Format("2006/01/02 - 15:04:05"),
-			color, code, reset,
+			statusColor, statusCode, reset,
 			latency,
-			requester,
+			clientIP,
 			methodColor, reset, method,
 			c.Request.URL.Path,
 			c.Errors.String(),
 		)
 	}
 }
+
+func colorForStatus(code int) string {
+	switch {
+	case code >= 200 && code <= 299:
+		return green
+	case code >= 300 && code <= 399:
+		return white
+	case code >= 400 && code <= 499:
+		return yellow
+	default:
+		return red
+	}
+}
+
+func colorForMethod(method string) string {
+	switch {
+	case method == "GET":
+		return blue
+	case method == "POST":
+		return cyan
+	case method == "PUT":
+		return yellow
+	case method == "DELETE":
+		return red
+	case method == "PATCH":
+		return green
+	case method == "HEAD":
+		return magenta
+	case method == "OPTIONS":
+		return white
+	default:
+		return reset
+	}
+}

+ 14 - 7
mode.go

@@ -5,6 +5,7 @@
 package gin
 
 import (
+	"fmt"
 	"os"
 )
 
@@ -24,6 +25,15 @@ const (
 var gin_mode int = debugCode
 var mode_name string = DebugMode
 
+func init() {
+	value := os.Getenv(GIN_MODE)
+	if len(value) == 0 {
+		SetMode(DebugMode)
+	} else {
+		SetMode(value)
+	}
+}
+
 func SetMode(value string) {
 	switch value {
 	case DebugMode:
@@ -33,7 +43,7 @@ func SetMode(value string) {
 	case TestMode:
 		gin_mode = testCode
 	default:
-		panic("gin mode unknown, the allowed modes are: " + DebugMode + " and " + ReleaseMode)
+		panic("gin mode unknown: " + value)
 	}
 	mode_name = value
 }
@@ -46,11 +56,8 @@ func IsDebugging() bool {
 	return gin_mode == debugCode
 }
 
-func init() {
-	value := os.Getenv(GIN_MODE)
-	if len(value) == 0 {
-		SetMode(DebugMode)
-	} else {
-		SetMode(value)
+func debugPrint(format string, values ...interface{}) {
+	if IsDebugging() {
+		fmt.Printf("[GIN-debug] "+format, values)
 	}
 }

+ 4 - 3
routergroup.go

@@ -48,7 +48,7 @@ func (group *RouterGroup) Handle(httpMethod, relativePath string, handlers []Han
 	handlers = group.combineHandlers(handlers)
 	if IsDebugging() {
 		nuHandlers := len(handlers)
-		handlerName := nameOfFuncion(handlers[nuHandlers-1])
+		handlerName := nameOfFunction(handlers[nuHandlers-1])
 		debugPrint("%-5s %-25s --> %s (%d handlers)\n", httpMethod, absolutePath, handlerName, nuHandlers)
 	}
 
@@ -105,6 +105,8 @@ func (group *RouterGroup) Static(relativePath, root string) {
 	absolutePath := group.calculateAbsolutePath(relativePath)
 	handler := group.createStaticHandler(absolutePath, root)
 	absolutePath = path.Join(absolutePath, "/*filepath")
+
+	// Register GET and HEAD handlers
 	group.GET(absolutePath, handler)
 	group.HEAD(absolutePath, handler)
 }
@@ -120,8 +122,7 @@ func (group *RouterGroup) combineHandlers(handlers []HandlerFunc) []HandlerFunc
 	finalSize := len(group.Handlers) + len(handlers)
 	mergedHandlers := make([]HandlerFunc, 0, finalSize)
 	mergedHandlers = append(mergedHandlers, group.Handlers...)
-	mergedHandlers = append(mergedHandlers, handlers...)
-	return mergedHandlers
+	return append(mergedHandlers, handlers...)
 }
 
 func (group *RouterGroup) calculateAbsolutePath(relativePath string) string {

+ 3 - 10
utils.go

@@ -6,7 +6,6 @@ package gin
 
 import (
 	"encoding/xml"
-	"fmt"
 	"reflect"
 	"runtime"
 	"strings"
@@ -39,20 +38,14 @@ func (h H) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
 }
 
 func filterFlags(content string) string {
-	for i, a := range content {
-		if a == ' ' || a == ';' {
+	for i, char := range content {
+		if char == ' ' || char == ';' {
 			return content[:i]
 		}
 	}
 	return content
 }
 
-func debugPrint(format string, values ...interface{}) {
-	if IsDebugging() {
-		fmt.Printf("[GIN-debug] "+format, values)
-	}
-}
-
 func chooseData(custom, wildcard interface{}) interface{} {
 	if custom == nil {
 		if wildcard == nil {
@@ -84,6 +77,6 @@ func lastChar(str string) uint8 {
 	return str[size-1]
 }
 
-func nameOfFuncion(f interface{}) string {
+func nameOfFunction(f interface{}) string {
 	return runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name()
 }