Kaynağa Gözat

Do not allow duplicate headers in handshake

Gary Burd 10 yıl önce
ebeveyn
işleme
a4e0143e1f
2 değiştirilmiş dosya ile 28 ekleme ve 2 silme
  1. 9 2
      client.go
  2. 19 0
      client_server_test.go

+ 9 - 2
client.go

@@ -197,11 +197,18 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
 		req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
 	}
 	for k, vs := range requestHeader {
-		if k == "Host" {
+		switch {
+		case k == "Host":
 			if len(vs) > 0 {
 				req.Host = vs[0]
 			}
-		} else {
+		case k == "Upgrade" ||
+			k == "Connection" ||
+			k == "Sec-Websocket-Key" ||
+			k == "Sec-Websocket-Version" ||
+			(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
+			return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
+		default:
 			req.Header[k] = vs
 		}
 	}

+ 19 - 0
client_server_test.go

@@ -268,6 +268,25 @@ func TestDialBadOrigin(t *testing.T) {
 	}
 }
 
+func TestDialBadHeader(t *testing.T) {
+	s := newServer(t)
+	defer s.Close()
+
+	for _, k := range []string{"Upgrade",
+		"Connection",
+		"Sec-Websocket-Key",
+		"Sec-Websocket-Version",
+		"Sec-Websocket-Protocol"} {
+		h := http.Header{}
+		h.Set(k, "bad")
+		ws, _, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
+		if err == nil {
+			ws.Close()
+			t.Errorf("Dial with header %s returned nil", k)
+		}
+	}
+}
+
 func TestHandshake(t *testing.T) {
 	s := newServer(t)
 	defer s.Close()