Bläddra i källkod

x/crypto/ssh: make sure the initial key exchange happens once.

This is done by running the key exchange and setting the session ID
under mutex. If the first exchange encounters an already set session
ID, then do nothing.

This fixes a race condition:

On setting up the connection, both sides sent a kexInit to initiate
the first (mandatory) key exchange.  If one side was faster, the
faster side might have completed the key exchange, before the slow
side had a chance to send a kexInit.  The slow side would send a
kexInit which would trigger a second key exchange. The resulting
confirmation message (msgNewKeys) would confuse the authentication
loop.

This fix removes sessionID from the transport struct.

This fix also deletes the unused interface rekeyingTransport.

Fixes #15066

Change-Id: I7f303bce5d3214c9bdd58f52d21178a185871d90
Reviewed-on: https://go-review.googlesource.com/21606
Reviewed-by: Adam Langley <agl@golang.org>
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
Han-Wen Nienhuys 9 år sedan
förälder
incheckning
d68c3ecb62
6 ändrade filer med 51 tillägg och 58 borttagningar
  1. 1 1
      ssh/client.go
  2. 41 29
      ssh/handshake.go
  3. 6 6
      ssh/handshake_test.go
  4. 1 1
      ssh/kex.go
  5. 1 1
      ssh/server.go
  6. 1 20
      ssh/transport.go

+ 1 - 1
ssh/client.go

@@ -97,7 +97,7 @@ func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) e
 	c.transport = newClientTransport(
 	c.transport = newClientTransport(
 		newTransport(c.sshConn.conn, config.Rand, true /* is client */),
 		newTransport(c.sshConn.conn, config.Rand, true /* is client */),
 		c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr())
 		c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr())
-	if err := c.transport.requestKeyChange(); err != nil {
+	if err := c.transport.requestInitialKeyChange(); err != nil {
 		return err
 		return err
 	}
 	}
 
 

+ 41 - 29
ssh/handshake.go

@@ -29,25 +29,6 @@ type keyingTransport interface {
 	// direction will be effected if a msgNewKeys message is sent
 	// direction will be effected if a msgNewKeys message is sent
 	// or received.
 	// or received.
 	prepareKeyChange(*algorithms, *kexResult) error
 	prepareKeyChange(*algorithms, *kexResult) error
-
-	// getSessionID returns the session ID. prepareKeyChange must
-	// have been called once.
-	getSessionID() []byte
-}
-
-// rekeyingTransport is the interface of handshakeTransport that we
-// (internally) expose to ClientConn and ServerConn.
-type rekeyingTransport interface {
-	packetConn
-
-	// requestKeyChange asks the remote side to change keys. All
-	// writes are blocked until the key change succeeds, which is
-	// signaled by reading a msgNewKeys.
-	requestKeyChange() error
-
-	// getSessionID returns the session ID. This is only valid
-	// after the first key change has completed.
-	getSessionID() []byte
 }
 }
 
 
 // handshakeTransport implements rekeying on top of a keyingTransport
 // handshakeTransport implements rekeying on top of a keyingTransport
@@ -86,6 +67,9 @@ type handshakeTransport struct {
 	sentInitMsg     *kexInitMsg
 	sentInitMsg     *kexInitMsg
 	writtenSinceKex uint64
 	writtenSinceKex uint64
 	writeError      error
 	writeError      error
+
+	// The session ID or nil if first kex did not complete yet.
+	sessionID []byte
 }
 }
 
 
 func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
 func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
