Prechádzať zdrojové kódy

Require GET in Upgrader.Upgrade.

Return error if the request method is not GET.

Remove all request method tests from the examples.
Gary Burd 10 rokov pred
rodič
commit
567453a710

+ 20 - 5
client_server_test.go

@@ -56,11 +56,6 @@ func newTLSServer(t *testing.T) *cstServer {
 }
 
 func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	if r.Method != "GET" {
-		t.Logf("method %s not allowed", r.Method)
-		http.Error(w, "method not allowed", 405)
-		return
-	}
 	subprotos := Subprotocols(r)
 	if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) {
 		t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols)
@@ -287,6 +282,26 @@ func TestDialBadHeader(t *testing.T) {
 	}
 }
 
+func TestBadMethod(t *testing.T) {
+	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		ws, err := cstUpgrader.Upgrade(w, r, nil)
+		if err == nil {
+			t.Errorf("handshake succeeded, expect fail")
+			ws.Close()
+		}
+	}))
+	defer s.Close()
+
+	resp, err := http.PostForm(s.URL, url.Values{})
+	if err != nil {
+		t.Fatalf("PostForm returned error %v", err)
+	}
+	resp.Body.Close()
+	if resp.StatusCode != http.StatusMethodNotAllowed {
+		t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed)
+	}
+}
+
 func TestHandshake(t *testing.T) {
 	s := newServer(t)
 	defer s.Close()

+ 0 - 4
examples/chat/conn.go

@@ -90,10 +90,6 @@ func (c *connection) writePump() {
 
 // serveWs handles websocket requests from the peer.
 func serveWs(w http.ResponseWriter, r *http.Request) {
-	if r.Method != "GET" {
-		http.Error(w, "Method not allowed", 405)
-		return
-	}
 	ws, err := upgrader.Upgrade(w, r, nil)
 	if err != nil {
 		log.Println(err)

+ 0 - 5
examples/command/main.go

@@ -95,11 +95,6 @@ func internalError(ws *websocket.Conn, msg string, err error) {
 var upgrader = websocket.Upgrader{}
 
 func serveWs(w http.ResponseWriter, r *http.Request) {
-	if r.Method != "GET" {
-		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
-		return
-	}
-
 	ws, err := upgrader.Upgrade(w, r, nil)
 	if err != nil {
 		log.Println("upgrade:", err)

+ 0 - 4
examples/echo/server.go

@@ -23,10 +23,6 @@ func echo(w http.ResponseWriter, r *http.Request) {
 		http.Error(w, "Not found", 404)
 		return
 	}
-	if r.Method != "GET" {
-		http.Error(w, "Method not allowed", 405)
-		return
-	}
 	c, err := upgrader.Upgrade(w, r, nil)
 	if err != nil {
 		log.Print("upgrade:", err)

+ 3 - 0
server.go

@@ -93,6 +93,9 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
 // request. Use the responseHeader to specify cookies (Set-Cookie) and the
 // application negotiated subprotocol (Sec-Websocket-Protocol).
 func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
+	if r.Method != "GET" {
+		return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET")
+	}
 	if values := r.Header["Sec-Websocket-Version"]; len(values) == 0 || values[0] != "13" {
 		return u.returnError(w, r, http.StatusBadRequest, "websocket: version != 13")
 	}