Ver código fonte

go.crypto/ssh: avoid recover() when handling invalid channel ids

This proposal removes the use of recover() to catch
invalid channel ids sent from the remote side. The
recover() unfortuntaly makes debugging harder as it
obscures other panic causes.

Another source of panic()s exists inside marshal.go,
which will be handled with in a later CL.

R=agl, gustav.paul
CC=golang-dev
https://golang.org/cl/6404046
Dave Cheney 13 anos atrás
pai
commit
f77e98d970
1 arquivos alterados com 52 adições e 19 exclusões
  1. 52 19
      ssh/client.go

+ 52 - 19
ssh/client.go

@@ -193,11 +193,6 @@ func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
 // to their respective ClientChans.
 func (c *ClientConn) mainLoop() {
 	defer func() {
-		// We don't check, for example, that the channel IDs from the
-		// server are valid before using them. Thus a bad server can
-		// cause us to panic, but we don't want to crash the program.
-		recover()
-
 		c.Close()
 		c.closeAll()
 	}()
@@ -224,7 +219,11 @@ func (c *ClientConn) mainLoop() {
 			if length != uint32(len(packet)) {
 				return
 			}
-			c.getChan(remoteId).stdout.write(packet)
+			ch, ok := c.getChan(remoteId)
+			if !ok {
+				return
+			}
+			ch.stdout.write(packet)
 		case msgChannelExtendedData:
 			if len(packet) < 13 {
 				// malformed data packet
@@ -242,18 +241,33 @@ func (c *ClientConn) mainLoop() {
 			// for stderr on interactive sessions. Other data types are
 			// silently discarded.
 			if datatype == 1 {
-				c.getChan(remoteId).stderr.write(packet)
+				ch, ok := c.getChan(remoteId)
+				if !ok {
+					return
+				}
+				ch.stderr.write(packet)
 			}
 		default:
 			switch msg := decode(packet).(type) {
 			case *channelOpenMsg:
 				c.handleChanOpen(msg)
 			case *channelOpenConfirmMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *channelOpenFailureMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *channelCloseMsg:
-				ch := c.getChan(msg.PeersId)
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
 				ch.theyClosed = true
 				ch.stdout.eof()
 				ch.stderr.eof()
@@ -264,19 +278,38 @@ func (c *ClientConn) mainLoop() {
 				}
 				c.chanList.remove(msg.PeersId)
 			case *channelEOFMsg:
-				ch := c.getChan(msg.PeersId)
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
 				ch.stdout.eof()
 				// RFC 4254 is mute on how EOF affects dataExt messages but
 				// it is logical to signal EOF at the same time.
 				ch.stderr.eof()
 			case *channelRequestSuccessMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *channelRequestFailureMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *channelRequestMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *windowAdjustMsg:
-				if !c.getChan(msg.PeersId).remoteWin.add(msg.AdditionalBytes) {
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				if !ch.remoteWin.add(msg.AdditionalBytes) {
 					// invalid window update
 					return
 				}
@@ -509,19 +542,19 @@ func (c *chanList) newChan(t *transport) *clientChan {
 	return ch
 }
 
-func (c *chanList) getChan(id uint32) *clientChan {
+func (c *chanList) getChan(id uint32) (*clientChan, bool) {
 	c.Lock()
 	defer c.Unlock()
 	if id >= uint32(len(c.chans)) {
-		return nil
+		return nil, false
 	}
-	return c.chans[int(id)]
+	return c.chans[id], true
 }
 
 func (c *chanList) remove(id uint32) {
 	c.Lock()
 	defer c.Unlock()
-	c.chans[int(id)] = nil
+	c.chans[id] = nil
 }
 
 func (c *chanList) closeAll() {