Browse Source

Merge pull request #182 from dutchcoders/forwarded-for-fix

Fixed issue allowing to spoof ClientIP()
Javier Provecho Fernandez 11 years ago
parent
commit
0099840c98
2 changed files with 119 additions and 7 deletions
  1. 78 7
      context.go
  2. 41 0
      context_test.go

+ 78 - 7
context.go

@@ -12,7 +12,9 @@ import (
 	"github.com/gin-gonic/gin/render"
 	"github.com/julienschmidt/httprouter"
 	"log"
+	"net"
 	"net/http"
+	"strings"
 )
 
 const (
@@ -197,15 +199,84 @@ func (c *Context) MustGet(key string) interface{} {
 	return value
 }
 
-func (c *Context) ClientIP() string {
-	clientIP := c.Request.Header.Get("X-Real-IP")
-	if len(clientIP) == 0 {
-		clientIP = c.Request.Header.Get("X-Forwarded-For")
+func ipInMasks(ip net.IP, masks []interface{}) bool {
+	for _, proxy := range masks {
+		var mask *net.IPNet
+		var err error
+
+		switch t := proxy.(type) {
+		case string:
+			if _, mask, err = net.ParseCIDR(t); err != nil {
+				panic(err)
+			}
+		case net.IP:
+			mask = &net.IPNet{IP: t, Mask: net.CIDRMask(len(t)*8, len(t)*8)}
+		case net.IPNet:
+			mask = &t
+		}
+
+		if mask.Contains(ip) {
+			return true
+		}
 	}
-	if len(clientIP) == 0 {
-		clientIP = c.Request.RemoteAddr
+
+	return false
+}
+
+// the ForwardedFor middleware unwraps the X-Forwarded-For headers, be careful to only use this
+// middleware if you've got servers in front of this server. The list with (known) proxies and
+// local ips are being filtered out of the forwarded for list, giving the last not local ip being
+// the real client ip.
+func ForwardedFor(proxies ...interface{}) HandlerFunc {
+	if len(proxies) == 0 {
+		// default to local ips
+		var reservedLocalIps = []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16"}
+
+		proxies = make([]interface{}, len(reservedLocalIps))
+
+		for i, v := range reservedLocalIps {
+			proxies[i] = v
+		}
 	}
-	return clientIP
+
+	return func(c *Context) {
+		// the X-Forwarded-For header contains an array with left most the client ip, then
+		// comma separated, all proxies the request passed. The last proxy appears
+		// as the remote address of the request. Returning the client
+		// ip to comply with default RemoteAddr response.
+
+		// check if remoteaddr is local ip or in list of defined proxies
+		remoteIp := net.ParseIP(strings.Split(c.Request.RemoteAddr, ":")[0])
+
+		if !ipInMasks(remoteIp, proxies) {
+			return
+		}
+
+		if forwardedFor := c.Request.Header.Get("X-Forwarded-For"); forwardedFor != "" {
+			parts := strings.Split(forwardedFor, ",")
+
+			for i := len(parts) - 1; i >= 0; i-- {
+				part := parts[i]
+
+				ip := net.ParseIP(strings.TrimSpace(part))
+
+				if ipInMasks(ip, proxies) {
+					continue
+				}
+
+				// returning remote addr conform the original remote addr format
+				c.Request.RemoteAddr = ip.String() + ":0"
+
+				// remove forwarded for address
+				c.Request.Header.Set("X-Forwarded-For", "")
+				return
+			}
+		}
+	}
+}
+
+func (c *Context) ClientIP() string {
+	return c.Request.RemoteAddr
 }
 
 /************************************/

+ 41 - 0
context_test.go

@@ -440,3 +440,44 @@ func TestBindingJSONMalformed(t *testing.T) {
 		t.Errorf("Content-Type should not be application/json, was %s", w.HeaderMap.Get("Content-Type"))
 	}
 }
+
+func TestClientIP(t *testing.T) {
+	r := New()
+
+	var clientIP string = ""
+	r.GET("/", func(c *Context) {
+		clientIP = c.ClientIP()
+	})
+
+	body := bytes.NewBuffer([]byte(""))
+	req, _ := http.NewRequest("GET", "/", body)
+	req.RemoteAddr = "clientip:1234"
+	w := httptest.NewRecorder()
+	r.ServeHTTP(w, req)
+
+	if clientIP != "clientip:1234" {
+		t.Errorf("ClientIP should not be %s, but clientip:1234", clientIP)
+	}
+}
+
+func TestClientIPWithXForwardedForWithProxy(t *testing.T) {
+	r := New()
+	r.Use(ForwardedFor())
+
+	var clientIP string = ""
+	r.GET("/", func(c *Context) {
+		clientIP = c.ClientIP()
+	})
+
+	body := bytes.NewBuffer([]byte(""))
+	req, _ := http.NewRequest("GET", "/", body)
+	req.RemoteAddr = "172.16.8.3:1234"
+	req.Header.Set("X-Real-Ip", "realip")
+	req.Header.Set("X-Forwarded-For", "1.2.3.4, 10.10.0.4, 192.168.0.43, 172.16.8.4")
+	w := httptest.NewRecorder()
+	r.ServeHTTP(w, req)
+
+	if clientIP != "1.2.3.4:0" {
+		t.Errorf("ClientIP should not be %s, but 1.2.3.4:0", clientIP)
+	}
+}