Jelajahi Sumber

Use bufio.Reader returned from hijack in upgrade

Use the bufio.Reader returned from hijack if the reader's buffer size is
equal to the buffer size specified in Upgrader.ReadBufferSize.
Gary Burd 8 tahun lalu
induk
melakukan
286b5c9371
3 mengubah file dengan 38 tambahan dan 7 penghapusan
  1. 20 1
      conn.go
  2. 14 0
      conn_test.go
  3. 4 6
      server.go

+ 20 - 1
conn.go

@@ -265,6 +265,10 @@ type Conn struct {
 }
 
 func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
+	return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil)
+}
+
+func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn {
 	mu := make(chan bool, 1)
 	mu <- true
 
@@ -274,13 +278,28 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
 	if readBufferSize < maxControlFramePayloadSize {
 		readBufferSize = maxControlFramePayloadSize
 	}
+
+    // Reuse the supplied brw.Reader if brw.Reader's buf is the requested size.
+	var br *bufio.Reader
+	if brw != nil && brw.Reader != nil {
+		// This code assumes that peek on a reset reader returns
+		// bufio.Reader.buf[:0].
+		brw.Reader.Reset(conn)
+		if p, err := brw.Reader.Peek(0); err == nil && cap(p) == readBufferSize {
+			br = brw.Reader
+		}
+	}
+	if br == nil {
+		br = bufio.NewReaderSize(conn, readBufferSize)
+	}
+
 	if writeBufferSize == 0 {
 		writeBufferSize = defaultWriteBufferSize
 	}
 
 	c := &Conn{
 		isServer:               isServer,
-		br:                     bufio.NewReaderSize(conn, readBufferSize),
+		br:                     br,
 		conn:                   conn,
 		mu:                     mu,
 		readFinal:              true,

+ 14 - 0
conn_test.go

@@ -463,3 +463,17 @@ func TestFailedConnectionReadPanic(t *testing.T) {
 	}
 	t.Fatal("should not get here")
 }
+
+func TestBufioReaderReuse(t *testing.T) {
+	brw := bufio.NewReadWriter(bufio.NewReader(nil), nil)
+	c := newConnBRW(nil, false, 0, 0, brw)
+	if c.br != brw.Reader {
+		t.Error("connection did not reuse bufio.Reader")
+	}
+
+	brw = bufio.NewReadWriter(bufio.NewReaderSize(nil, 1234), nil) // size must not equal bufio.defaultBufSize
+	c = newConnBRW(nil, false, 0, 0, brw)
+	if c.br == brw.Reader {
+		t.Error("connection reuse bufio.Reader with wrong size")
+	}
+}

+ 4 - 6
server.go

@@ -152,7 +152,6 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
 
 	var (
 		netConn net.Conn
-		br      *bufio.Reader
 		err     error
 	)
 
@@ -160,19 +159,18 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
 	if !ok {
 		return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
 	}
-	var rw *bufio.ReadWriter
-	netConn, rw, err = h.Hijack()
+	var brw *bufio.ReadWriter
+	netConn, brw, err = h.Hijack()
 	if err != nil {
 		return u.returnError(w, r, http.StatusInternalServerError, err.Error())
 	}
-	br = rw.Reader
 
-	if br.Buffered() > 0 {
+	if brw.Reader.Buffered() > 0 {
 		netConn.Close()
 		return nil, errors.New("websocket: client sent data before handshake is complete")
 	}
 
-	c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize)
+	c := newConnBRW(netConn, true, u.ReadBufferSize, u.WriteBufferSize, brw)
 	c.subprotocol = subprotocol
 
 	if compress {