Browse Source

go.crypto/ssh: prevent channel writes after Close

Fixes golang/go#3810.

This change introduces an atomic boolean to guard the close
of the clientChan. Previously the client code was quite
lax with the ordering of the close messages and could allow
window adjustment or EOF messages to leak after Close had
been signaled.

Consolidating the changes to the serverChan will be handled
in a following CL.

R=agl, kardianos, gustav.paul
CC=golang-dev
https://golang.org/cl/6405064
Dave Cheney 13 years ago
parent
commit
c1c0bfbd3a
4 changed files with 89 additions and 34 deletions
  1. 36 20
      ssh/channel.go
  2. 4 13
      ssh/client.go
  3. 1 1
      ssh/session.go
  4. 48 0
      ssh/session_test.go

+ 36 - 20
ssh/channel.go

@@ -9,6 +9,7 @@ import (
 	"fmt"
 	"io"
 	"sync"
+	"sync/atomic"
 )
 
 // extendedDataTypeCode identifies an OpenSSL extended data type. See RFC 4254,
@@ -81,11 +82,7 @@ type channel struct {
 	localId, remoteId uint32
 	remoteWin         window
 	maxPacket         uint32
-
-	theyClosed  bool // indicates the close msg has been received from the remote side
-	weClosed    bool // incidates the close msg has been sent from our side
-	theySentEOF bool // used by serverChan
-	dead        bool // used by ServerChan to force close
+	isclosed          uint32 // atomic bool, non zero if true
 }
 
 func (c *channel) sendWindowAdj(n int) error {
@@ -96,13 +93,6 @@ func (c *channel) sendWindowAdj(n int) error {
 	return c.writePacket(marshal(msgChannelWindowAdjust, msg))
 }
 
-// sendClose signals the intent to close the channel.
-func (c *channel) sendClose() error {
-	return c.writePacket(marshal(msgChannelClose, channelCloseMsg{
-		PeersId: c.remoteId,
-	}))
-}
-
 // sendEOF sends EOF to the server. RFC 4254 Section 5.3
 func (c *channel) sendEOF() error {
 	return c.writePacket(marshal(msgChannelEOF, channelEOFMsg{
@@ -121,21 +111,36 @@ func (c *channel) sendChannelOpenFailure(reason RejectionReason, message string)
 }
 
 func (c *channel) writePacket(b []byte) error {
+	if c.closed() {
+		return io.EOF
+	}
 	if uint32(len(b)) > c.maxPacket {
 		return fmt.Errorf("ssh: cannot write %d bytes, maxPacket is %d bytes", len(b), c.maxPacket)
 	}
 	return c.conn.writePacket(b)
 }
 
+func (c *channel) closed() bool {
+	return atomic.LoadUint32(&c.isclosed) > 0
+}
+
+func (c *channel) setClosed() bool {
+	return atomic.CompareAndSwapUint32(&c.isclosed, 0, 1)
+}
+
 type serverChan struct {
 	channel
 	// immutable once created
 	chanType  string
 	extraData []byte
 
-	serverConn *ServerConn
-	myWindow   uint32
-	err        error
+	serverConn  *ServerConn
+	myWindow    uint32
+	weClosed    bool // incidates the close msg has been sent from our side
+	theyClosed  bool // indicates the close msg has been received from the remote side
+	theySentEOF bool
+	dead        bool
+	err         error
 
 	pendingRequests []ChannelRequest
 	pendingData     []byte
@@ -393,6 +398,13 @@ func (c *serverChan) Close() error {
 	return c.sendClose()
 }
 
+// sendClose signals the intent to close the channel.
+func (c *serverChan) sendClose() error {
+	return c.writePacket(marshal(msgChannelClose, channelCloseMsg{
+		PeersId: c.remoteId,
+	}))
+}
+
 func (c *serverChan) AckRequest(ok bool) error {
 	c.serverConn.lock.Lock()
 	defer c.serverConn.lock.Unlock()
@@ -477,13 +489,17 @@ func (c *clientChan) waitForChannelOpenResponse() error {
 	return errors.New("ssh: unexpected packet")
 }
 
-// Close closes the channel. This does not close the underlying connection.
 func (c *clientChan) Close() error {
-	if !c.weClosed {
-		c.weClosed = true
-		return c.sendClose()
+	if !c.setClosed() {
+		return errors.New("ssh: channel already closed")
 	}
-	return nil
+	c.stdout.eof()
+	c.stderr.eof()
+	close(c.msg)
+	// TODO(dfc) step around channel.writePacket() because closed() is now true
+	return c.channel.conn.writePacket(marshal(msgChannelClose, channelCloseMsg{
+		PeersId: c.remoteId,
+	}))
 }
 
 // A chanWriter represents the stdin of a remote process.

+ 4 - 13
ssh/client.go

@@ -269,14 +269,9 @@ func (c *ClientConn) mainLoop() {
 				if !ok {
 					return
 				}
-				ch.theyClosed = true
-				ch.stdout.eof()
-				ch.stderr.eof()
-				close(ch.msg)
-				if !ch.weClosed {
-					ch.weClosed = true
-					ch.sendClose()
-				}
+				ch.Close()
+				// TODO(dfc) may need to optimisically remove the 
+				// channel before closing
 				c.chanList.remove(msg.PeersId)
 			case *channelEOFMsg:
 				ch, ok := c.getChan(msg.PeersId)
@@ -506,10 +501,6 @@ func (c *chanList) closeAll() {
 		if ch == nil {
 			continue
 		}
-
-		ch.theyClosed = true
-		ch.stdout.eof()
-		ch.stderr.eof()
-		close(ch.msg)
+		ch.Close()
 	}
 }

+ 1 - 1
ssh/session.go

@@ -370,7 +370,7 @@ func (s *Session) stdin() {
 	}
 	s.copyFuncs = append(s.copyFuncs, func() error {
 		_, err := io.Copy(s.clientChan.stdin, s.Stdin)
-		if err1 := s.clientChan.stdin.Close(); err == nil {
+		if err1 := s.clientChan.stdin.Close(); err == nil && err1 != io.EOF {
 			err = err1
 		}
 		return err

+ 48 - 0
ssh/session_test.go

@@ -380,6 +380,54 @@ func TestServerStdoutRespectsMaxPacketSize(t *testing.T) {
 	}
 }
 
+// TODO(dfc) currently writes succeed after Close()
+func testClientCannotSendAfterEOF(t *testing.T) {
+	conn := dial(shellHandler, t)
+	defer conn.Close()
+	session, err := conn.NewSession()
+	if err != nil {
+		t.Fatalf("Unable to request new session: %v", err)
+	}
+	defer session.Close()
+	in, err := session.StdinPipe()
+	if err != nil {
+		t.Fatalf("Unable to connect channel stdin: %v", err)
+	}
+	if err := session.Shell(); err != nil {
+		t.Fatalf("Unable to execute command: %v", err)
+	}
+	if err := in.Close(); err != nil {
+		t.Fatalf("Unable to close stdin: %v", err)
+	}
+	if _, err := in.Write([]byte("foo")); err == nil {
+		t.Fatalf("Session write should fail")
+	}
+}
+
+func TestClientCannotSendAfterClose(t *testing.T) {
+	conn := dial(shellHandler, t)
+	defer conn.Close()
+	session, err := conn.NewSession()
+	if err != nil {
+		t.Fatalf("Unable to request new session: %v", err)
+	}
+	defer session.Close()
+	in, err := session.StdinPipe()
+	if err != nil {
+		t.Fatalf("Unable to connect channel stdin: %v", err)
+	}
+	if err := session.Shell(); err != nil {
+		t.Fatalf("Unable to execute command: %v", err)
+	}
+	// close underlying channel
+	if err := session.channel.Close(); err != nil {
+		t.Fatalf("Unable to close session: %v", err)
+	}
+	if _, err := in.Write([]byte("foo")); err == nil {
+		t.Fatalf("Session write should fail")
+	}
+}
+
 type exitStatusMsg struct {
 	PeersId   uint32
 	Request   string