Browse Source

go.crypto/ssh: avoid recover() when handling invalid channel ids

This proposal removes the use of recover() to catch
invalid channel ids sent from the remote side. The
recover() unfortuntaly makes debugging harder as it
obscures other panic causes.

Another source of panic()s exists inside marshal.go,
which will be handled with in a later CL.

R=agl, gustav.paul
CC=golang-dev
https://golang.org/cl/6404046
Dave Cheney 13 years ago
parent
commit
f77e98d970
1 changed files with 52 additions and 19 deletions
  1. 52 19
      ssh/client.go

+ 52 - 19
ssh/client.go

@@ -193,11 +193,6 @@ func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
 // to their respective ClientChans.
 // to their respective ClientChans.
 func (c *ClientConn) mainLoop() {
 func (c *ClientConn) mainLoop() {
 	defer func() {
 	defer func() {
-		// We don't check, for example, that the channel IDs from the
-		// server are valid before using them. Thus a bad server can
-		// cause us to panic, but we don't want to crash the program.
-		recover()
-
 		c.Close()
 		c.Close()
 		c.closeAll()
 		c.closeAll()
 	}()
 	}()
@@ -224,7 +219,11 @@ func (c *ClientConn) mainLoop() {
 			if length != uint32(len(packet)) {
 			if length != uint32(len(packet)) {
 				return
 				return
 			}
 			}
-			c.getChan(remoteId).stdout.write(packet)
+			ch, ok := c.getChan(remoteId)
+			if !ok {
+				return
+			}
+			ch.stdout.write(packet)
 		case msgChannelExtendedData:
 		case msgChannelExtendedData:
 			if len(packet) < 13 {
 			if len(packet) < 13 {
 				// malformed data packet
 				// malformed data packet
@@ -242,18 +241,33 @@ func (c *ClientConn) mainLoop() {
 			// for stderr on interactive sessions. Other data types are
 			// for stderr on interactive sessions. Other data types are
 			// silently discarded.
 			// silently discarded.
 			if datatype == 1 {
 			if datatype == 1 {
-				c.getChan(remoteId).stderr.write(packet)
+				ch, ok := c.getChan(remoteId)
+				if !ok {
+					return
+				}
+				ch.stderr.write(packet)
 			}
 			}
 		default:
 		default:
 			switch msg := decode(packet).(type) {
 			switch msg := decode(packet).(type) {
 			case *channelOpenMsg:
 			case *channelOpenMsg:
 				c.handleChanOpen(msg)
 				c.handleChanOpen(msg)
 			case *channelOpenConfirmMsg:
 			case *channelOpenConfirmMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *channelOpenFailureMsg:
 			case *channelOpenFailureMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *channelCloseMsg:
 			case *channelCloseMsg:
-				ch := c.getChan(msg.PeersId)
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
 				ch.theyClosed = true
 				ch.theyClosed = true
 				ch.stdout.eof()
 				ch.stdout.eof()
 				ch.stderr.eof()
 				ch.stderr.eof()
@@ -264,19 +278,38 @@ func (c *ClientConn) mainLoop() {
 				}
 				}
 				c.chanList.remove(msg.PeersId)
 				c.chanList.remove(msg.PeersId)
 			case *channelEOFMsg:
 			case *channelEOFMsg:
-				ch := c.getChan(msg.PeersId)
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
 				ch.stdout.eof()
 				ch.stdout.eof()
 				// RFC 4254 is mute on how EOF affects dataExt messages but
 				// RFC 4254 is mute on how EOF affects dataExt messages but
 				// it is logical to signal EOF at the same time.
 				// it is logical to signal EOF at the same time.
 				ch.stderr.eof()
 				ch.stderr.eof()
 			case *channelRequestSuccessMsg:
 			case *channelRequestSuccessMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *channelRequestFailureMsg:
 			case *channelRequestFailureMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *channelRequestMsg:
 			case *channelRequestMsg:
-				c.getChan(msg.PeersId).msg <- msg
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				ch.msg <- msg
 			case *windowAdjustMsg:
 			case *windowAdjustMsg:
-				if !c.getChan(msg.PeersId).remoteWin.add(msg.AdditionalBytes) {
+				ch, ok := c.getChan(msg.PeersId)
+				if !ok {
+					return
+				}
+				if !ch.remoteWin.add(msg.AdditionalBytes) {
 					// invalid window update
 					// invalid window update
 					return
 					return
 				}
 				}
@@ -509,19 +542,19 @@ func (c *chanList) newChan(t *transport) *clientChan {
 	return ch
 	return ch
 }
 }
 
 
-func (c *chanList) getChan(id uint32) *clientChan {
+func (c *chanList) getChan(id uint32) (*clientChan, bool) {
 	c.Lock()
 	c.Lock()
 	defer c.Unlock()
 	defer c.Unlock()
 	if id >= uint32(len(c.chans)) {
 	if id >= uint32(len(c.chans)) {
-		return nil
+		return nil, false
 	}
 	}
-	return c.chans[int(id)]
+	return c.chans[id], true
 }
 }
 
 
 func (c *chanList) remove(id uint32) {
 func (c *chanList) remove(id uint32) {
 	c.Lock()
 	c.Lock()
 	defer c.Unlock()
 	defer c.Unlock()
-	c.chans[int(id)] = nil
+	c.chans[id] = nil
 }
 }
 
 
 func (c *chanList) closeAll() {
 func (c *chanList) closeAll() {