Procházet zdrojové kódy

Require GET in Upgrader.Upgrade.

Return error if the request method is not GET.

Remove all request method tests from the examples.
Gary Burd před 10 roky
rodič
revize
567453a710
5 změnil soubory, kde provedl 23 přidání a 18 odebrání
  1. 20 5
      client_server_test.go
  2. 0 4
      examples/chat/conn.go
  3. 0 5
      examples/command/main.go
  4. 0 4
      examples/echo/server.go
  5. 3 0
      server.go

+ 20 - 5
client_server_test.go

@@ -56,11 +56,6 @@ func newTLSServer(t *testing.T) *cstServer {
 }
 
 func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
-	if r.Method != "GET" {
-		t.Logf("method %s not allowed", r.Method)
-		http.Error(w, "method not allowed", 405)
-		return
-	}
 	subprotos := Subprotocols(r)
 	if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) {
 		t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols)
@@ -287,6 +282,26 @@ func TestDialBadHeader(t *testing.T) {
 	}
 }
 
+func TestBadMethod(t *testing.T) {
+	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		ws, err := cstUpgrader.Upgrade(w, r, nil)
+		if err == nil {
+			t.Errorf("handshake succeeded, expect fail")
+			ws.Close()
+		}
+	}))
+	defer s.Close()
+
+	resp, err := http.PostForm(s.URL, url.Values{})
+	if err != nil {
+		t.Fatalf("PostForm returned error %v", err)
+	}
+	resp.Body.Close()
+	if resp.StatusCode != http.StatusMethodNotAllowed {
+		t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed)
+	}
+}
+
 func TestHandshake(t *testing.T) {
 	s := newServer(t)
 	defer s.Close()

+ 0 - 4
examples/chat/conn.go

@@ -90,10 +90,6 @@ func (c *connection) writePump() {
 
 // serveWs handles websocket requests from the peer.
 func serveWs(w http.ResponseWriter, r *http.Request) {
-	if r.Method != "GET" {
-		http.Error(w, "Method not allowed", 405)
-		return
-	}
 	ws, err := upgrader.Upgrade(w, r, nil)
 	if err != nil {
 		log.Println(err)

+ 0 - 5
examples/command/main.go

@@ -95,11 +95,6 @@ func internalError(ws *websocket.Conn, msg string, err error) {
 var upgrader = websocket.Upgrader{}
 
 func serveWs(w http.ResponseWriter, r *http.Request) {
-	if r.Method != "GET" {
-		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
-		return
-	}
-
 	ws, err := upgrader.Upgrade(w, r, nil)
 	if err != nil {
 		log.Println("upgrade:", err)

+ 0 - 4
examples/echo/server.go

@@ -23,10 +23,6 @@ func echo(w http.ResponseWriter, r *http.Request) {
 		http.Error(w, "Not found", 404)
 		return
 	}
-	if r.Method != "GET" {
-		http.Error(w, "Method not allowed", 405)
-		return
-	}
 	c, err := upgrader.Upgrade(w, r, nil)
 	if err != nil {
 		log.Print("upgrade:", err)

+ 3 - 0
server.go

@@ -93,6 +93,9 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
 // request. Use the responseHeader to specify cookies (Set-Cookie) and the
 // application negotiated subprotocol (Sec-Websocket-Protocol).
 func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
+	if r.Method != "GET" {
+		return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET")
+	}
 	if values := r.Header["Sec-Websocket-Version"]; len(values) == 0 || values[0] != "13" {
 		return u.returnError(w, r, http.StatusBadRequest, "websocket: version != 13")
 	}