@@ -122,7 +106,7 @@ func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byt
 }
 }
 
 
 func (t *handshakeTransport) getSessionID() []byte {
 func (t *handshakeTransport) getSessionID() []byte {
-	return t.conn.getSessionID()
+	return t.sessionID
 }
 }
 
 
 func (t *handshakeTransport) id() string {
 func (t *handshakeTransport) id() string {
@@ -183,9 +167,9 @@ func (t *handshakeTransport) readOnePacket() ([]byte, error) {
 	if p[0] != msgKexInit {
 	if p[0] != msgKexInit {
 		return p, nil
 		return p, nil
 	}
 	}
-	err = t.enterKeyExchange(p)
 
 
 	t.mu.Lock()
 	t.mu.Lock()
+	err = t.enterKeyExchangeLocked(p)
 	if err != nil {
 	if err != nil {
 		// drop connection
 		// drop connection
 		t.conn.Close()
 		t.conn.Close()
@@ -211,25 +195,39 @@ func (t *handshakeTransport) readOnePacket() ([]byte, error) {
 	return []byte{msgNewKeys}, nil
 	return []byte{msgNewKeys}, nil
 }
 }
 
 
+// keyChangeCategory describes whether a key exchange is the first on a
+// connection, or a subsequent one.
+type keyChangeCategory bool
+
+const (
+	firstKeyExchange      keyChangeCategory = true
+	subsequentKeyExchange keyChangeCategory = false
+)
+
 // sendKexInit sends a key change message, and returns the message
 // sendKexInit sends a key change message, and returns the message
 // that was sent. After initiating the key change, all writes will be
 // that was sent. After initiating the key change, all writes will be
 // blocked until the change is done, and a failed key change will
 // blocked until the change is done, and a failed key change will
 // close the underlying transport. This function is safe for
 // close the underlying transport. This function is safe for
 // concurrent use by multiple goroutines.
 // concurrent use by multiple goroutines.
-func (t *handshakeTransport) sendKexInit() (*kexInitMsg, []byte, error) {
+func (t *handshakeTransport) sendKexInit(isFirst keyChangeCategory) (*kexInitMsg, []byte, error) {
 	t.mu.Lock()
 	t.mu.Lock()
 	defer t.mu.Unlock()
 	defer t.mu.Unlock()
-	return t.sendKexInitLocked()
+	return t.sendKexInitLocked(isFirst)
+}
+
+func (t *handshakeTransport) requestInitialKeyChange() error {
+	_, _, err := t.sendKexInit(firstKeyExchange)
+	return err
 }
 }
 
 
 func (t *handshakeTransport) requestKeyChange() error {
 func (t *handshakeTransport) requestKeyChange() error {
-	_, _, err := t.sendKexInit()
+	_, _, err := t.sendKexInit(subsequentKeyExchange)
 	return err
 	return err
 }
 }
 
 
 // sendKexInitLocked sends a key change message. t.mu must be locked
 // sendKexInitLocked sends a key change message. t.mu must be locked
 // while this happens.
 // while this happens.
-func (t *handshakeTransport) sendKexInitLocked() (*kexInitMsg, []byte, error) {
+func (t *handshakeTransport) sendKexInitLocked(isFirst keyChangeCategory) (*kexInitMsg, []byte, error) {
 	// kexInits may be sent either in response to the other side,
 	// kexInits may be sent either in response to the other side,
 	// or because our side wants to initiate a key change, so we
 	// or because our side wants to initiate a key change, so we
 	// may have already sent a kexInit. In that case, don't send a
 	// may have already sent a kexInit. In that case, don't send a
@@ -237,6 +235,14 @@ func (t *handshakeTransport) sendKexInitLocked() (*kexInitMsg, []byte, error) {
 	if t.sentInitMsg != nil {
 	if t.sentInitMsg != nil {
 		return t.sentInitMsg, t.sentInitPacket, nil
 		return t.sentInitMsg, t.sentInitPacket, nil
 	}
 	}
+
+	// If this is the initial key change, but we already have a sessionID,
+	// then do nothing because the key exchange has already completed
+	// asynchronously.
+	if isFirst && t.sessionID != nil {
+		return nil, nil, nil
+	}
+
 	msg := &kexInitMsg{
 	msg := &kexInitMsg{
 		KexAlgos:                t.config.KeyExchanges,
 		KexAlgos:                t.config.KeyExchanges,
 		CiphersClientServer:     t.config.Ciphers,
 		CiphersClientServer:     t.config.Ciphers,
@@ -276,7 +282,7 @@ func (t *handshakeTransport) writePacket(p []byte) error {
 	defer t.mu.Unlock()
 	defer t.mu.Unlock()
 
 
 	if t.writtenSinceKex > t.config.RekeyThreshold {
 	if t.writtenSinceKex > t.config.RekeyThreshold {
-		t.sendKexInitLocked()
+		t.sendKexInitLocked(subsequentKeyExchange)
 	}
 	}
 	for t.sentInitMsg != nil && t.writeError == nil {
 	for t.sentInitMsg != nil && t.writeError == nil {
 		t.cond.Wait()
 		t.cond.Wait()
@@ -300,12 +306,12 @@ func (t *handshakeTransport) Close() error {
 	return t.conn.Close()
 	return t.conn.Close()
 }
 }
 
 
-// enterKeyExchange runs the key exchange.
-func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
+// enterKeyExchange runs the key exchange. t.mu must be held while running this.
+func (t *handshakeTransport) enterKeyExchangeLocked(otherInitPacket []byte) error {
 	if debugHandshake {
 	if debugHandshake {
 		log.Printf("%s entered key exchange", t.id())
 		log.Printf("%s entered key exchange", t.id())
 	}
 	}
-	myInit, myInitPacket, err := t.sendKexInit()
+	myInit, myInitPacket, err := t.sendKexInitLocked(subsequentKeyExchange)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -362,6 +368,11 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
 		return err
 		return err
 	}
 	}
 
 
+	if t.sessionID == nil {
+		t.sessionID = result.H
+		result.SessionID = result.H
+	}
+
 	t.conn.prepareKeyChange(algs, result)
 	t.conn.prepareKeyChange(algs, result)
 	if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
 	if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
 		return err
 		return err
@@ -371,6 +382,7 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
 	} else if packet[0] != msgNewKeys {
 	} else if packet[0] != msgNewKeys {
 		return unexpectedMessageError(msgNewKeys, packet[0])
 		return unexpectedMessageError(msgNewKeys, packet[0])
 	}
 	}
+
 	return nil
 	return nil
 }
 }
 
 

+ 6 - 6
ssh/handshake_test.go

@@ -104,7 +104,7 @@ func TestHandshakeBasic(t *testing.T) {
 			}
 			}
 			if i == 5 {
 			if i == 5 {
 				// halfway through, we request a key change.
 				// halfway through, we request a key change.
-				_, _, err := trC.sendKexInit()
+				_, _, err := trC.sendKexInit(subsequentKeyExchange)
 				if err != nil {
 				if err != nil {
 					t.Fatalf("sendKexInit: %v", err)
 					t.Fatalf("sendKexInit: %v", err)
 				}
 				}
@@ -161,7 +161,7 @@ func TestHandshakeError(t *testing.T) {
 	}
 	}
 
 
 	// Now request a key change.
 	// Now request a key change.
-	_, _, err = trC.sendKexInit()
+	_, _, err = trC.sendKexInit(subsequentKeyExchange)
 	if err != nil {
 	if err != nil {
 		t.Errorf("sendKexInit: %v", err)
 		t.Errorf("sendKexInit: %v", err)
 	}
 	}
@@ -202,7 +202,7 @@ func TestHandshakeTwice(t *testing.T) {
 	}
 	}
 
 
 	// Now request a key change.
 	// Now request a key change.
-	_, _, err = trC.sendKexInit()
+	_, _, err = trC.sendKexInit(subsequentKeyExchange)
 	if err != nil {
 	if err != nil {
 		t.Errorf("sendKexInit: %v", err)
 		t.Errorf("sendKexInit: %v", err)
 	}
 	}
@@ -215,7 +215,7 @@ func TestHandshakeTwice(t *testing.T) {
 	}
 	}
 
 
 	// 2nd key change.
 	// 2nd key change.
-	_, _, err = trC.sendKexInit()
+	_, _, err = trC.sendKexInit(subsequentKeyExchange)
 	if err != nil {
 	if err != nil {
 		t.Errorf("sendKexInit: %v", err)
 		t.Errorf("sendKexInit: %v", err)
 	}
 	}
@@ -430,7 +430,7 @@ func TestDisconnect(t *testing.T) {
 
 
 	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
 	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
 	errMsg := &disconnectMsg{
 	errMsg := &disconnectMsg{
-		Reason: 42,
+		Reason:  42,
 		Message: "such is life",
 		Message: "such is life",
 	}
 	}
 	trC.writePacket(Marshal(errMsg))
 	trC.writePacket(Marshal(errMsg))
@@ -441,7 +441,7 @@ func TestDisconnect(t *testing.T) {
 		t.Fatalf("readPacket 1: %v", err)
 		t.Fatalf("readPacket 1: %v", err)
 	}
 	}
 	if packet[0] != msgRequestSuccess {
 	if packet[0] != msgRequestSuccess {
-		t.Errorf("got packet %v, want packet type %d", packet,  msgRequestSuccess)
+		t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess)
 	}
 	}
 
 
 	_, err = trS.readPacket()
 	_, err = trS.readPacket()

+ 1 - 1
ssh/kex.go

@@ -46,7 +46,7 @@ type kexResult struct {
 	Hash crypto.Hash
 	Hash crypto.Hash
 
 
 	// The session ID, which is the first H computed. This is used
 	// The session ID, which is the first H computed. This is used
-	// to signal data inside transport.
+	// to derive key material inside the transport.
 	SessionID []byte
 	SessionID []byte
 }
 }
 
 

+ 1 - 1
ssh/server.go

@@ -188,7 +188,7 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error)
 	tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */)
 	tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */)
 	s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config)
 	s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config)
 
 
