Преглед на файлове

ssh: handle bad servers better.

This change prevents bad servers from crashing a client by sending an
invalid channel ID. It also makes the client disconnect in more cases
of invalid messages from a server and cleans up the client channels
in the event of a disconnect.

R=dave
CC=golang-dev
https://golang.org/cl/6099050
Adam Langley преди 13 години
родител
ревизия
bcdd6a2fd3
променени са 2 файла, в които са добавени 75 реда и са изтрити 17 реда
  1. 50 17
      ssh/client.go
  2. 25 0
      ssh/session_test.go

+ 50 - 17
ssh/client.go

@@ -184,8 +184,16 @@ func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
 // mainLoop reads incoming messages and routes channel messages
 // to their respective ClientChans.
 func (c *ClientConn) mainLoop() {
-	// TODO(dfc) signal the underlying close to all channels
-	defer c.Close()
+	defer func() {
+		// We don't check, for example, that the channel IDs from the
+		// server are valid before using them. Thus a bad server can
+		// cause us to panic, but we don't want to crash the program.
+		recover()
+
+		c.Close()
+		c.closeAll()
+	}()
+
 	for {
 		packet, err := c.readPacket()
 		if err != nil {
@@ -199,28 +207,34 @@ func (c *ClientConn) mainLoop() {
 		case msgChannelData:
 			if len(packet) < 9 {
 				// malformed data packet
-				break
+				return
 			}
 			peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
-			if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
-				packet = packet[9:]
-				c.getChan(peersId).stdout.handleData(packet[:length])
+			length := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8])
+			packet = packet[9:]
+
+			if length != uint32(len(packet)) {
+				return
 			}
+			c.getChan(peersId).stdout.handleData(packet)
 		case msgChannelExtendedData:
 			if len(packet) < 13 {
 				// malformed data packet
-				break
+				return
 			}
 			peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
 			datatype := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8])
-			if length := int(packet[9])<<24 | int(packet[10])<<16 | int(packet[11])<<8 | int(packet[12]); length > 0 {
-				packet = packet[13:]
-				// RFC 4254 5.2 defines data_type_code 1 to be data destined
-				// for stderr on interactive sessions. Other data types are
-				// silently discarded.
-				if datatype == 1 {
-					c.getChan(peersId).stderr.handleData(packet[:length])
-				}
+			length := uint32(packet[9])<<24 | uint32(packet[10])<<16 | uint32(packet[11])<<8 | uint32(packet[12])
+			packet = packet[13:]
+
+			if length != uint32(len(packet)) {
+				return
+			}
+			// RFC 4254 5.2 defines data_type_code 1 to be data destined
+			// for stderr on interactive sessions. Other data types are
+			// silently discarded.
+			if datatype == 1 {
+				c.getChan(peersId).stderr.handleData(packet)
 			}
 		default:
 			switch msg := decode(packet).(type) {
@@ -256,10 +270,10 @@ func (c *ClientConn) mainLoop() {
 			case *windowAdjustMsg:
 				if !c.getChan(msg.PeersId).stdin.win.add(msg.AdditionalBytes) {
 					// invalid window update
-					break
+					return
 				}
 			case *disconnectMsg:
-				break
+				return
 			default:
 				fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg)
 			}
@@ -408,6 +422,9 @@ func (c *chanlist) newChan(t *transport) *clientChan {
 func (c *chanlist) getChan(id uint32) *clientChan {
 	c.Lock()
 	defer c.Unlock()
+	if id >= uint32(len(c.chans)) {
+		return nil
+	}
 	return c.chans[int(id)]
 }
 
@@ -417,6 +434,22 @@ func (c *chanlist) remove(id uint32) {
 	c.chans[int(id)] = nil
 }
 
+func (c *chanlist) closeAll() {
+	c.Lock()
+	defer c.Unlock()
+
+	for _, ch := range c.chans {
+		if ch == nil {
+			continue
+		}
+
+		ch.theyClosed = true
+		ch.stdout.eof()
+		ch.stderr.eof()
+		close(ch.msg)
+	}
+}
+
 // A chanWriter represents the stdin of a remote process.
 type chanWriter struct {
 	win        *window

+ 25 - 0
ssh/session_test.go

@@ -275,6 +275,20 @@ func TestExitWithoutStatusOrSignal(t *testing.T) {
 	}
 }
 
+func TestInvalidServerMessage(t *testing.T) {
+	conn := dial(sendInvalidRecord, t)
+	defer conn.Close()
+	session, err := conn.NewSession()
+	if err != nil {
+		t.Fatalf("Unable to request new session: %s", err)
+	}
+	// Make sure that we closed all the clientChans when the connection
+	// failed.
+	session.wait()
+
+	defer session.Close()
+}
+
 type exitStatusMsg struct {
 	PeersId   uint32
 	Request   string
@@ -373,3 +387,14 @@ func sendSignal(signal string, ch *channel) {
 	}
 	ch.serverConn.writePacket(marshal(msgChannelRequest, sig))
 }
+
+func sendInvalidRecord(ch *channel) {
+	defer ch.Close()
+	packet := make([]byte, 1+4+4+1)
+	packet[0] = msgChannelData
+	marshalUint32(packet[1:], 29348723 /* invalid channel id */)
+	marshalUint32(packet[5:], 1)
+	packet[9] = 42
+
+	ch.serverConn.writePacket(packet)
+}