Browse Source

websocket: handle solicited and unsolicited Ping/Pong frames correctly

This change prevents Read from failing with io.EOF, ErrNotImplemented on
exchanging control frames such as ping and pong.

Fixes golang/go#6377.
Fixes golang/go#7825.
Fixes golang/go#10156.

Change-Id: I600cf493de3671d7e3d11e2e12d32f43928b7bfc
Reviewed-on: https://go-review.googlesource.com/13054
Reviewed-by: Andrew Gerrand <adg@golang.org>
Mikio Hara 10 years ago
parent
commit
b963d2882a
3 changed files with 112 additions and 13 deletions
  1. 9 10
      websocket/hybi.go
  2. 9 2
      websocket/hybi_test.go
  3. 94 1
      websocket/websocket_test.go

+ 9 - 10
websocket/hybi.go

@@ -267,7 +267,7 @@ type hybiFrameHandler struct {
 	payloadType byte
 }
 
-func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (r frameReader, err error) {
+func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, error) {
 	if handler.conn.IsServerConn() {
 		// The client MUST mask all frames sent to the server.
 		if frame.(*hybiFrameReader).header.MaskingKey == nil {
@@ -291,20 +291,19 @@ func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (r frameReader,
 		handler.payloadType = frame.PayloadType()
 	case CloseFrame:
 		return nil, io.EOF
-	case PingFrame:
-		pingMsg := make([]byte, maxControlFramePayloadLength)
-		n, err := io.ReadFull(frame, pingMsg)
-		if err != nil && err != io.ErrUnexpectedEOF {
+	case PingFrame, PongFrame:
+		b := make([]byte, maxControlFramePayloadLength)
+		n, err := io.ReadFull(frame, b)
+		if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
 			return nil, err
 		}
 		io.Copy(ioutil.Discard, frame)
-		n, err = handler.WritePong(pingMsg[:n])
-		if err != nil {
-			return nil, err
+		if frame.PayloadType() == PingFrame {
+			if _, err := handler.WritePong(b[:n]); err != nil {
+				return nil, err
+			}
 		}
 		return nil, nil
-	case PongFrame:
-		return nil, ErrNotImplemented
 	}
 	return frame, nil
 }

+ 9 - 2
websocket/hybi_test.go

@@ -326,7 +326,7 @@ func testHybiFrame(t *testing.T, testHeader, testPayload, testMaskedPayload []by
 	}
 	payload := make([]byte, len(testPayload))
 	_, err = r.Read(payload)
-	if err != nil {
+	if err != nil && err != io.EOF {
 		t.Errorf("read %v", err)
 	}
 	if !bytes.Equal(testPayload, payload) {
@@ -363,13 +363,20 @@ func TestHybiShortBinaryFrame(t *testing.T) {
 }
 
 func TestHybiControlFrame(t *testing.T) {
-	frameHeader := &hybiFrameHeader{Fin: true, OpCode: PingFrame}
 	payload := []byte("hello")
+
+	frameHeader := &hybiFrameHeader{Fin: true, OpCode: PingFrame}
 	testHybiFrame(t, []byte{0x89, 0x05}, payload, payload, frameHeader)
 
+	frameHeader = &hybiFrameHeader{Fin: true, OpCode: PingFrame}
+	testHybiFrame(t, []byte{0x89, 0x00}, nil, nil, frameHeader)
+
 	frameHeader = &hybiFrameHeader{Fin: true, OpCode: PongFrame}
 	testHybiFrame(t, []byte{0x8A, 0x05}, payload, payload, frameHeader)
 
+	frameHeader = &hybiFrameHeader{Fin: true, OpCode: PongFrame}
+	testHybiFrame(t, []byte{0x8A, 0x00}, nil, nil, frameHeader)
+
 	frameHeader = &hybiFrameHeader{Fin: true, OpCode: CloseFrame}
 	payload = []byte{0x03, 0xe8} // 1000
 	testHybiFrame(t, []byte{0x88, 0x02}, payload, payload, frameHeader)

+ 94 - 1
websocket/websocket_test.go

@@ -24,7 +24,10 @@ import (
 var serverAddr string
 var once sync.Once
 
-func echoServer(ws *Conn) { io.Copy(ws, ws) }
+func echoServer(ws *Conn) {
+	defer ws.Close()
+	io.Copy(ws, ws)
+}
 
 type Count struct {
 	S string
@@ -32,6 +35,7 @@ type Count struct {
 }
 
 func countServer(ws *Conn) {
+	defer ws.Close()
 	for {
 		var count Count
 		err := JSON.Receive(ws, &count)
@@ -47,6 +51,55 @@ func countServer(ws *Conn) {
 	}
 }
 
+type testCtrlAndDataHandler struct {
+	hybiFrameHandler
+}
+
+func (h *testCtrlAndDataHandler) WritePing(b []byte) (int, error) {
+	h.hybiFrameHandler.conn.wio.Lock()
+	defer h.hybiFrameHandler.conn.wio.Unlock()
+	w, err := h.hybiFrameHandler.conn.frameWriterFactory.NewFrameWriter(PingFrame)
+	if err != nil {
+		return 0, err
+	}
+	n, err := w.Write(b)
+	w.Close()
+	return n, err
+}
+
+func ctrlAndDataServer(ws *Conn) {
+	defer ws.Close()
+	h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
+	ws.frameHandler = h
+
+	go func() {
+		for i := 0; ; i++ {
+			var b []byte
+			if i%2 != 0 { // with or without payload
+				b = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-SERVER", i))
+			}
+			if _, err := h.WritePing(b); err != nil {
+				break
+			}
+			if _, err := h.WritePong(b); err != nil { // unsolicited pong
+				break
+			}
+			time.Sleep(10 * time.Millisecond)
+		}
+	}()
+
+	b := make([]byte, 128)
+	for {
+		n, err := ws.Read(b)
+		if err != nil {
+			break
+		}
+		if _, err := ws.Write(b[:n]); err != nil {
+			break
+		}
+	}
+}
+
 func subProtocolHandshake(config *Config, req *http.Request) error {
 	for _, proto := range config.Protocol {
 		if proto == "chat" {
@@ -66,6 +119,7 @@ func subProtoServer(ws *Conn) {
 func startServer() {
 	http.Handle("/echo", Handler(echoServer))
 	http.Handle("/count", Handler(countServer))
+	http.Handle("/ctrldata", Handler(ctrlAndDataServer))
 	subproto := Server{
 		Handshake: subProtocolHandshake,
 		Handler:   Handler(subProtoServer),
@@ -492,3 +546,42 @@ func TestOrigin(t *testing.T) {
 		}
 	}
 }
+
+func TestCtrlAndData(t *testing.T) {
+	once.Do(startServer)
+
+	c, err := net.Dial("tcp", serverAddr)
+	if err != nil {
+		t.Fatal(err)
+	}
+	ws, err := NewClient(newConfig(t, "/ctrldata"), c)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer ws.Close()
+
+	h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
+	ws.frameHandler = h
+
+	b := make([]byte, 128)
+	for i := 0; i < 2; i++ {
+		data := []byte(fmt.Sprintf("#%d-DATA-FRAME-FROM-CLIENT", i))
+		if _, err := ws.Write(data); err != nil {
+			t.Fatalf("#%d: %v", i, err)
+		}
+		var ctrl []byte
+		if i%2 != 0 { // with or without payload
+			ctrl = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-CLIENT", i))
+		}
+		if _, err := h.WritePing(ctrl); err != nil {
+			t.Fatalf("#%d: %v", i, err)
+		}
+		n, err := ws.Read(b)
+		if err != nil {
+			t.Fatalf("#%d: %v", i, err)
+		}
+		if !bytes.Equal(b[:n], data) {
+			t.Fatalf("#%d: got %v; want %v", i, b[:n], data)
+		}
+	}
+}