Przeglądaj źródła

Fix Read() to return errUnexpectedEOF when EOF is received before all bytes in the frame have been read

Tarmigan Casebolt 9 lat temu
rodzic
commit
ae46df13e9
2 zmienionych plików z 30 dodań i 0 usunięć
  1. 3 0
      conn.go
  2. 27 0
      conn_test.go

+ 3 - 0
conn.go

@@ -821,6 +821,9 @@ func (r messageReader) Read(b []byte) (int, error) {
 				r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n])
 			}
 			r.c.readRemaining -= int64(n)
+			if r.c.readRemaining > 0 && r.c.readErr == io.EOF {
+				r.c.readErr = errUnexpectedEOF
+			}
 			return n, r.c.readErr
 		}
 

+ 27 - 0
conn_test.go

@@ -174,6 +174,33 @@ func TestCloseBeforeFinalFrame(t *testing.T) {
 	}
 }
 
+func TestEOFWithinFrame(t *testing.T) {
+	const bufSize = 512
+
+	var b bytes.Buffer
+	wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024)
+	rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024)
+
+	w, _ := wc.NextWriter(BinaryMessage)
+	w.Write(make([]byte, bufSize))
+	w.Close()
+
+	b.Truncate(bufSize / 2)
+
+	op, r, err := rc.NextReader()
+	if op != BinaryMessage || err != nil {
+		t.Fatalf("NextReader() returned %d, %v", op, err)
+	}
+	_, err = io.Copy(ioutil.Discard, r)
+	if err != errUnexpectedEOF {
+		t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
+	}
+	_, _, err = rc.NextReader()
+	if err != errUnexpectedEOF {
+		t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
+	}
+}
+
 func TestEOFBeforeFinalFrame(t *testing.T) {
 	const bufSize = 512