Преглед на файлове

Cleanup EOF handling.

- Modify data message reader to return io.ErrUnexpectedEOF if a close
  message is received before the final frame of the message.
- Modify NextReader to return io.ErrUnexpectedEOF if underlying
  connection returns io.EOF before a close message.
Gary Burd преди 11 години
родител
ревизия
10afcadf69
променени са 2 файла, в които са добавени 81 реда и са изтрити 9 реда
  1. 31 9
      conn.go
  2. 50 0
      conn_test.go

+ 31 - 9
conn.go

@@ -516,6 +516,22 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
 
 // Read methods
 
+// readFull is like io.ReadFull except that io.EOF is never returned.
+func (c *Conn) readFull(p []byte) (err error) {
+	var n int
+	for n < len(p) && err == nil {
+		var nn int
+		nn, err = c.br.Read(p[n:])
+		n += nn
+	}
+	if n == len(p) {
+		err = nil
+	} else if err == io.EOF {
+		err = io.ErrUnexpectedEOF
+	}
+	return
+}
+
 func (c *Conn) advanceFrame() (int, error) {
 
 	// 1. Skip remainder of previous frame.
@@ -529,7 +545,7 @@ func (c *Conn) advanceFrame() (int, error) {
 	// 2. Read and parse first two bytes of frame header.
 
 	var b [8]byte
-	if _, err := io.ReadFull(c.br, b[:2]); err != nil {
+	if err := c.readFull(b[:2]); err != nil {
 		return noFrame, err
 	}
 
@@ -569,12 +585,12 @@ func (c *Conn) advanceFrame() (int, error) {
 
 	switch c.readRemaining {
 	case 126:
-		if _, err := io.ReadFull(c.br, b[:2]); err != nil {
+		if err := c.readFull(b[:2]); err != nil {
 			return noFrame, err
 		}
 		c.readRemaining = int64(binary.BigEndian.Uint16(b[:2]))
 	case 127:
-		if _, err := io.ReadFull(c.br, b[:8]); err != nil {
+		if err := c.readFull(b[:8]); err != nil {
 			return noFrame, err
 		}
 		c.readRemaining = int64(binary.BigEndian.Uint64(b[:8]))
@@ -588,7 +604,7 @@ func (c *Conn) advanceFrame() (int, error) {
 
 	if mask {
 		c.readMaskPos = 0
-		if _, err := io.ReadFull(c.br, c.readMaskKey[:]); err != nil {
+		if err := c.readFull(c.readMaskKey[:]); err != nil {
 			return noFrame, err
 		}
 	}
@@ -612,7 +628,7 @@ func (c *Conn) advanceFrame() (int, error) {
 	if c.readRemaining > 0 {
 		payload = make([]byte, c.readRemaining)
 		c.readRemaining = 0
-		if _, err := io.ReadFull(c.br, payload); err != nil {
+		if err := c.readFull(payload); err != nil {
 			return noFrame, err
 		}
 		if c.isServer {
@@ -686,7 +702,7 @@ type messageReader struct {
 	seq int
 }
 
-func (r messageReader) Read(b []byte) (n int, err error) {
+func (r messageReader) Read(b []byte) (int, error) {
 
 	if r.seq != r.c.readSeq {
 		return 0, io.EOF
@@ -713,13 +729,19 @@ func (r messageReader) Read(b []byte) (n int, err error) {
 		}
 
 		frameType, err := r.c.advanceFrame()
-		if err != nil {
+		switch {
+		case err != nil:
 			r.c.readErr = hideTempErr(err)
-		} else if frameType == TextMessage || frameType == BinaryMessage {
+		case frameType == TextMessage || frameType == BinaryMessage:
 			r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
 		}
 	}
-	return 0, r.c.readErr
+
+	err := r.c.readErr
+	if err == io.EOF && r.seq == r.c.readSeq {
+		err = io.ErrUnexpectedEOF
+	}
+	return 0, err
 }
 
 // ReadMessage is a helper method for getting a reader using NextReader and

+ 50 - 0
conn_test.go

@@ -143,6 +143,56 @@ func TestControl(t *testing.T) {
 	}
 }
 
+func TestCloseBeforeFinalFrame(t *testing.T) {
+	const bufSize = 512
+
+	var b1, b2 bytes.Buffer
+	wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
+	rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
+
+	w, _ := wc.NextWriter(BinaryMessage)
+	w.Write(make([]byte, bufSize+bufSize/2))
+	wc.WriteControl(CloseMessage, []byte{}, time.Now().Add(10*time.Second))
+	w.Close()
+
+	op, r, err := rc.NextReader()
+	if op != BinaryMessage || err != nil {
+		t.Fatalf("NextReader() returned %d, %v", op, err)
+	}
+	_, err = io.Copy(ioutil.Discard, r)
+	if err != io.ErrUnexpectedEOF {
+		t.Fatalf("io.Copy() returned %v, want %v", err, io.ErrUnexpectedEOF)
+	}
+	_, _, err = rc.NextReader()
+	if err != io.EOF {
+		t.Fatalf("NextReader() returned %v, want %v", err, io.EOF)
+	}
+}
+
+func TestEOFBeforeFinalFrame(t *testing.T) {
+	const bufSize = 512
+
+	var b1, b2 bytes.Buffer
+	wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
+	rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
+
+	w, _ := wc.NextWriter(BinaryMessage)
+	w.Write(make([]byte, bufSize+bufSize/2))
+
+	op, r, err := rc.NextReader()
+	if op != BinaryMessage || err != nil {
+		t.Fatalf("NextReader() returned %d, %v", op, err)
+	}
+	_, err = io.Copy(ioutil.Discard, r)
+	if err != io.ErrUnexpectedEOF {
+		t.Fatalf("io.Copy() returned %v, want %v", err, io.ErrUnexpectedEOF)
+	}
+	_, _, err = rc.NextReader()
+	if err != io.ErrUnexpectedEOF {
+		t.Fatalf("NextReader() returned %v, want %v", err, io.ErrUnexpectedEOF)
+	}
+}
+
 func TestReadLimit(t *testing.T) {
 
 	const readLimit = 512