|
|
@@ -66,8 +66,8 @@ type handshakeTransport struct {
|
|
|
|
|
|
// If the read loop wants to schedule a kex, it pings this
|
|
|
// channel, and the write loop will send out a kex
|
|
|
- // message. The boolean is whether this is the first request or not.
|
|
|
- requestKex chan bool
|
|
|
+ // message.
|
|
|
+ requestKex chan struct{}
|
|
|
|
|
|
// If the other side requests or confirms a kex, its kexInit
|
|
|
// packet is sent here for the write loop to find it.
|
|
|
@@ -102,14 +102,14 @@ func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion,
|
|
|
serverVersion: serverVersion,
|
|
|
clientVersion: clientVersion,
|
|
|
incoming: make(chan []byte, chanSize),
|
|
|
- requestKex: make(chan bool, 1),
|
|
|
+ requestKex: make(chan struct{}, 1),
|
|
|
startKex: make(chan *pendingKex, 1),
|
|
|
|
|
|
config: config,
|
|
|
}
|
|
|
|
|
|
// We always start with a mandatory key exchange.
|
|
|
- t.requestKex <- true
|
|
|
+ t.requestKex <- struct{}{}
|
|
|
return t
|
|
|
}
|
|
|
|
|
|
@@ -166,6 +166,7 @@ func (t *handshakeTransport) printPacket(p []byte, write bool) {
|
|
|
if write {
|
|
|
action = "sent"
|
|
|
}
|
|
|
+
|
|
|
if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
|
|
|
log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p))
|
|
|
} else {
|
|
|
@@ -230,14 +231,13 @@ func (t *handshakeTransport) recordWriteError(err error) {
|
|
|
|
|
|
func (t *handshakeTransport) requestKeyExchange() {
|
|
|
select {
|
|
|
- case t.requestKex <- false:
|
|
|
+ case t.requestKex <- struct{}{}:
|
|
|
default:
|
|
|
// something already requested a kex, so do nothing.
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func (t *handshakeTransport) kexLoop() {
|
|
|
- firstSent := false
|
|
|
|
|
|
write:
|
|
|
for t.getWriteError() == nil {
|
|
|
@@ -251,18 +251,8 @@ write:
|
|
|
if !ok {
|
|
|
break write
|
|
|
}
|
|
|
- case requestFirst := <-t.requestKex:
|
|
|
- // For the first key exchange, both
|
|
|
- // sides will initiate a key exchange,
|
|
|
- // and both channels will fire. To
|
|
|
- // avoid doing two key exchanges in a
|
|
|
- // row, ignore our own request for an
|
|
|
- // initial kex if we have already sent
|
|
|
- // it out.
|
|
|
- if firstSent && requestFirst {
|
|
|
-
|
|
|
- continue
|
|
|
- }
|
|
|
+ case <-t.requestKex:
|
|
|
+ break
|
|
|
}
|
|
|
|
|
|
if !sent {
|
|
|
@@ -270,7 +260,6 @@ write:
|
|
|
t.recordWriteError(err)
|
|
|
break
|
|
|
}
|
|
|
- firstSent = true
|
|
|
sent = true
|
|
|
}
|
|
|
}
|
|
|
@@ -287,7 +276,8 @@ write:
|
|
|
|
|
|
// We're not servicing t.startKex, but the remote end
|
|
|
// has just sent us a kexInitMsg, so it can't send
|
|
|
- // another key change request.
|
|
|
+ // another key change request, until we close the done
|
|
|
+ // channel on the pendingKex request.
|
|
|
|
|
|
err := t.enterKeyExchange(request.otherInit)
|
|
|
|
|
|
@@ -301,6 +291,23 @@ write:
|
|
|
} else if t.algorithms != nil {
|
|
|
t.writeBytesLeft = t.algorithms.w.rekeyBytes()
|
|
|
}
|
|
|
+
|
|
|
+ // we have completed the key exchange. Since the
|
|
|
+ // reader is still blocked, it is safe to clear out
|
|
|
+ // the requestKex channel. This avoids the situation
|
|
|
+ // where: 1) we consumed our own request for the
|
|
|
+ // initial kex, and 2) the kex from the remote side
|
|
|
+ // caused another send on the requestKex channel,
|
|
|
+ clear:
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-t.requestKex:
|
|
|
+ //
|
|
|
+ default:
|
|
|
+ break clear
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
request.done <- t.writeError
|
|
|
|
|
|
// kex finished. Push packets that we received while
|