|
|
@@ -18,11 +18,19 @@ import (
|
|
|
)
|
|
|
|
|
|
const (
|
|
|
+ // Frame header byte 0 bits from Section 5.2 of RFC 6455
|
|
|
+ finalBit = 1 << 7
|
|
|
+ rsv1Bit = 1 << 6
|
|
|
+ rsv2Bit = 1 << 5
|
|
|
+ rsv3Bit = 1 << 4
|
|
|
+
|
|
|
+ // Frame header byte 1 bits from Section 5.2 of RFC 6455
|
|
|
+ maskBit = 1 << 7
|
|
|
+
|
|
|
maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask
|
|
|
maxControlFramePayloadSize = 125
|
|
|
- finalBit = 1 << 7
|
|
|
- maskBit = 1 << 7
|
|
|
- writeWait = time.Second
|
|
|
+
|
|
|
+ writeWait = time.Second
|
|
|
|
|
|
defaultReadBufferSize = 4096
|
|
|
defaultWriteBufferSize = 4096
|
|
|
@@ -230,17 +238,20 @@ type Conn struct {
|
|
|
subprotocol string
|
|
|
|
|
|
// Write fields
|
|
|
- mu chan bool // used as mutex to protect write to conn and closeSent
|
|
|
- closeSent bool // true if close message was sent
|
|
|
-
|
|
|
- // Message writer fields.
|
|
|
+ mu chan bool // used as mutex to protect write to conn and closeSent
|
|
|
+ closeSent bool // whether close message was sent
|
|
|
writeErr error
|
|
|
writeBuf []byte // frame is constructed in this buffer.
|
|
|
writePos int // end of data in writeBuf.
|
|
|
writeFrameType int // type of the current frame.
|
|
|
writeDeadline time.Time
|
|
|
+ messageWriter *messageWriter // the current low-level message writer
|
|
|
+ writer io.WriteCloser // the current writer returned to the application
|
|
|
isWriting bool // for best-effort concurrent write detection
|
|
|
- messageWriter *messageWriter // the current writer
|
|
|
+
|
|
|
+ enableWriteCompression bool
|
|
|
+ writeCompress bool // whether next call to flushFrame should set RSV1
|
|
|
+ newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error)
|
|
|
|
|
|
// Read fields
|
|
|
readErr error
|
|
|
@@ -254,7 +265,10 @@ type Conn struct {
|
|
|
handlePong func(string) error
|
|
|
handlePing func(string) error
|
|
|
readErrCount int
|
|
|
- messageReader *messageReader // the current reader
|
|
|
+ messageReader *messageReader // the current low-level reader
|
|
|
+
|
|
|
+ readDecompress bool // whether last read frame had RSV1 set
|
|
|
+ newDecompressionReader func(io.Reader) io.Reader
|
|
|
}
|
|
|
|
|
|
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
|
|
|
@@ -272,14 +286,15 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
|
|
|
}
|
|
|
|
|
|
c := &Conn{
|
|
|
- isServer: isServer,
|
|
|
- br: bufio.NewReaderSize(conn, readBufferSize),
|
|
|
- conn: conn,
|
|
|
- mu: mu,
|
|
|
- readFinal: true,
|
|
|
- writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize),
|
|
|
- writeFrameType: noFrame,
|
|
|
- writePos: maxFrameHeaderSize,
|
|
|
+ isServer: isServer,
|
|
|
+ br: bufio.NewReaderSize(conn, readBufferSize),
|
|
|
+ conn: conn,
|
|
|
+ mu: mu,
|
|
|
+ readFinal: true,
|
|
|
+ writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize),
|
|
|
+ writeFrameType: noFrame,
|
|
|
+ writePos: maxFrameHeaderSize,
|
|
|
+ enableWriteCompression: true,
|
|
|
}
|
|
|
c.SetPingHandler(nil)
|
|
|
c.SetPongHandler(nil)
|
|
|
@@ -403,8 +418,12 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
|
|
|
return nil, c.writeErr
|
|
|
}
|
|
|
|
|
|
- if c.writeFrameType != noFrame {
|
|
|
- if err := c.flushFrame(true, nil); err != nil {
|
|
|
+ // Close previous writer if not already closed by the application. It's
|
|
|
+ // probably better to return an error in this situation, but we cannot
|
|
|
+ // change this without breaking existing applications.
|
|
|
+ if c.writer != nil {
|
|
|
+ err := c.writer.Close()
|
|
|
+ if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
}
|
|
|
@@ -414,11 +433,24 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
|
|
|
}
|
|
|
|
|
|
c.writeFrameType = messageType
|
|
|
- w := &messageWriter{c}
|
|
|
- c.messageWriter = w
|
|
|
+ c.messageWriter = &messageWriter{c}
|
|
|
+
|
|
|
+ var w io.WriteCloser = c.messageWriter
|
|
|
+ if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
|
|
|
+ c.writeCompress = true
|
|
|
+ var err error
|
|
|
+ w, err = c.newCompressionWriter(w)
|
|
|
+ if err != nil {
|
|
|
+ c.writer.Close()
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
return w, nil
|
|
|
}
|
|
|
|
|
|
+// flushFrame writes buffered data and extra as a frame to the network. The
|
|
|
+// final argument indicates that this is the last frame in the message.
|
|
|
func (c *Conn) flushFrame(final bool, extra []byte) error {
|
|
|
length := c.writePos - maxFrameHeaderSize + len(extra)
|
|
|
|
|
|
@@ -426,6 +458,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
|
|
|
if isControl(c.writeFrameType) &&
|
|
|
(!final || length > maxControlFramePayloadSize) {
|
|
|
c.messageWriter = nil
|
|
|
+ c.writer = nil
|
|
|
c.writeFrameType = noFrame
|
|
|
c.writePos = maxFrameHeaderSize
|
|
|
return errInvalidControlFrame
|
|
|
@@ -435,6 +468,11 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
|
|
|
if final {
|
|
|
b0 |= finalBit
|
|
|
}
|
|
|
+ if c.writeCompress {
|
|
|
+ b0 |= rsv1Bit
|
|
|
+ }
|
|
|
+ c.writeCompress = false
|
|
|
+
|
|
|
b1 := byte(0)
|
|
|
if !c.isServer {
|
|
|
b1 |= maskBit
|
|
|
@@ -494,6 +532,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
|
|
|
c.writeFrameType = continuationFrame
|
|
|
if final {
|
|
|
c.messageWriter = nil
|
|
|
+ c.writer = nil
|
|
|
c.writeFrameType = noFrame
|
|
|
}
|
|
|
return c.writeErr
|
|
|
@@ -526,14 +565,14 @@ func (w *messageWriter) ncopy(max int) (int, error) {
|
|
|
return n, nil
|
|
|
}
|
|
|
|
|
|
-func (w *messageWriter) write(final bool, p []byte) (int, error) {
|
|
|
+func (w *messageWriter) Write(p []byte) (int, error) {
|
|
|
if err := w.err(); err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
|
|
|
if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
|
|
|
// Don't buffer large messages.
|
|
|
- err := w.c.flushFrame(final, p)
|
|
|
+ err := w.c.flushFrame(false, p)
|
|
|
if err != nil {
|
|
|
return 0, err
|
|
|
}
|
|
|
@@ -553,10 +592,6 @@ func (w *messageWriter) write(final bool, p []byte) (int, error) {
|
|
|
return nn, nil
|
|
|
}
|
|
|
|
|
|
-func (w *messageWriter) Write(p []byte) (int, error) {
|
|
|
- return w.write(false, p)
|
|
|
-}
|
|
|
-
|
|
|
func (w *messageWriter) WriteString(p string) (int, error) {
|
|
|
if err := w.err(); err != nil {
|
|
|
return 0, err
|
|
|
@@ -658,12 +693,17 @@ func (c *Conn) advanceFrame() (int, error) {
|
|
|
|
|
|
final := p[0]&finalBit != 0
|
|
|
frameType := int(p[0] & 0xf)
|
|
|
- reserved := int((p[0] >> 4) & 0x7)
|
|
|
mask := p[1]&maskBit != 0
|
|
|
c.readRemaining = int64(p[1] & 0x7f)
|
|
|
|
|
|
- if reserved != 0 {
|
|
|
- return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
|
|
|
+ c.readDecompress = false
|
|
|
+ if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
|
|
|
+ c.readDecompress = true
|
|
|
+ p[0] &^= rsv1Bit
|
|
|
+ }
|
|
|
+
|
|
|
+ if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
|
|
|
+ return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
|
|
|
}
|
|
|
|
|
|
switch frameType {
|
|
|
@@ -807,8 +847,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
|
|
break
|
|
|
}
|
|
|
if frameType == TextMessage || frameType == BinaryMessage {
|
|
|
- r := &messageReader{c}
|
|
|
- c.messageReader = r
|
|
|
+ c.messageReader = &messageReader{c}
|
|
|
+ var r io.Reader = c.messageReader
|
|
|
+ if c.readDecompress {
|
|
|
+ r = c.newDecompressionReader(r)
|
|
|
+ }
|
|
|
return frameType, r, nil
|
|
|
}
|
|
|
}
|