Просмотр исходного кода

go.crypto/ssh: move common channel methods into an embedded struct

This CL introduces a new struct, channel to hold common shared
functions.

* add a new channel struct, which is embeded in {client,server}Chan.
* move common methods from {client,server}Chan into channel.
* remove unneeded used of serverConn.lock in serverChan
 (transport.writePacket has its own mutex).
* remove filteredConn, introduce conn.

R=agl, gustav.paul
CC=golang-dev
https://golang.org/cl/6128059
Dave Cheney 13 лет назад
Родитель
Сommit
523290a72d
4 измененных файлов с 123 добавлено и 142 удалено
  1. 59 45
      ssh/channel.go
  2. 24 47
      ssh/client.go
  3. 30 28
      ssh/server.go
  4. 10 22
      ssh/transport.go

+ 59 - 45
ssh/channel.go

@@ -70,18 +70,55 @@ const (
 	ResourceShortage
 )
 
+type channel struct {
+	conn              // the underlying transport
+	localId, remoteId uint32
+
+	theyClosed  bool // indicates the close msg has been received from the remote side
+	weClosed    bool // incidates the close msg has been sent from our side
+	theySentEOF bool // used by serverChan
+	dead        bool // used by ServerChan to force close
+}
+
+func (c *channel) sendWindowAdj(n int) error {
+	msg := windowAdjustMsg{
+		PeersId:         c.remoteId,
+		AdditionalBytes: uint32(n),
+	}
+	return c.writePacket(marshal(msgChannelWindowAdjust, msg))
+}
+
+// sendClose signals the intent to close the channel.
+func (c *channel) sendClose() error {
+	return c.writePacket(marshal(msgChannelClose, channelCloseMsg{
+		PeersId: c.remoteId,
+	}))
+}
+
+// sendEOF sends EOF to the server. RFC 4254 Section 5.3
+func (c *channel) sendEOF() error {
+	return c.writePacket(marshal(msgChannelEOF, channelEOFMsg{
+		PeersId: c.remoteId,
+	}))
+}
+
+func (c *channel) sendChannelOpenFailure(reason RejectionReason, message string) error {
+	reject := channelOpenFailureMsg{
+		PeersId:  c.remoteId,
+		Reason:   reason,
+		Message:  message,
+		Language: "en",
+	}
+	return c.writePacket(marshal(msgChannelOpenFailure, reject))
+}
+
 type serverChan struct {
+	channel
 	// immutable once created
 	chanType  string
 	extraData []byte
 
-	theyClosed  bool
-	theySentEOF bool
-	weClosed    bool
-	dead        bool
-
 	serverConn            *ServerConn
-	localId, remoteId     uint32
 	myWindow, theirWindow uint32
 	maxPacketSize         uint32
 	err                   error
@@ -91,7 +128,6 @@ type serverChan struct {
 	head, length    int
 
 	// This lock is inferior to serverConn.lock
-	lock sync.Mutex
 	cond *sync.Cond
 }
 
@@ -109,7 +145,7 @@ func (c *serverChan) Accept() error {
 		MyWindow:      c.myWindow,
 		MaxPacketSize: c.maxPacketSize,
 	}
-	return c.serverConn.writePacket(marshal(msgChannelOpenConfirm, confirm))
+	return c.writePacket(marshal(msgChannelOpenConfirm, confirm))
 }
 
 func (c *serverChan) Reject(reason RejectionReason, message string) error {
@@ -120,18 +156,12 @@ func (c *serverChan) Reject(reason RejectionReason, message string) error {
 		return c.serverConn.err
 	}
 
-	reject := channelOpenFailureMsg{
-		PeersId:  c.remoteId,
-		Reason:   reason,
-		Message:  message,
-		Language: "en",
-	}
-	return c.serverConn.writePacket(marshal(msgChannelOpenFailure, reject))
+	return c.sendChannelOpenFailure(reason, message)
 }
 
 func (c *serverChan) handlePacket(packet interface{}) {
-	c.lock.Lock()
-	defer c.lock.Unlock()
+	c.cond.L.Lock()
+	defer c.cond.L.Unlock()
 
 	switch packet := packet.(type) {
 	case *channelRequestMsg:
@@ -158,8 +188,8 @@ func (c *serverChan) handlePacket(packet interface{}) {
 }
 
 func (c *serverChan) handleData(data []byte) {
-	c.lock.Lock()
-	defer c.lock.Unlock()
+	c.cond.L.Lock()
+	defer c.cond.L.Unlock()
 
 	// The other side should never send us more than our window.
 	if len(data)+c.length > len(c.pendingData) {
@@ -213,11 +243,7 @@ func (edc extendedDataChannel) Write(data []byte) (n int, err error) {
 		marshalUint32(packet[9:], uint32(len(todo)))
 		copy(packet[13:], todo)
 
-		c.serverConn.lock.Lock()
-		err = c.serverConn.writePacket(packet)
-		c.serverConn.lock.Unlock()
-
-		if err != nil {
+		if err = c.writePacket(packet); err != nil {
 			return
 		}
 
@@ -236,20 +262,15 @@ func (c *serverChan) Read(data []byte) (n int, err error) {
 			PeersId:         c.remoteId,
 			AdditionalBytes: windowAdjustment,
 		})
-		c.serverConn.lock.Lock()
-		err = c.serverConn.writePacket(packet)
-		c.serverConn.lock.Unlock()
-		if err != nil {
-			return
-		}
+		err = c.writePacket(packet)
 	}
 
 	return
 }
 
 func (c *serverChan) read(data []byte) (n int, err error, windowAdjustment uint32) {
-	c.lock.Lock()
-	defer c.lock.Unlock()
+	c.cond.L.Lock()
+	defer c.cond.L.Unlock()
 
 	if c.err != nil {
 		return 0, c.err, 0
@@ -300,8 +321,8 @@ func (c *serverChan) read(data []byte) (n int, err error, windowAdjustment uint3
 // getWindowSpace takes, at most, max bytes of space from the peer's window. It
 // returns the number of bytes actually reserved.
 func (c *serverChan) getWindowSpace(max uint32) (uint32, error) {
-	c.lock.Lock()
-	defer c.lock.Unlock()
+	c.cond.L.Lock()
+	defer c.cond.L.Unlock()
 
 	for {
 		if c.dead || c.weClosed {
@@ -342,11 +363,7 @@ func (c *serverChan) Write(data []byte) (n int, err error) {
 		marshalUint32(packet[5:], uint32(len(todo)))
 		copy(packet[9:], todo)
 
-		c.serverConn.lock.Lock()
-		err = c.serverConn.writePacket(packet)
-		c.serverConn.lock.Unlock()
-
-		if err != nil {
+		if err = c.writePacket(packet); err != nil {
 			return
 		}
 
@@ -370,10 +387,7 @@ func (c *serverChan) Close() error {
 	}
 	c.weClosed = true
 
-	closeMsg := channelCloseMsg{
-		PeersId: c.remoteId,
-	}
-	return c.serverConn.writePacket(marshal(msgChannelClose, closeMsg))
+	return c.sendClose()
 }
 
 func (c *serverChan) AckRequest(ok bool) error {
@@ -388,13 +402,13 @@ func (c *serverChan) AckRequest(ok bool) error {
 		ack := channelRequestFailureMsg{
 			PeersId: c.remoteId,
 		}
-		return c.serverConn.writePacket(marshal(msgChannelFailure, ack))
+		return c.writePacket(marshal(msgChannelFailure, ack))
 	}
 
 	ack := channelRequestSuccessMsg{
 		PeersId: c.remoteId,
 	}
-	return c.serverConn.writePacket(marshal(msgChannelSuccess, ack))
+	return c.writePacket(marshal(msgChannelSuccess, ack))
 }
 
 func (c *serverChan) ChannelType() string {

+ 24 - 47
ssh/client.go

@@ -425,36 +425,35 @@ func (c *ClientConfig) rand() io.Reader {
 // A clientChan represents a single RFC 4254 channel that is multiplexed
 // over a single SSH connection.
 type clientChan struct {
-	packetWriter
-	localId, remoteId uint32
-	stdin             *chanWriter      // receives window adjustments
-	stdout            *chanReader      // receives the payload of channelData messages
-	stderr            *chanReader      // receives the payload of channelExtendedData messages
-	msg               chan interface{} // incoming messages
-	theyClosed        bool             // indicates the close msg has been received from the remote side
-	weClosed          bool             // incidates the close msg has been sent from our side
+	channel
+	stdin  *chanWriter
+	stdout *chanReader
+	stderr *chanReader
+	msg    chan interface{}
 }
 
 // newClientChan returns a partially constructed *clientChan
 // using the local id provided. To be usable clientChan.remoteId
 // needs to be assigned once known.
-func newClientChan(t *transport, localId uint32) *clientChan {
+func newClientChan(cc conn, id uint32) *clientChan {
 	c := &clientChan{
-		packetWriter: t,
-		localId:      localId,
-		msg:          make(chan interface{}, 16),
+		channel: channel{
+			conn:    cc,
+			localId: id,
+		},
+		msg: make(chan interface{}, 16),
 	}
 	c.stdin = &chanWriter{
-		win:        &window{Cond: sync.NewCond(new(sync.Mutex))},
-		clientChan: c,
+		win:     &window{Cond: sync.NewCond(new(sync.Mutex))},
+		channel: &c.channel,
 	}
 	c.stdout = &chanReader{
-		data:       make(chan []byte, 16),
-		clientChan: c,
+		data:    make(chan []byte, 16),
+		channel: &c.channel,
 	}
 	c.stderr = &chanReader{
-		data:       make(chan []byte, 16),
-		clientChan: c,
+		data:    make(chan []byte, 16),
+		channel: &c.channel,
 	}
 	return c
 }
@@ -474,28 +473,6 @@ func (c *clientChan) waitForChannelOpenResponse() error {
 	return errors.New("ssh: unexpected packet")
 }
 
-// sendEOF sends EOF to the server. RFC 4254 Section 5.3
-func (c *clientChan) sendEOF() error {
-	return c.writePacket(marshal(msgChannelEOF, channelEOFMsg{
-		PeersId: c.remoteId,
-	}))
-}
-
-// sendClose signals the intent to close the channel.
-func (c *clientChan) sendClose() error {
-	return c.writePacket(marshal(msgChannelClose, channelCloseMsg{
-		PeersId: c.remoteId,
-	}))
-}
-
-func (c *clientChan) sendWindowAdj(n int) error {
-	msg := windowAdjustMsg{
-		PeersId:         c.remoteId,
-		AdditionalBytes: uint32(n),
-	}
-	return c.writePacket(marshal(msgChannelWindowAdjust, msg))
-}
-
 // Close closes the channel. This does not close the underlying connection.
 func (c *clientChan) Close() error {
 	if !c.weClosed {
@@ -565,8 +542,8 @@ func (c *chanList) closeAll() {
 
 // A chanWriter represents the stdin of a remote process.
 type chanWriter struct {
-	win        *window
-	clientChan *clientChan // the channel backing this writer
+	win *window
+	*channel
 }
 
 // Write writes data to the remote process's standard input.
@@ -575,13 +552,13 @@ func (w *chanWriter) Write(data []byte) (written int, err error) {
 		// n cannot be larger than 2^31 as len(data) cannot
 		// be larger than 2^31
 		n := int(w.win.reserve(uint32(len(data))))
-		remoteId := w.clientChan.remoteId
+		remoteId := w.remoteId
 		packet := []byte{
 			msgChannelData,
 			byte(remoteId >> 24), byte(remoteId >> 16), byte(remoteId >> 8), byte(remoteId),
 			byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n),
 		}
-		if err = w.clientChan.writePacket(append(packet, data[:n]...)); err != nil {
+		if err = w.writePacket(append(packet, data[:n]...)); err != nil {
 			break
 		}
 		data = data[n:]
@@ -598,7 +575,7 @@ func min(a, b int) int {
 }
 
 func (w *chanWriter) Close() error {
-	return w.clientChan.sendEOF()
+	return w.sendEOF()
 }
 
 // A chanReader represents stdout or stderr of a remote process.
@@ -608,7 +585,7 @@ type chanReader struct {
 	// it unable to receive new messages from the remote side.
 	data       chan []byte // receives data from remote
 	dataClosed bool        // protects data from being closed twice
-	clientChan *clientChan // the channel backing this reader
+	*channel               // the channel backing this reader
 	buf        []byte
 }
 
@@ -635,7 +612,7 @@ func (r *chanReader) Read(data []byte) (int, error) {
 		if len(r.buf) > 0 {
 			n := copy(data, r.buf)
 			r.buf = r.buf[n:]
-			return n, r.clientChan.sendWindowAdj(n)
+			return n, r.sendWindowAdj(n)
 		}
 		r.buf, ok = <-r.data
 		if !ok {

+ 30 - 28
ssh/server.go

@@ -101,9 +101,8 @@ type ServerConn struct {
 	channels   map[uint32]*serverChan
 	nextChanId uint32
 
-	// lock protects err and also allows Channels to serialise their writes
-	// to out.
-	lock sync.RWMutex
+	// lock protects err and channels.
+	lock sync.Mutex
 	err  error
 
 	// cachedPubKeys contains the cache results of tests for public keys.
@@ -121,12 +120,11 @@ type ServerConn struct {
 // Server returns a new SSH server connection
 // using c as the underlying transport.
 func Server(c net.Conn, config *ServerConfig) *ServerConn {
-	conn := &ServerConn{
+	return &ServerConn{
 		transport: newTransport(c, config.rand()),
 		channels:  make(map[uint32]*serverChan),
 		config:    config,
 	}
-	return conn
 }
 
 // kexDH performs Diffie-Hellman key agreement on a ServerConnection. The
@@ -500,6 +498,7 @@ const defaultWindowSize = 32768
 // Accept reads and processes messages on a ServerConn. It must be called
 // in order to demultiplex messages to any resulting Channels.
 func (s *ServerConn) Accept() (Channel, error) {
+	// TODO(dfc) s.lock is not held here so visibility of s.err is not guarenteed.
 	if s.err != nil {
 		return nil, s.err
 	}
@@ -512,6 +511,7 @@ func (s *ServerConn) Accept() (Channel, error) {
 			s.err = err
 			s.lock.Unlock()
 
+			// TODO(dfc) s.lock protects s.channels but isn't being held here.
 			for _, c := range s.channels {
 				c.dead = true
 				c.handleData(nil)
@@ -541,17 +541,20 @@ func (s *ServerConn) Accept() (Channel, error) {
 		default:
 			switch msg := decode(packet).(type) {
 			case *channelOpenMsg:
-				c := new(serverChan)
-				c.chanType = msg.ChanType
-				c.remoteId = msg.PeersId
-				c.theirWindow = msg.PeersWindow
-				c.maxPacketSize = msg.MaxPacketSize
-				c.extraData = msg.TypeSpecificData
-				c.myWindow = defaultWindowSize
-				c.serverConn = s
-				c.cond = sync.NewCond(&c.lock)
-				c.pendingData = make([]byte, c.myWindow)
-
+				c := &serverChan{
+					channel: channel{
+						conn:     s,
+						remoteId: msg.PeersId,
+					},
+					theirWindow:   msg.PeersWindow,
+					chanType:      msg.ChanType,
+					maxPacketSize: msg.MaxPacketSize,
+					extraData:     msg.TypeSpecificData,
+					myWindow:      defaultWindowSize,
+					serverConn:    s,
+					cond:          sync.NewCond(new(sync.Mutex)),
+					pendingData:   make([]byte, defaultWindowSize),
+				}
 				s.lock.Lock()
 				c.localId = s.nextChanId
 				s.nextChanId++
@@ -625,18 +628,6 @@ type Listener struct {
 	config   *ServerConfig
 }
 
-// Accept waits for and returns the next incoming SSH connection.
-// The receiver should call Handshake() in another goroutine 
-// to avoid blocking the accepter.
-func (l *Listener) Accept() (*ServerConn, error) {
-	c, err := l.listener.Accept()
-	if err != nil {
-		return nil, err
-	}
-	conn := Server(c, l.config)
-	return conn, nil
-}
-
 // Addr returns the listener's network address.
 func (l *Listener) Addr() net.Addr {
 	return l.listener.Addr()
@@ -647,6 +638,17 @@ func (l *Listener) Close() error {
 	return l.listener.Close()
 }
 
+// Accept waits for and returns the next incoming SSH connection.
+// The receiver should call Handshake() in another goroutine 
+// to avoid blocking the accepter.
+func (l *Listener) Accept() (*ServerConn, error) {
+	c, err := l.listener.Accept()
+	if err != nil {
+		return nil, err
+	}
+	return Server(c, l.config), nil
+}
+
 // Listen creates an SSH listener accepting connections on
 // the given network address using net.Listen.
 func Listen(network, addr string, config *ServerConfig) (*Listener, error) {

+ 10 - 22
ssh/transport.go

@@ -23,25 +23,14 @@ const (
 	minPaddingSize     = 4 // TODO(huin) should this be configurable?
 )
 
-// filteredConn reduces the set of methods exposed when embeddeding
-// a net.Conn inside ssh.transport.
-// TODO(dfc) suggestions for a better name will be warmly received.
-type filteredConn interface {
-	// Close closes the connection.
-	Close() error
-
-	// LocalAddr returns the local network address.
-	LocalAddr() net.Addr
-
-	// RemoteAddr returns the remote network address.
-	RemoteAddr() net.Addr
-}
-
-// Types implementing packetWriter provide the ability to send packets to
-// an SSH peer.
-type packetWriter interface {
+// conn represents an ssh transport that implements packet based
+// operations.
+type conn interface {
 	// Encrypt and send a packet of data to the remote peer.
 	writePacket(packet []byte) error
+
+	// Close closes the connection.
+	Close() error
 }
 
 // transport represents the SSH connection to the remote peer.
@@ -49,7 +38,7 @@ type transport struct {
 	reader
 	writer
 
-	filteredConn
+	net.Conn
 }
 
 // reader represents the incoming connection state.
@@ -58,9 +47,9 @@ type reader struct {
 	common
 }
 
-// writer represnts the outgoing connection state.
+// writer represents the outgoing connection state.
 type writer struct {
-	*sync.Mutex // protects writer.Writer from concurrent writes
+	sync.Mutex // protects writer.Writer from concurrent writes
 	*bufio.Writer
 	rand io.Reader
 	common
@@ -230,12 +219,11 @@ func newTransport(conn net.Conn, rand io.Reader) *transport {
 		writer: writer{
 			Writer: bufio.NewWriter(conn),
 			rand:   rand,
-			Mutex:  new(sync.Mutex),
 			common: common{
 				cipher: noneCipher{},
 			},
 		},
-		filteredConn: conn,
+		Conn: conn,
 	}
 }