-	if err := s.transport.requestKeyChange(); err != nil {
+	if err := s.transport.requestInitialKeyChange(); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 

+ 1 - 20
ssh/transport.go

@@ -39,19 +39,6 @@ type transport struct {
 	rand      io.Reader
 	rand      io.Reader
 
 
 	io.Closer
 	io.Closer
-
-	// Initial H used for the session ID. Once assigned this does
-	// not change, even during subsequent key exchanges.
-	sessionID []byte
-}
-
-// getSessionID returns the ID of the SSH connection. The return value
-// should not be modified.
-func (t *transport) getSessionID() []byte {
-	if t.sessionID == nil {
-		panic("session ID not set yet")
-	}
-	return t.sessionID
 }
 }
 
 
 // packetCipher represents a combination of SSH encryption/MAC
 // packetCipher represents a combination of SSH encryption/MAC
@@ -81,12 +68,6 @@ type connectionState struct {
 // both directions are triggered by reading and writing a msgNewKey packet
 // both directions are triggered by reading and writing a msgNewKey packet
 // respectively.
 // respectively.
 func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
 func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
-	if t.sessionID == nil {
-		t.sessionID = kexResult.H
-	}
-
-	kexResult.SessionID = t.sessionID
-
 	if ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult); err != nil {
 	if ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult); err != nil {
 		return err
 		return err
 	} else {
 	} else {
@@ -119,7 +100,7 @@ func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
 		case msgNewKeys:
 		case msgNewKeys:
 			select {
 			select {
 			case cipher := <-s.pendingKeyChange:
 			case cipher := <-s.pendingKeyChange:
-			s.packetCipher = cipher
+				s.packetCipher = cipher
 			default:
 			default:
 				return nil, errors.New("ssh: got bogus newkeys message.")
 				return nil, errors.New("ssh: got bogus newkeys message.")
 			}
 			}