|
|
@@ -29,25 +29,6 @@ type keyingTransport interface {
|
|
|
// direction will be effected if a msgNewKeys message is sent
|
|
|
// or received.
|
|
|
prepareKeyChange(*algorithms, *kexResult) error
|
|
|
-
|
|
|
- // getSessionID returns the session ID. prepareKeyChange must
|
|
|
- // have been called once.
|
|
|
- getSessionID() []byte
|
|
|
-}
|
|
|
-
|
|
|
-// rekeyingTransport is the interface of handshakeTransport that we
|
|
|
-// (internally) expose to ClientConn and ServerConn.
|
|
|
-type rekeyingTransport interface {
|
|
|
- packetConn
|
|
|
-
|
|
|
- // requestKeyChange asks the remote side to change keys. All
|
|
|
- // writes are blocked until the key change succeeds, which is
|
|
|
- // signaled by reading a msgNewKeys.
|
|
|
- requestKeyChange() error
|
|
|
-
|
|
|
- // getSessionID returns the session ID. This is only valid
|
|
|
- // after the first key change has completed.
|
|
|
- getSessionID() []byte
|
|
|
}
|
|
|
|
|
|
// handshakeTransport implements rekeying on top of a keyingTransport
|
|
|
@@ -86,6 +67,9 @@ type handshakeTransport struct {
|
|
|
sentInitMsg *kexInitMsg
|
|
|
writtenSinceKex uint64
|
|
|
writeError error
|
|
|
+
|
|
|
+ // The session ID or nil if first kex did not complete yet.
|
|
|
+ sessionID []byte
|
|
|
}
|
|
|
|
|
|
func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
|
|
|
@@ -122,7 +106,7 @@ func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byt
|
|
|
}
|
|
|
|
|
|
func (t *handshakeTransport) getSessionID() []byte {
|
|
|
- return t.conn.getSessionID()
|
|
|
+ return t.sessionID
|
|
|
}
|
|
|
|
|
|
func (t *handshakeTransport) id() string {
|
|
|
@@ -183,9 +167,9 @@ func (t *handshakeTransport) readOnePacket() ([]byte, error) {
|
|
|
if p[0] != msgKexInit {
|
|
|
return p, nil
|
|
|
}
|
|
|
- err = t.enterKeyExchange(p)
|
|
|
|
|
|
t.mu.Lock()
|
|
|
+ err = t.enterKeyExchangeLocked(p)
|
|
|
if err != nil {
|
|
|
// drop connection
|
|
|
t.conn.Close()
|
|
|
@@ -211,25 +195,39 @@ func (t *handshakeTransport) readOnePacket() ([]byte, error) {
|
|
|
return []byte{msgNewKeys}, nil
|
|
|
}
|
|
|
|
|
|
+// keyChangeCategory describes whether a key exchange is the first on a
|
|
|
+// connection, or a subsequent one.
|
|
|
+type keyChangeCategory bool
|
|
|
+
|
|
|
+const (
|
|
|
+ firstKeyExchange keyChangeCategory = true
|
|
|
+ subsequentKeyExchange keyChangeCategory = false
|
|
|
+)
|
|
|
+
|
|
|
// sendKexInit sends a key change message, and returns the message
|
|
|
// that was sent. After initiating the key change, all writes will be
|
|
|
// blocked until the change is done, and a failed key change will
|
|
|
// close the underlying transport. This function is safe for
|
|
|
// concurrent use by multiple goroutines.
|
|
|
-func (t *handshakeTransport) sendKexInit() (*kexInitMsg, []byte, error) {
|
|
|
+func (t *handshakeTransport) sendKexInit(isFirst keyChangeCategory) (*kexInitMsg, []byte, error) {
|
|
|
t.mu.Lock()
|
|
|
defer t.mu.Unlock()
|
|
|
- return t.sendKexInitLocked()
|
|
|
+ return t.sendKexInitLocked(isFirst)
|
|
|
+}
|
|
|
+
|
|
|
+func (t *handshakeTransport) requestInitialKeyChange() error {
|
|
|
+ _, _, err := t.sendKexInit(firstKeyExchange)
|
|
|
+ return err
|
|
|
}
|
|
|
|
|
|
func (t *handshakeTransport) requestKeyChange() error {
|
|
|
- _, _, err := t.sendKexInit()
|
|
|
+ _, _, err := t.sendKexInit(subsequentKeyExchange)
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
// sendKexInitLocked sends a key change message. t.mu must be locked
|
|
|
// while this happens.
|
|
|
-func (t *handshakeTransport) sendKexInitLocked() (*kexInitMsg, []byte, error) {
|
|
|
+func (t *handshakeTransport) sendKexInitLocked(isFirst keyChangeCategory) (*kexInitMsg, []byte, error) {
|
|
|
// kexInits may be sent either in response to the other side,
|
|
|
// or because our side wants to initiate a key change, so we
|
|
|
// may have already sent a kexInit. In that case, don't send a
|
|
|
@@ -237,6 +235,14 @@ func (t *handshakeTransport) sendKexInitLocked() (*kexInitMsg, []byte, error) {
|
|
|
if t.sentInitMsg != nil {
|
|
|
return t.sentInitMsg, t.sentInitPacket, nil
|
|
|
}
|
|
|
+
|
|
|
+ // If this is the initial key change, but we already have a sessionID,
|
|
|
+ // then do nothing because the key exchange has already completed
|
|
|
+ // asynchronously.
|
|
|
+ if isFirst && t.sessionID != nil {
|
|
|
+ return nil, nil, nil
|
|
|
+ }
|
|
|
+
|
|
|
msg := &kexInitMsg{
|
|
|
KexAlgos: t.config.KeyExchanges,
|
|
|
CiphersClientServer: t.config.Ciphers,
|
|
|
@@ -276,7 +282,7 @@ func (t *handshakeTransport) writePacket(p []byte) error {
|
|
|
defer t.mu.Unlock()
|
|
|
|
|
|
if t.writtenSinceKex > t.config.RekeyThreshold {
|
|
|
- t.sendKexInitLocked()
|
|
|
+ t.sendKexInitLocked(subsequentKeyExchange)
|
|
|
}
|
|
|
for t.sentInitMsg != nil && t.writeError == nil {
|
|
|
t.cond.Wait()
|
|
|
@@ -300,12 +306,12 @@ func (t *handshakeTransport) Close() error {
|
|
|
return t.conn.Close()
|
|
|
}
|
|
|
|
|
|
-// enterKeyExchange runs the key exchange.
|
|
|
-func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
|
|
|
+// enterKeyExchange runs the key exchange. t.mu must be held while running this.
|
|
|
+func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) error {
|
|
|
if debugHandshake {
|
|
|
log.Printf("%s entered key exchange", t.id())
|
|
|
}
|
|
|
- myInit, myInitPacket, err := t.sendKexInit()
|
|
|
+ myInit, myInitPacket, err := t.sendKexInitLocked(subsequentKeyExchange)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
@@ -362,6 +368,11 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
+ if t.sessionID == nil {
|
|
|
+ t.sessionID = result.H
|
|
|
+ result.SessionID = result.H
|
|
|
+ }
|
|
|
+
|
|
|
t.conn.prepareKeyChange(algs, result)
|
|
|
if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
|
|
|
return err
|
|
|
@@ -371,6 +382,7 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
|
|
|
} else if packet[0] != msgNewKeys {
|
|
|
return unexpectedMessageError(msgNewKeys, packet[0])
|
|
|
}
|
|
|
+
|
|
|
return nil
|
|
|
}
|
|
|
|