فهرست منبع

Updated Origin check.
The host in the Origin header must match the host of the request by default.

Joachim Bauch 11 سال پیش
والد
کامیت
b03dcbad2a
4فایلهای تغییر یافته به همراه35 افزوده شده و 12 حذف شده
  1. 8 2
      examples/autobahn/server.go
  2. 1 4
      examples/chat/conn.go
  3. 1 4
      examples/filewatch/main.go
  4. 25 2
      server.go

+ 8 - 2
examples/autobahn/server.go

@@ -30,7 +30,10 @@ import (
 func echoCopy(w http.ResponseWriter, r *http.Request, writerOnly bool) {
 	u := websocket.Upgrader{
 		ReadBufferSize:  4096,
-		WriteBufferSize: 4096}
+		WriteBufferSize: 4096,
+		CheckOrigin: func(r *http.Request) bool {
+			return true
+		}}
 	conn, err := u.Upgrade(w, r, nil)
 	if err != nil {
 		log.Println("Upgrade:", err)
@@ -92,7 +95,10 @@ func echoCopyFull(w http.ResponseWriter, r *http.Request) {
 func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage bool) {
 	u := websocket.Upgrader{
 		ReadBufferSize:  4096,
-		WriteBufferSize: 4096}
+		WriteBufferSize: 4096,
+		CheckOrigin: func(r *http.Request) bool {
+			return true
+		}}
 	conn, err := u.Upgrade(w, r, nil)
 	if err != nil {
 		log.Println("Upgrade:", err)

+ 1 - 4
examples/chat/conn.go

@@ -91,10 +91,7 @@ func serveWs(w http.ResponseWriter, r *http.Request) {
 	}
 	u := websocket.Upgrader{
 		ReadBufferSize:  1024,
-		WriteBufferSize: 1024,
-		CheckOrigin: func(r *http.Request) bool {
-			return r.Header.Get("Origin") == "http://"+r.Host
-		}}
+		WriteBufferSize: 1024}
 	ws, err := u.Upgrade(w, r, nil)
 	if err != nil {
 		if _, ok := err.(websocket.HandshakeError); !ok {

+ 1 - 4
examples/filewatch/main.go

@@ -109,10 +109,7 @@ func writer(ws *websocket.Conn, lastMod time.Time) {
 func serveWs(w http.ResponseWriter, r *http.Request) {
 	u := websocket.Upgrader{
 		ReadBufferSize:  1024,
-		WriteBufferSize: 1024,
-		CheckOrigin: func(r *http.Request) bool {
-			return r.Header.Get("Origin") == "http://"+r.Host
-		}}
+		WriteBufferSize: 1024}
 	ws, err := u.Upgrade(w, r, nil)
 	if err != nil {
 		if _, ok := err.(websocket.HandshakeError); !ok {

+ 25 - 2
server.go

@@ -9,6 +9,7 @@ import (
 	"errors"
 	"net"
 	"net/http"
+	"net/url"
 	"strings"
 	"time"
 )
@@ -42,7 +43,8 @@ type Upgrader struct {
 	Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
 
 	// CheckOrigin returns true if the request Origin header is acceptable.
-	// If CheckOrigin is nil, then no origin check is done.
+	// If CheckOrigin is nil, the host in the Origin header must match
+	// the host of the request.
 	CheckOrigin func(r *http.Request) bool
 }
 
@@ -70,6 +72,19 @@ func (u *Upgrader) hasSubprotocol(subprotocol string) bool {
 	return false
 }
 
+// Check if host in Origin header matches host of request
+func (u *Upgrader) checkSameOrigin(r *http.Request) bool {
+	origin := r.Header.Get("Origin")
+	if origin == "" {
+		return false
+	}
+	uri, err := url.ParseRequestURI(origin)
+	if err != nil {
+		return false
+	}
+	return uri.Host == r.Host
+}
+
 // Upgrade upgrades the HTTP server connection to the WebSocket protocol.
 //
 // The responseHeader is included in the response to the client's upgrade
@@ -100,7 +115,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
 		return nil, err
 	}
 
-	if u.CheckOrigin != nil && !u.CheckOrigin(r) {
+	checkOrigin := u.CheckOrigin
+	if checkOrigin == nil {
+		checkOrigin = u.checkSameOrigin
+	}
+	if !checkOrigin(r) {
 		err := HandshakeError{"websocket: origin not allowed"}
 		u.returnError(w, r, http.StatusForbidden, err)
 		return nil, err
@@ -229,6 +248,10 @@ func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header,
 	u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) {
 		// don't return errors to maintain backwards compatibility
 	}
+	u.CheckOrigin = func(r *http.Request) bool {
+		// allow all connections by default
+		return true
+	}
 	return u.Upgrade(w, r, responseHeader)
 }