浏览代码

go.crypto/ssh: avoid recover() when handling invalid channel ids

This proposal removes the use of recover() to catch
invalid channel ids sent from the remote side. The
recover() unfortuntaly makes debugging harder as it
obscures other panic causes.

Another source of panic()s exists inside marshal.go,
which will be handled with in a later CL.

R=agl, gustav.paul
CC=golang-dev
https://golang.org/cl/6404046
Dave Cheney 13 年之前
父节点
当前提交
f77e98d970
共有 1 个文件被更改,包括 52 次插入19 次删除
  1. 52 19
      ssh/client.go

+ 52 - 19
ssh/client.go

@@ -193,11 +193,6 @@ func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
 // to their respective ClientChans.
 // to their respective ClientChans.
 func (c *ClientConn) mainLoop() {
 func (c *ClientConn) mainLoop() {
 	defer func() {
 	defer func() {
-		// We don't check, for example, that the channel IDs from the
-		// server are valid before using them. Thus a bad server can
-		// cause us to panic, but we don't want to crash the program.
-		recover()
-
 		c.Close()
 		c.Close()
 		c.closeAll()
 		c.closeAll()
 	}()
 	}()
@@ -224,7 +219,11 @@ func (c *ClientConn) mainLoop() {
 			if length != uint32(len(packet)) {
 			if length != uint32(len(packet)) {
 				return
 				return
 			}
 			}
-			c.getChan(remoteId).stdout.write(packet)
+			ch, ok := c.getChan(remoteId)
+			if !ok {
+				return
+			}
+			ch.stdout.write(packet)
 		case msgChannelExtendedData:
 		case msgChannelExtendedData:
 			if len(packet) < 13 {
 			if len(packet) < 13 {
 				// malformed data packet
 				// malformed data packet
@@ -242,18 +241,33 @@ func (c *ClientConn) mainLoop() {
 			// for stderr on interactive sessions. Other data types are
 			// for stderr on interactive sessions. Other data types are
 			// silently discarded.
 			// silently discarded.
 			if datatype == 1 {
 			if datatype == 1 {
-				c.getChan(remoteId).stderr.write(packet)
+				ch, ok := c.getChan(remoteId)
+				if !ok {
+					return
+				}
+				ch.stderr.write(packet)
 			}
 			}
 		default:
 		default:
 			switch msg := decode(packet).(type) {
 			switch msg := decode(packet).(type) {
 			case *channelOpenMsg:
 			case *channelOpenMsg:
 				c.handleChanOpen(msg)
 				c.handleChanOpen(msg)
 			case *channelOpenConfirmMsg:
 			case *channelOpenConfirmMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *channelOpenFailureMsg:
 			case *channelOpenFailureMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *channelCloseMsg:
 			case *channelCloseMsg:
-				ch := c.getChan(msg.PeersId)
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
 				ch.theyClosed = true
 				ch.theyClosed = true
 				ch.stdout.eof()
 				ch.stdout.eof()
 				ch.stderr.eof()
 				ch.stderr.eof()
@@ -264,19 +278,38 @@ func (c *ClientConn) mainLoop() {
 				}
 				}
 				c.chanList.remove(msg.PeersId)
 				c.chanList.remove(msg.PeersId)
 			case *channelEOFMsg:
 			case *channelEOFMsg:
-				ch := c.getChan(msg.PeersId)
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
 				ch.stdout.eof()
 				ch.stdout.eof()
 				// RFC 4254 is mute on how EOF affects dataExt messages but
 				// RFC 4254 is mute on how EOF affects dataExt messages but
 				// it is logical to signal EOF at the same time.
 				// it is logical to signal EOF at the same time.
 				ch.stderr.eof()
 				ch.stderr.eof()
 			case *channelRequestSuccessMsg:
 			case *channelRequestSuccessMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *channelRequestFailureMsg:
 			case *channelRequestFailureMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *channelRequestMsg:
 			case *channelRequestMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *windowAdjustMsg:
 			case *windowAdjustMsg:
-				if !c.getChan(msg.PeersId).remoteWin.add(msg.AdditionalBytes) {
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				if !ch.remoteWin.add(msg.AdditionalBytes) {
 					// invalid window update
 					// invalid window update
 					return
 					return
 				}
 				}
@@ -509,19 +542,19 @@ func (c *chanList) newChan(t *transport) *clientChan {
 	return ch
 	return ch
 }
 }
 
 
-func (c *chanList) getChan(id uint32) *clientChan {
+func (c *chanList) getChan(id uint32) (*clientChan, bool) {
 	c.Lock()
 	c.Lock()
 	defer c.Unlock()
 	defer c.Unlock()
 	if id >= uint32(len(c.chans)) {
 	if id >= uint32(len(c.chans)) {
-		return nil
+		return nil, false
 	}
 	}
-	return c.chans[int(id)]
+	return c.chans[id], true
 }
 }
 
 
 func (c *chanList) remove(id uint32) {
 func (c *chanList) remove(id uint32) {
 	c.Lock()
 	c.Lock()
 	defer c.Unlock()
 	defer c.Unlock()
-	c.chans[int(id)] = nil
+	c.chans[id] = nil
 }
 }
 
 
 func (c *chanList) closeAll() {
 func (c *chanList) closeAll() {