瀏覽代碼

Provide all close frame data to application

- Export closeError.
- Do not convert normal closure and going away to io.EOF.
Gary Burd 10 年之前
父節點
當前提交
b6ab76f1fe
共有 4 個文件被更改,包括 23 次插入23 次删除
  1. 12 13
      conn.go
  2. 8 5
      conn_test.go
  3. 1 3
      json.go
  4. 2 2
      json_test.go

+ 12 - 13
conn.go

@@ -88,19 +88,23 @@ func (e *netError) Error() string   { return e.msg }
 func (e *netError) Temporary() bool { return e.temporary }
 func (e *netError) Timeout() bool   { return e.timeout }
 
-// closeError represents close frame.
-type closeError struct {
-	code int
-	text string
+// CloseError represents close frame.
+type CloseError struct {
+
+	// Code is defined in RFC 6455, section 11.7.
+	Code int
+
+	// Text is the optional text payload.
+	Text string
 }
 
-func (e *closeError) Error() string {
-	return "websocket: close " + strconv.Itoa(e.code) + " " + e.text
+func (e *CloseError) Error() string {
+	return "websocket: close " + strconv.Itoa(e.Code) + " " + e.Text
 }
 
 var (
 	errWriteTimeout        = &netError{msg: "websocket: write timeout", timeout: true}
-	errUnexpectedEOF       = &closeError{code: CloseAbnormalClosure, text: io.ErrUnexpectedEOF.Error()}
+	errUnexpectedEOF       = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()}
 	errBadWriteOpCode      = errors.New("websocket: bad write message type")
 	errWriteClosed         = errors.New("websocket: write closed")
 	errInvalidControlFrame = errors.New("websocket: invalid control frame")
@@ -673,12 +677,7 @@ func (c *Conn) advanceFrame() (int, error) {
 			closeCode = int(binary.BigEndian.Uint16(payload))
 			closeText = string(payload[2:])
 		}
-		switch closeCode {
-		case CloseNormalClosure, CloseGoingAway:
-			return noFrame, io.EOF
-		default:
-			return noFrame, &closeError{code: closeCode, text: closeText}
-		}
+		return noFrame, &CloseError{Code: closeCode, Text: closeText}
 	}
 
 	return frameType, nil

+ 8 - 5
conn_test.go

@@ -10,6 +10,7 @@ import (
 	"io"
 	"io/ioutil"
 	"net"
+	"reflect"
 	"testing"
 	"testing/iotest"
 	"time"
@@ -146,13 +147,15 @@ func TestControl(t *testing.T) {
 func TestCloseBeforeFinalFrame(t *testing.T) {
 	const bufSize = 512
 
+	expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
+
 	var b1, b2 bytes.Buffer
 	wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
 	rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
 
 	w, _ := wc.NextWriter(BinaryMessage)
 	w.Write(make([]byte, bufSize+bufSize/2))
-	wc.WriteControl(CloseMessage, FormatCloseMessage(CloseNormalClosure, ""), time.Now().Add(10*time.Second))
+	wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
 	w.Close()
 
 	op, r, err := rc.NextReader()
@@ -160,12 +163,12 @@ func TestCloseBeforeFinalFrame(t *testing.T) {
 		t.Fatalf("NextReader() returned %d, %v", op, err)
 	}
 	_, err = io.Copy(ioutil.Discard, r)
-	if err != errUnexpectedEOF {
-		t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
+	if !reflect.DeepEqual(err, expectedErr) {
+		t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
 	}
 	_, _, err = rc.NextReader()
-	if err != io.EOF {
-		t.Fatalf("NextReader() returned %v, want %v", err, io.EOF)
+	if !reflect.DeepEqual(err, expectedErr) {
+		t.Fatalf("NextReader() returned %v, want %v", err, expectedErr)
 	}
 }
 

+ 1 - 3
json.go

@@ -48,9 +48,7 @@ func (c *Conn) ReadJSON(v interface{}) error {
 	}
 	err = json.NewDecoder(r).Decode(v)
 	if err == io.EOF {
-		// Decode returns io.EOF when the message is empty or all whitespace.
-		// Convert to io.ErrUnexpectedEOF so that application can distinguish
-		// between an error reading the JSON value and the connection closing.
+		// One value is expected in the message.
 		err = io.ErrUnexpectedEOF
 	}
 	return err

+ 2 - 2
json_test.go

@@ -38,7 +38,7 @@ func TestJSON(t *testing.T) {
 	}
 }
 
-func TestPartialJsonRead(t *testing.T) {
+func TestPartialJSONRead(t *testing.T) {
 	var buf bytes.Buffer
 	c := fakeNetConn{&buf, &buf}
 	wc := newConn(c, true, 1024, 1024)
@@ -87,7 +87,7 @@ func TestPartialJsonRead(t *testing.T) {
 	}
 
 	err = rc.ReadJSON(&v)
-	if err != io.EOF {
+	if _, ok := err.(*CloseError); !ok {
 		t.Error("final", err)
 	}
 }