Ver código fonte

compression: add tests, rename option

Gary Burd 9 anos atrás
pai
commit
a0ef436d00
6 arquivos alterados com 88 adições e 66 exclusões
  1. 4 4
      client.go
  2. 18 3
      client_server_test.go
  3. 56 51
      conn_test.go
  4. 2 2
      doc.go
  5. 5 3
      examples/autobahn/server.go
  6. 3 3
      server.go

+ 4 - 4
client.go

@@ -73,11 +73,11 @@ type Dialer struct {
 	// Subprotocols specifies the client's requested subprotocols.
 	Subprotocols []string
 
-	// CompressionSupported specifies if the client should attempt to negotiate per
-	// message compression (RFC 7692). Setting this value to true does not
+	// EnableCompression specifies if the client should attempt to negotiate
+	// per message compression (RFC 7692). Setting this value to true does not
 	// guarantee that compression will be supported. Currently only "no context
 	// takeover" modes are supported.
-	CompressionSupported bool
+	EnableCompression bool
 }
 
 var errMalformedURL = errors.New("malformed ws or wss URL")
@@ -230,7 +230,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
 		}
 	}
 
-	if d.CompressionSupported {
+	if d.EnableCompression {
 		req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
 	}
 

+ 18 - 3
client_server_test.go

@@ -20,9 +20,10 @@ import (
 )
 
 var cstUpgrader = Upgrader{
-	Subprotocols:    []string{"p0", "p1"},
-	ReadBufferSize:  1024,
-	WriteBufferSize: 1024,
+	Subprotocols:      []string{"p0", "p1"},
+	ReadBufferSize:    1024,
+	WriteBufferSize:   1024,
+	EnableCompression: true,
 	Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {
 		http.Error(w, reason.Error(), status)
 	},
@@ -446,3 +447,17 @@ func TestHostHeader(t *testing.T) {
 
 	sendRecv(t, ws)
 }
+
+func TestDialCompression(t *testing.T) {
+	s := newServer(t)
+	defer s.Close()
+
+	dialer := cstDialer
+	dialer.EnableCompression = true
+	ws, _, err := dialer.Dial(s.URL, nil)
+	if err != nil {
+		t.Fatalf("Dial: %v", err)
+	}
+	defer ws.Close()
+	sendRecv(t, ws)
+}

+ 56 - 51
conn_test.go

@@ -48,60 +48,65 @@ func TestFraming(t *testing.T) {
 		writeBuf[i] = byte(i)
 	}
 
-	for _, isServer := range []bool{true, false} {
-		for _, chunker := range readChunkers {
-
-			var connBuf bytes.Buffer
-			wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
-			rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024)
-
-			for _, n := range frameSizes {
-				for _, iocopy := range []bool{true, false} {
-					name := fmt.Sprintf("s:%v, r:%s, n:%d c:%v", isServer, chunker.name, n, iocopy)
-
-					w, err := wc.NextWriter(TextMessage)
-					if err != nil {
-						t.Errorf("%s: wc.NextWriter() returned %v", name, err)
-						continue
-					}
-					var nn int
-					if iocopy {
-						var n64 int64
-						n64, err = io.Copy(w, bytes.NewReader(writeBuf[:n]))
-						nn = int(n64)
-					} else {
-						nn, err = w.Write(writeBuf[:n])
-					}
-					if err != nil || nn != n {
-						t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
-						continue
-					}
-					err = w.Close()
-					if err != nil {
-						t.Errorf("%s: w.Close() returned %v", name, err)
-						continue
-					}
+	for _, compress := range []bool{false, true} {
+		for _, isServer := range []bool{true, false} {
+			for _, chunker := range readChunkers {
+
+				var connBuf bytes.Buffer
+				wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
+				rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024)
+				if compress {
+					wc.newCompressionWriter = compressNoContextTakeover
+					rc.newDecompressionReader = decompressNoContextTakeover
+				}
+				for _, n := range frameSizes {
+					for _, iocopy := range []bool{true, false} {
+						name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d c:%v", compress, isServer, chunker.name, n, iocopy)
+
+						w, err := wc.NextWriter(TextMessage)
+						if err != nil {
+							t.Errorf("%s: wc.NextWriter() returned %v", name, err)
+							continue
+						}
+						var nn int
+						if iocopy {
+							var n64 int64
+							n64, err = io.Copy(w, bytes.NewReader(writeBuf[:n]))
+							nn = int(n64)
+						} else {
+							nn, err = w.Write(writeBuf[:n])
+						}
+						if err != nil || nn != n {
+							t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
+							continue
+						}
+						err = w.Close()
+						if err != nil {
+							t.Errorf("%s: w.Close() returned %v", name, err)
+							continue
+						}
 
-					opCode, r, err := rc.NextReader()
-					if err != nil || opCode != TextMessage {
-						t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
-						continue
-					}
-					rbuf, err := ioutil.ReadAll(r)
-					if err != nil {
-						t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
-						continue
-					}
+						opCode, r, err := rc.NextReader()
+						if err != nil || opCode != TextMessage {
+							t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
+							continue
+						}
+						rbuf, err := ioutil.ReadAll(r)
+						if err != nil {
+							t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
+							continue
+						}
 
-					if len(rbuf) != n {
-						t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
-						continue
-					}
+						if len(rbuf) != n {
+							t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
+							continue
+						}
 
-					for i, b := range rbuf {
-						if byte(i) != b {
-							t.Errorf("%s: bad byte at offset %d", name, i)
-							break
+						for i, b := range rbuf {
+							if byte(i) != b {
+								t.Errorf("%s: bad byte at offset %d", name, i)
+								break
+							}
 						}
 					}
 				}

+ 2 - 2
doc.go

@@ -153,8 +153,8 @@
 // Compression [Experimental]
 //
 // Per message compression extensions (RFC 7692) are experimentally supported
-// by this package in a limited capacity. Enabling the CompressionSupported
-// option in Dialer or Upgrader will attempt to negotiate per message deflate
+// by this package in a limited capacity. Setting the EnableCompression option
+// to true in Dialer or Upgrader will attempt to negotiate per message deflate
 // support. If compression was successfully negotiated with the connection's
 // peer, any message received in compressed form will be automatically
 // decompressed. All Read methods will return uncompressed bytes.

+ 5 - 3
examples/autobahn/server.go

@@ -8,17 +8,19 @@ package main
 import (
 	"errors"
 	"flag"
-	"github.com/gorilla/websocket"
 	"io"
 	"log"
 	"net/http"
 	"time"
 	"unicode/utf8"
+
+	"github.com/gorilla/websocket"
 )
 
 var upgrader = websocket.Upgrader{
-	ReadBufferSize:  4096,
-	WriteBufferSize: 4096,
+	ReadBufferSize:    4096,
+	WriteBufferSize:   4096,
+	EnableCompression: true,
 	CheckOrigin: func(r *http.Request) bool {
 		return true
 	},

+ 3 - 3
server.go

@@ -47,11 +47,11 @@ type Upgrader struct {
 	// must match the host of the request.
 	CheckOrigin func(r *http.Request) bool
 
-	// CompressionSupported specify if the server should attempt to negotiate per
+	// EnableCompression specify if the server should attempt to negotiate per
 	// message compression (RFC 7692). Setting this value to true does not
 	// guarantee that compression will be supported. Currently only "no context
 	// takeover" modes are supported.
-	CompressionSupported bool
+	EnableCompression bool
 }
 
 func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) {
@@ -140,7 +140,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
 
 	// Negotiate PMCE
 	var compress bool
-	if u.CompressionSupported {
+	if u.EnableCompression {
 		for _, ext := range parseExtensions(r.Header) {
 			if ext[""] != "permessage-deflate" {
 				continue