|
|
@@ -184,8 +184,16 @@ func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
|
|
|
// mainLoop reads incoming messages and routes channel messages
|
|
|
// to their respective ClientChans.
|
|
|
func (c *ClientConn) mainLoop() {
|
|
|
- // TODO(dfc) signal the underlying close to all channels
|
|
|
- defer c.Close()
|
|
|
+ defer func() {
|
|
|
+ // We don't check, for example, that the channel IDs from the
|
|
|
+ // server are valid before using them. Thus a bad server can
|
|
|
+ // cause us to panic, but we don't want to crash the program.
|
|
|
+ recover()
|
|
|
+
|
|
|
+ c.Close()
|
|
|
+ c.closeAll()
|
|
|
+ }()
|
|
|
+
|
|
|
for {
|
|
|
packet, err := c.readPacket()
|
|
|
if err != nil {
|
|
|
@@ -199,28 +207,34 @@ func (c *ClientConn) mainLoop() {
|
|
|
case msgChannelData:
|
|
|
if len(packet) < 9 {
|
|
|
// malformed data packet
|
|
|
- break
|
|
|
+ return
|
|
|
}
|
|
|
peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
|
|
|
- if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
|
|
|
- packet = packet[9:]
|
|
|
- c.getChan(peersId).stdout.handleData(packet[:length])
|
|
|
+ length := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8])
|
|
|
+ packet = packet[9:]
|
|
|
+
|
|
|
+ if length != uint32(len(packet)) {
|
|
|
+ return
|
|
|
}
|
|
|
+ c.getChan(peersId).stdout.handleData(packet)
|
|
|
case msgChannelExtendedData:
|
|
|
if len(packet) < 13 {
|
|
|
// malformed data packet
|
|
|
- break
|
|
|
+ return
|
|
|
}
|
|
|
peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
|
|
|
datatype := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8])
|
|
|
- if length := int(packet[9])<<24 | int(packet[10])<<16 | int(packet[11])<<8 | int(packet[12]); length > 0 {
|
|
|
- packet = packet[13:]
|
|
|
- // RFC 4254 5.2 defines data_type_code 1 to be data destined
|
|
|
- // for stderr on interactive sessions. Other data types are
|
|
|
- // silently discarded.
|
|
|
- if datatype == 1 {
|
|
|
- c.getChan(peersId).stderr.handleData(packet[:length])
|
|
|
- }
|
|
|
+ length := uint32(packet[9])<<24 | uint32(packet[10])<<16 | uint32(packet[11])<<8 | uint32(packet[12])
|
|
|
+ packet = packet[13:]
|
|
|
+
|
|
|
+ if length != uint32(len(packet)) {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ // RFC 4254 5.2 defines data_type_code 1 to be data destined
|
|
|
+ // for stderr on interactive sessions. Other data types are
|
|
|
+ // silently discarded.
|
|
|
+ if datatype == 1 {
|
|
|
+ c.getChan(peersId).stderr.handleData(packet)
|
|
|
}
|
|
|
default:
|
|
|
switch msg := decode(packet).(type) {
|
|
|
@@ -256,10 +270,10 @@ func (c *ClientConn) mainLoop() {
|
|
|
case *windowAdjustMsg:
|
|
|
if !c.getChan(msg.PeersId).stdin.win.add(msg.AdditionalBytes) {
|
|
|
// invalid window update
|
|
|
- break
|
|
|
+ return
|
|
|
}
|
|
|
case *disconnectMsg:
|
|
|
- break
|
|
|
+ return
|
|
|
default:
|
|
|
fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg)
|
|
|
}
|
|
|
@@ -408,6 +422,9 @@ func (c *chanlist) newChan(t *transport) *clientChan {
|
|
|
func (c *chanlist) getChan(id uint32) *clientChan {
|
|
|
c.Lock()
|
|
|
defer c.Unlock()
|
|
|
+ if id >= uint32(len(c.chans)) {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
return c.chans[int(id)]
|
|
|
}
|
|
|
|
|
|
@@ -417,6 +434,22 @@ func (c *chanlist) remove(id uint32) {
|
|
|
c.chans[int(id)] = nil
|
|
|
}
|
|
|
|
|
|
+func (c *chanlist) closeAll() {
|
|
|
+ c.Lock()
|
|
|
+ defer c.Unlock()
|
|
|
+
|
|
|
+ for _, ch := range c.chans {
|
|
|
+ if ch == nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ ch.theyClosed = true
|
|
|
+ ch.stdout.eof()
|
|
|
+ ch.stderr.eof()
|
|
|
+ close(ch.msg)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
// A chanWriter represents the stdin of a remote process.
|
|
|
type chanWriter struct {
|
|
|
win *window
|