瀏覽代碼

Improve the errors returned from ReadJSON.

The JSON decoder returns io.EOF when a message is empty or all
whitespace. Convert io.EOF return values from the JSON decoder to
io.ErrUnexpectedEOF so that applications can distinguish between an
error reading the JSON value and the connection closing.
Gary Burd 10 年之前
父節點
當前提交
2dbddebb82
共有 2 個文件被更改,包括 65 次插入1 次删除
  1. 9 1
      json.go
  2. 56 0
      json_test.go

+ 9 - 1
json.go

@@ -6,6 +6,7 @@ package websocket
 
 import (
 	"encoding/json"
+	"io"
 )
 
 // WriteJSON is deprecated, use c.WriteJSON instead.
@@ -45,5 +46,12 @@ func (c *Conn) ReadJSON(v interface{}) error {
 	if err != nil {
 		return err
 	}
-	return json.NewDecoder(r).Decode(v)
+	err = json.NewDecoder(r).Decode(v)
+	if err == io.EOF {
+		// Decode returns io.EOF when the message is empty or all whitespace.
+		// Convert to io.ErrUnexpectedEOF so that application can distinguish
+		// between an error reading the JSON value and the connection closing.
+		err = io.ErrUnexpectedEOF
+	}
+	return err
 }

+ 56 - 0
json_test.go

@@ -6,6 +6,8 @@ package websocket
 
 import (
 	"bytes"
+	"encoding/json"
+	"io"
 	"reflect"
 	"testing"
 )
@@ -36,6 +38,60 @@ func TestJSON(t *testing.T) {
 	}
 }
 
+func TestPartialJsonRead(t *testing.T) {
+	var buf bytes.Buffer
+	c := fakeNetConn{&buf, &buf}
+	wc := newConn(c, true, 1024, 1024)
+	rc := newConn(c, false, 1024, 1024)
+
+	var v struct {
+		A int
+		B string
+	}
+	v.A = 1
+	v.B = "hello"
+
+	messageCount := 0
+
+	// Partial JSON values.
+
+	data, err := json.Marshal(v)
+	if err != nil {
+		t.Fatal(err)
+	}
+	for i := len(data) - 1; i >= 0; i-- {
+		if err := wc.WriteMessage(TextMessage, data[:i]); err != nil {
+			t.Fatal(err)
+		}
+		messageCount++
+	}
+
+	// Whitespace.
+
+	if err := wc.WriteMessage(TextMessage, []byte(" ")); err != nil {
+		t.Fatal(err)
+	}
+	messageCount++
+
+	// Close.
+
+	if err := wc.WriteMessage(CloseMessage, FormatCloseMessage(CloseNormalClosure, "")); err != nil {
+		t.Fatal(err)
+	}
+
+	for i := 0; i < messageCount; i++ {
+		err := rc.ReadJSON(&v)
+		if err != io.ErrUnexpectedEOF {
+			t.Error("read", i, err)
+		}
+	}
+
+	err = rc.ReadJSON(&v)
+	if err != io.EOF {
+		t.Error("final", err)
+	}
+}
+
 func TestDeprecatedJSON(t *testing.T) {
 	var buf bytes.Buffer
 	c := fakeNetConn{&buf, &buf}