Procházet zdrojové kódy

Add hooks to support RFC 7692 (per-message compression extension)

Add newCompressionWriter and newDecompressionReader fields to Conn. When
not nil, these functions are used to create a compression/decompression
wrapper around an underlying message writer/reader.

Add code to set and check for RSV1 frame header bit.

Add functions compressNoContextTakeover and decompressNoContextTakeover
for creating no context takeover wrappers around an underlying message
writer/reader.

Work remaining:

- Add fields to Dialer and Upgrader for specifying compression options.
- Add compression negotiation to Dialer and Upgrader.
- Add function to enable/disable write compression:

    // EnableWriteCompression enables and disables write compression of
    // subsequent text and binary messages. This function is a noop if
    // compression was not negotiated with the peer.
    func (c *Conn) EnableWriteCompression(enable bool) {
            c.enableWriteCompression = enable
    }
Gary Burd před 9 roky
rodič
revize
a87eae1d6f
3 změnil soubory, kde provedl 191 přidání a 32 odebrání
  1. 85 0
      compression.go
  2. 31 0
      compression_test.go
  3. 75 32
      conn.go

+ 85 - 0
compression.go

@@ -0,0 +1,85 @@
+// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package websocket
+
+import (
+	"compress/flate"
+	"errors"
+	"io"
+	"strings"
+)
+
+func decompressNoContextTakeover(r io.Reader) io.Reader {
+	const tail =
+	// Add four bytes as specified in RFC
+	"\x00\x00\xff\xff" +
+		// Add final block to squelch unexpected EOF error from flate reader.
+		"\x01\x00\x00\xff\xff"
+
+	return flate.NewReader(io.MultiReader(r, strings.NewReader(tail)))
+}
+
+func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
+	tw := &truncWriter{w: w}
+	fw, err := flate.NewWriter(tw, 3)
+	return &flateWrapper{fw: fw, tw: tw}, err
+}
+
+// truncWriter is an io.Writer that writes all but the last four bytes of the
+// stream to another io.Writer.
+type truncWriter struct {
+	w io.WriteCloser
+	n int
+	p [4]byte
+}
+
+func (w *truncWriter) Write(p []byte) (int, error) {
+	n := 0
+
+	// fill buffer first for simplicity.
+	if w.n < len(w.p) {
+		n = copy(w.p[w.n:], p)
+		p = p[n:]
+		w.n += n
+		if len(p) == 0 {
+			return n, nil
+		}
+	}
+
+	m := len(p)
+	if m > len(w.p) {
+		m = len(w.p)
+	}
+
+	if nn, err := w.w.Write(w.p[:m]); err != nil {
+		return n + nn, err
+	}
+
+	copy(w.p[:], w.p[m:])
+	copy(w.p[len(w.p)-m:], p[len(p)-m:])
+	nn, err := w.w.Write(p[:len(p)-m])
+	return n + nn, err
+}
+
+type flateWrapper struct {
+	fw *flate.Writer
+	tw *truncWriter
+}
+
+func (w *flateWrapper) Write(p []byte) (int, error) {
+	return w.fw.Write(p)
+}
+
+func (w *flateWrapper) Close() error {
+	err1 := w.fw.Flush()
+	if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
+		return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
+	}
+	err2 := w.tw.w.Close()
+	if err1 != nil {
+		return err1
+	}
+	return err2
+}

+ 31 - 0
compression_test.go

@@ -0,0 +1,31 @@
+package websocket
+
+import (
+	"bytes"
+	"io"
+	"testing"
+)
+
+type nopCloser struct{ io.Writer }
+
+func (nopCloser) Close() error { return nil }
+
+func TestTruncWriter(t *testing.T) {
+	const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321"
+	for n := 1; n <= 10; n++ {
+		var b bytes.Buffer
+		w := &truncWriter{w: nopCloser{&b}}
+		p := []byte(data)
+		for len(p) > 0 {
+			m := len(p)
+			if m > n {
+				m = n
+			}
+			w.Write(p[:m])
+			p = p[m:]
+		}
+		if b.String() != data[:len(data)-len(w.p)] {
+			t.Errorf("%d: %q", n, b.String())
+		}
+	}
+}

+ 75 - 32
conn.go

