Browse Source

go.crypto/ssh: reuse packet buffer for channel writes.

Test that different extended data streams within a channel are
thread-safe.

benchmark             old MB/s     new MB/s     speedup
BenchmarkEndToEnd     79.26        87.98        1.11x

benchmark                          old allocs     new allocs     delta
BenchmarkEndToEnd                  110            73             -33.64%

benchmark                          old bytes     new bytes     delta
BenchmarkEndToEnd                  2605720       1299768       -50.12%

LGTM=dave, jpsugar
R=agl, dave, jpsugar
CC=golang-codereviews
https://golang.org/cl/136420043
Han-Wen Nienhuys 11 years ago
parent
commit
fc84ae5437
3 changed files with 74 additions and 7 deletions
  1. 29 6
      ssh/channel.go
  2. 3 1
      ssh/mempipe_test.go
  3. 42 0
      ssh/mux_test.go

+ 29 - 6
ssh/channel.go

@@ -69,8 +69,10 @@ type Channel interface {
 	// if the data stream is closed or blocked by flow control.
 	// if the data stream is closed or blocked by flow control.
 	SendRequest(name string, wantReply bool, payload []byte) (bool, error)
 	SendRequest(name string, wantReply bool, payload []byte) (bool, error)
 
 
-	// Stderr returns an io.ReadWriter that writes to this channel with the
-	// extended data type set to stderr.
+	// Stderr returns an io.ReadWriter that writes to this channel
+	// with the extended data type set to stderr. Stderr may
+	// safely be read and written from a different goroutine than
+	// Read and Write respectively.
 	Stderr() io.ReadWriter
 	Stderr() io.ReadWriter
 }
 }
 
 
@@ -188,11 +190,15 @@ type channel struct {
 	myWindow uint32
 	myWindow uint32
 
 
 	// writeMu serializes calls to mux.conn.writePacket() and
 	// writeMu serializes calls to mux.conn.writePacket() and
-	// protects sentClose. This mutex must be different from
-	// windowMu, as writePacket can block if there is a key
-	// exchange pending
+	// protects sentClose and packetPool. This mutex must be
+	// different from windowMu, as writePacket can block if there
+	// is a key exchange pending.
 	writeMu   sync.Mutex
 	writeMu   sync.Mutex
 	sentClose bool
 	sentClose bool
+
+	// packetPool has a buffer for each extended channel ID to
+	// save allocations during writes.
+	packetPool map[uint32][]byte
 }
 }
 
 
 // writePacket sends a packet. If the packet is a channel close, it updates
 // writePacket sends a packet. If the packet is a channel close, it updates
@@ -233,14 +239,26 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
 		opCode = msgChannelExtendedData
 		opCode = msgChannelExtendedData
 	}
 	}
 
 
+	c.writeMu.Lock()
+	packet := c.packetPool[extendedCode]
+	// We don't remove the buffer from packetPool, so
+	// WriteExtended calls from different goroutines will be
+	// flagged as errors by the race detector.
+	c.writeMu.Unlock()
+
 	for len(data) > 0 {
 	for len(data) > 0 {
 		space := min(c.maxRemotePayload, len(data))
 		space := min(c.maxRemotePayload, len(data))
 		if space, err = c.remoteWin.reserve(space); err != nil {
 		if space, err = c.remoteWin.reserve(space); err != nil {
 			return n, err
 			return n, err
 		}
 		}
+		if want := headerLength + space; uint32(cap(packet)) < want {
+			packet = make([]byte, want)
+		} else {
+			packet = packet[:want]
+		}
+
 		todo := data[:space]
 		todo := data[:space]
 
 
-		packet := make([]byte, headerLength+uint32(len(todo)))
 		packet[0] = opCode
 		packet[0] = opCode
 		binary.BigEndian.PutUint32(packet[1:], c.remoteId)
 		binary.BigEndian.PutUint32(packet[1:], c.remoteId)
 		if extendedCode > 0 {
 		if extendedCode > 0 {
@@ -256,6 +274,10 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
 		data = data[len(todo):]
 		data = data[len(todo):]
 	}
 	}
 
 
+	c.writeMu.Lock()
+	c.packetPool[extendedCode] = packet
+	c.writeMu.Unlock()
+
 	return n, err
 	return n, err
 }
 }
 
 
@@ -442,6 +464,7 @@ func (m *mux) newChannel(chanType string, direction channelDirection, extraData
 		chanType:         chanType,
 		chanType:         chanType,
 		extraData:        extraData,
 		extraData:        extraData,
 		mux:              m,
 		mux:              m,
+		packetPool:       make(map[uint32][]byte),
 	}
 	}
 	ch.localId = m.chanList.add(ch)
 	ch.localId = m.chanList.add(ch)
 	return ch
 	return ch

+ 3 - 1
ssh/mempipe_test.go

@@ -59,7 +59,9 @@ func (t *memTransport) writePacket(p []byte) error {
 	if t.write.eof {
 	if t.write.eof {
 		return io.EOF
 		return io.EOF
 	}
 	}
-	t.write.pending = append(t.write.pending, p)
+	c := make([]byte, len(p))
+	copy(c, p)
+	t.write.pending = append(t.write.pending, c)
 	t.write.Cond.Signal()
 	t.write.Cond.Signal()
 	return nil
 	return nil
 }
 }

+ 42 - 0
ssh/mux_test.go

@@ -49,6 +49,48 @@ func channelPair(t *testing.T) (*channel, *channel, *mux) {
 	return <-res, ch, c
 	return <-res, ch, c
 }
 }
 
 
+// Test that stderr and stdout can be addressed from different
+// goroutines. This is intended for use with the race detector.
+func TestMuxChannelExtendedThreadSafety(t *testing.T) {
+	writer, reader, mux := channelPair(t)
+	defer writer.Close()
+	defer reader.Close()
+	defer mux.Close()
+
+	var wr, rd sync.WaitGroup
+	magic := "hello world"
+
+	wr.Add(2)
+	go func() {
+		io.WriteString(writer, magic)
+		wr.Done()
+	}()
+	go func() {
+		io.WriteString(writer.Stderr(), magic)
+		wr.Done()
+	}()
+
+	rd.Add(2)
+	go func() {
+		c, err := ioutil.ReadAll(reader)
+		if string(c) != magic {
+			t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err)
+		}
+		rd.Done()
+	}()
+	go func() {
+		c, err := ioutil.ReadAll(reader.Stderr())
+		if string(c) != magic {
+			t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err)
+		}
+		rd.Done()
+	}()
+
+	wr.Wait()
+	writer.CloseWrite()
+	rd.Wait()
+}
+
 func TestMuxReadWrite(t *testing.T) {
 func TestMuxReadWrite(t *testing.T) {
 	s, c, mux := channelPair(t)
 	s, c, mux := channelPair(t)
 	defer s.Close()
 	defer s.Close()