瀏覽代碼

Updated Origin check.
The host in the Origin header must match the host of the request by default.

Joachim Bauch 11 年之前
父節點
當前提交
b03dcbad2a
共有 4 個文件被更改,包括 35 次插入12 次删除
  1. 8 2
      examples/autobahn/server.go
  2. 1 4
      examples/chat/conn.go
  3. 1 4
      examples/filewatch/main.go
  4. 25 2
      server.go

+ 8 - 2
examples/autobahn/server.go

@@ -30,7 +30,10 @@ import (
 func echoCopy(w http.ResponseWriter, r *http.Request, writerOnly bool) {
 	u := websocket.Upgrader{
 		ReadBufferSize:  4096,
-		WriteBufferSize: 4096}
+		WriteBufferSize: 4096,
+		CheckOrigin: func(r *http.Request) bool {
+			return true
+		}}
 	conn, err := u.Upgrade(w, r, nil)
 	if err != nil {
 		log.Println("Upgrade:", err)
@@ -92,7 +95,10 @@ func echoCopyFull(w http.ResponseWriter, r *http.Request) {
 func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage bool) {
 	u := websocket.Upgrader{
 		ReadBufferSize:  4096,
-		WriteBufferSize: 4096}
+		WriteBufferSize: 4096,
+		CheckOrigin: func(r *http.Request) bool {
+			return true
+		}}
 	conn, err := u.Upgrade(w, r, nil)
 	if err != nil {
 		log.Println("Upgrade:", err)

+ 1 - 4
examples/chat/conn.go

@@ -91,10 +91,7 @@ func serveWs(w http.ResponseWriter, r *http.Request) {
 	}
 	u := websocket.Upgrader{
 		ReadBufferSize:  1024,
-		WriteBufferSize: 1024,
-		CheckOrigin: func(r *http.Request) bool {
-			return r.Header.Get("Origin") == "http://"+r.Host
-		}}
+		WriteBufferSize: 1024}
 	ws, err := u.Upgrade(w, r, nil)
 	if err != nil {
 		if _, ok := err.(websocket.HandshakeError); !ok {

+ 1 - 4
examples/filewatch/main.go

@@ -109,10 +109,7 @@ func writer(ws *websocket.Conn, lastMod time.Time) {
 func serveWs(w http.ResponseWriter, r *http.Request) {
 	u := websocket.Upgrader{
 		ReadBufferSize:  1024,
-		WriteBufferSize: 1024,
-		CheckOrigin: func(r *http.Request) bool {
-			return r.Header.Get("Origin") == "http://"+r.Host
-		}}
+		WriteBufferSize: 1024}
 	ws, err := u.Upgrade(w, r, nil)
 	if err != nil {
 		if _, ok := err.(websocket.HandshakeError); !ok {

+ 25 - 2
server.go

@@ -9,6 +9,7 @@ import (
 	"errors"
 	"net"
 	"net/http"
+	"net/url"
 	"strings"
 	"time"
 )
@@ -42,7 +43,8 @@ type Upgrader struct {
 	Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
 
 	// CheckOrigin returns true if the request Origin header is acceptable.
-	// If CheckOrigin is nil, then no origin check is done.
+	// If CheckOrigin is nil, the host in the Origin header must match
+	// the host of the request.
 	CheckOrigin func(r *http.Request) bool
 }
 
@@ -70,6 +72,19 @@ func (u *Upgrader) hasSubprotocol(subprotocol string) bool {
 	return false
 }
 
+// Check if host in Origin header matches host of request
+func (u *Upgrader) checkSameOrigin(r *http.Request) bool {
+	origin := r.Header.Get("Origin")
+	if origin == "" {
+		return false
+	}
+	uri, err := url.ParseRequestURI(origin)
+	if err != nil {
+		return false
+	}
+	return uri.Host == r.Host
+}
+
 // Upgrade upgrades the HTTP server connection to the WebSocket protocol.
 //
 // The responseHeader is included in the response to the client's upgrade
@@ -100,7 +115,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
 		return nil, err
 	}
 
-	if u.CheckOrigin != nil && !u.CheckOrigin(r) {
+	checkOrigin := u.CheckOrigin
+	if checkOrigin == nil {
+		checkOrigin = u.checkSameOrigin
+	}
+	if !checkOrigin(r) {
 		err := HandshakeError{"websocket: origin not allowed"}
 		u.returnError(w, r, http.StatusForbidden, err)
 		return nil, err
@@ -229,6 +248,10 @@ func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header,
 	u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) {
 		// don't return errors to maintain backwards compatibility
 	}
+	u.CheckOrigin = func(r *http.Request) bool {
+		// allow all connections by default
+		return true
+	}
 	return u.Upgrade(w, r, responseHeader)
 }