|
|
@@ -235,9 +235,10 @@ type Conn struct {
|
|
|
writeErr error
|
|
|
|
|
|
enableWriteCompression bool
|
|
|
- newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error)
|
|
|
+ newCompressionWriter func(io.WriteCloser) io.WriteCloser
|
|
|
|
|
|
// Read fields
|
|
|
+ reader io.ReadCloser // the current reader returned to the application
|
|
|
readErr error
|
|
|
br *bufio.Reader
|
|
|
readRemaining int64 // bytes remaining in current frame.
|
|
|
@@ -253,7 +254,7 @@ type Conn struct {
|
|
|
messageReader *messageReader // the current low-level reader
|
|
|
|
|
|
readDecompress bool // whether last read frame had RSV1 set
|
|
|
- newDecompressionReader func(io.Reader) io.Reader
|
|
|
+ newDecompressionReader func(io.Reader) io.ReadCloser
|
|
|
}
|
|
|
|
|
|
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
|
|
|
@@ -443,11 +444,7 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
|
|
|
}
|
|
|
c.writer = mw
|
|
|
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
|
|
|
- w, err := c.newCompressionWriter(c.writer)
|
|
|
- if err != nil {
|
|
|
- c.writer = nil
|
|
|
- return nil, err
|
|
|
- }
|
|
|
+ w := c.newCompressionWriter(c.writer)
|
|
|
mw.compress = true
|
|
|
c.writer = w
|
|
|
}
|
|
|
@@ -855,6 +852,11 @@ func (c *Conn) handleProtocolError(message string) error {
|
|
|
// permanent. Once this method returns a non-nil error, all subsequent calls to
|
|
|
// this method return the same error.
|
|
|
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
|
|
+ // Close previous reader, only relevant for decompression.
|
|
|
+ if c.reader != nil {
|
|
|
+ c.reader.Close()
|
|
|
+ c.reader = nil
|
|
|
+ }
|
|
|
|
|
|
c.messageReader = nil
|
|
|
c.readLength = 0
|
|
|
@@ -867,11 +869,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
|
|
}
|
|
|
if frameType == TextMessage || frameType == BinaryMessage {
|
|
|
c.messageReader = &messageReader{c}
|
|
|
- var r io.Reader = c.messageReader
|
|
|
+ c.reader = c.messageReader
|
|
|
if c.readDecompress {
|
|
|
- r = c.newDecompressionReader(r)
|
|
|
+ c.reader = c.newDecompressionReader(c.reader)
|
|
|
}
|
|
|
- return frameType, r, nil
|
|
|
+ return frameType, c.reader, nil
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -933,6 +935,10 @@ func (r *messageReader) Read(b []byte) (int, error) {
|
|
|
return 0, err
|
|
|
}
|
|
|
|
|
|
+func (r *messageReader) Close() error {
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
// ReadMessage is a helper method for getting a reader using NextReader and
|
|
|
// reading from that reader to a buffer.
|
|
|
func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
|