Переглянути джерело

go.crypto/ssh: assorted close related fixes

Fixes golang/go#3810.

Fixes chanWriter Write after close behaviour bug.

Fixes serverChan writePacket after close bug.

Addresses final comments by agl on 6405064, plus various cleanups.

R=agl, kardianos, gustav.paul, fullung
CC=golang-dev
https://golang.org/cl/6479056
Dave Cheney 13 роки тому
батько
коміт
552202b8e3
5 змінених файлів з 51 додано та 60 видалено
  1. 45 41
      ssh/channel.go
  2. 4 3
      ssh/client.go
  3. 1 1
      ssh/server.go
  4. 1 9
      ssh/session_test.go
  5. 0 6
      ssh/transport.go

+ 45 - 41
ssh/channel.go

@@ -82,7 +82,7 @@ type channel struct {
 	localId, remoteId uint32
 	remoteWin         window
 	maxPacket         uint32
-	isclosed          uint32 // atomic bool, non zero if true
+	isClosed          uint32 // atomic bool, non zero if true
 }
 
 func (c *channel) sendWindowAdj(n int) error {
@@ -93,13 +93,20 @@ func (c *channel) sendWindowAdj(n int) error {
 	return c.writePacket(marshal(msgChannelWindowAdjust, msg))
 }
 
-// sendEOF sends EOF to the server. RFC 4254 Section 5.3
+// sendEOF sends EOF to the remote side. RFC 4254 Section 5.3
 func (c *channel) sendEOF() error {
 	return c.writePacket(marshal(msgChannelEOF, channelEOFMsg{
 		PeersId: c.remoteId,
 	}))
 }
 
+// sendClose informs the remote side of our intent to close the channel.
+func (c *channel) sendClose() error {
+	return c.conn.writePacket(marshal(msgChannelClose, channelCloseMsg{
+		PeersId: c.remoteId,
+	}))
+}
+
 func (c *channel) sendChannelOpenFailure(reason RejectionReason, message string) error {
 	reject := channelOpenFailureMsg{
 		PeersId:  c.remoteId,
@@ -121,11 +128,11 @@ func (c *channel) writePacket(b []byte) error {
 }
 
 func (c *channel) closed() bool {
-	return atomic.LoadUint32(&c.isclosed) > 0
+	return atomic.LoadUint32(&c.isClosed) > 0
 }
 
 func (c *channel) setClosed() bool {
-	return atomic.CompareAndSwapUint32(&c.isclosed, 0, 1)
+	return atomic.CompareAndSwapUint32(&c.isClosed, 0, 1)
 }
 
 type serverChan struct {
@@ -136,10 +143,9 @@ type serverChan struct {
 
 	serverConn  *ServerConn
 	myWindow    uint32
-	weClosed    bool // incidates the close msg has been sent from our side
 	theyClosed  bool // indicates the close msg has been received from the remote side
 	theySentEOF bool
-	dead        bool
+	isDead      uint32
 	err         error
 
 	pendingRequests []ChannelRequest
@@ -247,7 +253,7 @@ func (edc extendedDataChannel) Write(data []byte) (n int, err error) {
 	const headerLength = 13 // 1 byte message type, 4 bytes remoteId, 4 bytes extended message type, 4 bytes data length
 	c := edc.c
 	for len(data) > 0 {
-		space := uint32(min(int(c.maxPacket-headerLength), len(data)))
+		space := min(c.maxPacket-headerLength, len(data))
 		if space, err = c.getWindowSpace(space); err != nil {
 			return 0, err
 		}
@@ -297,7 +303,7 @@ func (c *serverChan) read(data []byte) (n int, err error, windowAdjustment uint3
 	}
 
 	for {
-		if c.theySentEOF || c.theyClosed || c.dead {
+		if c.theySentEOF || c.theyClosed || c.dead() {
 			return 0, io.EOF, 0
 		}
 
@@ -315,7 +321,7 @@ func (c *serverChan) read(data []byte) (n int, err error, windowAdjustment uint3
 		}
 
 		if c.length > 0 {
-			tail := min(c.head+c.length, len(c.pendingData))
+			tail := min(uint32(c.head+c.length), len(c.pendingData))
 			n = copy(data, c.pendingData[c.head:tail])
 			c.head += n
 			c.length -= n
@@ -341,24 +347,24 @@ func (c *serverChan) read(data []byte) (n int, err error, windowAdjustment uint3
 // getWindowSpace takes, at most, max bytes of space from the peer's window. It
 // returns the number of bytes actually reserved.
 func (c *serverChan) getWindowSpace(max uint32) (uint32, error) {
-	var err error
-	// TODO(dfc) This lock and check of c.weClosed is necessary because unlike
-	// clientChan, c.weClosed is observed by more than one goroutine.
-	c.cond.L.Lock()
-	if c.dead || c.weClosed {
-		err = io.EOF
-	}
-	c.cond.L.Unlock()
-	if err != nil {
-		return 0, err
+	if c.dead() || c.closed() {
+		return 0, io.EOF
 	}
 	return c.remoteWin.reserve(max), nil
 }
 
+func (c *serverChan) dead() bool {
+	return atomic.LoadUint32(&c.isDead) > 0
+}
+
+func (c *serverChan) setDead() {
+	atomic.StoreUint32(&c.isDead, 1)
+}
+
 func (c *serverChan) Write(data []byte) (n int, err error) {
 	const headerLength = 9 // 1 byte message type, 4 bytes remoteId, 4 bytes data length
 	for len(data) > 0 {
-		space := uint32(min(int(c.maxPacket-headerLength), len(data)))
+		space := min(c.maxPacket-headerLength, len(data))
 		if space, err = c.getWindowSpace(space); err != nil {
 			return 0, err
 		}
@@ -384,6 +390,7 @@ func (c *serverChan) Write(data []byte) (n int, err error) {
 	return
 }
 
+// Close signals the intent to close the channel.
 func (c *serverChan) Close() error {
 	c.serverConn.lock.Lock()
 	defer c.serverConn.lock.Unlock()
@@ -392,21 +399,12 @@ func (c *serverChan) Close() error {
 		return c.serverConn.err
 	}
 
-	if c.weClosed {
+	if !c.setClosed() {
 		return errors.New("ssh: channel already closed")
 	}
-	c.weClosed = true
-
 	return c.sendClose()
 }
 
-// sendClose signals the intent to close the channel.
-func (c *serverChan) sendClose() error {
-	return c.writePacket(marshal(msgChannelClose, channelCloseMsg{
-		PeersId: c.remoteId,
-	}))
-}
-
 func (c *serverChan) AckRequest(ok bool) error {
 	c.serverConn.lock.Lock()
 	defer c.serverConn.lock.Unlock()
@@ -491,32 +489,37 @@ func (c *clientChan) waitForChannelOpenResponse() error {
 	return errors.New("ssh: unexpected packet")
 }
 
+// Close signals the intent to close the channel.
 func (c *clientChan) Close() error {
 	if !c.setClosed() {
 		return errors.New("ssh: channel already closed")
 	}
 	c.stdout.eof()
 	c.stderr.eof()
-	close(c.msg)
-	// TODO(dfc) step around channel.writePacket() because closed() is now true
-	return c.channel.conn.writePacket(marshal(msgChannelClose, channelCloseMsg{
-		PeersId: c.remoteId,
-	}))
+	return c.sendClose()
 }
 
 // A chanWriter represents the stdin of a remote process.
 type chanWriter struct {
 	*channel
+	// indicates the writer has been closed. eof is owned by the 
+	// caller of Write/Close. 
+	eof bool
 }
 
 // Write writes data to the remote process's standard input.
 func (w *chanWriter) Write(data []byte) (written int, err error) {
 	const headerLength = 9 // 1 byte message type, 4 bytes remoteId, 4 bytes data length
 	for len(data) > 0 {
+		if w.eof || w.closed() {
+			err = io.EOF
+			return
+		}
 		// never send more data than maxPacket even if
 		// there is sufficent window.
-		n := min(int(w.maxPacket-headerLength), len(data))
-		n = int(w.remoteWin.reserve(uint32(n)))
+		n := min(w.maxPacket-headerLength, len(data))
+		r := w.remoteWin.reserve(n)
+		n = r
 		remoteId := w.remoteId
 		packet := []byte{
 			msgChannelData,
@@ -527,19 +530,20 @@ func (w *chanWriter) Write(data []byte) (written int, err error) {
 			break
 		}
 		data = data[n:]
-		written += n
+		written += int(n)
 	}
 	return
 }
 
-func min(a, b int) int {
-	if a < b {
+func min(a uint32, b int) uint32 {
+	if a < uint32(b) {
 		return a
 	}
-	return b
+	return uint32(b)
 }
 
 func (w *chanWriter) Close() error {
+	w.eof = true
 	return w.sendEOF()
 }
 

+ 4 - 3
ssh/client.go

@@ -249,7 +249,8 @@ func (c *ClientConn) mainLoop() {
 				ch.stderr.write(packet)
 			}
 		default:
-			switch msg := decode(packet).(type) {
+			msg := decode(packet)
+			switch msg := msg.(type) {
 			case *channelOpenMsg:
 				c.handleChanOpen(msg)
 			case *channelOpenConfirmMsg:
@@ -270,8 +271,7 @@ func (c *ClientConn) mainLoop() {
 					return
 				}
 				ch.Close()
-				// TODO(dfc) may need to optimisically remove the 
-				// channel before closing
+				close(ch.msg)
 				c.chanList.remove(msg.PeersId)
 			case *channelEOFMsg:
 				ch, ok := c.getChan(msg.PeersId)
@@ -502,5 +502,6 @@ func (c *chanList) closeAll() {
 			continue
 		}
 		ch.Close()
+		close(ch.msg)
 	}
 }

+ 1 - 1
ssh/server.go

@@ -536,7 +536,7 @@ func (s *ServerConn) Accept() (Channel, error) {
 
 			// TODO(dfc) s.lock protects s.channels but isn't being held here.
 			for _, c := range s.channels {
-				c.dead = true
+				c.setDead()
 				c.handleData(nil)
 			}
 

+ 1 - 9
ssh/session_test.go

@@ -47,7 +47,7 @@ func dial(handler serverType, t *testing.T) *ClientConn {
 		done := make(chan struct{})
 		for {
 			ch, err := conn.Accept()
-			if err == io.EOF {
+			if err == io.EOF || err == io.ErrUnexpectedEOF {
 				return
 			}
 			// We sometimes get ECONNRESET rather than EOF.
@@ -389,10 +389,6 @@ func TestServerStdoutRespectsMaxPacketSize(t *testing.T) {
 }
 
 func TestClientCannotSendAfterEOF(t *testing.T) {
-	// TODO(dfc) currently writes succeed after Close()
-	t.Logf("test skipped")
-	return
-
 	conn := dial(exitWithoutSignalOrStatus, t)
 	defer conn.Close()
 	session, err := conn.NewSession()
@@ -628,10 +624,6 @@ func discardHandler(ch *serverChan, t *testing.T) {
 	// grow the window to avoid being fooled by
 	// the initial 1 << 14 window.
 	ch.sendWindowAdj(1024 * 1024)
-	// TODO(dfc) io.Copy can return a non EOF error here
-	// because closed channel errors can leak here if the
-	// read from ch causes a window adjustment after the 
-	// remote has signaled close.
 	io.Copy(ioutil.Discard, ch)
 }
 

+ 0 - 6
ssh/transport.go

@@ -194,12 +194,6 @@ func (w *writer) writePacket(packet []byte) error {
 	return w.Flush()
 }
 
-// Send a message to the remote peer
-func (t *transport) sendMessage(typ uint8, msg interface{}) error {
-	packet := marshal(typ, msg)
-	return t.writePacket(packet)
-}
-
 func newTransport(conn net.Conn, rand io.Reader) *transport {
 	return &transport{
 		reader: reader{