Pārlūkot izejas kodu

pool flate readers

Cyrus Katrak 9 gadi atpakaļ
vecāks
revīzija
2db2f66488
2 mainītis faili ar 55 papildinājumiem un 10 dzēšanām
  1. 41 6
      compression.go
  2. 14 4
      conn.go

+ 41 - 6
compression.go

@@ -14,15 +14,22 @@ import (
 
 var (
 	flateWriterPool = sync.Pool{}
+	flateReaderPool = sync.Pool{}
 )
 
-func decompressNoContextTakeover(r io.Reader) io.Reader {
+func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
 	const tail =
 	// Add four bytes as specified in RFC
 	"\x00\x00\xff\xff" +
 		// Add final block to squelch unexpected EOF error from flate reader.
 		"\x01\x00\x00\xff\xff"
-	return flate.NewReader(io.MultiReader(r, strings.NewReader(tail)))
+
+	i := flateReaderPool.Get()
+	if i == nil {
+		i = flate.NewReader(nil)
+	}
+	i.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
+	return &flateReadWrapper{i.(io.ReadCloser)}
 }
 
 func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
@@ -36,7 +43,7 @@ func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
 		fw = i.(*flate.Writer)
 		fw.Reset(tw)
 	}
-	return &flateWrapper{fw: fw, tw: tw}, err
+	return &flateWriteWrapper{fw: fw, tw: tw}, err
 }
 
 // truncWriter is an io.Writer that writes all but the last four bytes of the
@@ -75,19 +82,19 @@ func (w *truncWriter) Write(p []byte) (int, error) {
 	return n + nn, err
 }
 
-type flateWrapper struct {
+type flateWriteWrapper struct {
 	fw *flate.Writer
 	tw *truncWriter
 }
 
-func (w *flateWrapper) Write(p []byte) (int, error) {
+func (w *flateWriteWrapper) Write(p []byte) (int, error) {
 	if w.fw == nil {
 		return 0, errWriteClosed
 	}
 	return w.fw.Write(p)
 }
 
-func (w *flateWrapper) Close() error {
+func (w *flateWriteWrapper) Close() error {
 	if w.fw == nil {
 		return errWriteClosed
 	}
@@ -103,3 +110,31 @@ func (w *flateWrapper) Close() error {
 	}
 	return err2
 }
+
+type flateReadWrapper struct {
+	fr io.ReadCloser
+}
+
+func (r *flateReadWrapper) Read(p []byte) (int, error) {
+	if r.fr == nil {
+		return 0, io.ErrClosedPipe
+	}
+	n, err := r.fr.Read(p)
+	if err == io.EOF {
+		// Preemptively place the reader back in the pool. This helps with
+		// scenarios where the application does not call NextReader() soon after
+		// this final read.
+		r.Close()
+	}
+	return n, err
+}
+
+func (r *flateReadWrapper) Close() error {
+	if r.fr == nil {
+		return io.ErrClosedPipe
+	}
+	err := r.fr.Close()
+	flateReaderPool.Put(r.fr)
+	r.fr = nil
+	return err
+}

+ 14 - 4
conn.go

@@ -238,6 +238,7 @@ type Conn struct {
 	newCompressionWriter   func(io.WriteCloser) (io.WriteCloser, error)
 
 	// Read fields
+	reader        io.ReadCloser // the current reader returned to the application
 	readErr       error
 	br            *bufio.Reader
 	readRemaining int64 // bytes remaining in current frame.
@@ -253,7 +254,7 @@ type Conn struct {
 	messageReader *messageReader // the current low-level reader
 
 	readDecompress         bool // whether last read frame had RSV1 set
-	newDecompressionReader func(io.Reader) io.Reader
+	newDecompressionReader func(io.Reader) io.ReadCloser
 }
 
 func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
@@ -855,6 +856,11 @@ func (c *Conn) handleProtocolError(message string) error {
 // permanent. Once this method returns a non-nil error, all subsequent calls to
 // this method return the same error.
 func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
+	// Close previous reader, only relevant for decompression.
+	if c.reader != nil {
+		c.reader.Close()
+		c.reader = nil
+	}
 
 	c.messageReader = nil
 	c.readLength = 0
@@ -867,11 +873,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
 		}
 		if frameType == TextMessage || frameType == BinaryMessage {
 			c.messageReader = &messageReader{c}
-			var r io.Reader = c.messageReader
+			c.reader = c.messageReader
 			if c.readDecompress {
-				r = c.newDecompressionReader(r)
+				c.reader = c.newDecompressionReader(c.reader)
 			}
-			return frameType, r, nil
+			return frameType, c.reader, nil
 		}
 	}
 
@@ -933,6 +939,10 @@ func (r *messageReader) Read(b []byte) (int, error) {
 	return 0, err
 }
 
+func (r *messageReader) Close() error {
+	return nil
+}
+
 // ReadMessage is a helper method for getting a reader using NextReader and
 // reading from that reader to a buffer.
 func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {