浏览代码

Use bufio.Reader returned from hijack in upgrade

Use the bufio.Reader returned from hijack if the reader's buffer size is
equal to the buffer size specified in Upgrader.ReadBufferSize.
Gary Burd 8 年之前
父节点
当前提交
286b5c9371
共有 3 个文件被更改,包括 38 次插入7 次删除
  1. 20 1
      conn.go
  2. 14 0
      conn_test.go
  3. 4 6
      server.go

+ 20 - 1
conn.go

@@ -265,6 +265,10 @@ type Conn struct {
 }
 
 func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
+	return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil)
+}
+
+func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn {
 	mu := make(chan bool, 1)
 	mu <- true
 
@@ -274,13 +278,28 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
 	if readBufferSize < maxControlFramePayloadSize {
 		readBufferSize = maxControlFramePayloadSize
 	}
+
+    // Reuse the supplied brw.Reader if brw.Reader's buf is the requested size.
+	var br *bufio.Reader
+	if brw != nil && brw.Reader != nil {
+		// This code assumes that peek on a reset reader returns
+		// bufio.Reader.buf[:0].
+		brw.Reader.Reset(conn)
+		if p, err := brw.Reader.Peek(0); err == nil && cap(p) == readBufferSize {
+			br = brw.Reader
+		}
+	}
+	if br == nil {
+		br = bufio.NewReaderSize(conn, readBufferSize)
+	}
+
 	if writeBufferSize == 0 {
 		writeBufferSize = defaultWriteBufferSize
 	}
 
 	c := &Conn{
 		isServer:               isServer,
-		br:                     bufio.NewReaderSize(conn, readBufferSize),
+		br:                     br,
 		conn:                   conn,
 		mu:                     mu,
 		readFinal:              true,

+ 14 - 0
conn_test.go

@@ -463,3 +463,17 @@ func TestFailedConnectionReadPanic(t *testing.T) {
 	}
 	t.Fatal("should not get here")
 }
+
+func TestBufioReaderReuse(t *testing.T) {
+	brw := bufio.NewReadWriter(bufio.NewReader(nil), nil)
+	c := newConnBRW(nil, false, 0, 0, brw)
+	if c.br != brw.Reader {
+		t.Error("connection did not reuse bufio.Reader")
+	}
+
+	brw = bufio.NewReadWriter(bufio.NewReaderSize(nil, 1234), nil) // size must not equal bufio.defaultBufSize
+	c = newConnBRW(nil, false, 0, 0, brw)
+	if c.br == brw.Reader {
+		t.Error("connection reuse bufio.Reader with wrong size")
+	}
+}

+ 4 - 6
server.go

@@ -152,7 +152,6 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
 
 	var (
 		netConn net.Conn
-		br      *bufio.Reader
 		err     error
 	)
 
@@ -160,19 +159,18 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
 	if !ok {
 		return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
 	}
-	var rw *bufio.ReadWriter
-	netConn, rw, err = h.Hijack()
+	var brw *bufio.ReadWriter
+	netConn, brw, err = h.Hijack()
 	if err != nil {
 		return u.returnError(w, r, http.StatusInternalServerError, err.Error())
 	}
-	br = rw.Reader
 
-	if br.Buffered() > 0 {
+	if brw.Reader.Buffered() > 0 {
 		netConn.Close()
 		return nil, errors.New("websocket: client sent data before handshake is complete")
 	}
 
-	c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize)
+	c := newConnBRW(netConn, true, u.ReadBufferSize, u.WriteBufferSize, brw)
 	c.subprotocol = subprotocol
 
 	if compress {