Ver Fonte

Require GET in Upgrader.Upgrade.

Return error if the request method is not GET.

Remove all request method tests from the examples.
Gary Burd há 10 anos atrás
pai
commit
567453a710
5 ficheiros alterados com 23 adições e 18 exclusões
  1. 20 5
      client_server_test.go
  2. 0 4
      examples/chat/conn.go
  3. 0 5
      examples/command/main.go
  4. 0 4
      examples/echo/server.go
  5. 3 0
      server.go

+ 20 - 5
client_server_test.go

@@ -56,11 +56,6 @@ func newTLSServer(t *testing.T) *cstServer {
 }
 
 func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	if r.Method != "GET" {
-		t.Logf("method %s not allowed", r.Method)
-		http.Error(w, "method not allowed", 405)
-		return
-	}
 	subprotos := Subprotocols(r)
 	if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) {
 		t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols)
@@ -287,6 +282,26 @@ func TestDialBadHeader(t *testing.T) {
 	}
 }
 
+func TestBadMethod(t *testing.T) {
+	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		ws, err := cstUpgrader.Upgrade(w, r, nil)
+		if err == nil {
+			t.Errorf("handshake succeeded, expect fail")
+			ws.Close()
+		}
+	}))
+	defer s.Close()
+
+	resp, err := http.PostForm(s.URL, url.Values{})
+	if err != nil {
+		t.Fatalf("PostForm returned error %v", err)
+	}
+	resp.Body.Close()
+	if resp.StatusCode != http.StatusMethodNotAllowed {
+		t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed)
+	}
+}
+
 func TestHandshake(t *testing.T) {
 	s := newServer(t)
 	defer s.Close()

+ 0 - 4
examples/chat/conn.go

@@ -90,10 +90,6 @@ func (c *connection) writePump() {
 
 // serveWs handles websocket requests from the peer.
 func serveWs(w http.ResponseWriter, r *http.Request) {
-	if r.Method != "GET" {
-		http.Error(w, "Method not allowed", 405)
-		return
-	}
 	ws, err := upgrader.Upgrade(w, r, nil)
 	if err != nil {
 		log.Println(err)

+ 0 - 5
examples/command/main.go

@@ -95,11 +95,6 @@ func internalError(ws *websocket.Conn, msg string, err error) {
 var upgrader = websocket.Upgrader{}
 
 func serveWs(w http.ResponseWriter, r *http.Request) {
-	if r.Method != "GET" {
-		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
-		return
-	}
-
 	ws, err := upgrader.Upgrade(w, r, nil)
 	if err != nil {
 		log.Println("upgrade:", err)

+ 0 - 4
examples/echo/server.go

@@ -23,10 +23,6 @@ func echo(w http.ResponseWriter, r *http.Request) {
 		http.Error(w, "Not found", 404)
 		return
 	}
-	if r.Method != "GET" {
-		http.Error(w, "Method not allowed", 405)
-		return
-	}
 	c, err := upgrader.Upgrade(w, r, nil)
 	if err != nil {
 		log.Print("upgrade:", err)

+ 3 - 0
server.go

@@ -93,6 +93,9 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
 // request. Use the responseHeader to specify cookies (Set-Cookie) and the
 // application negotiated subprotocol (Sec-Websocket-Protocol).
 func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
+	if r.Method != "GET" {
+		return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET")
+	}
 	if values := r.Header["Sec-Websocket-Version"]; len(values) == 0 || values[0] != "13" {
 		return u.returnError(w, r, http.StatusBadRequest, "websocket: version != 13")
 	}