Browse Source

Improve errors.

- Use new closeError type for reporting close frames to the application.
- Use closeError with code 1006 when the peer closes connection without
  sending a close frame. The error io.ErrUnexpectedEOF was used
  previously. This change helps developers distinguish abnormal closure
  and an unexpected EOF in the JSON parser.
Gary Burd 11 years ago
parent
commit
47f93dfaed
2 changed files with 32 additions and 20 deletions
  1. 25 13
      conn.go
  2. 7 7
      conn_test.go

+ 25 - 13
conn.go

@@ -70,18 +70,30 @@ var (
 	ErrReadLimit = errors.New("websocket: read limit exceeded")
 	ErrReadLimit = errors.New("websocket: read limit exceeded")
 )
 )
 
 
-type websocketError struct {
+// netError satisfies the net Error interface.
+type netError struct {
 	msg       string
 	msg       string
 	temporary bool
 	temporary bool
 	timeout   bool
 	timeout   bool
 }
 }
 
 
-func (e *websocketError) Error() string   { return e.msg }
-func (e *websocketError) Temporary() bool { return e.temporary }
-func (e *websocketError) Timeout() bool   { return e.timeout }
+func (e *netError) Error() string   { return e.msg }
+func (e *netError) Temporary() bool { return e.temporary }
+func (e *netError) Timeout() bool   { return e.timeout }
+
+// closeError represents close frame.
+type closeError struct {
+	code int
+	text string
+}
+
+func (e *closeError) Error() string {
+	return "websocket: close " + strconv.Itoa(e.code) + " " + e.text
+}
 
 
 var (
 var (
-	errWriteTimeout        = &websocketError{msg: "websocket: write timeout", timeout: true}
+	errWriteTimeout        = &netError{msg: "websocket: write timeout", timeout: true}
+	errUnexpectedEOF       = &closeError{code: CloseAbnormalClosure, text: io.ErrUnexpectedEOF.Error()}
 	errBadWriteOpCode      = errors.New("websocket: bad write message type")
 	errBadWriteOpCode      = errors.New("websocket: bad write message type")
 	errWriteClosed         = errors.New("websocket: write closed")
 	errWriteClosed         = errors.New("websocket: write closed")
 	errInvalidControlFrame = errors.New("websocket: invalid control frame")
 	errInvalidControlFrame = errors.New("websocket: invalid control frame")
@@ -527,7 +539,7 @@ func (c *Conn) readFull(p []byte) (err error) {
 	if n == len(p) {
 	if n == len(p) {
 		err = nil
 		err = nil
 	} else if err == io.EOF {
 	} else if err == io.EOF {
-		err = io.ErrUnexpectedEOF
+		err = errUnexpectedEOF
 	}
 	}
 	return
 	return
 }
 }
@@ -649,17 +661,17 @@ func (c *Conn) advanceFrame() (int, error) {
 		}
 		}
 	case CloseMessage:
 	case CloseMessage:
 		c.WriteControl(CloseMessage, []byte{}, time.Now().Add(writeWait))
 		c.WriteControl(CloseMessage, []byte{}, time.Now().Add(writeWait))
-		if len(payload) < 2 {
-			return noFrame, io.EOF
+		closeCode := CloseNoStatusReceived
+		closeText := ""
+		if len(payload) >= 2 {
+			closeCode = int(binary.BigEndian.Uint16(payload))
+			closeText = string(payload[2:])
 		}
 		}
-		closeCode := binary.BigEndian.Uint16(payload)
 		switch closeCode {
 		switch closeCode {
 		case CloseNormalClosure, CloseGoingAway:
 		case CloseNormalClosure, CloseGoingAway:
 			return noFrame, io.EOF
 			return noFrame, io.EOF
 		default:
 		default:
-			return noFrame, errors.New("websocket: close " +
-				strconv.Itoa(int(closeCode)) + " " +
-				string(payload[2:]))
+			return noFrame, &closeError{code: closeCode, text: closeText}
 		}
 		}
 	}
 	}
 
 
@@ -739,7 +751,7 @@ func (r messageReader) Read(b []byte) (int, error) {
 
 
 	err := r.c.readErr
 	err := r.c.readErr
 	if err == io.EOF && r.seq == r.c.readSeq {
 	if err == io.EOF && r.seq == r.c.readSeq {
-		err = io.ErrUnexpectedEOF
+		err = errUnexpectedEOF
 	}
 	}
 	return 0, err
 	return 0, err
 }
 }

+ 7 - 7
conn_test.go

@@ -152,7 +152,7 @@ func TestCloseBeforeFinalFrame(t *testing.T) {
 
 
 	w, _ := wc.NextWriter(BinaryMessage)
 	w, _ := wc.NextWriter(BinaryMessage)
 	w.Write(make([]byte, bufSize+bufSize/2))
 	w.Write(make([]byte, bufSize+bufSize/2))
-	wc.WriteControl(CloseMessage, []byte{}, time.Now().Add(10*time.Second))
+	wc.WriteControl(CloseMessage, FormatCloseMessage(CloseNormalClosure, ""), time.Now().Add(10*time.Second))
 	w.Close()
 	w.Close()
 
 
 	op, r, err := rc.NextReader()
 	op, r, err := rc.NextReader()
@@ -160,8 +160,8 @@ func TestCloseBeforeFinalFrame(t *testing.T) {
 		t.Fatalf("NextReader() returned %d, %v", op, err)
 		t.Fatalf("NextReader() returned %d, %v", op, err)
 	}
 	}
 	_, err = io.Copy(ioutil.Discard, r)
 	_, err = io.Copy(ioutil.Discard, r)
-	if err != io.ErrUnexpectedEOF {
-		t.Fatalf("io.Copy() returned %v, want %v", err, io.ErrUnexpectedEOF)
+	if err != errUnexpectedEOF {
+		t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
 	}
 	}
 	_, _, err = rc.NextReader()
 	_, _, err = rc.NextReader()
 	if err != io.EOF {
 	if err != io.EOF {
@@ -184,12 +184,12 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
 		t.Fatalf("NextReader() returned %d, %v", op, err)
 		t.Fatalf("NextReader() returned %d, %v", op, err)
 	}
 	}
 	_, err = io.Copy(ioutil.Discard, r)
 	_, err = io.Copy(ioutil.Discard, r)
-	if err != io.ErrUnexpectedEOF {
-		t.Fatalf("io.Copy() returned %v, want %v", err, io.ErrUnexpectedEOF)
+	if err != errUnexpectedEOF {
+		t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
 	}
 	}
 	_, _, err = rc.NextReader()
 	_, _, err = rc.NextReader()
-	if err != io.ErrUnexpectedEOF {
-		t.Fatalf("NextReader() returned %v, want %v", err, io.ErrUnexpectedEOF)
+	if err != errUnexpectedEOF {
+		t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
 	}
 	}
 }
 }