Quellcode durchsuchen

Merge branch 'compress'

Gary Burd vor 9 Jahren
Ursprung
Commit
5ddbd28fbd
3 geänderte Dateien mit 65 neuen und 29 gelöschten Zeilen
  1. 47 17
      compression.go
  2. 16 10
      conn.go
  3. 2 2
      doc.go

+ 47 - 17
compression.go

@@ -13,30 +13,32 @@ import (
 )
 
 var (
-	flateWriterPool = sync.Pool{}
+	flateWriterPool = sync.Pool{New: func() interface{} {
+		fw, _ := flate.NewWriter(nil, 3)
+		return fw
+	}}
+	flateReaderPool = sync.Pool{New: func() interface{} {
+		return flate.NewReader(nil)
+	}}
 )
 
-func decompressNoContextTakeover(r io.Reader) io.Reader {
+func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
 	const tail =
 	// Add four bytes as specified in RFC
 	"\x00\x00\xff\xff" +
 		// Add final block to squelch unexpected EOF error from flate reader.
 		"\x01\x00\x00\xff\xff"
-	return flate.NewReader(io.MultiReader(r, strings.NewReader(tail)))
+
+	fr, _ := flateReaderPool.Get().(io.ReadCloser)
+	fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
+	return &flateReadWrapper{fr}
 }
 
-func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
+func compressNoContextTakeover(w io.WriteCloser) io.WriteCloser {
 	tw := &truncWriter{w: w}
-	i := flateWriterPool.Get()
-	var fw *flate.Writer
-	var err error
-	if i == nil {
-		fw, err = flate.NewWriter(tw, 3)
-	} else {
-		fw = i.(*flate.Writer)
-		fw.Reset(tw)
-	}
-	return &flateWrapper{fw: fw, tw: tw}, err
+	fw, _ := flateWriterPool.Get().(*flate.Writer)
+	fw.Reset(tw)
+	return &flateWriteWrapper{fw: fw, tw: tw}
 }
 
 // truncWriter is an io.Writer that writes all but the last four bytes of the
@@ -75,19 +77,19 @@ func (w *truncWriter) Write(p []byte) (int, error) {
 	return n + nn, err
 }
 
-type flateWrapper struct {
+type flateWriteWrapper struct {
 	fw *flate.Writer
 	tw *truncWriter
 }
 
-func (w *flateWrapper) Write(p []byte) (int, error) {
+func (w *flateWriteWrapper) Write(p []byte) (int, error) {
 	if w.fw == nil {
 		return 0, errWriteClosed
 	}
 	return w.fw.Write(p)
 }
 
-func (w *flateWrapper) Close() error {
+func (w *flateWriteWrapper) Close() error {
 	if w.fw == nil {
 		return errWriteClosed
 	}
@@ -103,3 +105,31 @@ func (w *flateWrapper) Close() error {
 	}
 	return err2
 }
+
+type flateReadWrapper struct {
+	fr io.ReadCloser
+}
+
+func (r *flateReadWrapper) Read(p []byte) (int, error) {
+	if r.fr == nil {
+		return 0, io.ErrClosedPipe
+	}
+	n, err := r.fr.Read(p)
+	if err == io.EOF {
+		// Preemptively place the reader back in the pool. This helps with
+		// scenarios where the application does not call NextReader() soon after
+		// this final read.
+		r.Close()
+	}
+	return n, err
+}
+
+func (r *flateReadWrapper) Close() error {
+	if r.fr == nil {
+		return io.ErrClosedPipe
+	}
+	err := r.fr.Close()
+	flateReaderPool.Put(r.fr)
+	r.fr = nil
+	return err
+}

+ 16 - 10
conn.go

@@ -235,9 +235,10 @@ type Conn struct {
 	writeErr   error
 
 	enableWriteCompression bool
-	newCompressionWriter   func(io.WriteCloser) (io.WriteCloser, error)
+	newCompressionWriter   func(io.WriteCloser) io.WriteCloser
 
 	// Read fields
+	reader        io.ReadCloser // the current reader returned to the application
 	readErr       error
 	br            *bufio.Reader
 	readRemaining int64 // bytes remaining in current frame.
@@ -253,7 +254,7 @@ type Conn struct {
 	messageReader *messageReader // the current low-level reader
 
 	readDecompress         bool // whether last read frame had RSV1 set
-	newDecompressionReader func(io.Reader) io.Reader
+	newDecompressionReader func(io.Reader) io.ReadCloser
 }
 
 func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
@@ -443,11 +444,7 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
 	}
 	c.writer = mw
 	if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
-		w, err := c.newCompressionWriter(c.writer)
-		if err != nil {
-			c.writer = nil
-			return nil, err
-		}
+		w := c.newCompressionWriter(c.writer)
 		mw.compress = true
 		c.writer = w
 	}
@@ -855,6 +852,11 @@ func (c *Conn) handleProtocolError(message string) error {
 // permanent. Once this method returns a non-nil error, all subsequent calls to
 // this method return the same error.
 func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
+	// Close previous reader, only relevant for decompression.
+	if c.reader != nil {
+		c.reader.Close()
+		c.reader = nil
+	}
 
 	c.messageReader = nil
 	c.readLength = 0
@@ -867,11 +869,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
 		}
 		if frameType == TextMessage || frameType == BinaryMessage {
 			c.messageReader = &messageReader{c}
-			var r io.Reader = c.messageReader
+			c.reader = c.messageReader
 			if c.readDecompress {
-				r = c.newDecompressionReader(r)
+				c.reader = c.newDecompressionReader(c.reader)
 			}
-			return frameType, r, nil
+			return frameType, c.reader, nil
 		}
 	}
 
@@ -933,6 +935,10 @@ func (r *messageReader) Read(b []byte) (int, error) {
 	return 0, err
 }
 
+func (r *messageReader) Close() error {
+	return nil
+}
+
 // ReadMessage is a helper method for getting a reader using NextReader and
 // reading from that reader to a buffer.
 func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {

+ 2 - 2
doc.go

@@ -150,7 +150,7 @@
 // application's responsibility to check the Origin header before calling
 // Upgrade.
 //
-// Compression [Experimental]
+// Compression
 //
 // Per message compression extensions (RFC 7692) are experimentally supported
 // by this package in a limited capacity. Setting the EnableCompression option
@@ -162,7 +162,7 @@
 // Per message compression of messages written to a connection can be enabled
 // or disabled by calling the corresponding Conn method:
 //
-// conn.EnableWriteCompression(true)
+//  conn.EnableWriteCompression(true)
 //
 // Currently this package does not support compression with "context takeover".
 // This means that messages must be compressed and decompressed in isolation,