// Copyright 2011 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package ssh import ( "crypto" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/rsa" "encoding/binary" "errors" "fmt" "io" "math/big" "net" "sync" ) // clientVersion is the default identification string that the client will use. var clientVersion = []byte("SSH-2.0-Go") // ClientConn represents the client side of an SSH connection. type ClientConn struct { *transport config *ClientConfig chanList // channels associated with this connection forwardList // forwarded tcpip connections from the remote side globalRequest // Address as passed to the Dial function. dialAddress string serverVersion string } type globalRequest struct { sync.Mutex response chan interface{} } // Client returns a new SSH client connection using c as the underlying transport. func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) { return clientWithAddress(c, "", config) } func clientWithAddress(c net.Conn, addr string, config *ClientConfig) (*ClientConn, error) { conn := &ClientConn{ transport: newTransport(c, config.rand()), config: config, globalRequest: globalRequest{response: make(chan interface{}, 1)}, dialAddress: addr, } if err := conn.handshake(); err != nil { conn.Close() return nil, fmt.Errorf("handshake failed: %v", err) } go conn.mainLoop() return conn, nil } // handshake performs the client side key exchange. See RFC 4253 Section 7. func (c *ClientConn) handshake() error { var magics handshakeMagics var version []byte if len(c.config.ClientVersion) > 0 { version = []byte(c.config.ClientVersion) } else { version = clientVersion } magics.clientVersion = version version = append(version, '\r', '\n') if _, err := c.Write(version); err != nil { return err } if err := c.Flush(); err != nil { return err } // read remote server version version, err := readVersion(c) if err != nil { return err } magics.serverVersion = version c.serverVersion = string(version) clientKexInit := kexInitMsg{ KexAlgos: c.config.Crypto.kexes(), ServerHostKeyAlgos: supportedHostKeyAlgos, CiphersClientServer: c.config.Crypto.ciphers(), CiphersServerClient: c.config.Crypto.ciphers(), MACsClientServer: c.config.Crypto.macs(), MACsServerClient: c.config.Crypto.macs(), CompressionClientServer: supportedCompressions, CompressionServerClient: supportedCompressions, } kexInitPacket := marshal(msgKexInit, clientKexInit) magics.clientKexInit = kexInitPacket if err := c.writePacket(kexInitPacket); err != nil { return err } packet, err := c.readPacket() if err != nil { return err } magics.serverKexInit = packet var serverKexInit kexInitMsg if err = unmarshal(&serverKexInit, packet, msgKexInit); err != nil { return err } kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(c.transport, &clientKexInit, &serverKexInit) if !ok { return errors.New("ssh: no common algorithms") } if serverKexInit.FirstKexFollows && kexAlgo != serverKexInit.KexAlgos[0] { // The server sent a Kex message for the wrong algorithm, // which we have to ignore. if _, err := c.readPacket(); err != nil { return err } } var result *kexResult switch kexAlgo { case kexAlgoECDH256: result, err = c.kexECDH(elliptic.P256(), &magics, hostKeyAlgo) case kexAlgoECDH384: result, err = c.kexECDH(elliptic.P384(), &magics, hostKeyAlgo) case kexAlgoECDH521: result, err = c.kexECDH(elliptic.P521(), &magics, hostKeyAlgo) case kexAlgoDH14SHA1: dhGroup14Once.Do(initDHGroup14) result, err = c.kexDH(crypto.SHA1, dhGroup14, &magics, hostKeyAlgo) case kexAlgoDH1SHA1: dhGroup1Once.Do(initDHGroup1) result, err = c.kexDH(crypto.SHA1, dhGroup1, &magics, hostKeyAlgo) default: err = fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo) } if err != nil { return err } err = verifyHostKeySignature(hostKeyAlgo, result.HostKey, result.H, result.Signature) if err != nil { return err } if checker := c.config.HostKeyChecker; checker != nil { err = checker.Check(c.dialAddress, c.RemoteAddr(), hostKeyAlgo, result.HostKey) if err != nil { return err } } if err = c.writePacket([]byte{msgNewKeys}); err != nil { return err } if err = c.transport.writer.setupKeys(clientKeys, result.K, result.H, result.H, result.Hash); err != nil { return err } if packet, err = c.readPacket(); err != nil { return err } if packet[0] != msgNewKeys { return UnexpectedMessageError{msgNewKeys, packet[0]} } if err := c.transport.reader.setupKeys(serverKeys, result.K, result.H, result.H, result.Hash); err != nil { return err } return c.authenticate(result.H) } // kexECDH performs Elliptic Curve Diffie-Hellman key exchange as // described in RFC 5656, section 4. func (c *ClientConn) kexECDH(curve elliptic.Curve, magics *handshakeMagics, hostKeyAlgo string) (*kexResult, error) { ephKey, err := ecdsa.GenerateKey(curve, c.config.rand()) if err != nil { return nil, err } kexInit := kexECDHInitMsg{ ClientPubKey: elliptic.Marshal(curve, ephKey.PublicKey.X, ephKey.PublicKey.Y), } serialized := marshal(msgKexECDHInit, kexInit) if err := c.writePacket(serialized); err != nil { return nil, err } packet, err := c.readPacket() if err != nil { return nil, err } var reply kexECDHReplyMsg if err = unmarshal(&reply, packet, msgKexECDHReply); err != nil { return nil, err } x, y := elliptic.Unmarshal(curve, reply.EphemeralPubKey) if x == nil { return nil, errors.New("ssh: elliptic.Unmarshal failure") } if !validateECPublicKey(curve, x, y) { return nil, errors.New("ssh: ephemeral server key not on curve") } // generate shared secret secret, _ := curve.ScalarMult(x, y, ephKey.D.Bytes()) hashFunc := ecHash(curve) h := hashFunc.New() writeString(h, magics.clientVersion) writeString(h, magics.serverVersion) writeString(h, magics.clientKexInit) writeString(h, magics.serverKexInit) writeString(h, reply.HostKey) writeString(h, kexInit.ClientPubKey) writeString(h, reply.EphemeralPubKey) K := make([]byte, intLength(secret)) marshalInt(K, secret) h.Write(K) return &kexResult{ H: h.Sum(nil), K: K, HostKey: reply.HostKey, Signature: reply.Signature, Hash: hashFunc, }, nil } // Verify the host key obtained in the key exchange. func verifyHostKeySignature(hostKeyAlgo string, hostKeyBytes []byte, data []byte, signature []byte) error { hostKey, rest, ok := ParsePublicKey(hostKeyBytes) if len(rest) > 0 || !ok { return errors.New("ssh: could not parse hostkey") } // Select hash function to match the hostkey algorithm, as per // RFC 4253, section 6.1 (for RSA/DSS) and RFC 5656, section // 6.2.1 (for ECDSA). var hashFunc crypto.Hash switch hostKeyAlgo { case KeyAlgoRSA: hashFunc = crypto.SHA1 case KeyAlgoDSA: hashFunc = crypto.SHA1 case KeyAlgoECDSA256: hashFunc = crypto.SHA256 case KeyAlgoECDSA384: hashFunc = crypto.SHA384 case KeyAlgoECDSA521: hashFunc = crypto.SHA512 default: return errors.New("ssh: unknown key algorithm: " + hostKeyAlgo) } signed := hashFunc.New() signed.Write(data) digest := signed.Sum(nil) sig, rest, ok := parseSignatureBody(signature) if len(rest) > 0 || !ok { return errors.New("ssh: signature parse error") } if sig.Format != hostKeyAlgo { return fmt.Errorf("ssh: unexpected signature type %q", sig.Format) } return verifySignature(digest, sig, hostKey) } func verifySignature(hash []byte, sig *signature, key interface{}) error { switch pubKey := key.(type) { case *rsa.PublicKey: return verifyRSASignature(hash, sig, pubKey) } return fmt.Errorf("ssh: unknown key type %T", key) } func verifyRSASignature(hash []byte, sig *signature, key *rsa.PublicKey) error { return rsa.VerifyPKCS1v15(key, crypto.SHA1, hash, sig.Blob) } // kexResult captures the outcome of a key exchange. type kexResult struct { // Session hash. See also RFC 4253, section 8. H []byte // Shared secret. See also RFC 4253, section 8. K []byte // Host key as hashed into H HostKey []byte // Signature of H Signature []byte // Hash function that was used. Hash crypto.Hash } // kexDH performs Diffie-Hellman key agreement on a ClientConn. func (c *ClientConn) kexDH(hashFunc crypto.Hash, group *dhGroup, magics *handshakeMagics, hostKeyAlgo string) (*kexResult, error) { x, err := rand.Int(c.config.rand(), group.p) if err != nil { return nil, err } X := new(big.Int).Exp(group.g, x, group.p) kexDHInit := kexDHInitMsg{ X: X, } if err := c.writePacket(marshal(msgKexDHInit, kexDHInit)); err != nil { return nil, err } packet, err := c.readPacket() if err != nil { return nil, err } var kexDHReply kexDHReplyMsg if err = unmarshal(&kexDHReply, packet, msgKexDHReply); err != nil { return nil, err } kInt, err := group.diffieHellman(kexDHReply.Y, x) if err != nil { return nil, err } h := hashFunc.New() writeString(h, magics.clientVersion) writeString(h, magics.serverVersion) writeString(h, magics.clientKexInit) writeString(h, magics.serverKexInit) writeString(h, kexDHReply.HostKey) writeInt(h, X) writeInt(h, kexDHReply.Y) K := make([]byte, intLength(kInt)) marshalInt(K, kInt) h.Write(K) return &kexResult{ H: h.Sum(nil), K: K, HostKey: kexDHReply.HostKey, Signature: kexDHReply.Signature, Hash: hashFunc, }, nil } // mainLoop reads incoming messages and routes channel messages // to their respective ClientChans. func (c *ClientConn) mainLoop() { defer func() { c.Close() c.chanList.closeAll() c.forwardList.closeAll() }() for { packet, err := c.readPacket() if err != nil { break } // TODO(dfc) A note on blocking channel use. // The msg, data and dataExt channels of a clientChan can // cause this loop to block indefinately if the consumer does // not service them. switch packet[0] { case msgChannelData: if len(packet) < 9 { // malformed data packet return } remoteId := binary.BigEndian.Uint32(packet[1:5]) length := binary.BigEndian.Uint32(packet[5:9]) packet = packet[9:] if length != uint32(len(packet)) { return } ch, ok := c.getChan(remoteId) if !ok { return } ch.stdout.write(packet) case msgChannelExtendedData: if len(packet) < 13 { // malformed data packet return } remoteId := binary.BigEndian.Uint32(packet[1:5]) datatype := binary.BigEndian.Uint32(packet[5:9]) length := binary.BigEndian.Uint32(packet[9:13]) packet = packet[13:] if length != uint32(len(packet)) { return } // RFC 4254 5.2 defines data_type_code 1 to be data destined // for stderr on interactive sessions. Other data types are // silently discarded. if datatype == 1 { ch, ok := c.getChan(remoteId) if !ok { return } ch.stderr.write(packet) } default: decoded, err := decode(packet) if err != nil { if _, ok := err.(UnexpectedMessageError); ok { fmt.Printf("mainLoop: unexpected message: %v\n", err) continue } return } switch msg := decoded.(type) { case *channelOpenMsg: c.handleChanOpen(msg) case *channelOpenConfirmMsg: ch, ok := c.getChan(msg.PeersId) if !ok { return } ch.msg <- msg case *channelOpenFailureMsg: ch, ok := c.getChan(msg.PeersId) if !ok { return } ch.msg <- msg case *channelCloseMsg: ch, ok := c.getChan(msg.PeersId) if !ok { return } ch.Close() close(ch.msg) c.chanList.remove(msg.PeersId) case *channelEOFMsg: ch, ok := c.getChan(msg.PeersId) if !ok { return } ch.stdout.eof() // RFC 4254 is mute on how EOF affects dataExt messages but // it is logical to signal EOF at the same time. ch.stderr.eof() case *channelRequestSuccessMsg: ch, ok := c.getChan(msg.PeersId) if !ok { return } ch.msg <- msg case *channelRequestFailureMsg: ch, ok := c.getChan(msg.PeersId) if !ok { return } ch.msg <- msg case *channelRequestMsg: ch, ok := c.getChan(msg.PeersId) if !ok { return } ch.msg <- msg case *windowAdjustMsg: ch, ok := c.getChan(msg.PeersId) if !ok { return } if !ch.remoteWin.add(msg.AdditionalBytes) { // invalid window update return } case *globalRequestMsg: // This handles keepalive messages and matches // the behaviour of OpenSSH. if msg.WantReply { c.writePacket(marshal(msgRequestFailure, globalRequestFailureMsg{})) } case *globalRequestSuccessMsg, *globalRequestFailureMsg: c.globalRequest.response <- msg case *disconnectMsg: return default: fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg) } } } } // Handle channel open messages from the remote side. func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) { if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { c.sendConnectionFailed(msg.PeersId) } switch msg.ChanType { case "forwarded-tcpip": laddr, rest, ok := parseTCPAddr(msg.TypeSpecificData) if !ok { // invalid request c.sendConnectionFailed(msg.PeersId) return } l, ok := c.forwardList.lookup(*laddr) if !ok { // TODO: print on a more structured log. fmt.Println("could not find forward list entry for", laddr) // Section 7.2, implementations MUST reject suprious incoming // connections. c.sendConnectionFailed(msg.PeersId) return } raddr, rest, ok := parseTCPAddr(rest) if !ok { // invalid request c.sendConnectionFailed(msg.PeersId) return } ch := c.newChan(c.transport) ch.remoteId = msg.PeersId ch.remoteWin.add(msg.PeersWindow) ch.maxPacket = msg.MaxPacketSize m := channelOpenConfirmMsg{ PeersId: ch.remoteId, MyId: ch.localId, MyWindow: 1 << 14, // As per RFC 4253 6.1, 32k is also the minimum. MaxPacketSize: 1 << 15, } c.writePacket(marshal(msgChannelOpenConfirm, m)) l <- forward{ch, raddr} default: // unknown channel type m := channelOpenFailureMsg{ PeersId: msg.PeersId, Reason: UnknownChannelType, Message: fmt.Sprintf("unknown channel type: %v", msg.ChanType), Language: "en_US.UTF-8", } c.writePacket(marshal(msgChannelOpenFailure, m)) } } // sendGlobalRequest sends a global request message as specified // in RFC4254 section 4. To correctly synchronise messages, a lock // is held internally until a response is returned. func (c *ClientConn) sendGlobalRequest(m interface{}) (*globalRequestSuccessMsg, error) { c.globalRequest.Lock() defer c.globalRequest.Unlock() if err := c.writePacket(marshal(msgGlobalRequest, m)); err != nil { return nil, err } r := <-c.globalRequest.response if r, ok := r.(*globalRequestSuccessMsg); ok { return r, nil } return nil, errors.New("request failed") } // sendConnectionFailed rejects an incoming channel identified // by remoteId. func (c *ClientConn) sendConnectionFailed(remoteId uint32) error { m := channelOpenFailureMsg{ PeersId: remoteId, Reason: ConnectionFailed, Message: "invalid request", Language: "en_US.UTF-8", } return c.writePacket(marshal(msgChannelOpenFailure, m)) } // parseTCPAddr parses the originating address from the remote into a *net.TCPAddr. // RFC 4254 section 7.2 is mute on what to do if parsing fails but the forwardlist // requires a valid *net.TCPAddr to operate, so we enforce that restriction here. func parseTCPAddr(b []byte) (*net.TCPAddr, []byte, bool) { addr, b, ok := parseString(b) if !ok { return nil, b, false } port, b, ok := parseUint32(b) if !ok { return nil, b, false } ip := net.ParseIP(string(addr)) if ip == nil { return nil, b, false } return &net.TCPAddr{IP: ip, Port: int(port)}, b, true } // Dial connects to the given network address using net.Dial and // then initiates a SSH handshake, returning the resulting client connection. func Dial(network, addr string, config *ClientConfig) (*ClientConn, error) { conn, err := net.Dial(network, addr) if err != nil { return nil, err } return clientWithAddress(conn, addr, config) } // A ClientConfig structure is used to configure a ClientConn. After one has // been passed to an SSH function it must not be modified. type ClientConfig struct { // Rand provides the source of entropy for key exchange. If Rand is // nil, the cryptographic random reader in package crypto/rand will // be used. Rand io.Reader // The username to authenticate. User string // A slice of ClientAuth methods. Only the first instance // of a particular RFC 4252 method will be used during authentication. Auth []ClientAuth // HostKeyChecker, if not nil, is called during the cryptographic // handshake to validate the server's host key. A nil HostKeyChecker // implies that all host keys are accepted. HostKeyChecker HostKeyChecker // Cryptographic-related configuration. Crypto CryptoConfig // The identification string that will be used for the connection. // If empty, a reasonable default is used. ClientVersion string } func (c *ClientConfig) rand() io.Reader { if c.Rand == nil { return rand.Reader } return c.Rand } // Thread safe channel list. type chanList struct { // protects concurrent access to chans sync.Mutex // chans are indexed by the local id of the channel, clientChan.localId. // The PeersId value of messages received by ClientConn.mainLoop is // used to locate the right local clientChan in this slice. chans []*clientChan } // Allocate a new ClientChan with the next avail local id. func (c *chanList) newChan(t *transport) *clientChan { c.Lock() defer c.Unlock() for i := range c.chans { if c.chans[i] == nil { ch := newClientChan(t, uint32(i)) c.chans[i] = ch return ch } } i := len(c.chans) ch := newClientChan(t, uint32(i)) c.chans = append(c.chans, ch) return ch } func (c *chanList) getChan(id uint32) (*clientChan, bool) { c.Lock() defer c.Unlock() if id >= uint32(len(c.chans)) { return nil, false } return c.chans[id], true } func (c *chanList) remove(id uint32) { c.Lock() defer c.Unlock() c.chans[id] = nil } func (c *chanList) closeAll() { c.Lock() defer c.Unlock() for _, ch := range c.chans { if ch == nil { continue } ch.Close() close(ch.msg) } }