Browse Source

go.crypto/ssh: allow server to respond to client init key exchange.

A windows SSH client, PuTTY, by default re-keys after every 60 minutes
or 1G of data transfer.

R=dave, agl
CC=golang-dev
https://golang.org/cl/6301071
Daniel Theophanes 13 năm trước cách đây
mục cha
commit
f8bd48becc
1 tập tin đã thay đổi với 89 bổ sung59 xóa
  1. 89 59
      ssh/server.go

+ 89 - 59
ssh/server.go

@@ -93,7 +93,7 @@ type cachedPubKey struct {
 
 const maxCachedPubKeys = 16
 
-// A ServerConn represents an incomming connection.
+// A ServerConn represents an incoming connection.
 type ServerConn struct {
 	*transport
 	config *ServerConfig
@@ -115,6 +115,14 @@ type ServerConn struct {
 	// It is empty if no authentication is used.  It is populated before
 	// any authentication callback is called and not assigned to after that.
 	User string
+
+	// ClientVersion is the client's version, populated after
+	// Handshake is called. It should not be modified.
+	ClientVersion []byte
+
+	// Initial H used for the session ID. Once assigned this must not change
+	// even during subsequent key exchanges.
+	sessionId []byte
 }
 
 // Server returns a new SSH server connection
@@ -187,7 +195,7 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
 		return nil, nil, errors.New("ssh: internal error")
 	}
 
-	serializedSig := serializeSignature(hostAlgoRSA, sig)
+	serializedSig := serializeSignature(hostKeyAlgo, sig)
 
 	kexDHReply := kexDHReplyMsg{
 		HostKey:   serializedHostKey,
@@ -204,22 +212,47 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
 var serverVersion = []byte("SSH-2.0-Go\r\n")
 
 // Handshake performs an SSH transport and client authentication on the given ServerConn.
-func (s *ServerConn) Handshake() error {
-	var magics handshakeMagics
-	if _, err := s.Write(serverVersion); err != nil {
-		return err
+func (s *ServerConn) Handshake() (err error) {
+	if _, err = s.Write(serverVersion); err != nil {
+		return
 	}
-	if err := s.Flush(); err != nil {
-		return err
+	if err = s.Flush(); err != nil {
+		return
 	}
-	magics.serverVersion = serverVersion[:len(serverVersion)-2]
 
-	version, err := readVersion(s)
+	s.ClientVersion, err = readVersion(s)
 	if err != nil {
-		return err
+		return
+	}
+	if err = s.clientInitHandshake(nil, nil); err != nil {
+		return
 	}
-	magics.clientVersion = version
 
+	var packet []byte
+	if packet, err = s.readPacket(); err != nil {
+		return
+	}
+	var serviceRequest serviceRequestMsg
+	if err = unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil {
+		return
+	}
+	if serviceRequest.Service != serviceUserAuth {
+		return errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
+	}
+	serviceAccept := serviceAcceptMsg{
+		Service: serviceUserAuth,
+	}
+	if err = s.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
+		return
+	}
+
+	if err = s.authenticate(s.sessionId); err != nil {
+		return
+	}
+	return
+}
+
+func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexInitPacket []byte) (err error) {
 	serverKexInit := kexInitMsg{
 		KexAlgos:                supportedKexAlgos,
 		ServerHostKeyAlgos:      supportedHostKeyAlgos,
@@ -230,26 +263,23 @@ func (s *ServerConn) Handshake() error {
 		CompressionClientServer: supportedCompressions,
 		CompressionServerClient: supportedCompressions,
 	}
-	kexInitPacket := marshal(msgKexInit, serverKexInit)
-	magics.serverKexInit = kexInitPacket
+	serverKexInitPacket := marshal(msgKexInit, serverKexInit)
 
-	if err := s.writePacket(kexInitPacket); err != nil {
-		return err
-	}
-
-	packet, err := s.readPacket()
-	if err != nil {
-		return err
+	if err = s.writePacket(serverKexInitPacket); err != nil {
+		return
 	}
 
-	magics.clientKexInit = packet
-
-	var clientKexInit kexInitMsg
-	if err = unmarshal(&clientKexInit, packet, msgKexInit); err != nil {
-		return err
+	if clientKexInitPacket == nil {
+		clientKexInit = new(kexInitMsg)
+		if clientKexInitPacket, err = s.readPacket(); err != nil {
+			return
+		}
+		if err = unmarshal(clientKexInit, clientKexInitPacket, msgKexInit); err != nil {
+			return
+		}
 	}
 
-	kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(s.transport, &clientKexInit, &serverKexInit)
+	kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(s.transport, clientKexInit, &serverKexInit)
 	if !ok {
 		return errors.New("ssh: no common algorithms")
 	}
@@ -257,11 +287,17 @@ func (s *ServerConn) Handshake() error {
 	if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] {
 		// The client sent a Kex message for the wrong algorithm,
 		// which we have to ignore.
-		if _, err := s.readPacket(); err != nil {
-			return err
+		if _, err = s.readPacket(); err != nil {
+			return
 		}
 	}
 
+	var magics handshakeMagics
+	magics.serverVersion = serverVersion[:len(serverVersion)-2]
+	magics.clientVersion = s.ClientVersion
+	magics.serverKexInit = marshal(msgKexInit, serverKexInit)
+	magics.clientKexInit = clientKexInitPacket
+
 	var H, K []byte
 	var hashFunc crypto.Hash
 	switch kexAlgo {
@@ -277,47 +313,33 @@ func (s *ServerConn) Handshake() error {
 		err = errors.New("ssh: unexpected key exchange algorithm " + kexAlgo)
 	}
 	if err != nil {
-		return err
+		return
+	}
+	// sessionId must only be assigned during initial handshake.
+	if s.sessionId == nil {
+		s.sessionId = H
 	}
 
+	var packet []byte
+
 	if err = s.writePacket([]byte{msgNewKeys}); err != nil {
-		return err
+		return
 	}
-	if err = s.transport.writer.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
-		return err
+	if err = s.transport.writer.setupKeys(serverKeys, K, H, s.sessionId, hashFunc); err != nil {
+		return
 	}
+
 	if packet, err = s.readPacket(); err != nil {
-		return err
+		return
 	}
-
 	if packet[0] != msgNewKeys {
 		return UnexpectedMessageError{msgNewKeys, packet[0]}
 	}
-	if err = s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc); err != nil {
-		return err
-	}
-	if packet, err = s.readPacket(); err != nil {
-		return err
-	}
-
-	var serviceRequest serviceRequestMsg
-	if err = unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil {
-		return err
-	}
-	if serviceRequest.Service != serviceUserAuth {
-		return errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
-	}
-	serviceAccept := serviceAcceptMsg{
-		Service: serviceUserAuth,
-	}
-	if err = s.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
-		return err
+	if err = s.transport.reader.setupKeys(clientKeys, K, H, s.sessionId, hashFunc); err != nil {
+		return
 	}
 
-	if err = s.authenticate(H); err != nil {
-		return err
-	}
-	return nil
+	return
 }
 
 func isAcceptableAlgo(algo string) bool {
@@ -498,7 +520,7 @@ const defaultWindowSize = 32768
 // Accept reads and processes messages on a ServerConn. It must be called
 // in order to demultiplex messages to any resulting Channels.
 func (s *ServerConn) Accept() (Channel, error) {
-	// TODO(dfc) s.lock is not held here so visibility of s.err is not guarenteed.
+	// TODO(dfc) s.lock is not held here so visibility of s.err is not guaranteed.
 	if s.err != nil {
 		return nil, s.err
 	}
@@ -610,6 +632,14 @@ func (s *ServerConn) Accept() (Channel, error) {
 					}
 				}
 
+			case *kexInitMsg:
+				s.lock.Lock()
+				if err := s.clientInitHandshake(msg, packet); err != nil {
+					s.lock.Unlock()
+					return nil, err
+				}
+				s.lock.Unlock()
+
 			case UnexpectedMessageError:
 				return nil, msg
 			case *disconnectMsg: