Browse Source

go.crypto/ssh: add Stderr() in Channel interface.

Adds support for piping Stderr to the client.

R=golang-dev, dave, agl
CC=golang-dev
https://golang.org/cl/5674081
Daniel Theophanes 13 years ago
parent
commit
6c548e9506
1 changed files with 64 additions and 12 deletions
  1. 64 12
      ssh/channel.go

+ 64 - 12
ssh/channel.go

@@ -10,8 +10,15 @@ import (
 	"sync"
 )
 
+// extendedDataTypeCode identifies an OpenSSL extended data type. See RFC 4254,
+// section 5.2.
+type extendedDataTypeCode uint32
+
+// extendedDataStderr is the extended data type that is used for stderr.
+const extendedDataStderr extendedDataTypeCode = 1
+
 // A Channel is an ordered, reliable, duplex stream that is multiplexed over an
-// SSH connection.
+// SSH connection. Channel.Read can return a ChannelRequest as an error.
 type Channel interface {
 	// Accept accepts the channel creation request.
 	Accept() error
@@ -25,6 +32,10 @@ type Channel interface {
 	Write(data []byte) (int, error)
 	Close() error
 
+	// Stderr returns an io.Writer that writes to this channel with the
+	// extended data type set to stderr.
+	Stderr() io.Writer
+
 	// AckRequest either sends an ack or nack to the channel request.
 	AckRequest(ok bool) error
 
@@ -168,6 +179,52 @@ func (c *channel) handleData(data []byte) {
 	c.cond.Signal()
 }
 
+func (c *channel) Stderr() io.Writer {
+	return extendedDataChannel{c: c, t: extendedDataStderr}
+}
+
+// extendedDataChannel is an io.Writer that writes any data to c as extended
+// data of the given type.
+type extendedDataChannel struct {
+	t extendedDataTypeCode
+	c *channel
+}
+
+func (edc extendedDataChannel) Write(data []byte) (n int, err error) {
+	c := edc.c
+	for len(data) > 0 {
+		var space uint32
+		if space, err = c.getWindowSpace(uint32(len(data))); err != nil {
+			return 0, err
+		}
+
+		todo := data
+		if uint32(len(todo)) > space {
+			todo = todo[:space]
+		}
+
+		packet := make([]byte, 1+4+4+4+len(todo))
+		packet[0] = msgChannelExtendedData
+		marshalUint32(packet[1:], c.theirId)
+		marshalUint32(packet[5:], uint32(edc.t))
+		marshalUint32(packet[9:], uint32(len(todo)))
+		copy(packet[13:], todo)
+
+		c.serverConn.lock.Lock()
+		err = c.serverConn.writePacket(packet)
+		c.serverConn.lock.Unlock()
+
+		if err != nil {
+			return
+		}
+
+		n += len(todo)
+		data = data[len(todo):]
+	}
+
+	return
+}
+
 func (c *channel) Read(data []byte) (n int, err error) {
 	c.lock.Lock()
 	defer c.lock.Unlock()
@@ -265,22 +322,17 @@ func (c *channel) Write(data []byte) (n int, err error) {
 
 		packet := make([]byte, 1+4+4+len(todo))
 		packet[0] = msgChannelData
-		packet[1] = byte(c.theirId >> 24)
-		packet[2] = byte(c.theirId >> 16)
-		packet[3] = byte(c.theirId >> 8)
-		packet[4] = byte(c.theirId)
-		packet[5] = byte(len(todo) >> 24)
-		packet[6] = byte(len(todo) >> 16)
-		packet[7] = byte(len(todo) >> 8)
-		packet[8] = byte(len(todo))
+		marshalUint32(packet[1:], c.theirId)
+		marshalUint32(packet[5:], uint32(len(todo)))
 		copy(packet[9:], todo)
 
 		c.serverConn.lock.Lock()
-		if err = c.serverConn.writePacket(packet); err != nil {
-			c.serverConn.lock.Unlock()
+		err = c.serverConn.writePacket(packet)
+		c.serverConn.lock.Unlock()
+
+		if err != nil {
 			return
 		}
-		c.serverConn.lock.Unlock()
 
 		n += len(todo)
 		data = data[len(todo):]