@@ -18,11 +18,19 @@ import (
 )
 
 const (
+	// Frame header byte 0 bits from Section 5.2 of RFC 6455
+	finalBit = 1 << 7
+	rsv1Bit  = 1 << 6
+	rsv2Bit  = 1 << 5
+	rsv3Bit  = 1 << 4
+
+	// Frame header byte 1 bits from Section 5.2 of RFC 6455
+	maskBit = 1 << 7
+
 	maxFrameHeaderSize         = 2 + 8 + 4 // Fixed header + length + mask
 	maxControlFramePayloadSize = 125
-	finalBit                   = 1 << 7
-	maskBit                    = 1 << 7
-	writeWait                  = time.Second
+
+	writeWait = time.Second
 
 	defaultReadBufferSize  = 4096
 	defaultWriteBufferSize = 4096
@@ -230,17 +238,20 @@ type Conn struct {
 	subprotocol string
 
 	// Write fields
-	mu        chan bool // used as mutex to protect write to conn and closeSent
-	closeSent bool      // true if close message was sent
-
-	// Message writer fields.
+	mu             chan bool // used as mutex to protect write to conn and closeSent
+	closeSent      bool      // whether close message was sent
 	writeErr       error
 	writeBuf       []byte // frame is constructed in this buffer.
 	writePos       int    // end of data in writeBuf.
 	writeFrameType int    // type of the current frame.
 	writeDeadline  time.Time
+	messageWriter  *messageWriter // the current low-level message writer
+	writer         io.WriteCloser // the current writer returned to the application
 	isWriting      bool           // for best-effort concurrent write detection
-	messageWriter  *messageWriter // the current writer
+
+	enableWriteCompression bool
+	writeCompress          bool // whether next call to flushFrame should set RSV1
+	newCompressionWriter   func(io.WriteCloser) (io.WriteCloser, error)
 
 	// Read fields
 	readErr       error
@@ -254,7 +265,10 @@ type Conn struct {
 	handlePong    func(string) error
 	handlePing    func(string) error
 	readErrCount  int
-	messageReader *messageReader // the current reader
+	messageReader *messageReader // the current low-level reader
+
+	readDecompress         bool // whether last read frame had RSV1 set
+	newDecompressionReader func(io.Reader) io.Reader
 }
 
 func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
@@ -272,14 +286,15 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
 	}
 
 	c := &Conn{
-		isServer:       isServer,
-		br:             bufio.NewReaderSize(conn, readBufferSize),
-		conn:           conn,
-		mu:             mu,
-		readFinal:      true,
-		writeBuf:       make([]byte, writeBufferSize+maxFrameHeaderSize),
-		writeFrameType: noFrame,
-		writePos:       maxFrameHeaderSize,
+		isServer:               isServer,
+		br:                     bufio.NewReaderSize(conn, readBufferSize),
+		conn:                   conn,
+		mu:                     mu,
+		readFinal:              true,
+		writeBuf:               make([]byte, writeBufferSize+maxFrameHeaderSize),
+		writeFrameType:         noFrame,
+		writePos:               maxFrameHeaderSize,
+		enableWriteCompression: true,
 	}
 	c.SetPingHandler(nil)
 	c.SetPongHandler(nil)
@@ -403,8 +418,12 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
 		return nil, c.writeErr
 	}
 
-	if c.writeFrameType != noFrame {
-		if err := c.flushFrame(true, nil); err != nil {
+	// Close previous writer if not already closed by the application. It's
+	// probably better to return an error in this situation, but we cannot
+	// change this without breaking existing applications.
+	if c.writer != nil {
+		err := c.writer.Close()
+		if err != nil {
 			return nil, err
 		}
 	}
@@ -414,11 +433,24 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
 	}
 
 	c.writeFrameType = messageType
-	w := &messageWriter{c}
-	c.messageWriter = w
+	c.messageWriter = &messageWriter{c}
+
+	var w io.WriteCloser = c.messageWriter
+	if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
+		c.writeCompress = true
+		var err error
+		w, err = c.newCompressionWriter(w)
+		if err != nil {
+			c.writer.Close()
+			return nil, err
+		}
+	}
+
 	return w, nil
 }
 
+// flushFrame writes buffered data and extra as a frame to the network. The
+// final argument indicates that this is the last frame in the message.
 func (c *Conn) flushFrame(final bool, extra []byte) error {
 	length := c.writePos - maxFrameHeaderSize + len(extra)
 
@@ -426,6 +458,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
 	if isControl(c.writeFrameType) &&
 		(!final || length > maxControlFramePayloadSize) {
 		c.messageWriter = nil
+		c.writer = nil
 		c.writeFrameType = noFrame
 		c.writePos = maxFrameHeaderSize
 		return errInvalidControlFrame
@@ -435,6 +468,11 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
 	if final {
 		b0 |= finalBit
 	}
+	if c.writeCompress {
+		b0 |= rsv1Bit
+	}
+	c.writeCompress = false
+
 	b1 := byte(0)
 	if !c.isServer {
 		b1 |= maskBit
@@ -494,6 +532,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
 	c.writeFrameType = continuationFrame
 	if final {
 		c.messageWriter = nil
+		c.writer = nil
 		c.writeFrameType = noFrame
 	}
 	return c.writeErr
@@ -526,14 +565,14 @@ func (w *messageWriter) ncopy(max int) (int, error) {
 	return n, nil
 }
 
-func (w *messageWriter) write(final bool, p []byte) (int, error) {
+func (w *messageWriter) Write(p []byte) (int, error) {
 	if err := w.err(); err != nil {
 		return 0, err
 	}
 
 	if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
 		// Don't buffer large messages.
-		err := w.c.flushFrame(final, p)
+		err := w.c.flushFrame(false, p)
 		if err != nil {
 			return 0, err
 		}
@@ -553,10 +592,6 @@ func (w *messageWriter) write(final bool, p []byte) (int, error) {
 	return nn, nil
 }
 
-func (w *messageWriter) Write(p []byte) (int, error) {
-	return w.write(false, p)
-}
-
 func (w *messageWriter) WriteString(p string) (int, error) {
 	if err := w.err(); err != nil {
 		return 0, err
@@ -658,12 +693,17 @@ func (c *Conn) advanceFrame() (int, error) {
 
 	final := p[0]&finalBit != 0
 	frameType := int(p[0] & 0xf)
-	reserved := int((p[0] >> 4) & 0x7)
 	mask := p[1]&maskBit != 0
 	c.readRemaining = int64(p[1] & 0x7f)
 
-	if reserved != 0 {
-		return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
+	c.readDecompress = false
+	if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
+		c.readDecompress = true
+		p[0] &^= rsv1Bit
+	}
+
+	if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
+		return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
 	}
 
 	switch frameType {
@@ -807,8 +847,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
 			break
 		}
 		if frameType == TextMessage || frameType == BinaryMessage {
-			r := &messageReader{c}
-			c.messageReader = r
+			c.messageReader = &messageReader{c}
+			var r io.Reader = c.messageReader
+			if c.readDecompress {
+				r = c.newDecompressionReader(r)
+			}
 			return frameType, r, nil
 		}
 	}