|
|
@@ -93,7 +93,7 @@ type cachedPubKey struct {
|
|
|
|
|
|
const maxCachedPubKeys = 16
|
|
|
|
|
|
-// A ServerConn represents an incomming connection.
|
|
|
+// A ServerConn represents an incoming connection.
|
|
|
type ServerConn struct {
|
|
|
*transport
|
|
|
config *ServerConfig
|
|
|
@@ -115,6 +115,14 @@ type ServerConn struct {
|
|
|
// It is empty if no authentication is used. It is populated before
|
|
|
// any authentication callback is called and not assigned to after that.
|
|
|
User string
|
|
|
+
|
|
|
+ // ClientVersion is the client's version, populated after
|
|
|
+ // Handshake is called. It should not be modified.
|
|
|
+ ClientVersion []byte
|
|
|
+
|
|
|
+ // Initial H used for the session ID. Once assigned this must not change
|
|
|
+ // even during subsequent key exchanges.
|
|
|
+ sessionId []byte
|
|
|
}
|
|
|
|
|
|
// Server returns a new SSH server connection
|
|
|
@@ -187,7 +195,7 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
|
|
|
return nil, nil, errors.New("ssh: internal error")
|
|
|
}
|
|
|
|
|
|
- serializedSig := serializeSignature(hostAlgoRSA, sig)
|
|
|
+ serializedSig := serializeSignature(hostKeyAlgo, sig)
|
|
|
|
|
|
kexDHReply := kexDHReplyMsg{
|
|
|
HostKey: serializedHostKey,
|
|
|
@@ -204,22 +212,47 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
|
|
|
var serverVersion = []byte("SSH-2.0-Go\r\n")
|
|
|
|
|
|
// Handshake performs an SSH transport and client authentication on the given ServerConn.
|
|
|
-func (s *ServerConn) Handshake() error {
|
|
|
- var magics handshakeMagics
|
|
|
- if _, err := s.Write(serverVersion); err != nil {
|
|
|
- return err
|
|
|
+func (s *ServerConn) Handshake() (err error) {
|
|
|
+ if _, err = s.Write(serverVersion); err != nil {
|
|
|
+ return
|
|
|
}
|
|
|
- if err := s.Flush(); err != nil {
|
|
|
- return err
|
|
|
+ if err = s.Flush(); err != nil {
|
|
|
+ return
|
|
|
}
|
|
|
- magics.serverVersion = serverVersion[:len(serverVersion)-2]
|
|
|
|
|
|
- version, err := readVersion(s)
|
|
|
+ s.ClientVersion, err = readVersion(s)
|
|
|
if err != nil {
|
|
|
- return err
|
|
|
+ return
|
|
|
+ }
|
|
|
+ if err = s.clientInitHandshake(nil, nil); err != nil {
|
|
|
+ return
|
|
|
}
|
|
|
- magics.clientVersion = version
|
|
|
|
|
|
+ var packet []byte
|
|
|
+ if packet, err = s.readPacket(); err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ var serviceRequest serviceRequestMsg
|
|
|
+ if err = unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ if serviceRequest.Service != serviceUserAuth {
|
|
|
+ return errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
|
|
|
+ }
|
|
|
+ serviceAccept := serviceAcceptMsg{
|
|
|
+ Service: serviceUserAuth,
|
|
|
+ }
|
|
|
+ if err = s.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if err = s.authenticate(s.sessionId); err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ return
|
|
|
+}
|
|
|
+
|
|
|
+func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexInitPacket []byte) (err error) {
|
|
|
serverKexInit := kexInitMsg{
|
|
|
KexAlgos: supportedKexAlgos,
|
|
|
ServerHostKeyAlgos: supportedHostKeyAlgos,
|
|
|
@@ -230,26 +263,23 @@ func (s *ServerConn) Handshake() error {
|
|
|
CompressionClientServer: supportedCompressions,
|
|
|
CompressionServerClient: supportedCompressions,
|
|
|
}
|
|
|
- kexInitPacket := marshal(msgKexInit, serverKexInit)
|
|
|
- magics.serverKexInit = kexInitPacket
|
|
|
+ serverKexInitPacket := marshal(msgKexInit, serverKexInit)
|
|
|
|
|
|
- if err := s.writePacket(kexInitPacket); err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
-
|
|
|
- packet, err := s.readPacket()
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
+ if err = s.writePacket(serverKexInitPacket); err != nil {
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
- magics.clientKexInit = packet
|
|
|
-
|
|
|
- var clientKexInit kexInitMsg
|
|
|
- if err = unmarshal(&clientKexInit, packet, msgKexInit); err != nil {
|
|
|
- return err
|
|
|
+ if clientKexInitPacket == nil {
|
|
|
+ clientKexInit = new(kexInitMsg)
|
|
|
+ if clientKexInitPacket, err = s.readPacket(); err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ if err = unmarshal(clientKexInit, clientKexInitPacket, msgKexInit); err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(s.transport, &clientKexInit, &serverKexInit)
|
|
|
+ kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(s.transport, clientKexInit, &serverKexInit)
|
|
|
if !ok {
|
|
|
return errors.New("ssh: no common algorithms")
|
|
|
}
|
|
|
@@ -257,11 +287,17 @@ func (s *ServerConn) Handshake() error {
|
|
|
if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] {
|
|
|
// The client sent a Kex message for the wrong algorithm,
|
|
|
// which we have to ignore.
|
|
|
- if _, err := s.readPacket(); err != nil {
|
|
|
- return err
|
|
|
+ if _, err = s.readPacket(); err != nil {
|
|
|
+ return
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ var magics handshakeMagics
|
|
|
+ magics.serverVersion = serverVersion[:len(serverVersion)-2]
|
|
|
+ magics.clientVersion = s.ClientVersion
|
|
|
+ magics.serverKexInit = marshal(msgKexInit, serverKexInit)
|
|
|
+ magics.clientKexInit = clientKexInitPacket
|
|
|
+
|
|
|
var H, K []byte
|
|
|
var hashFunc crypto.Hash
|
|
|
switch kexAlgo {
|
|
|
@@ -277,47 +313,33 @@ func (s *ServerConn) Handshake() error {
|
|
|
err = errors.New("ssh: unexpected key exchange algorithm " + kexAlgo)
|
|
|
}
|
|
|
if err != nil {
|
|
|
- return err
|
|
|
+ return
|
|
|
+ }
|
|
|
+ // sessionId must only be assigned during initial handshake.
|
|
|
+ if s.sessionId == nil {
|
|
|
+ s.sessionId = H
|
|
|
}
|
|
|
|
|
|
+ var packet []byte
|
|
|
+
|
|
|
if err = s.writePacket([]byte{msgNewKeys}); err != nil {
|
|
|
- return err
|
|
|
+ return
|
|
|
}
|
|
|
- if err = s.transport.writer.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
|
|
|
- return err
|
|
|
+ if err = s.transport.writer.setupKeys(serverKeys, K, H, s.sessionId, hashFunc); err != nil {
|
|
|
+ return
|
|
|
}
|
|
|
+
|
|
|
if packet, err = s.readPacket(); err != nil {
|
|
|
- return err
|
|
|
+ return
|
|
|
}
|
|
|
-
|
|
|
if packet[0] != msgNewKeys {
|
|
|
return UnexpectedMessageError{msgNewKeys, packet[0]}
|
|
|
}
|
|
|
- if err = s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc); err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- if packet, err = s.readPacket(); err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
-
|
|
|
- var serviceRequest serviceRequestMsg
|
|
|
- if err = unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- if serviceRequest.Service != serviceUserAuth {
|
|
|
- return errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
|
|
|
- }
|
|
|
- serviceAccept := serviceAcceptMsg{
|
|
|
- Service: serviceUserAuth,
|
|
|
- }
|
|
|
- if err = s.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
|
|
|
- return err
|
|
|
+ if err = s.transport.reader.setupKeys(clientKeys, K, H, s.sessionId, hashFunc); err != nil {
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
- if err = s.authenticate(H); err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- return nil
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
func isAcceptableAlgo(algo string) bool {
|
|
|
@@ -498,7 +520,7 @@ const defaultWindowSize = 32768
|
|
|
// Accept reads and processes messages on a ServerConn. It must be called
|
|
|
// in order to demultiplex messages to any resulting Channels.
|
|
|
func (s *ServerConn) Accept() (Channel, error) {
|
|
|
- // TODO(dfc) s.lock is not held here so visibility of s.err is not guarenteed.
|
|
|
+ // TODO(dfc) s.lock is not held here so visibility of s.err is not guaranteed.
|
|
|
if s.err != nil {
|
|
|
return nil, s.err
|
|
|
}
|
|
|
@@ -610,6 +632,14 @@ func (s *ServerConn) Accept() (Channel, error) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ case *kexInitMsg:
|
|
|
+ s.lock.Lock()
|
|
|
+ if err := s.clientInitHandshake(msg, packet); err != nil {
|
|
|
+ s.lock.Unlock()
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ s.lock.Unlock()
|
|
|
+
|
|
|
case UnexpectedMessageError:
|
|
|
return nil, msg
|
|
|
case *disconnectMsg:
|