Browse Source

Merge pull request #167 from garyburd/master

enable compression
Gary Burd 9 năm trước cách đây
mục cha
commit
8003df83ee
7 tập tin đã thay đổi với 165 bổ sung57 xóa
  1. 27 0
      client.go
  2. 18 3
      client_server_test.go
  3. 7 0
      conn.go
  4. 56 51
      conn_test.go
  5. 21 0
      doc.go
  6. 5 3
      examples/autobahn/server.go
  7. 31 0
      server.go

+ 27 - 0
client.go

@@ -23,6 +23,8 @@ import (
 // invalid.
 var ErrBadHandshake = errors.New("websocket: bad handshake")
 
+var errInvalidCompression = errors.New("websocket: invalid compression negotiation")
+
 // NewClient creates a new client connection using the given net connection.
 // The URL u specifies the host and request URI. Use requestHeader to specify
 // the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies
@@ -70,6 +72,12 @@ type Dialer struct {
 
 	// Subprotocols specifies the client's requested subprotocols.
 	Subprotocols []string
+
+	// EnableCompression specifies if the client should attempt to negotiate
+	// per message compression (RFC 7692). Setting this value to true does not
+	// guarantee that compression will be supported. Currently only "no context
+	// takeover" modes are supported.
+	EnableCompression bool
 }
 
 var errMalformedURL = errors.New("malformed ws or wss URL")
@@ -214,6 +222,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
 			k == "Connection" ||
 			k == "Sec-Websocket-Key" ||
 			k == "Sec-Websocket-Version" ||
+			k == "Sec-Websocket-Extensions" ||
 			(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
 			return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
 		default:
@@ -221,6 +230,10 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
 		}
 	}
 
+	if d.EnableCompression {
+		req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
+	}
+
 	hostPort, hostNoPort := hostPortNoPort(u)
 
 	var proxyURL *url.URL
@@ -337,6 +350,20 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
 		return nil, resp, ErrBadHandshake
 	}
 
+	for _, ext := range parseExtensions(req.Header) {
+		if ext[""] != "permessage-deflate" {
+			continue
+		}
+		_, snct := ext["server_no_context_takeover"]
+		_, cnct := ext["client_no_context_takeover"]
+		if !snct || !cnct {
+			return nil, resp, errInvalidCompression
+		}
+		conn.newCompressionWriter = compressNoContextTakeover
+		conn.newDecompressionReader = decompressNoContextTakeover
+		break
+	}
+
 	resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
 	conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
 

+ 18 - 3
client_server_test.go

