Browse Source

go.crypto/ssh: avoid recover() when handling invalid channel ids

This proposal removes the use of recover() to catch
invalid channel ids sent from the remote side. The
recover() unfortuntaly makes debugging harder as it
obscures other panic causes.

Another source of panic()s exists inside marshal.go,
which will be handled with in a later CL.

R=agl, gustav.paul
CC=golang-dev
https://golang.org/cl/6404046
Dave Cheney 13 năm trước cách đây
mục cha
commit
f77e98d970
1 tập tin đã thay đổi với 52 bổ sung19 xóa
  1. 52 19
      ssh/client.go

+ 52 - 19
ssh/client.go

@@ -193,11 +193,6 @@ func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
 // to their respective ClientChans.
 func (c *ClientConn) mainLoop() {
 	defer func() {
-		// We don't check, for example, that the channel IDs from the
-		// server are valid before using them. Thus a bad server can
-		// cause us to panic, but we don't want to crash the program.
-		recover()
-
 		c.Close()
 		c.closeAll()
 	}()
@@ -224,7 +219,11 @@ func (c *ClientConn) mainLoop() {
 			if length != uint32(len(packet)) {
 				return
 			}
-			c.getChan(remoteId).stdout.write(packet)
+			ch, ok := c.getChan(remoteId)
+			if !ok {
+				return
+			}
+			ch.stdout.write(packet)
 		case msgChannelExtendedData:
 			if len(packet) < 13 {
 				// malformed data packet
@@ -242,18 +241,33 @@ func (c *ClientConn) mainLoop() {
 			// for stderr on interactive sessions. Other data types are
 			// silently discarded.
 			if datatype == 1 {
-				c.getChan(remoteId).stderr.write(packet)
+				ch, ok := c.getChan(remoteId)
+				if !ok {
+					return
+				}
+				ch.stderr.write(packet)
 			}
 		default:
 			switch msg := decode(packet).(type) {
 			case *channelOpenMsg:
 				c.handleChanOpen(msg)
 			case *channelOpenConfirmMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *channelOpenFailureMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *channelCloseMsg:
-				ch := c.getChan(msg.PeersId)
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
 				ch.theyClosed = true
 				ch.stdout.eof()
 				ch.stderr.eof()
@@ -264,19 +278,38 @@ func (c *ClientConn) mainLoop() {
 				}
 				c.chanList.remove(msg.PeersId)
 			case *channelEOFMsg:
-				ch := c.getChan(msg.PeersId)
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
 				ch.stdout.eof()
 				// RFC 4254 is mute on how EOF affects dataExt messages but
 				// it is logical to signal EOF at the same time.
 				ch.stderr.eof()
 			case *channelRequestSuccessMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *channelRequestFailureMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *channelRequestMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *windowAdjustMsg:
-				if !c.getChan(msg.PeersId).remoteWin.add(msg.AdditionalBytes) {
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				if !ch.remoteWin.add(msg.AdditionalBytes) {
 					// invalid window update
 					return
 				}
@@ -509,19 +542,19 @@ func (c *chanList) newChan(t *transport) *clientChan {
 	return ch
 }
 
-func (c *chanList) getChan(id uint32) *clientChan {
+func (c *chanList) getChan(id uint32) (*clientChan, bool) {
 	c.Lock()
 	defer c.Unlock()
 	if id >= uint32(len(c.chans)) {
-		return nil
+		return nil, false
 	}
-	return c.chans[int(id)]
+	return c.chans[id], true
 }
 
 func (c *chanList) remove(id uint32) {
 	c.Lock()
 	defer c.Unlock()
-	c.chans[int(id)] = nil
+	c.chans[id] = nil
 }
 
 func (c *chanList) closeAll() {