Procházet zdrojové kódy

Improve errors.

- Use new closeError type for reporting close frames to the application.
- Use closeError with code 1006 when the peer closes connection without
  sending a close frame. The error io.ErrUnexpectedEOF was used
  previously. This change helps developers distinguish abnormal closure
  and an unexpected EOF in the JSON parser.
Gary Burd před 11 roky
rodič
revize
47f93dfaed
2 změnil soubory, kde provedl 32 přidání a 20 odebrání
  1. 25 13
      conn.go
  2. 7 7
      conn_test.go

+ 25 - 13
conn.go

@@ -70,18 +70,30 @@ var (
 	ErrReadLimit = errors.New("websocket: read limit exceeded")
 	ErrReadLimit = errors.New("websocket: read limit exceeded")
 )
 )
 
 
-type websocketError struct {
+// netError satisfies the net Error interface.
+type netError struct {
 	msg       string
 	msg       string
 	temporary bool
 	temporary bool
 	timeout   bool
 	timeout   bool
 }
 }
 
 
-func (e *websocketError) Error() string   { return e.msg }
-func (e *websocketError) Temporary() bool { return e.temporary }
-func (e *websocketError) Timeout() bool   { return e.timeout }
+func (e *netError) Error() string   { return e.msg }
+func (e *netError) Temporary() bool { return e.temporary }
+func (e *netError) Timeout() bool   { return e.timeout }
+
+// closeError represents close frame.
+type closeError struct {
+	code int
+	text string
+}
+
+func (e *closeError) Error() string {
+	return "websocket: close " + strconv.Itoa(e.code) + " " + e.text
+}
 
 
 var (
 var (
-	errWriteTimeout        = &websocketError{msg: "websocket: write timeout", timeout: true}
+	errWriteTimeout        = &netError{msg: "websocket: write timeout", timeout: true}
+	errUnexpectedEOF       = &closeError{code: CloseAbnormalClosure, text: io.ErrUnexpectedEOF.Error()}
 	errBadWriteOpCode      = errors.New("websocket: bad write message type")
 	errBadWriteOpCode      = errors.New("websocket: bad write message type")
 	errWriteClosed         = errors.New("websocket: write closed")
 	errWriteClosed         = errors.New("websocket: write closed")
 	errInvalidControlFrame = errors.New("websocket: invalid control frame")
 	errInvalidControlFrame = errors.New("websocket: invalid control frame")
@@ -527,7 +539,7 @@ func (c *Conn) readFull(p []byte) (err error) {
 	if n == len(p) {
 	if n == len(p) {
 		err = nil
 		err = nil
 	} else if err == io.EOF {
 	} else if err == io.EOF {
-		err = io.ErrUnexpectedEOF
+		err = errUnexpectedEOF
 	}
 	}
 	return
 	return
 }
 }
@@ -649,17 +661,17 @@ func (c *Conn) advanceFrame() (int, error) {
 		}
 		}
 	case CloseMessage:
 	case CloseMessage:
 		c.WriteControl(CloseMessage, []byte{}, time.Now().Add(writeWait))
 		c.WriteControl(CloseMessage, []byte{}, time.Now().Add(writeWait))
-		if len(payload) < 2 {
-			return noFrame, io.EOF
+		closeCode := CloseNoStatusReceived
+		closeText := ""
+		if len(payload) >= 2 {
+			closeCode = int(binary.BigEndian.Uint16(payload))
+			closeText = string(payload[2:])
 		}
 		}
-		closeCode := binary.BigEndian.Uint16(payload)
 		switch closeCode {
 		switch closeCode {
 		case CloseNormalClosure, CloseGoingAway:
 		case CloseNormalClosure, CloseGoingAway:
 			return noFrame, io.EOF
 			return noFrame, io.EOF
 		default:
 		default:
-			return noFrame, errors.New("websocket: close " +
-				strconv.Itoa(int(closeCode)) + " " +
-				string(payload[2:]))
+			return noFrame, &closeError{code: closeCode, text: closeText}
 		}
 		}
 	}
 	}
 
 
@@ -739,7 +751,7 @@ func (r messageReader) Read(b []byte) (int, error) {
 
 
 	err := r.c.readErr
 	err := r.c.readErr
 	if err == io.EOF && r.seq == r.c.readSeq {
 	if err == io.EOF && r.seq == r.c.readSeq {
-		err = io.ErrUnexpectedEOF
+		err = errUnexpectedEOF
 	}
 	}
 	return 0, err
 	return 0, err
 }
 }

+ 7 - 7
conn_test.go

@@ -152,7 +152,7 @@ func TestCloseBeforeFinalFrame(t *testing.T) {
 
 
 	w, _ := wc.NextWriter(BinaryMessage)
 	w, _ := wc.NextWriter(BinaryMessage)
 	w.Write(make([]byte, bufSize+bufSize/2))
 	w.Write(make([]byte, bufSize+bufSize/2))
-	wc.WriteControl(CloseMessage, []byte{}, time.Now().Add(10*time.Second))
+	wc.WriteControl(CloseMessage, FormatCloseMessage(CloseNormalClosure, ""), time.Now().Add(10*time.Second))
 	w.Close()
 	w.Close()
 
 
 	op, r, err := rc.NextReader()
 	op, r, err := rc.NextReader()
@@ -160,8 +160,8 @@ func TestCloseBeforeFinalFrame(t *testing.T) {
 		t.Fatalf("NextReader() returned %d, %v", op, err)
 		t.Fatalf("NextReader() returned %d, %v", op, err)
 	}
 	}
 	_, err = io.Copy(ioutil.Discard, r)
 	_, err = io.Copy(ioutil.Discard, r)
-	if err != io.ErrUnexpectedEOF {
-		t.Fatalf("io.Copy() returned %v, want %v", err, io.ErrUnexpectedEOF)
+	if err != errUnexpectedEOF {
+		t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
 	}
 	}
 	_, _, err = rc.NextReader()
 	_, _, err = rc.NextReader()
 	if err != io.EOF {
 	if err != io.EOF {
@@ -184,12 +184,12 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
 		t.Fatalf("NextReader() returned %d, %v", op, err)
 		t.Fatalf("NextReader() returned %d, %v", op, err)
 	}
 	}
 	_, err = io.Copy(ioutil.Discard, r)
 	_, err = io.Copy(ioutil.Discard, r)
-	if err != io.ErrUnexpectedEOF {
-		t.Fatalf("io.Copy() returned %v, want %v", err, io.ErrUnexpectedEOF)
+	if err != errUnexpectedEOF {
+		t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
 	}
 	}
 	_, _, err = rc.NextReader()
 	_, _, err = rc.NextReader()
-	if err != io.ErrUnexpectedEOF {
-		t.Fatalf("NextReader() returned %v, want %v", err, io.ErrUnexpectedEOF)
+	if err != errUnexpectedEOF {
+		t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
 	}
 	}
 }
 }