|
@@ -184,8 +184,16 @@ func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
|
|
|
|
|
|
|
|
|
func (c *ClientConn) mainLoop() {
|
|
|
-
|
|
|
- defer c.Close()
|
|
|
+ defer func() {
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ recover()
|
|
|
+
|
|
|
+ c.Close()
|
|
|
+ c.closeAll()
|
|
|
+ }()
|
|
|
+
|
|
|
for {
|
|
|
packet, err := c.readPacket()
|
|
|
if err != nil {
|
|
@@ -199,28 +207,34 @@ func (c *ClientConn) mainLoop() {
|
|
|
case msgChannelData:
|
|
|
if len(packet) < 9 {
|
|
|
|
|
|
- break
|
|
|
+ return
|
|
|
}
|
|
|
peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
|
|
|
- if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
|
|
|
- packet = packet[9:]
|
|
|
- c.getChan(peersId).stdout.handleData(packet[:length])
|
|
|
+ length := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8])
|
|
|
+ packet = packet[9:]
|
|
|
+
|
|
|
+ if length != uint32(len(packet)) {
|
|
|
+ return
|
|
|
}
|
|
|
+ c.getChan(peersId).stdout.handleData(packet)
|
|
|
case msgChannelExtendedData:
|
|
|
if len(packet) < 13 {
|
|
|
|
|
|
- break
|
|
|
+ return
|
|
|
}
|
|
|
peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
|
|
|
datatype := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8])
|
|
|
- if length := int(packet[9])<<24 | int(packet[10])<<16 | int(packet[11])<<8 | int(packet[12]); length > 0 {
|
|
|
- packet = packet[13:]
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
- if datatype == 1 {
|
|
|
- c.getChan(peersId).stderr.handleData(packet[:length])
|
|
|
- }
|
|
|
+ length := uint32(packet[9])<<24 | uint32(packet[10])<<16 | uint32(packet[11])<<8 | uint32(packet[12])
|
|
|
+ packet = packet[13:]
|
|
|
+
|
|
|
+ if length != uint32(len(packet)) {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ if datatype == 1 {
|
|
|
+ c.getChan(peersId).stderr.handleData(packet)
|
|
|
}
|
|
|
default:
|
|
|
switch msg := decode(packet).(type) {
|
|
@@ -256,10 +270,10 @@ func (c *ClientConn) mainLoop() {
|
|
|
case *windowAdjustMsg:
|
|
|
if !c.getChan(msg.PeersId).stdin.win.add(msg.AdditionalBytes) {
|
|
|
|
|
|
- break
|
|
|
+ return
|
|
|
}
|
|
|
case *disconnectMsg:
|
|
|
- break
|
|
|
+ return
|
|
|
default:
|
|
|
fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg)
|
|
|
}
|
|
@@ -408,6 +422,9 @@ func (c *chanlist) newChan(t *transport) *clientChan {
|
|
|
func (c *chanlist) getChan(id uint32) *clientChan {
|
|
|
c.Lock()
|
|
|
defer c.Unlock()
|
|
|
+ if id >= uint32(len(c.chans)) {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
return c.chans[int(id)]
|
|
|
}
|
|
|
|
|
@@ -417,6 +434,22 @@ func (c *chanlist) remove(id uint32) {
|
|
|
c.chans[int(id)] = nil
|
|
|
}
|
|
|
|
|
|
+func (c *chanlist) closeAll() {
|
|
|
+ c.Lock()
|
|
|
+ defer c.Unlock()
|
|
|
+
|
|
|
+ for _, ch := range c.chans {
|
|
|
+ if ch == nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ ch.theyClosed = true
|
|
|
+ ch.stdout.eof()
|
|
|
+ ch.stderr.eof()
|
|
|
+ close(ch.msg)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
|
|
|
type chanWriter struct {
|
|
|
win *window
|