소스 검색

Cleanup read operations.

- Use io.ReadFull instead of similar function in package.
- Return from Read with partial data. Don't attempt to fill buffer.
- Do not return net.Error with Temporary() == true
Gary Burd 11 년 전
부모
커밋
efd7f76a14
1개의 변경된 파일34개의 추가작업 그리고 37개의 파일을 삭제
  1. 34 37
      conn.go

+ 34 - 37
conn.go

@@ -95,6 +95,13 @@ const (
 	writeWait                  = time.Second
 )
 
+func hideTempErr(err error) error {
+	if e, ok := err.(net.Error); ok && e.Temporary() {
+		err = struct{ error }{err}
+	}
+	return err
+}
+
 func isControl(frameType int) bool {
 	return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage
 }
@@ -501,7 +508,7 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
 // SetWriteDeadline sets the write deadline on the underlying network
 // connection. After a write has timed out, the websocket state is corrupt and
 // all future writes will return an error. A zero value for t means writes will
-// not time out 
+// not time out
 func (c *Conn) SetWriteDeadline(t time.Time) error {
 	c.writeDeadline = t
 	return nil
@@ -522,7 +529,7 @@ func (c *Conn) advanceFrame() (int, error) {
 	// 2. Read and parse first two bytes of frame header.
 
 	var b [8]byte
-	if err := c.read(b[:2]); err != nil {
+	if _, err := io.ReadFull(c.br, b[:2]); err != nil {
 		return noFrame, err
 	}
 
@@ -562,12 +569,12 @@ func (c *Conn) advanceFrame() (int, error) {
 
 	switch c.readRemaining {
 	case 126:
-		if err := c.read(b[:2]); err != nil {
+		if _, err := io.ReadFull(c.br, b[:2]); err != nil {
 			return noFrame, err
 		}
 		c.readRemaining = int64(binary.BigEndian.Uint16(b[:2]))
 	case 127:
-		if err := c.read(b[:8]); err != nil {
+		if _, err := io.ReadFull(c.br, b[:8]); err != nil {
 			return noFrame, err
 		}
 		c.readRemaining = int64(binary.BigEndian.Uint64(b[:8]))
@@ -581,7 +588,7 @@ func (c *Conn) advanceFrame() (int, error) {
 
 	if mask {
 		c.readMaskPos = 0
-		if err := c.read(c.readMaskKey[:]); err != nil {
+		if _, err := io.ReadFull(c.br, c.readMaskKey[:]); err != nil {
 			return noFrame, err
 		}
 	}
@@ -601,12 +608,15 @@ func (c *Conn) advanceFrame() (int, error) {
 
 	// 6. Read control frame payload.
 
-	payload := make([]byte, c.readRemaining)
-	c.readRemaining = 0
-	if err := c.read(payload); err != nil {
-		return noFrame, err
+	var payload []byte
+	if c.readRemaining > 0 {
+		payload = make([]byte, c.readRemaining)
+		c.readRemaining = 0
+		if _, err := io.ReadFull(c.br, payload); err != nil {
+			return noFrame, err
+		}
+		maskBytes(c.readMaskKey, 0, payload)
 	}
-	maskBytes(c.readMaskKey, 0, payload)
 
 	// 7. Process control frame payload.
 
@@ -643,23 +653,6 @@ func (c *Conn) handleProtocolError(message string) error {
 	return errors.New("websocket: " + message)
 }
 
-func (c *Conn) read(buf []byte) error {
-	var err error
-	for len(buf) > 0 && err == nil {
-		var nn int
-		nn, err = c.br.Read(buf)
-		buf = buf[nn:]
-	}
-	if err == io.EOF {
-		if len(buf) == 0 {
-			err = nil
-		} else {
-			err = io.ErrUnexpectedEOF
-		}
-	}
-	return err
-}
-
 // NextReader returns the next data message received from the peer. The
 // returned messageType is either TextMessage or BinaryMessage.
 //
@@ -674,8 +667,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
 	c.readLength = 0
 
 	for c.readErr == nil {
-		var frameType int
-		frameType, c.readErr = c.advanceFrame()
+		frameType, err := c.advanceFrame()
+		if err != nil {
+			c.readErr = hideTempErr(err)
+			break
+		}
 		if frameType == TextMessage || frameType == BinaryMessage {
 			return frameType, messageReader{c, c.readSeq}, nil
 		}
@@ -700,10 +696,11 @@ func (r messageReader) Read(b []byte) (n int, err error) {
 			if int64(len(b)) > r.c.readRemaining {
 				b = b[:r.c.readRemaining]
 			}
-			r.c.readErr = r.c.read(b)
-			r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b)
-			r.c.readRemaining -= int64(len(b))
-			return len(b), r.c.readErr
+			n, err := r.c.br.Read(b)
+			r.c.readErr = hideTempErr(err)
+			r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n])
+			r.c.readRemaining -= int64(n)
+			return n, r.c.readErr
 		}
 
 		if r.c.readFinal {
@@ -711,10 +708,10 @@ func (r messageReader) Read(b []byte) (n int, err error) {
 			return 0, io.EOF
 		}
 
-		var frameType int
-		frameType, r.c.readErr = r.c.advanceFrame()
-
-		if frameType == TextMessage || frameType == BinaryMessage {
+		frameType, err := r.c.advanceFrame()
+		if err != nil {
+			r.c.readErr = hideTempErr(err)
+		} else if frameType == TextMessage || frameType == BinaryMessage {
 			r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
 		}
 	}