瀏覽代碼

go.crypto/ssh: avoid recover() when handling invalid channel ids

This proposal removes the use of recover() to catch
invalid channel ids sent from the remote side. The
recover() unfortuntaly makes debugging harder as it
obscures other panic causes.

Another source of panic()s exists inside marshal.go,
which will be handled with in a later CL.

R=agl, gustav.paul
CC=golang-dev
https://golang.org/cl/6404046
Dave Cheney 13 年之前
父節點
當前提交
f77e98d970
共有 1 個文件被更改,包括 52 次插入19 次删除
  1. 52 19
      ssh/client.go

+ 52 - 19
ssh/client.go

@@ -193,11 +193,6 @@ func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
 // to their respective ClientChans.
 func (c *ClientConn) mainLoop() {
 	defer func() {
-		// We don't check, for example, that the channel IDs from the
-		// server are valid before using them. Thus a bad server can
-		// cause us to panic, but we don't want to crash the program.
-		recover()
-
 		c.Close()
 		c.closeAll()
 	}()
@@ -224,7 +219,11 @@ func (c *ClientConn) mainLoop() {
 			if length != uint32(len(packet)) {
 				return
 			}
-			c.getChan(remoteId).stdout.write(packet)
+			ch, ok := c.getChan(remoteId)
+			if !ok {
+				return
+			}
+			ch.stdout.write(packet)
 		case msgChannelExtendedData:
 			if len(packet) < 13 {
 				// malformed data packet
@@ -242,18 +241,33 @@ func (c *ClientConn) mainLoop() {
 			// for stderr on interactive sessions. Other data types are
 			// silently discarded.
 			if datatype == 1 {
-				c.getChan(remoteId).stderr.write(packet)
+				ch, ok := c.getChan(remoteId)
+				if !ok {
+					return
+				}
+				ch.stderr.write(packet)
 			}
 		default:
 			switch msg := decode(packet).(type) {
 			case *channelOpenMsg:
 				c.handleChanOpen(msg)
 			case *channelOpenConfirmMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *channelOpenFailureMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *channelCloseMsg:
-				ch := c.getChan(msg.PeersId)
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
 				ch.theyClosed = true
 				ch.stdout.eof()
 				ch.stderr.eof()
@@ -264,19 +278,38 @@ func (c *ClientConn) mainLoop() {
 				}
 				c.chanList.remove(msg.PeersId)
 			case *channelEOFMsg:
-				ch := c.getChan(msg.PeersId)
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
 				ch.stdout.eof()
 				// RFC 4254 is mute on how EOF affects dataExt messages but
 				// it is logical to signal EOF at the same time.
 				ch.stderr.eof()
 			case *channelRequestSuccessMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *channelRequestFailureMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *channelRequestMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *windowAdjustMsg:
-				if !c.getChan(msg.PeersId).remoteWin.add(msg.AdditionalBytes) {
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				if !ch.remoteWin.add(msg.AdditionalBytes) {
 					// invalid window update
 					return
 				}
@@ -509,19 +542,19 @@ func (c *chanList) newChan(t *transport) *clientChan {
 	return ch
 }
 
-func (c *chanList) getChan(id uint32) *clientChan {
+func (c *chanList) getChan(id uint32) (*clientChan, bool) {
 	c.Lock()
 	defer c.Unlock()
 	if id >= uint32(len(c.chans)) {
-		return nil
+		return nil, false
 	}
-	return c.chans[int(id)]
+	return c.chans[id], true
 }
 
 func (c *chanList) remove(id uint32) {
 	c.Lock()
 	defer c.Unlock()
-	c.chans[int(id)] = nil
+	c.chans[id] = nil
 }
 
 func (c *chanList) closeAll() {