Browse Source

Require GET in Upgrader.Upgrade.

Return error if the request method is not GET.

Remove all request method tests from the examples.
Gary Burd 10 years ago
parent
commit
567453a710
5 changed files with 23 additions and 18 deletions
  1. 20 5
      client_server_test.go
  2. 0 4
      examples/chat/conn.go
  3. 0 5
      examples/command/main.go
  4. 0 4
      examples/echo/server.go
  5. 3 0
      server.go

+ 20 - 5
client_server_test.go

@@ -56,11 +56,6 @@ func newTLSServer(t *testing.T) *cstServer {
 }
 
 func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	if r.Method != "GET" {
-		t.Logf("method %s not allowed", r.Method)
-		http.Error(w, "method not allowed", 405)
-		return
-	}
 	subprotos := Subprotocols(r)
 	if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) {
 		t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols)
@@ -287,6 +282,26 @@ func TestDialBadHeader(t *testing.T) {
 	}
 }
 
+func TestBadMethod(t *testing.T) {
+	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		ws, err := cstUpgrader.Upgrade(w, r, nil)
+		if err == nil {
+			t.Errorf("handshake succeeded, expect fail")
+			ws.Close()
+		}
+	}))
+	defer s.Close()
+
+	resp, err := http.PostForm(s.URL, url.Values{})
+	if err != nil {
+		t.Fatalf("PostForm returned error %v", err)
+	}
+	resp.Body.Close()
+	if resp.StatusCode != http.StatusMethodNotAllowed {
+		t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed)
+	}
+}
+
 func TestHandshake(t *testing.T) {
 	s := newServer(t)
 	defer s.Close()

+ 0 - 4
examples/chat/conn.go

@@ -90,10 +90,6 @@ func (c *connection) writePump() {
 
 // serveWs handles websocket requests from the peer.
 func serveWs(w http.ResponseWriter, r *http.Request) {
-	if r.Method != "GET" {
-		http.Error(w, "Method not allowed", 405)
-		return
-	}
 	ws, err := upgrader.Upgrade(w, r, nil)
 	if err != nil {
 		log.Println(err)

+ 0 - 5
examples/command/main.go

@@ -95,11 +95,6 @@ func internalError(ws *websocket.Conn, msg string, err error) {
 var upgrader = websocket.Upgrader{}
 
 func serveWs(w http.ResponseWriter, r *http.Request) {
-	if r.Method != "GET" {
-		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
-		return
-	}
-
 	ws, err := upgrader.Upgrade(w, r, nil)
 	if err != nil {
 		log.Println("upgrade:", err)

+ 0 - 4
examples/echo/server.go

@@ -23,10 +23,6 @@ func echo(w http.ResponseWriter, r *http.Request) {
 		http.Error(w, "Not found", 404)
 		return
 	}
-	if r.Method != "GET" {
-		http.Error(w, "Method not allowed", 405)
-		return
-	}
 	c, err := upgrader.Upgrade(w, r, nil)
 	if err != nil {
 		log.Print("upgrade:", err)

+ 3 - 0
server.go

@@ -93,6 +93,9 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
 // request. Use the responseHeader to specify cookies (Set-Cookie) and the
 // application negotiated subprotocol (Sec-Websocket-Protocol).
 func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
+	if r.Method != "GET" {
+		return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET")
+	}
 	if values := r.Header["Sec-Websocket-Version"]; len(values) == 0 || values[0] != "13" {
 		return u.returnError(w, r, http.StatusBadRequest, "websocket: version != 13")
 	}