Browse Source

Fix Read() to return errUnexpectedEOF when EOF is received before all bytes in the frame have been read

Tarmigan Casebolt 10 năm trước cách đây
mục cha
commit
ae46df13e9
2 tập tin đã thay đổi với 30 bổ sung0 xóa
  1. 3 0
      conn.go
  2. 27 0
      conn_test.go

+ 3 - 0
conn.go

@@ -821,6 +821,9 @@ func (r messageReader) Read(b []byte) (int, error) {
 				r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n])
 			}
 			r.c.readRemaining -= int64(n)
+			if r.c.readRemaining > 0 && r.c.readErr == io.EOF {
+				r.c.readErr = errUnexpectedEOF
+			}
 			return n, r.c.readErr
 		}
 

+ 27 - 0
conn_test.go

@@ -174,6 +174,33 @@ func TestCloseBeforeFinalFrame(t *testing.T) {
 	}
 }
 
+func TestEOFWithinFrame(t *testing.T) {
+	const bufSize = 512
+
+	var b bytes.Buffer
+	wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024)
+	rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024)
+
+	w, _ := wc.NextWriter(BinaryMessage)
+	w.Write(make([]byte, bufSize))
+	w.Close()
+
+	b.Truncate(bufSize / 2)
+
+	op, r, err := rc.NextReader()
+	if op != BinaryMessage || err != nil {
+		t.Fatalf("NextReader() returned %d, %v", op, err)
+	}
+	_, err = io.Copy(ioutil.Discard, r)
+	if err != errUnexpectedEOF {
+		t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
+	}
+	_, _, err = rc.NextReader()
+	if err != errUnexpectedEOF {
+		t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
+	}
+}
+
 func TestEOFBeforeFinalFrame(t *testing.T) {
 	const bufSize = 512