Browse Source

Use bufio.Reader returned from hijack in upgrade

Use the bufio.Reader returned from hijack if the reader's buffer size is
equal to the buffer size specified in Upgrader.ReadBufferSize.
Gary Burd 9 years ago
parent
commit
286b5c9371
3 changed files with 38 additions and 7 deletions
  1. 20 1
      conn.go
  2. 14 0
      conn_test.go
  3. 4 6
      server.go

+ 20 - 1
conn.go

@@ -265,6 +265,10 @@ type Conn struct {
 }
 }
 
 
 func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
 func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
+	return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil)
+}
+
+func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn {
 	mu := make(chan bool, 1)
 	mu := make(chan bool, 1)
 	mu <- true
 	mu <- true
 
 
@@ -274,13 +278,28 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
 	if readBufferSize < maxControlFramePayloadSize {
 	if readBufferSize < maxControlFramePayloadSize {
 		readBufferSize = maxControlFramePayloadSize
 		readBufferSize = maxControlFramePayloadSize
 	}
 	}
+
+    // Reuse the supplied brw.Reader if brw.Reader's buf is the requested size.
+	var br *bufio.Reader
+	if brw != nil && brw.Reader != nil {
+		// This code assumes that peek on a reset reader returns
+		// bufio.Reader.buf[:0].
+		brw.Reader.Reset(conn)
+		if p, err := brw.Reader.Peek(0); err == nil && cap(p) == readBufferSize {
+			br = brw.Reader
+		}
+	}
+	if br == nil {
+		br = bufio.NewReaderSize(conn, readBufferSize)
+	}
+
 	if writeBufferSize == 0 {
 	if writeBufferSize == 0 {
 		writeBufferSize = defaultWriteBufferSize
 		writeBufferSize = defaultWriteBufferSize
 	}
 	}
 
 
 	c := &Conn{
 	c := &Conn{
 		isServer:               isServer,
 		isServer:               isServer,
-		br:                     bufio.NewReaderSize(conn, readBufferSize),
+		br:                     br,
 		conn:                   conn,
 		conn:                   conn,
 		mu:                     mu,
 		mu:                     mu,
 		readFinal:              true,
 		readFinal:              true,

+ 14 - 0
conn_test.go

@@ -463,3 +463,17 @@ func TestFailedConnectionReadPanic(t *testing.T) {
 	}
 	}
 	t.Fatal("should not get here")
 	t.Fatal("should not get here")
 }
 }
+
+func TestBufioReaderReuse(t *testing.T) {
+	brw := bufio.NewReadWriter(bufio.NewReader(nil), nil)
+	c := newConnBRW(nil, false, 0, 0, brw)
+	if c.br != brw.Reader {
+		t.Error("connection did not reuse bufio.Reader")
+	}
+
+	brw = bufio.NewReadWriter(bufio.NewReaderSize(nil, 1234), nil) // size must not equal bufio.defaultBufSize
+	c = newConnBRW(nil, false, 0, 0, brw)
+	if c.br == brw.Reader {
+		t.Error("connection reuse bufio.Reader with wrong size")
+	}
+}

+ 4 - 6
server.go

@@ -152,7 +152,6 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
 
 
 	var (
 	var (
 		netConn net.Conn
 		netConn net.Conn
-		br      *bufio.Reader
 		err     error
 		err     error
 	)
 	)
 
 
@@ -160,19 +159,18 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
 	if !ok {
 	if !ok {
 		return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
 		return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
 	}
 	}
-	var rw *bufio.ReadWriter
-	netConn, rw, err = h.Hijack()
+	var brw *bufio.ReadWriter
+	netConn, brw, err = h.Hijack()
 	if err != nil {
 	if err != nil {
 		return u.returnError(w, r, http.StatusInternalServerError, err.Error())
 		return u.returnError(w, r, http.StatusInternalServerError, err.Error())
 	}
 	}
-	br = rw.Reader
 
 
-	if br.Buffered() > 0 {
+	if brw.Reader.Buffered() > 0 {
 		netConn.Close()
 		netConn.Close()
 		return nil, errors.New("websocket: client sent data before handshake is complete")
 		return nil, errors.New("websocket: client sent data before handshake is complete")
 	}
 	}
 
 
-	c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize)
+	c := newConnBRW(netConn, true, u.ReadBufferSize, u.WriteBufferSize, brw)
 	c.subprotocol = subprotocol
 	c.subprotocol = subprotocol
 
 
 	if compress {
 	if compress {