Explorar o código

go.crypto/ssh: prevent channel writes after Close

Fixes golang/go#3810.

This change introduces an atomic boolean to guard the close
of the clientChan. Previously the client code was quite
lax with the ordering of the close messages and could allow
window adjustment or EOF messages to leak after Close had
been signaled.

Consolidating the changes to the serverChan will be handled
in a following CL.

R=agl, kardianos, gustav.paul
CC=golang-dev
https://golang.org/cl/6405064
Dave Cheney %!s(int64=13) %!d(string=hai) anos
pai
achega
c1c0bfbd3a
Modificáronse 4 ficheiros con 89 adicións e 34 borrados
  1. 36 20
      ssh/channel.go
  2. 4 13
      ssh/client.go
  3. 1 1
      ssh/session.go
  4. 48 0
      ssh/session_test.go

+ 36 - 20
ssh/channel.go

@@ -9,6 +9,7 @@ import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"sync"
 	"sync"
+	"sync/atomic"
 )
 )
 
 
 // extendedDataTypeCode identifies an OpenSSL extended data type. See RFC 4254,
 // extendedDataTypeCode identifies an OpenSSL extended data type. See RFC 4254,
@@ -81,11 +82,7 @@ type channel struct {
 	localId, remoteId uint32
 	localId, remoteId uint32
 	remoteWin         window
 	remoteWin         window
 	maxPacket         uint32
 	maxPacket         uint32
-
-	theyClosed  bool // indicates the close msg has been received from the remote side
-	weClosed    bool // incidates the close msg has been sent from our side
-	theySentEOF bool // used by serverChan
-	dead        bool // used by ServerChan to force close
+	isclosed          uint32 // atomic bool, non zero if true
 }
 }
 
 
 func (c *channel) sendWindowAdj(n int) error {
 func (c *channel) sendWindowAdj(n int) error {
@@ -96,13 +93,6 @@ func (c *channel) sendWindowAdj(n int) error {
 	return c.writePacket(marshal(msgChannelWindowAdjust, msg))
 	return c.writePacket(marshal(msgChannelWindowAdjust, msg))
 }
 }
 
 
-// sendClose signals the intent to close the channel.
-func (c *channel) sendClose() error {
-	return c.writePacket(marshal(msgChannelClose, channelCloseMsg{
-		PeersId: c.remoteId,
-	}))
-}
-
 // sendEOF sends EOF to the server. RFC 4254 Section 5.3
 // sendEOF sends EOF to the server. RFC 4254 Section 5.3
 func (c *channel) sendEOF() error {
 func (c *channel) sendEOF() error {
 	return c.writePacket(marshal(msgChannelEOF, channelEOFMsg{
 	return c.writePacket(marshal(msgChannelEOF, channelEOFMsg{
@@ -121,21 +111,36 @@ func (c *channel) sendChannelOpenFailure(reason RejectionReason, message string)
 }
 }
 
 
 func (c *channel) writePacket(b []byte) error {
 func (c *channel) writePacket(b []byte) error {
+	if c.closed() {
+		return io.EOF
+	}
 	if uint32(len(b)) > c.maxPacket {
 	if uint32(len(b)) > c.maxPacket {
 		return fmt.Errorf("ssh: cannot write %d bytes, maxPacket is %d bytes", len(b), c.maxPacket)
 		return fmt.Errorf("ssh: cannot write %d bytes, maxPacket is %d bytes", len(b), c.maxPacket)
 	}
 	}
 	return c.conn.writePacket(b)
 	return c.conn.writePacket(b)
 }
 }
 
 
+func (c *channel) closed() bool {
+	return atomic.LoadUint32(&c.isclosed) > 0
+}
+
+func (c *channel) setClosed() bool {
+	return atomic.CompareAndSwapUint32(&c.isclosed, 0, 1)
+}
+
 type serverChan struct {
 type serverChan struct {
 	channel
 	channel
 	// immutable once created
 	// immutable once created
 	chanType  string
 	chanType  string
 	extraData []byte
 	extraData []byte
 
 
-	serverConn *ServerConn
-	myWindow   uint32
-	err        error
+	serverConn  *ServerConn
+	myWindow    uint32
+	weClosed    bool // incidates the close msg has been sent from our side
+	theyClosed  bool // indicates the close msg has been received from the remote side
+	theySentEOF bool
+	dead        bool
+	err         error
 
 
 	pendingRequests []ChannelRequest
 	pendingRequests []ChannelRequest
 	pendingData     []byte
 	pendingData     []byte
@@ -393,6 +398,13 @@ func (c *serverChan) Close() error {
 	return c.sendClose()
 	return c.sendClose()
 }
 }
 
 
+// sendClose signals the intent to close the channel.
+func (c *serverChan) sendClose() error {
+	return c.writePacket(marshal(msgChannelClose, channelCloseMsg{
+		PeersId: c.remoteId,
+	}))
+}
+
 func (c *serverChan) AckRequest(ok bool) error {
 func (c *serverChan) AckRequest(ok bool) error {
 	c.serverConn.lock.Lock()
 	c.serverConn.lock.Lock()
 	defer c.serverConn.lock.Unlock()
 	defer c.serverConn.lock.Unlock()
@@ -477,13 +489,17 @@ func (c *clientChan) waitForChannelOpenResponse() error {
 	return errors.New("ssh: unexpected packet")
 	return errors.New("ssh: unexpected packet")
 }
 }
 
 
-// Close closes the channel. This does not close the underlying connection.
 func (c *clientChan) Close() error {
 func (c *clientChan) Close() error {
-	if !c.weClosed {
-		c.weClosed = true
-		return c.sendClose()
+	if !c.setClosed() {
+		return errors.New("ssh: channel already closed")
 	}
 	}
-	return nil
+	c.stdout.eof()
+	c.stderr.eof()
+	close(c.msg)
+	// TODO(dfc) step around channel.writePacket() because closed() is now true
+	return c.channel.conn.writePacket(marshal(msgChannelClose, channelCloseMsg{
+		PeersId: c.remoteId,
+	}))
 }
 }
 
 
 // A chanWriter represents the stdin of a remote process.
 // A chanWriter represents the stdin of a remote process.

+ 4 - 13
ssh/client.go

@@ -269,14 +269,9 @@ func (c *ClientConn) mainLoop() {
 				if !ok {
 				if !ok {
 					return
 					return
 				}
 				}
-				ch.theyClosed = true
-				ch.stdout.eof()
-				ch.stderr.eof()
-				close(ch.msg)
-				if !ch.weClosed {
-					ch.weClosed = true
-					ch.sendClose()
-				}
+				ch.Close()
+				// TODO(dfc) may need to optimisically remove the 
+				// channel before closing
 				c.chanList.remove(msg.PeersId)
 				c.chanList.remove(msg.PeersId)
 			case *channelEOFMsg:
 			case *channelEOFMsg:
 				ch, ok := c.getChan(msg.PeersId)
 				ch, ok := c.getChan(msg.PeersId)
@@ -506,10 +501,6 @@ func (c *chanList) closeAll() {
 		if ch == nil {
 		if ch == nil {
 			continue
 			continue
 		}
 		}
-
-		ch.theyClosed = true
-		ch.stdout.eof()
-		ch.stderr.eof()
-		close(ch.msg)
+		ch.Close()
 	}
 	}
 }
 }

+ 1 - 1
ssh/session.go

@@ -370,7 +370,7 @@ func (s *Session) stdin() {
 	}
 	}
 	s.copyFuncs = append(s.copyFuncs, func() error {
 	s.copyFuncs = append(s.copyFuncs, func() error {
 		_, err := io.Copy(s.clientChan.stdin, s.Stdin)
 		_, err := io.Copy(s.clientChan.stdin, s.Stdin)
-		if err1 := s.clientChan.stdin.Close(); err == nil {
+		if err1 := s.clientChan.stdin.Close(); err == nil && err1 != io.EOF {
 			err = err1
 			err = err1
 		}
 		}
 		return err
 		return err

+ 48 - 0
ssh/session_test.go

@@ -380,6 +380,54 @@ func TestServerStdoutRespectsMaxPacketSize(t *testing.T) {
 	}
 	}
 }
 }
 
 
+// TODO(dfc) currently writes succeed after Close()
+func testClientCannotSendAfterEOF(t *testing.T) {
+	conn := dial(shellHandler, t)
+	defer conn.Close()
+	session, err := conn.NewSession()
+	if err != nil {
+		t.Fatalf("Unable to request new session: %v", err)
+	}
+	defer session.Close()
+	in, err := session.StdinPipe()
+	if err != nil {
+		t.Fatalf("Unable to connect channel stdin: %v", err)
+	}
+	if err := session.Shell(); err != nil {
+		t.Fatalf("Unable to execute command: %v", err)
+	}
+	if err := in.Close(); err != nil {
+		t.Fatalf("Unable to close stdin: %v", err)
+	}
+	if _, err := in.Write([]byte("foo")); err == nil {
+		t.Fatalf("Session write should fail")
+	}
+}
+
+func TestClientCannotSendAfterClose(t *testing.T) {
+	conn := dial(shellHandler, t)
+	defer conn.Close()
+	session, err := conn.NewSession()
+	if err != nil {
+		t.Fatalf("Unable to request new session: %v", err)
+	}
+	defer session.Close()
+	in, err := session.StdinPipe()
+	if err != nil {
+		t.Fatalf("Unable to connect channel stdin: %v", err)
+	}
+	if err := session.Shell(); err != nil {
+		t.Fatalf("Unable to execute command: %v", err)
+	}
+	// close underlying channel
+	if err := session.channel.Close(); err != nil {
+		t.Fatalf("Unable to close session: %v", err)
+	}
+	if _, err := in.Write([]byte("foo")); err == nil {
+		t.Fatalf("Session write should fail")
+	}
+}
+
 type exitStatusMsg struct {
 type exitStatusMsg struct {
 	PeersId   uint32
 	PeersId   uint32
 	Request   string
 	Request   string