浏览代码

Prepared Messages (#211)

Gary Burd 8 年之前
父节点
当前提交
804cb600d0
共有 6 个文件被更改,包括 357 次插入7 次删除
  1. 22 1
      conn.go
  2. 134 0
      conn_broadcast_test.go
  3. 1 0
      examples/autobahn/fuzzingclient.json
  4. 23 6
      examples/autobahn/server.go
  5. 103 0
      prepared.go
  6. 74 0
      prepared_test.go

+ 22 - 1
conn.go

@@ -659,12 +659,33 @@ func (w *messageWriter) Close() error {
 	return nil
 }
 
+// WritePreparedMessage writes prepared message into connection.
+func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error {
+	frameType, frameData, err := pm.frame(prepareKey{
+		isServer:         c.isServer,
+		compress:         c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType),
+		compressionLevel: c.compressionLevel,
+	})
+	if err != nil {
+		return err
+	}
+	if c.isWriting {
+		panic("concurrent write to websocket connection")
+	}
+	c.isWriting = true
+	err = c.write(frameType, c.writeDeadline, frameData, nil)
+	if !c.isWriting {
+		panic("concurrent write to websocket connection")
+	}
+	c.isWriting = false
+	return err
+}
+
 // WriteMessage is a helper method for getting a writer using NextWriter,
 // writing the message and closing the writer.
 func (c *Conn) WriteMessage(messageType int, data []byte) error {
 
 	if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
-
 		// Fast path with no allocations and single frame.
 
 		if err := c.prepWrite(messageType); err != nil {

+ 134 - 0
conn_broadcast_test.go

@@ -0,0 +1,134 @@
+// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.7
+
+package websocket
+
+import (
+	"io"
+	"io/ioutil"
+	"sync/atomic"
+	"testing"
+)
+
+// broadcastBench allows to run broadcast benchmarks.
+// In every broadcast benchmark we create many connections, then send the same
+// message into every connection and wait for all writes complete. This emulates
+// an application where many connections listen to the same data - i.e. PUB/SUB
+// scenarios with many subscribers in one channel.
+type broadcastBench struct {
+	w           io.Writer
+	message     *broadcastMessage
+	closeCh     chan struct{}
+	doneCh      chan struct{}
+	count       int32
+	conns       []*broadcastConn
+	compression bool
+	usePrepared bool
+}
+
+type broadcastMessage struct {
+	payload  []byte
+	prepared *PreparedMessage
+}
+
+type broadcastConn struct {
+	conn  *Conn
+	msgCh chan *broadcastMessage
+}
+
+func newBroadcastConn(c *Conn) *broadcastConn {
+	return &broadcastConn{
+		conn:  c,
+		msgCh: make(chan *broadcastMessage, 1),
+	}
+}
+
+func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
+	bench := &broadcastBench{
+		w:           ioutil.Discard,
+		doneCh:      make(chan struct{}),
+		closeCh:     make(chan struct{}),
+		usePrepared: usePrepared,
+		compression: compression,
+	}
+	msg := &broadcastMessage{
+		payload: textMessages(1)[0],
+	}
+	if usePrepared {
+		pm, _ := NewPreparedMessage(TextMessage, msg.payload)
+		msg.prepared = pm
+	}
+	bench.message = msg
+	bench.makeConns(10000)
+	return bench
+}
+
+func (b *broadcastBench) makeConns(numConns int) {
+	conns := make([]*broadcastConn, numConns)
+
+	for i := 0; i < numConns; i++ {
+		c := newConn(fakeNetConn{Reader: nil, Writer: b.w}, true, 1024, 1024)
+		if b.compression {
+			c.enableWriteCompression = true
+			c.newCompressionWriter = compressNoContextTakeover
+		}
+		conns[i] = newBroadcastConn(c)
+		go func(c *broadcastConn) {
+			for {
+				select {
+				case msg := <-c.msgCh:
+					if b.usePrepared {
+						c.conn.WritePreparedMessage(msg.prepared)
+					} else {
+						c.conn.WriteMessage(TextMessage, msg.payload)
+					}
+					val := atomic.AddInt32(&b.count, 1)
+					if val%int32(numConns) == 0 {
+						b.doneCh <- struct{}{}
+					}
+				case <-b.closeCh:
+					return
+				}
+			}
+		}(conns[i])
+	}
+	b.conns = conns
+}
+
+func (b *broadcastBench) close() {
+	close(b.closeCh)
+}
+
+func (b *broadcastBench) runOnce() {
+	for _, c := range b.conns {
+		c.msgCh <- b.message
+	}
+	<-b.doneCh
+}
+
+func BenchmarkBroadcast(b *testing.B) {
+	benchmarks := []struct {
+		name        string
+		usePrepared bool
+		compression bool
+	}{
+		{"NoCompression", false, false},
+		{"WithCompression", false, true},
+		{"NoCompressionPrepared", true, false},
+		{"WithCompressionPrepared", true, true},
+	}
+	for _, bm := range benchmarks {
+		b.Run(bm.name, func(b *testing.B) {
+			bench := newBroadcastBench(bm.usePrepared, bm.compression)
+			defer bench.close()
+			b.ResetTimer()
+			for i := 0; i < b.N; i++ {
+				bench.runOnce()
+			}
+			b.ReportAllocs()
+		})
+	}
+}

+ 1 - 0
examples/autobahn/fuzzingclient.json

@@ -4,6 +4,7 @@
    "outdir": "./reports/clients",
    "servers": [
         {"agent": "ReadAllWriteMessage", "url": "ws://localhost:9000/m", "options": {"version": 18}},
+        {"agent": "ReadAllWritePreparedMessage", "url": "ws://localhost:9000/p", "options": {"version": 18}},
         {"agent": "ReadAllWrite", "url": "ws://localhost:9000/r", "options": {"version": 18}},
         {"agent": "CopyFull", "url": "ws://localhost:9000/f", "options": {"version": 18}},
         {"agent": "CopyWriterOnly", "url": "ws://localhost:9000/c", "options": {"version": 18}}

+ 23 - 6
examples/autobahn/server.go

@@ -85,7 +85,7 @@ func echoCopyFull(w http.ResponseWriter, r *http.Request) {
 
 // echoReadAll echoes messages from the client by reading the entire message
 // with ioutil.ReadAll.
-func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage bool) {
+func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) {
 	conn, err := upgrader.Upgrade(w, r, nil)
 	if err != nil {
 		log.Println("Upgrade:", err)
@@ -109,9 +109,21 @@ func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage bool) {
 			}
 		}
 		if writeMessage {
-			err = conn.WriteMessage(mt, b)
-			if err != nil {
-				log.Println("WriteMessage:", err)
+			if !writePrepared {
+				err = conn.WriteMessage(mt, b)
+				if err != nil {
+					log.Println("WriteMessage:", err)
+				}
+			} else {
+				pm, err := websocket.NewPreparedMessage(mt, b)
+				if err != nil {
+					log.Println("NewPreparedMessage:", err)
+					return
+				}
+				err = conn.WritePreparedMessage(pm)
+				if err != nil {
+					log.Println("WritePreparedMessage:", err)
+				}
 			}
 		} else {
 			w, err := conn.NextWriter(mt)
@@ -132,11 +144,15 @@ func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage bool) {
 }
 
 func echoReadAllWriter(w http.ResponseWriter, r *http.Request) {
-	echoReadAll(w, r, false)
+	echoReadAll(w, r, false, false)
 }
 
 func echoReadAllWriteMessage(w http.ResponseWriter, r *http.Request) {
-	echoReadAll(w, r, true)
+	echoReadAll(w, r, true, false)
+}
+
+func echoReadAllWritePreparedMessage(w http.ResponseWriter, r *http.Request) {
+	echoReadAll(w, r, true, true)
 }
 
 func serveHome(w http.ResponseWriter, r *http.Request) {
@@ -161,6 +177,7 @@ func main() {
 	http.HandleFunc("/f", echoCopyFull)
 	http.HandleFunc("/r", echoReadAllWriter)
 	http.HandleFunc("/m", echoReadAllWriteMessage)
+	http.HandleFunc("/p", echoReadAllWritePreparedMessage)
 	err := http.ListenAndServe(*addr, nil)
 	if err != nil {
 		log.Fatal("ListenAndServe: ", err)

+ 103 - 0
prepared.go

@@ -0,0 +1,103 @@
+// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package websocket
+
+import (
+	"bytes"
+	"net"
+	"sync"
+	"time"
+)
+
+// PreparedMessage caches on the wire representations of a message payload.
+// Use PreparedMessage to efficiently send a message payload to multiple
+// connections. PreparedMessage is especially useful when compression is used
+// because the CPU and memory expensive compression operation can be executed
+// once for a given set of compression options.
+type PreparedMessage struct {
+	messageType int
+	data        []byte
+	err         error
+	mu          sync.Mutex
+	frames      map[prepareKey]*preparedFrame
+}
+
+// prepareKey defines a unique set of options to cache prepared frames in PreparedMessage.
+type prepareKey struct {
+	isServer         bool
+	compress         bool
+	compressionLevel int
+}
+
+// preparedFrame contains data in wire representation.
+type preparedFrame struct {
+	once sync.Once
+	data []byte
+}
+
+// NewPreparedMessage returns an initialized PreparedMessage. You can then send
+// it to connection using WritePreparedMessage method. Valid wire
+// representation will be calculated lazily only once for a set of current
+// connection options.
+func NewPreparedMessage(messageType int, data []byte) (*PreparedMessage, error) {
+	pm := &PreparedMessage{
+		messageType: messageType,
+		frames:      make(map[prepareKey]*preparedFrame),
+		data:        data,
+	}
+
+	// Prepare a plain server frame.
+	_, frameData, err := pm.frame(prepareKey{isServer: true, compress: false})
+	if err != nil {
+		return nil, err
+	}
+
+	// To protect against caller modifying the data argument, remember the data
+	// copied to the plain server frame.
+	pm.data = frameData[len(frameData)-len(data):]
+	return pm, nil
+}
+
+func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) {
+	pm.mu.Lock()
+	frame, ok := pm.frames[key]
+	if !ok {
+		frame = &preparedFrame{}
+		pm.frames[key] = frame
+	}
+	pm.mu.Unlock()
+
+	var err error
+	frame.once.Do(func() {
+		// Prepare a frame using a 'fake' connection.
+		// TODO: Refactor code in conn.go to allow more direct construction of
+		// the frame.
+		mu := make(chan bool, 1)
+		mu <- true
+		var nc prepareConn
+		c := &Conn{
+			conn:                   &nc,
+			mu:                     mu,
+			isServer:               key.isServer,
+			compressionLevel:       key.compressionLevel,
+			enableWriteCompression: true,
+			writeBuf:               make([]byte, defaultWriteBufferSize+maxFrameHeaderSize),
+		}
+		if key.compress {
+			c.newCompressionWriter = compressNoContextTakeover
+		}
+		err = c.WriteMessage(pm.messageType, pm.data)
+		frame.data = nc.buf.Bytes()
+	})
+	return pm.messageType, frame.data, err
+}
+
+type prepareConn struct {
+	buf bytes.Buffer
+	net.Conn
+}
+
+func (pc *prepareConn) Write(p []byte) (int, error)        { return pc.buf.Write(p) }
+func (pc *prepareConn) SetWriteDeadline(t time.Time) error { return nil }

+ 74 - 0
prepared_test.go

@@ -0,0 +1,74 @@
+// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package websocket
+
+import (
+	"bytes"
+	"compress/flate"
+	"math/rand"
+	"testing"
+)
+
+var preparedMessageTests = []struct {
+	messageType            int
+	isServer               bool
+	enableWriteCompression bool
+	compressionLevel       int
+}{
+	// Server
+	{TextMessage, true, false, flate.BestSpeed},
+	{TextMessage, true, true, flate.BestSpeed},
+	{TextMessage, true, true, flate.BestCompression},
+	{PingMessage, true, false, flate.BestSpeed},
+	{PingMessage, true, true, flate.BestSpeed},
+
+	// Client
+	{TextMessage, false, false, flate.BestSpeed},
+	{TextMessage, false, true, flate.BestSpeed},
+	{TextMessage, false, true, flate.BestCompression},
+	{PingMessage, false, false, flate.BestSpeed},
+	{PingMessage, false, true, flate.BestSpeed},
+}
+
+func TestPreparedMessage(t *testing.T) {
+	for _, tt := range preparedMessageTests {
+		var data = []byte("this is a test")
+		var buf bytes.Buffer
+		c := newConn(fakeNetConn{Reader: nil, Writer: &buf}, tt.isServer, 1024, 1024)
+		if tt.enableWriteCompression {
+			c.newCompressionWriter = compressNoContextTakeover
+		}
+		c.SetCompressionLevel(tt.compressionLevel)
+
+		// Seed random number generator for consistent frame mask.
+		rand.Seed(1234)
+
+		if err := c.WriteMessage(tt.messageType, data); err != nil {
+			t.Fatal(err)
+		}
+		want := buf.String()
+
+		pm, err := NewPreparedMessage(tt.messageType, data)
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		// Scribble on data to ensure that NewPreparedMessage takes a snapshot.
+		copy(data, "hello world")
+
+		// Seed random number generator for consistent frame mask.
+		rand.Seed(1234)
+
+		buf.Reset()
+		if err := c.WritePreparedMessage(pm); err != nil {
+			t.Fatal(err)
+		}
+		got := buf.String()
+
+		if got != want {
+			t.Errorf("write message != prepared message for %+v", tt)
+		}
+	}
+}