Ver código fonte

Use bufio.Reader returned from hijack in upgrade

Use the bufio.Reader returned from hijack if the reader's buffer size is
equal to the buffer size specified in Upgrader.ReadBufferSize.
Gary Burd 8 anos atrás
pai
commit
286b5c9371
3 arquivos alterados com 38 adições e 7 exclusões
  1. 20 1
      conn.go
  2. 14 0
      conn_test.go
  3. 4 6
      server.go

+ 20 - 1
conn.go

@@ -265,6 +265,10 @@ type Conn struct {
 }
 
 func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
+	return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil)
+}
+
+func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn {
 	mu := make(chan bool, 1)
 	mu <- true
 
@@ -274,13 +278,28 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
 	if readBufferSize < maxControlFramePayloadSize {
 		readBufferSize = maxControlFramePayloadSize
 	}
+
+    // Reuse the supplied brw.Reader if brw.Reader's buf is the requested size.
+	var br *bufio.Reader
+	if brw != nil && brw.Reader != nil {
+		// This code assumes that peek on a reset reader returns
+		// bufio.Reader.buf[:0].
+		brw.Reader.Reset(conn)
+		if p, err := brw.Reader.Peek(0); err == nil && cap(p) == readBufferSize {
+			br = brw.Reader
+		}
+	}
+	if br == nil {
+		br = bufio.NewReaderSize(conn, readBufferSize)
+	}
+
 	if writeBufferSize == 0 {
 		writeBufferSize = defaultWriteBufferSize
 	}
 
 	c := &Conn{
 		isServer:               isServer,
-		br:                     bufio.NewReaderSize(conn, readBufferSize),
+		br:                     br,
 		conn:                   conn,
 		mu:                     mu,
 		readFinal:              true,

+ 14 - 0
conn_test.go

@@ -463,3 +463,17 @@ func TestFailedConnectionReadPanic(t *testing.T) {
 	}
 	t.Fatal("should not get here")
 }
+
+func TestBufioReaderReuse(t *testing.T) {
+	brw := bufio.NewReadWriter(bufio.NewReader(nil), nil)
+	c := newConnBRW(nil, false, 0, 0, brw)
+	if c.br != brw.Reader {
+		t.Error("connection did not reuse bufio.Reader")
+	}
+
+	brw = bufio.NewReadWriter(bufio.NewReaderSize(nil, 1234), nil) // size must not equal bufio.defaultBufSize
+	c = newConnBRW(nil, false, 0, 0, brw)
+	if c.br == brw.Reader {
+		t.Error("connection reuse bufio.Reader with wrong size")
+	}
+}

+ 4 - 6
server.go

@@ -152,7 +152,6 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
 
 	var (
 		netConn net.Conn
-		br      *bufio.Reader
 		err     error
 	)
 
@@ -160,19 +159,18 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
 	if !ok {
 		return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
 	}
-	var rw *bufio.ReadWriter
-	netConn, rw, err = h.Hijack()
+	var brw *bufio.ReadWriter
+	netConn, brw, err = h.Hijack()
 	if err != nil {
 		return u.returnError(w, r, http.StatusInternalServerError, err.Error())
 	}
-	br = rw.Reader
 
-	if br.Buffered() > 0 {
+	if brw.Reader.Buffered() > 0 {
 		netConn.Close()
 		return nil, errors.New("websocket: client sent data before handshake is complete")
 	}
 
-	c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize)
+	c := newConnBRW(netConn, true, u.ReadBufferSize, u.WriteBufferSize, brw)
 	c.subprotocol = subprotocol
 
 	if compress {