Explorar el Código

Updated Origin check.
The host in the Origin header must match the host of the request by default.

Joachim Bauch hace 11 años
padre
commit
b03dcbad2a
Se han modificado 4 ficheros con 35 adiciones y 12 borrados
  1. 8 2
      examples/autobahn/server.go
  2. 1 4
      examples/chat/conn.go
  3. 1 4
      examples/filewatch/main.go
  4. 25 2
      server.go

+ 8 - 2
examples/autobahn/server.go

@@ -30,7 +30,10 @@ import (
 func echoCopy(w http.ResponseWriter, r *http.Request, writerOnly bool) {
 	u := websocket.Upgrader{
 		ReadBufferSize:  4096,
-		WriteBufferSize: 4096}
+		WriteBufferSize: 4096,
+		CheckOrigin: func(r *http.Request) bool {
+			return true
+		}}
 	conn, err := u.Upgrade(w, r, nil)
 	if err != nil {
 		log.Println("Upgrade:", err)
@@ -92,7 +95,10 @@ func echoCopyFull(w http.ResponseWriter, r *http.Request) {
 func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage bool) {
 	u := websocket.Upgrader{
 		ReadBufferSize:  4096,
-		WriteBufferSize: 4096}
+		WriteBufferSize: 4096,
+		CheckOrigin: func(r *http.Request) bool {
+			return true
+		}}
 	conn, err := u.Upgrade(w, r, nil)
 	if err != nil {
 		log.Println("Upgrade:", err)

+ 1 - 4
examples/chat/conn.go

@@ -91,10 +91,7 @@ func serveWs(w http.ResponseWriter, r *http.Request) {
 	}
 	u := websocket.Upgrader{
 		ReadBufferSize:  1024,
-		WriteBufferSize: 1024,
-		CheckOrigin: func(r *http.Request) bool {
-			return r.Header.Get("Origin") == "http://"+r.Host
-		}}
+		WriteBufferSize: 1024}
 	ws, err := u.Upgrade(w, r, nil)
 	if err != nil {
 		if _, ok := err.(websocket.HandshakeError); !ok {

+ 1 - 4
examples/filewatch/main.go

@@ -109,10 +109,7 @@ func writer(ws *websocket.Conn, lastMod time.Time) {
 func serveWs(w http.ResponseWriter, r *http.Request) {
 	u := websocket.Upgrader{
 		ReadBufferSize:  1024,
-		WriteBufferSize: 1024,
-		CheckOrigin: func(r *http.Request) bool {
-			return r.Header.Get("Origin") == "http://"+r.Host
-		}}
+		WriteBufferSize: 1024}
 	ws, err := u.Upgrade(w, r, nil)
 	if err != nil {
 		if _, ok := err.(websocket.HandshakeError); !ok {

+ 25 - 2
server.go

@@ -9,6 +9,7 @@ import (
 	"errors"
 	"net"
 	"net/http"
+	"net/url"
 	"strings"
 	"time"
 )
@@ -42,7 +43,8 @@ type Upgrader struct {
 	Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
 
 	// CheckOrigin returns true if the request Origin header is acceptable.
-	// If CheckOrigin is nil, then no origin check is done.
+	// If CheckOrigin is nil, the host in the Origin header must match
+	// the host of the request.
 	CheckOrigin func(r *http.Request) bool
 }
 
@@ -70,6 +72,19 @@ func (u *Upgrader) hasSubprotocol(subprotocol string) bool {
 	return false
 }
 
+// Check if host in Origin header matches host of request
+func (u *Upgrader) checkSameOrigin(r *http.Request) bool {
+	origin := r.Header.Get("Origin")
+	if origin == "" {
+		return false
+	}
+	uri, err := url.ParseRequestURI(origin)
+	if err != nil {
+		return false
+	}
+	return uri.Host == r.Host
+}
+
 // Upgrade upgrades the HTTP server connection to the WebSocket protocol.
 //
 // The responseHeader is included in the response to the client's upgrade
@@ -100,7 +115,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
 		return nil, err
 	}
 
-	if u.CheckOrigin != nil && !u.CheckOrigin(r) {
+	checkOrigin := u.CheckOrigin
+	if checkOrigin == nil {
+		checkOrigin = u.checkSameOrigin
+	}
+	if !checkOrigin(r) {
 		err := HandshakeError{"websocket: origin not allowed"}
 		u.returnError(w, r, http.StatusForbidden, err)
 		return nil, err
@@ -229,6 +248,10 @@ func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header,
 	u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) {
 		// don't return errors to maintain backwards compatibility
 	}
+	u.CheckOrigin = func(r *http.Request) bool {
+		// allow all connections by default
+		return true
+	}
 	return u.Upgrade(w, r, responseHeader)
 }