Browse Source

Updated Origin check.
The host in the Origin header must match the host of the request by default.

Joachim Bauch 11 năm trước cách đây
mục cha
commit
b03dcbad2a
4 tập tin đã thay đổi với 35 bổ sung12 xóa
  1. 8 2
      examples/autobahn/server.go
  2. 1 4
      examples/chat/conn.go
  3. 1 4
      examples/filewatch/main.go
  4. 25 2
      server.go

+ 8 - 2
examples/autobahn/server.go

@@ -30,7 +30,10 @@ import (
 func echoCopy(w http.ResponseWriter, r *http.Request, writerOnly bool) {
 	u := websocket.Upgrader{
 		ReadBufferSize:  4096,
-		WriteBufferSize: 4096}
+		WriteBufferSize: 4096,
+		CheckOrigin: func(r *http.Request) bool {
+			return true
+		}}
 	conn, err := u.Upgrade(w, r, nil)
 	if err != nil {
 		log.Println("Upgrade:", err)
@@ -92,7 +95,10 @@ func echoCopyFull(w http.ResponseWriter, r *http.Request) {
 func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage bool) {
 	u := websocket.Upgrader{
 		ReadBufferSize:  4096,
-		WriteBufferSize: 4096}
+		WriteBufferSize: 4096,
+		CheckOrigin: func(r *http.Request) bool {
+			return true
+		}}
 	conn, err := u.Upgrade(w, r, nil)
 	if err != nil {
 		log.Println("Upgrade:", err)

+ 1 - 4
examples/chat/conn.go

@@ -91,10 +91,7 @@ func serveWs(w http.ResponseWriter, r *http.Request) {
 	}
 	u := websocket.Upgrader{
 		ReadBufferSize:  1024,
-		WriteBufferSize: 1024,
-		CheckOrigin: func(r *http.Request) bool {
-			return r.Header.Get("Origin") == "http://"+r.Host
-		}}
+		WriteBufferSize: 1024}
 	ws, err := u.Upgrade(w, r, nil)
 	if err != nil {
 		if _, ok := err.(websocket.HandshakeError); !ok {

+ 1 - 4
examples/filewatch/main.go

@@ -109,10 +109,7 @@ func writer(ws *websocket.Conn, lastMod time.Time) {
 func serveWs(w http.ResponseWriter, r *http.Request) {
 	u := websocket.Upgrader{
 		ReadBufferSize:  1024,
-		WriteBufferSize: 1024,
-		CheckOrigin: func(r *http.Request) bool {
-			return r.Header.Get("Origin") == "http://"+r.Host
-		}}
+		WriteBufferSize: 1024}
 	ws, err := u.Upgrade(w, r, nil)
 	if err != nil {
 		if _, ok := err.(websocket.HandshakeError); !ok {

+ 25 - 2
server.go

@@ -9,6 +9,7 @@ import (
 	"errors"
 	"net"
 	"net/http"
+	"net/url"
 	"strings"
 	"time"
 )
@@ -42,7 +43,8 @@ type Upgrader struct {
 	Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
 
 	// CheckOrigin returns true if the request Origin header is acceptable.
-	// If CheckOrigin is nil, then no origin check is done.
+	// If CheckOrigin is nil, the host in the Origin header must match
+	// the host of the request.
 	CheckOrigin func(r *http.Request) bool
 }
 
@@ -70,6 +72,19 @@ func (u *Upgrader) hasSubprotocol(subprotocol string) bool {
 	return false
 }
 
+// Check if host in Origin header matches host of request
+func (u *Upgrader) checkSameOrigin(r *http.Request) bool {
+	origin := r.Header.Get("Origin")
+	if origin == "" {
+		return false
+	}
+	uri, err := url.ParseRequestURI(origin)
+	if err != nil {
+		return false
+	}
+	return uri.Host == r.Host
+}
+
 // Upgrade upgrades the HTTP server connection to the WebSocket protocol.
 //
 // The responseHeader is included in the response to the client's upgrade
@@ -100,7 +115,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
 		return nil, err
 	}
 
-	if u.CheckOrigin != nil && !u.CheckOrigin(r) {
+	checkOrigin := u.CheckOrigin
+	if checkOrigin == nil {
+		checkOrigin = u.checkSameOrigin
+	}
+	if !checkOrigin(r) {
 		err := HandshakeError{"websocket: origin not allowed"}
 		u.returnError(w, r, http.StatusForbidden, err)
 		return nil, err
@@ -229,6 +248,10 @@ func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header,
 	u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) {
 		// don't return errors to maintain backwards compatibility
 	}
+	u.CheckOrigin = func(r *http.Request) bool {
+		// allow all connections by default
+		return true
+	}
 	return u.Upgrade(w, r, responseHeader)
 }