瀏覽代碼

go.crypto/ssh: move interpretation of msgNewKeys into
transport.

Sending the msgNewKeys packet and setting up the key material
now happen under a lock, preventing races with concurrent
writers.

R=kardianos, agl, jpsugar, hanwenn
CC=golang-dev
https://golang.org/cl/14476043

Han-Wen Nienhuys 12 年之前
父節點
當前提交
4147256c9c
共有 5 個文件被更改,包括 138 次插入80 次删除
  1. 10 15
      ssh/client.go
  2. 23 11
      ssh/common.go
  3. 13 14
      ssh/kex.go
  4. 14 28
      ssh/server.go
  5. 78 12
      ssh/transport.go

+ 10 - 15
ssh/client.go

@@ -43,7 +43,7 @@ func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) {
 
 func clientWithAddress(c net.Conn, addr string, config *ClientConfig) (*ClientConn, error) {
 	conn := &ClientConn{
-		transport:     newTransport(c, config.rand()),
+		transport:     newTransport(c, config.rand(), true /* is client */),
 		config:        config,
 		globalRequest: globalRequest{response: make(chan interface{}, 1)},
 		dialAddress:   addr,
@@ -104,12 +104,12 @@ func (c *ClientConn) handshake() error {
 		return err
 	}
 
-	kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(c.transport, &clientKexInit, &serverKexInit)
-	if !ok {
+	algs := findAgreedAlgorithms(&clientKexInit, &serverKexInit)
+	if algs == nil {
 		return errors.New("ssh: no common algorithms")
 	}
 
-	if serverKexInit.FirstKexFollows && kexAlgo != serverKexInit.KexAlgos[0] {
+	if serverKexInit.FirstKexFollows && algs.kex != serverKexInit.KexAlgos[0] {
 		// The server sent a Kex message for the wrong algorithm,
 		// which we have to ignore.
 		if _, err := c.readPacket(); err != nil {
@@ -117,9 +117,9 @@ func (c *ClientConn) handshake() error {
 		}
 	}
 
-	kex, ok := kexAlgoMap[kexAlgo]
+	kex, ok := kexAlgoMap[algs.kex]
 	if !ok {
-		return fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
+		return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex)
 	}
 
 	magics := handshakeMagics{
@@ -133,23 +133,21 @@ func (c *ClientConn) handshake() error {
 		return err
 	}
 
-	err = verifyHostKeySignature(hostKeyAlgo, result.HostKey, result.H, result.Signature)
+	err = verifyHostKeySignature(algs.hostKey, result.HostKey, result.H, result.Signature)
 	if err != nil {
 		return err
 	}
 
 	if checker := c.config.HostKeyChecker; checker != nil {
-		err = checker.Check(c.dialAddress, c.RemoteAddr(), hostKeyAlgo, result.HostKey)
+		err = checker.Check(c.dialAddress, c.RemoteAddr(), algs.hostKey, result.HostKey)
 		if err != nil {
 			return err
 		}
 	}
 
-	if err = c.writePacket([]byte{msgNewKeys}); err != nil {
-		return err
-	}
+	c.transport.prepareKeyChange(algs, result)
 
-	if err = c.transport.writer.setupKeys(clientKeys, result.K, result.H, result.H, kex.Hash()); err != nil {
+	if err = c.writePacket([]byte{msgNewKeys}); err != nil {
 		return err
 	}
 	if packet, err = c.readPacket(); err != nil {
@@ -158,9 +156,6 @@ func (c *ClientConn) handshake() error {
 	if packet[0] != msgNewKeys {
 		return UnexpectedMessageError{msgNewKeys, packet[0]}
 	}
-	if err := c.transport.reader.setupKeys(serverKeys, result.K, result.H, result.H, kex.Hash()); err != nil {
-		return err
-	}
 	return c.authenticate(result.H)
 }
 

+ 23 - 11
ssh/common.go

@@ -90,49 +90,61 @@ func findCommonCipher(clientCiphers []string, serverCiphers []string) (commonCip
 	return
 }
 
-func findAgreedAlgorithms(transport *transport, clientKexInit, serverKexInit *kexInitMsg) (kexAlgo, hostKeyAlgo string, ok bool) {
-	kexAlgo, ok = findCommonAlgorithm(clientKexInit.KexAlgos, serverKexInit.KexAlgos)
+type algorithms struct {
+	kex          string
+	hostKey      string
+	wCipher      string
+	rCipher      string
+	rMAC         string
+	wMAC         string
+	rCompression string
+	wCompression string
+}
+
+func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms) {
+	var ok bool
+	result := &algorithms{}
+	result.kex, ok = findCommonAlgorithm(clientKexInit.KexAlgos, serverKexInit.KexAlgos)
 	if !ok {
 		return
 	}
 
-	hostKeyAlgo, ok = findCommonAlgorithm(clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
+	result.hostKey, ok = findCommonAlgorithm(clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
 	if !ok {
 		return
 	}
 
-	transport.writer.cipherAlgo, ok = findCommonCipher(clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
+	result.wCipher, ok = findCommonCipher(clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
 	if !ok {
 		return
 	}
 
-	transport.reader.cipherAlgo, ok = findCommonCipher(clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
+	result.rCipher, ok = findCommonCipher(clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
 	if !ok {
 		return
 	}
 
-	transport.writer.macAlgo, ok = findCommonAlgorithm(clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
+	result.wMAC, ok = findCommonAlgorithm(clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
 	if !ok {
 		return
 	}
 
-	transport.reader.macAlgo, ok = findCommonAlgorithm(clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
+	result.rMAC, ok = findCommonAlgorithm(clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
 	if !ok {
 		return
 	}
 
-	transport.writer.compressionAlgo, ok = findCommonAlgorithm(clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
+	result.wCompression, ok = findCommonAlgorithm(clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
 	if !ok {
 		return
 	}
 
-	transport.reader.compressionAlgo, ok = findCommonAlgorithm(clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
+	result.rCompression, ok = findCommonAlgorithm(clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
 	if !ok {
 		return
 	}
 
-	ok = true
-	return
+	return result
 }
 
 // Cryptographic configuration common to both ServerConfig and ClientConfig.

+ 13 - 14
ssh/kex.go

@@ -35,6 +35,15 @@ type kexResult struct {
 
 	// Signature of H
 	Signature []byte
+
+	// A cryptographic hash function that matches the security
+	// level of the key exchange algorithm. It is used for
+	// calculating H, and for deriving keys from H and K.
+	Hash crypto.Hash
+
+	// The session ID, which is the first H computed. This is used
+	// to signal data inside transport.
+	SessionID []byte
 }
 
 // handshakeMagics contains data that is always included in the
@@ -60,12 +69,6 @@ type kexAlgorithm interface {
 	// Client runs the client-side key agreement. Caller is
 	// responsible for verifying the host key signature.
 	Client(p packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error)
-
-	// Hash returns a cryptographic hash function that matches the
-	// security level of the key exchange algorithm. It is used
-	// for calculating kexResult.H, and for deriving keys from
-	// data in kexResult.
-	Hash() crypto.Hash
 }
 
 // dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
@@ -73,10 +76,6 @@ type dhGroup struct {
 	g, p *big.Int
 }
 
-func (group *dhGroup) Hash() crypto.Hash {
-	return crypto.SHA1
-}
-
 func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) {
 	if theirPublic.Sign() <= 0 || theirPublic.Cmp(group.p) >= 0 {
 		return nil, errors.New("ssh: DH parameter out of bounds")
@@ -128,6 +127,7 @@ func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handsha
 		K:         K,
 		HostKey:   kexDHReply.HostKey,
 		Signature: kexDHReply.Signature,
+		Hash:      crypto.SHA1,
 	}, nil
 }
 
@@ -187,6 +187,7 @@ func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handsha
 		K:         K,
 		HostKey:   hostKeyBytes,
 		Signature: sig,
+		Hash:      crypto.SHA1,
 	}, nil
 }
 
@@ -243,6 +244,7 @@ func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (
 		K:         K,
 		HostKey:   reply.HostKey,
 		Signature: reply.Signature,
+		Hash:      ecHash(kex.curve),
 	}, nil
 }
 
@@ -354,13 +356,10 @@ func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, p
 		K:         K,
 		HostKey:   reply.HostKey,
 		Signature: sig,
+		Hash:      ecHash(kex.curve),
 	}, nil
 }
 
-func (kex *ecdh) Hash() crypto.Hash {
-	return ecHash(kex.curve)
-}
-
 var kexAlgoMap = map[string]kexAlgorithm{}
 
 func init() {

+ 14 - 28
ssh/server.go

@@ -121,17 +121,13 @@ type ServerConn struct {
 	// ClientVersion is the client's version, populated after
 	// Handshake is called. It should not be modified.
 	ClientVersion []byte
-
-	// Initial H used for the session ID. Once assigned this must not change
-	// even during subsequent key exchanges.
-	sessionId []byte
 }
 
 // Server returns a new SSH server connection
 // using c as the underlying transport.
 func Server(c net.Conn, config *ServerConfig) *ServerConn {
 	return &ServerConn{
-		transport: newTransport(c, config.rand()),
+		transport: newTransport(c, config.rand(), false /* not client */),
 		channels:  make(map[uint32]*serverChan),
 		config:    config,
 	}
@@ -186,7 +182,7 @@ func (s *ServerConn) Handshake() (err error) {
 		return
 	}
 
-	if err = s.authenticate(s.sessionId); err != nil {
+	if err = s.authenticate(s.transport.sessionID); err != nil {
 		return
 	}
 	return
@@ -222,11 +218,12 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
 		}
 	}
 
-	kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(s.transport, clientKexInit, &serverKexInit)
-	if !ok {
+	algs := findAgreedAlgorithms(clientKexInit, &serverKexInit)
+	if algs == nil {
 		return errors.New("ssh: no common algorithms")
 	}
-	if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] {
+
+	if clientKexInit.FirstKexFollows && algs.kex != clientKexInit.KexAlgos[0] {
 		// The client sent a Kex message for the wrong algorithm,
 		// which we have to ignore.
 		if _, err = s.readPacket(); err != nil {
@@ -236,14 +233,14 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
 
 	var hostKey Signer
 	for _, k := range s.config.hostKeys {
-		if hostKeyAlgo == k.PublicKey().PublicKeyAlgo() {
+		if algs.hostKey == k.PublicKey().PublicKeyAlgo() {
 			hostKey = k
 		}
 	}
 
-	kex, ok := kexAlgoMap[kexAlgo]
+	kex, ok := kexAlgoMap[algs.kex]
 	if !ok {
-		return fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
+		return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex)
 	}
 
 	magics := handshakeMagics{
@@ -257,29 +254,18 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
 		return err
 	}
 
-	// sessionId must only be assigned during initial handshake.
-	if s.sessionId == nil {
-		s.sessionId = result.H
+	if err = s.transport.prepareKeyChange(algs, result); err != nil {
+		return err
 	}
 
-	var packet []byte
-
 	if err = s.writePacket([]byte{msgNewKeys}); err != nil {
 		return
 	}
-	if err = s.transport.writer.setupKeys(serverKeys, result.K, result.H, s.sessionId, kex.Hash()); err != nil {
-		return
-	}
-
-	if packet, err = s.readPacket(); err != nil {
-		return
-	}
-	if packet[0] != msgNewKeys {
+	if packet, err := s.readPacket(); err != nil {
+		return err
+	} else if packet[0] != msgNewKeys {
 		return UnexpectedMessageError{msgNewKeys, packet[0]}
 	}
-	if err = s.transport.reader.setupKeys(clientKeys, result.K, result.H, s.sessionId, kex.Hash()); err != nil {
-		return
-	}
 
 	return
 }

+ 78 - 12
ssh/transport.go

@@ -6,7 +6,6 @@ package ssh
 
 import (
 	"bufio"
-	"crypto"
 	"crypto/cipher"
 	"crypto/subtle"
 	"encoding/binary"
@@ -48,6 +47,10 @@ type transport struct {
 	writer
 
 	net.Conn
+
+	// Initial H used for the session ID. Once assigned this does
+	// not change, even during subsequent key exchanges.
+	sessionID []byte
 }
 
 // reader represents the incoming connection state.
@@ -64,6 +67,28 @@ type writer struct {
 	common
 }
 
+// prepareKeyChange sets up key material for a keychange. The key changes in
+// both directions are triggered by reading and writing a msgNewKey packet
+// respectively.
+func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
+	t.writer.cipherAlgo = algs.wCipher
+	t.writer.macAlgo = algs.wMAC
+	t.writer.compressionAlgo = algs.wCompression
+
+	t.reader.cipherAlgo = algs.rCipher
+	t.reader.macAlgo = algs.rMAC
+	t.reader.compressionAlgo = algs.rCompression
+
+	if t.sessionID == nil {
+		t.sessionID = kexResult.H
+	}
+
+	kexResult.SessionID = t.sessionID
+	t.reader.pendingKeyChange <- kexResult
+	t.writer.pendingKeyChange <- kexResult
+	return nil
+}
+
 // common represents the cipher state needed to process messages in a single
 // direction.
 type common struct {
@@ -74,6 +99,9 @@ type common struct {
 	cipherAlgo      string
 	macAlgo         string
 	compressionAlgo string
+
+	dir              direction
+	pendingKeyChange chan *kexResult
 }
 
 // Read and decrypt a single packet from the remote peer.
@@ -125,7 +153,19 @@ func (r *reader) readPacket() ([]byte, error) {
 	}
 
 	r.seqNum++
-	return packet[:length-paddingLength-1], nil
+	packet = packet[:length-paddingLength-1]
+
+	if len(packet) > 0 && packet[0] == msgNewKeys {
+		select {
+		case k := <-r.pendingKeyChange:
+			if err := r.setupKeys(r.dir, k); err != nil {
+				return nil, err
+			}
+		default:
+			return nil, errors.New("ssh: got bogus newkeys message.")
+		}
+	}
+	return packet, nil
 }
 
 // Read and decrypt next packet discarding debug and noop messages.
@@ -138,6 +178,7 @@ func (t *transport) readPacket() ([]byte, error) {
 		if len(packet) == 0 {
 			return nil, errors.New("ssh: zero length packet")
 		}
+
 		if packet[0] != msgIgnore && packet[0] != msgDebug {
 			return packet, nil
 		}
@@ -147,6 +188,8 @@ func (t *transport) readPacket() ([]byte, error) {
 
 // Encrypt and send a packet of data to the remote peer.
 func (w *writer) writePacket(packet []byte) error {
+	changeKeys := len(packet) > 0 && packet[0] == msgNewKeys
+
 	if len(packet) > maxPacket {
 		return errors.New("ssh: packet too large")
 	}
@@ -209,26 +252,49 @@ func (w *writer) writePacket(packet []byte) error {
 	}
 
 	w.seqNum++
-	return w.Flush()
+	if err = w.Flush(); err != nil {
+		return err
+	}
+
+	if changeKeys {
+		select {
+		case k := <-w.pendingKeyChange:
+			err = w.setupKeys(w.dir, k)
+		default:
+			panic("ssh: no key material for msgNewKeys")
+		}
+	}
+	return err
 }
 
-func newTransport(conn net.Conn, rand io.Reader) *transport {
-	return &transport{
+func newTransport(conn net.Conn, rand io.Reader, isClient bool) *transport {
+	t := &transport{
 		reader: reader{
 			Reader: bufio.NewReader(conn),
 			common: common{
-				cipher: noneCipher{},
+				cipher:           noneCipher{},
+				pendingKeyChange: make(chan *kexResult, 1),
 			},
 		},
 		writer: writer{
 			Writer: bufio.NewWriter(conn),
 			rand:   rand,
 			common: common{
-				cipher: noneCipher{},
+				cipher:           noneCipher{},
+				pendingKeyChange: make(chan *kexResult, 1),
 			},
 		},
 		Conn: conn,
 	}
+	if isClient {
+		t.reader.dir = serverKeys
+		t.writer.dir = clientKeys
+	} else {
+		t.reader.dir = clientKeys
+		t.writer.dir = serverKeys
+	}
+
+	return t
 }
 
 type direction struct {
@@ -246,7 +312,7 @@ var (
 // setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
 // described in RFC 4253, section 6.4. direction should either be serverKeys
 // (to setup server->client keys) or clientKeys (for client->server keys).
-func (c *common) setupKeys(d direction, K, H, sessionId []byte, hashFunc crypto.Hash) error {
+func (c *common) setupKeys(d direction, r *kexResult) error {
 	cipherMode := cipherModes[c.cipherAlgo]
 	macMode := macModes[c.macAlgo]
 
@@ -254,10 +320,10 @@ func (c *common) setupKeys(d direction, K, H, sessionId []byte, hashFunc crypto.
 	key := make([]byte, cipherMode.keySize)
 	macKey := make([]byte, macMode.keySize)
 
-	h := hashFunc.New()
-	generateKeyMaterial(iv, d.ivTag, K, H, sessionId, h)
-	generateKeyMaterial(key, d.keyTag, K, H, sessionId, h)
-	generateKeyMaterial(macKey, d.macKeyTag, K, H, sessionId, h)
+	h := r.Hash.New()
+	generateKeyMaterial(iv, d.ivTag, r.K, r.H, r.SessionID, h)
+	generateKeyMaterial(key, d.keyTag, r.K, r.H, r.SessionID, h)
+	generateKeyMaterial(macKey, d.macKeyTag, r.K, r.H, r.SessionID, h)
 
 	c.mac = macMode.new(macKey)