@@ -20,9 +20,10 @@ import (
 )
 
 var cstUpgrader = Upgrader{
-	Subprotocols:    []string{"p0", "p1"},
-	ReadBufferSize:  1024,
-	WriteBufferSize: 1024,
+	Subprotocols:      []string{"p0", "p1"},
+	ReadBufferSize:    1024,
+	WriteBufferSize:   1024,
+	EnableCompression: true,
 	Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {
 		http.Error(w, reason.Error(), status)
 	},
@@ -446,3 +447,17 @@ func TestHostHeader(t *testing.T) {
 
 	sendRecv(t, ws)
 }
+
+func TestDialCompression(t *testing.T) {
+	s := newServer(t)
+	defer s.Close()
+
+	dialer := cstDialer
+	dialer.EnableCompression = true
+	ws, _, err := dialer.Dial(s.URL, nil)
+	if err != nil {
+		t.Fatalf("Dial: %v", err)
+	}
+	defer ws.Close()
+	sendRecv(t, ws)
+}

+ 7 - 0
conn.go

@@ -985,6 +985,13 @@ func (c *Conn) UnderlyingConn() net.Conn {
 	return c.conn
 }
 
+// EnableWriteCompression enables and disables write compression of
+// subsequent text and binary messages. This function is a noop if
+// compression was not negotiated with the peer.
+func (c *Conn) EnableWriteCompression(enable bool) {
+	c.enableWriteCompression = enable
+}
+
 // FormatCloseMessage formats closeCode and text as a WebSocket close message.
 func FormatCloseMessage(closeCode int, text string) []byte {
 	buf := make([]byte, 2+len(text))

+ 56 - 51
conn_test.go

@@ -48,60 +48,65 @@ func TestFraming(t *testing.T) {
 		writeBuf[i] = byte(i)
 	}
 
-	for _, isServer := range []bool{true, false} {
-		for _, chunker := range readChunkers {
-
-			var connBuf bytes.Buffer
-			wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
-			rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024)
-
-			for _, n := range frameSizes {
-				for _, iocopy := range []bool{true, false} {
-					name := fmt.Sprintf("s:%v, r:%s, n:%d c:%v", isServer, chunker.name, n, iocopy)
-
-					w, err := wc.NextWriter(TextMessage)
-					if err != nil {
-						t.Errorf("%s: wc.NextWriter() returned %v", name, err)
-						continue
-					}
-					var nn int
-					if iocopy {
-						var n64 int64
-						n64, err = io.Copy(w, bytes.NewReader(writeBuf[:n]))
-						nn = int(n64)
-					} else {
-						nn, err = w.Write(writeBuf[:n])
-					}
-					if err != nil || nn != n {
-						t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
-						continue
-					}
-					err = w.Close()
-					if err != nil {
-						t.Errorf("%s: w.Close() returned %v", name, err)
-						continue
-					}
+	for _, compress := range []bool{false, true} {
+		for _, isServer := range []bool{true, false} {
+			for _, chunker := range readChunkers {
+
+				var connBuf bytes.Buffer
+				wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
+				rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024)
+				if compress {
+					wc.newCompressionWriter = compressNoContextTakeover
+					rc.newDecompressionReader = decompressNoContextTakeover
+				}
+				for _, n := range frameSizes {
+					for _, iocopy := range []bool{true, false} {
+						name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d c:%v", compress, isServer, chunker.name, n, iocopy)
+
+						w, err := wc.NextWriter(TextMessage)
+						if err != nil {
+							t.Errorf("%s: wc.NextWriter() returned %v", name, err)
+							continue
+						}
+						var nn int
+						if iocopy {
+							var n64 int64
+							n64, err = io.Copy(w, bytes.NewReader(writeBuf[:n]))
+							nn = int(n64)
+						} else {
+							nn, err = w.Write(writeBuf[:n])
+						}
+						if err != nil || nn != n {
+							t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
+							continue
+						}
+						err = w.Close()
+						if err != nil {
+							t.Errorf("%s: w.Close() returned %v", name, err)
+							continue
+						}
 
-					opCode, r, err := rc.NextReader()
-					if err != nil || opCode != TextMessage {
-						t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
-						continue
-					}
-					rbuf, err := ioutil.ReadAll(r)
-					if err != nil {
-						t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
-						continue
-					}
+						opCode, r, err := rc.NextReader()
+						if err != nil || opCode != TextMessage {
+							t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
+							continue
+						}
+						rbuf, err := ioutil.ReadAll(r)
+						if err != nil {
+							t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
+							continue
+						}
 
-					if len(rbuf) != n {
-						t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
-						continue
-					}
+						if len(rbuf) != n {
+							t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
+							continue
+						}
 
-					for i, b := range rbuf {
-						if byte(i) != b {
-							t.Errorf("%s: bad byte at offset %d", name, i)
-							break
+						for i, b := range rbuf {
+							if byte(i) != b {
+								t.Errorf("%s: bad byte at offset %d", name, i)
+								break
+							}
 						}
 					}
 				}

+ 21 - 0
doc.go

@@ -149,4 +149,25 @@
 // The deprecated Upgrade function does not enforce an origin policy. It's the
 // application's responsibility to check the Origin header before calling
 // Upgrade.
+//
+// Compression [Experimental]
+//
+// Per message compression extensions (RFC 7692) are experimentally supported
+// by this package in a limited capacity. Setting the EnableCompression option
+// to true in Dialer or Upgrader will attempt to negotiate per message deflate
+// support. If compression was successfully negotiated with the connection's
+// peer, any message received in compressed form will be automatically
+// decompressed. All Read methods will return uncompressed bytes.
+//
+// Per message compression of messages written to a connection can be enabled
+// or disabled by calling the corresponding Conn method:
+//
+// conn.EnableWriteCompression(true)
+//
+// Currently this package does not support compression with "context takeover".
+// This means that messages must be compressed and decompressed in isolation,
+// without retaining sliding window or dictionary state across messages. For
+// more details refer to RFC 7692.
+//
+// Use of compression is experimental and may result in decreased performance.
 package websocket

+ 5 - 3
examples/autobahn/server.go

@@ -8,17 +8,19 @@ package main
 import (
 	"errors"
 	"flag"
-	"github.com/gorilla/websocket"
 	"io"
 	"log"
 	"net/http"
 	"time"
 	"unicode/utf8"
+
+	"github.com/gorilla/websocket"
 )
 
 var upgrader = websocket.Upgrader{
-	ReadBufferSize:  4096,
-	WriteBufferSize: 4096,
+	ReadBufferSize:    4096,
+	WriteBufferSize:   4096,
+	EnableCompression: true,
 	CheckOrigin: func(r *http.Request) bool {
 		return true
 	},

+ 31 - 0
server.go

@@ -46,6 +46,12 @@ type Upgrader struct {
 	// CheckOrigin is nil, the host in the Origin header must not be set or
 	// must match the host of the request.
 	CheckOrigin func(r *http.Request) bool
+
+	// EnableCompression specify if the server should attempt to negotiate per
+	// message compression (RFC 7692). Setting this value to true does not
+	// guarantee that compression will be supported. Currently only "no context
+	// takeover" modes are supported.
+	EnableCompression bool
 }
 
 func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) {
@@ -100,6 +106,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
 	if r.Method != "GET" {
 		return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET")
 	}
+
+	if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
+		return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific Sec-Websocket-Extensions headers are unsupported")
+	}
+
 	if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
 		return u.returnError(w, r, http.StatusBadRequest, "websocket: version != 13")
 	}
@@ -127,6 +138,18 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
 
 	subprotocol := u.selectSubprotocol(r, responseHeader)
 
+	// Negotiate PMCE
+	var compress bool
+	if u.EnableCompression {
+		for _, ext := range parseExtensions(r.Header) {
+			if ext[""] != "permessage-deflate" {
+				continue
+			}
+			compress = true
+			break
+		}
+	}
+
 	var (
 		netConn net.Conn
 		br      *bufio.Reader
@@ -152,6 +175,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
 	c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize)
 	c.subprotocol = subprotocol
 
+	if compress {
+		c.newCompressionWriter = compressNoContextTakeover
+		c.newDecompressionReader = decompressNoContextTakeover
+	}
+
 	p := c.writeBuf[:0]
 	p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
 	p = append(p, computeAcceptKey(challengeKey)...)
@@ -161,6 +189,9 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
 		p = append(p, c.subprotocol...)
 		p = append(p, "\r\n"...)
 	}
+	if compress {
+		p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
+	}
 	for k, vs := range responseHeader {
 		if k == "Sec-Websocket-Protocol" {
 			continue