|
|
@@ -516,6 +516,22 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
|
|
|
|
|
|
// Read methods
|
|
|
|
|
|
+// readFull is like io.ReadFull except that io.EOF is never returned.
|
|
|
+func (c *Conn) readFull(p []byte) (err error) {
|
|
|
+ var n int
|
|
|
+ for n < len(p) && err == nil {
|
|
|
+ var nn int
|
|
|
+ nn, err = c.br.Read(p[n:])
|
|
|
+ n += nn
|
|
|
+ }
|
|
|
+ if n == len(p) {
|
|
|
+ err = nil
|
|
|
+ } else if err == io.EOF {
|
|
|
+ err = io.ErrUnexpectedEOF
|
|
|
+ }
|
|
|
+ return
|
|
|
+}
|
|
|
+
|
|
|
func (c *Conn) advanceFrame() (int, error) {
|
|
|
|
|
|
// 1. Skip remainder of previous frame.
|
|
|
@@ -529,7 +545,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
|
|
// 2. Read and parse first two bytes of frame header.
|
|
|
|
|
|
var b [8]byte
|
|
|
- if _, err := io.ReadFull(c.br, b[:2]); err != nil {
|
|
|
+ if err := c.readFull(b[:2]); err != nil {
|
|
|
return noFrame, err
|
|
|
}
|
|
|
|
|
|
@@ -569,12 +585,12 @@ func (c *Conn) advanceFrame() (int, error) {
|
|
|
|
|
|
switch c.readRemaining {
|
|
|
case 126:
|
|
|
- if _, err := io.ReadFull(c.br, b[:2]); err != nil {
|
|
|
+ if err := c.readFull(b[:2]); err != nil {
|
|
|
return noFrame, err
|
|
|
}
|
|
|
c.readRemaining = int64(binary.BigEndian.Uint16(b[:2]))
|
|
|
case 127:
|
|
|
- if _, err := io.ReadFull(c.br, b[:8]); err != nil {
|
|
|
+ if err := c.readFull(b[:8]); err != nil {
|
|
|
return noFrame, err
|
|
|
}
|
|
|
c.readRemaining = int64(binary.BigEndian.Uint64(b[:8]))
|
|
|
@@ -588,7 +604,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
|
|
|
|
|
if mask {
|
|
|
c.readMaskPos = 0
|
|
|
- if _, err := io.ReadFull(c.br, c.readMaskKey[:]); err != nil {
|
|
|
+ if err := c.readFull(c.readMaskKey[:]); err != nil {
|
|
|
return noFrame, err
|
|
|
}
|
|
|
}
|
|
|
@@ -612,7 +628,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
|
|
if c.readRemaining > 0 {
|
|
|
payload = make([]byte, c.readRemaining)
|
|
|
c.readRemaining = 0
|
|
|
- if _, err := io.ReadFull(c.br, payload); err != nil {
|
|
|
+ if err := c.readFull(payload); err != nil {
|
|
|
return noFrame, err
|
|
|
}
|
|
|
if c.isServer {
|
|
|
@@ -686,7 +702,7 @@ type messageReader struct {
|
|
|
seq int
|
|
|
}
|
|
|
|
|
|
-func (r messageReader) Read(b []byte) (n int, err error) {
|
|
|
+func (r messageReader) Read(b []byte) (int, error) {
|
|
|
|
|
|
if r.seq != r.c.readSeq {
|
|
|
return 0, io.EOF
|
|
|
@@ -713,13 +729,19 @@ func (r messageReader) Read(b []byte) (n int, err error) {
|
|
|
}
|
|
|
|
|
|
frameType, err := r.c.advanceFrame()
|
|
|
- if err != nil {
|
|
|
+ switch {
|
|
|
+ case err != nil:
|
|
|
r.c.readErr = hideTempErr(err)
|
|
|
- } else if frameType == TextMessage || frameType == BinaryMessage {
|
|
|
+ case frameType == TextMessage || frameType == BinaryMessage:
|
|
|
r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
|
|
|
}
|
|
|
}
|
|
|
- return 0, r.c.readErr
|
|
|
+
|
|
|
+ err := r.c.readErr
|
|
|
+ if err == io.EOF && r.seq == r.c.readSeq {
|
|
|
+ err = io.ErrUnexpectedEOF
|
|
|
+ }
|
|
|
+ return 0, err
|
|
|
}
|
|
|
|
|
|
// ReadMessage is a helper method for getting a reader using NextReader and
|