| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630 |
- // Copyright 2011 The Go Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package ssh
- import (
- "crypto"
- "crypto/rand"
- "encoding/binary"
- "errors"
- "fmt"
- "io"
- "math/big"
- "net"
- "sync"
- )
- // clientVersion is the fixed identification string that the client will use.
- var clientVersion = []byte("SSH-2.0-Go\r\n")
- // ClientConn represents the client side of an SSH connection.
- type ClientConn struct {
- *transport
- config *ClientConfig
- chanList // channels associated with this connection
- forwardList // forwarded tcpip connections from the remote side
- globalRequest
- }
- type globalRequest struct {
- sync.Mutex
- response chan interface{}
- }
- // Client returns a new SSH client connection using c as the underlying transport.
- func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) {
- conn := &ClientConn{
- transport: newTransport(c, config.rand()),
- config: config,
- globalRequest: globalRequest{response: make(chan interface{}, 1)},
- }
- if err := conn.handshake(); err != nil {
- conn.Close()
- return nil, err
- }
- go conn.mainLoop()
- return conn, nil
- }
- // handshake performs the client side key exchange. See RFC 4253 Section 7.
- func (c *ClientConn) handshake() error {
- var magics handshakeMagics
- if _, err := c.Write(clientVersion); err != nil {
- return err
- }
- if err := c.Flush(); err != nil {
- return err
- }
- magics.clientVersion = clientVersion[:len(clientVersion)-2]
- // read remote server version
- version, err := readVersion(c)
- if err != nil {
- return err
- }
- magics.serverVersion = version
- clientKexInit := kexInitMsg{
- KexAlgos: supportedKexAlgos,
- ServerHostKeyAlgos: supportedHostKeyAlgos,
- CiphersClientServer: c.config.Crypto.ciphers(),
- CiphersServerClient: c.config.Crypto.ciphers(),
- MACsClientServer: c.config.Crypto.macs(),
- MACsServerClient: c.config.Crypto.macs(),
- CompressionClientServer: supportedCompressions,
- CompressionServerClient: supportedCompressions,
- }
- kexInitPacket := marshal(msgKexInit, clientKexInit)
- magics.clientKexInit = kexInitPacket
- if err := c.writePacket(kexInitPacket); err != nil {
- return err
- }
- packet, err := c.readPacket()
- if err != nil {
- return err
- }
- magics.serverKexInit = packet
- var serverKexInit kexInitMsg
- if err = unmarshal(&serverKexInit, packet, msgKexInit); err != nil {
- return err
- }
- kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(c.transport, &clientKexInit, &serverKexInit)
- if !ok {
- return errors.New("ssh: no common algorithms")
- }
- if serverKexInit.FirstKexFollows && kexAlgo != serverKexInit.KexAlgos[0] {
- // The server sent a Kex message for the wrong algorithm,
- // which we have to ignore.
- if _, err := c.readPacket(); err != nil {
- return err
- }
- }
- var H, K []byte
- var hashFunc crypto.Hash
- switch kexAlgo {
- case kexAlgoDH14SHA1:
- hashFunc = crypto.SHA1
- dhGroup14Once.Do(initDHGroup14)
- H, K, err = c.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo)
- case keyAlgoDH1SHA1:
- hashFunc = crypto.SHA1
- dhGroup1Once.Do(initDHGroup1)
- H, K, err = c.kexDH(dhGroup1, hashFunc, &magics, hostKeyAlgo)
- default:
- err = fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
- }
- if err != nil {
- return err
- }
- if err = c.writePacket([]byte{msgNewKeys}); err != nil {
- return err
- }
- if err = c.transport.writer.setupKeys(clientKeys, K, H, H, hashFunc); err != nil {
- return err
- }
- if packet, err = c.readPacket(); err != nil {
- return err
- }
- if packet[0] != msgNewKeys {
- return UnexpectedMessageError{msgNewKeys, packet[0]}
- }
- if err := c.transport.reader.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
- return err
- }
- return c.authenticate(H)
- }
- // kexDH performs Diffie-Hellman key agreement on a ClientConn. The
- // returned values are given the same names as in RFC 4253, section 8.
- func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) ([]byte, []byte, error) {
- x, err := rand.Int(c.config.rand(), group.p)
- if err != nil {
- return nil, nil, err
- }
- X := new(big.Int).Exp(group.g, x, group.p)
- kexDHInit := kexDHInitMsg{
- X: X,
- }
- if err := c.writePacket(marshal(msgKexDHInit, kexDHInit)); err != nil {
- return nil, nil, err
- }
- packet, err := c.readPacket()
- if err != nil {
- return nil, nil, err
- }
- var kexDHReply kexDHReplyMsg
- if err = unmarshal(&kexDHReply, packet, msgKexDHReply); err != nil {
- return nil, nil, err
- }
- kInt, err := group.diffieHellman(kexDHReply.Y, x)
- if err != nil {
- return nil, nil, err
- }
- h := hashFunc.New()
- writeString(h, magics.clientVersion)
- writeString(h, magics.serverVersion)
- writeString(h, magics.clientKexInit)
- writeString(h, magics.serverKexInit)
- writeString(h, kexDHReply.HostKey)
- writeInt(h, X)
- writeInt(h, kexDHReply.Y)
- K := make([]byte, intLength(kInt))
- marshalInt(K, kInt)
- h.Write(K)
- H := h.Sum(nil)
- return H, K, nil
- }
- // mainLoop reads incoming messages and routes channel messages
- // to their respective ClientChans.
- func (c *ClientConn) mainLoop() {
- defer func() {
- c.Close()
- c.closeAll()
- }()
- for {
- packet, err := c.readPacket()
- if err != nil {
- break
- }
- // TODO(dfc) A note on blocking channel use.
- // The msg, data and dataExt channels of a clientChan can
- // cause this loop to block indefinately if the consumer does
- // not service them.
- switch packet[0] {
- case msgChannelData:
- if len(packet) < 9 {
- // malformed data packet
- return
- }
- remoteId := binary.BigEndian.Uint32(packet[1:5])
- length := binary.BigEndian.Uint32(packet[5:9])
- packet = packet[9:]
- if length != uint32(len(packet)) {
- return
- }
- ch, ok := c.getChan(remoteId)
- if !ok {
- return
- }
- ch.stdout.write(packet)
- case msgChannelExtendedData:
- if len(packet) < 13 {
- // malformed data packet
- return
- }
- remoteId := binary.BigEndian.Uint32(packet[1:5])
- datatype := binary.BigEndian.Uint32(packet[5:9])
- length := binary.BigEndian.Uint32(packet[9:13])
- packet = packet[13:]
- if length != uint32(len(packet)) {
- return
- }
- // RFC 4254 5.2 defines data_type_code 1 to be data destined
- // for stderr on interactive sessions. Other data types are
- // silently discarded.
- if datatype == 1 {
- ch, ok := c.getChan(remoteId)
- if !ok {
- return
- }
- ch.stderr.write(packet)
- }
- default:
- switch msg := decode(packet).(type) {
- case *channelOpenMsg:
- c.handleChanOpen(msg)
- case *channelOpenConfirmMsg:
- ch, ok := c.getChan(msg.PeersId)
- if !ok {
- return
- }
- ch.msg <- msg
- case *channelOpenFailureMsg:
- ch, ok := c.getChan(msg.PeersId)
- if !ok {
- return
- }
- ch.msg <- msg
- case *channelCloseMsg:
- ch, ok := c.getChan(msg.PeersId)
- if !ok {
- return
- }
- ch.theyClosed = true
- ch.stdout.eof()
- ch.stderr.eof()
- close(ch.msg)
- if !ch.weClosed {
- ch.weClosed = true
- ch.sendClose()
- }
- c.chanList.remove(msg.PeersId)
- case *channelEOFMsg:
- ch, ok := c.getChan(msg.PeersId)
- if !ok {
- return
- }
- ch.stdout.eof()
- // RFC 4254 is mute on how EOF affects dataExt messages but
- // it is logical to signal EOF at the same time.
- ch.stderr.eof()
- case *channelRequestSuccessMsg:
- ch, ok := c.getChan(msg.PeersId)
- if !ok {
- return
- }
- ch.msg <- msg
- case *channelRequestFailureMsg:
- ch, ok := c.getChan(msg.PeersId)
- if !ok {
- return
- }
- ch.msg <- msg
- case *channelRequestMsg:
- ch, ok := c.getChan(msg.PeersId)
- if !ok {
- return
- }
- ch.msg <- msg
- case *windowAdjustMsg:
- ch, ok := c.getChan(msg.PeersId)
- if !ok {
- return
- }
- if !ch.remoteWin.add(msg.AdditionalBytes) {
- // invalid window update
- return
- }
- case *globalRequestSuccessMsg, *globalRequestFailureMsg:
- c.globalRequest.response <- msg
- case *disconnectMsg:
- return
- default:
- fmt.Printf("mainLoop: unhandled message %T: %v\n", msg, msg)
- }
- }
- }
- }
- // Handle channel open messages from the remote side.
- func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
- switch msg.ChanType {
- case "forwarded-tcpip":
- laddr, rest, ok := parseTCPAddr(msg.TypeSpecificData)
- if !ok {
- // invalid request
- c.sendConnectionFailed(msg.PeersId)
- return
- }
- l, ok := c.forwardList.lookup(laddr)
- if !ok {
- fmt.Println("could not find forward list entry for", laddr)
- // Section 7.2, implementations MUST reject suprious incoming
- // connections.
- c.sendConnectionFailed(msg.PeersId)
- return
- }
- raddr, rest, ok := parseTCPAddr(rest)
- if !ok {
- // invalid request
- c.sendConnectionFailed(msg.PeersId)
- return
- }
- ch := c.newChan(c.transport)
- ch.remoteId = msg.PeersId
- ch.remoteWin.add(msg.PeersWindow)
- m := channelOpenConfirmMsg{
- PeersId: ch.remoteId,
- MyId: ch.localId,
- MyWindow: 1 << 14,
- MaxPacketSize: 1 << 15, // RFC 4253 6.1
- }
- c.writePacket(marshal(msgChannelOpenConfirm, m))
- l <- forward{ch, raddr}
- default:
- // unknown channel type
- m := channelOpenFailureMsg{
- PeersId: msg.PeersId,
- Reason: UnknownChannelType,
- Message: fmt.Sprintf("unknown channel type: %v", msg.ChanType),
- Language: "en_US.UTF-8",
- }
- c.writePacket(marshal(msgChannelOpenFailure, m))
- }
- }
- // sendGlobalRequest sends a global request message as specified
- // in RFC4254 section 4. To correctly synchronise messages, a lock
- // is held internally until a response is returned.
- func (c *ClientConn) sendGlobalRequest(m interface{}) (*globalRequestSuccessMsg, error) {
- c.globalRequest.Lock()
- defer c.globalRequest.Unlock()
- if err := c.writePacket(marshal(msgGlobalRequest, m)); err != nil {
- return nil, err
- }
- r := <-c.globalRequest.response
- if r, ok := r.(*globalRequestSuccessMsg); ok {
- return r, nil
- }
- return nil, errors.New("request failed")
- }
- // sendConnectionFailed rejects an incoming channel identified
- // by remoteId.
- func (c *ClientConn) sendConnectionFailed(remoteId uint32) error {
- m := channelOpenFailureMsg{
- PeersId: remoteId,
- Reason: ConnectionFailed,
- Message: "invalid request",
- Language: "en_US.UTF-8",
- }
- return c.writePacket(marshal(msgChannelOpenFailure, m))
- }
- // parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.
- // RFC 4254 section 7.2 is mute on what to do if parsing fails but the forwardlist
- // requires a valid *net.TCPAddr to operate, so we enforce that restriction here.
- func parseTCPAddr(b []byte) (*net.TCPAddr, []byte, bool) {
- addr, b, ok := parseString(b)
- if !ok {
- return nil, b, false
- }
- port, b, ok := parseUint32(b)
- if !ok {
- return nil, b, false
- }
- ip := net.ParseIP(string(addr))
- if ip == nil {
- return nil, b, false
- }
- return &net.TCPAddr{ip, int(port)}, b, true
- }
- // Dial connects to the given network address using net.Dial and
- // then initiates a SSH handshake, returning the resulting client connection.
- func Dial(network, addr string, config *ClientConfig) (*ClientConn, error) {
- conn, err := net.Dial(network, addr)
- if err != nil {
- return nil, err
- }
- return Client(conn, config)
- }
- // A ClientConfig structure is used to configure a ClientConn. After one has
- // been passed to an SSH function it must not be modified.
- type ClientConfig struct {
- // Rand provides the source of entropy for key exchange. If Rand is
- // nil, the cryptographic random reader in package crypto/rand will
- // be used.
- Rand io.Reader
- // The username to authenticate.
- User string
- // A slice of ClientAuth methods. Only the first instance
- // of a particular RFC 4252 method will be used during authentication.
- Auth []ClientAuth
- // Cryptographic-related configuration.
- Crypto CryptoConfig
- }
- func (c *ClientConfig) rand() io.Reader {
- if c.Rand == nil {
- return rand.Reader
- }
- return c.Rand
- }
- // A clientChan represents a single RFC 4254 channel that is multiplexed
- // over a single SSH connection.
- type clientChan struct {
- channel
- stdin *chanWriter
- stdout *chanReader
- stderr *chanReader
- msg chan interface{}
- }
- // newClientChan returns a partially constructed *clientChan
- // using the local id provided. To be usable clientChan.remoteId
- // needs to be assigned once known.
- func newClientChan(cc conn, id uint32) *clientChan {
- c := &clientChan{
- channel: channel{
- conn: cc,
- localId: id,
- remoteWin: window{Cond: newCond()},
- },
- msg: make(chan interface{}, 16),
- }
- c.stdin = &chanWriter{
- channel: &c.channel,
- }
- c.stdout = &chanReader{
- channel: &c.channel,
- buffer: newBuffer(),
- }
- c.stderr = &chanReader{
- channel: &c.channel,
- buffer: newBuffer(),
- }
- return c
- }
- // waitForChannelOpenResponse, if successful, fills out
- // the remoteId and records any initial window advertisement.
- func (c *clientChan) waitForChannelOpenResponse() error {
- switch msg := (<-c.msg).(type) {
- case *channelOpenConfirmMsg:
- // fixup remoteId field
- c.remoteId = msg.MyId
- c.remoteWin.add(msg.MyWindow)
- return nil
- case *channelOpenFailureMsg:
- return errors.New(safeString(msg.Message))
- }
- return errors.New("ssh: unexpected packet")
- }
- // Close closes the channel. This does not close the underlying connection.
- func (c *clientChan) Close() error {
- if !c.weClosed {
- c.weClosed = true
- return c.sendClose()
- }
- return nil
- }
- // Thread safe channel list.
- type chanList struct {
- // protects concurrent access to chans
- sync.Mutex
- // chans are indexed by the local id of the channel, clientChan.localId.
- // The PeersId value of messages received by ClientConn.mainLoop is
- // used to locate the right local clientChan in this slice.
- chans []*clientChan
- }
- // Allocate a new ClientChan with the next avail local id.
- func (c *chanList) newChan(t *transport) *clientChan {
- c.Lock()
- defer c.Unlock()
- for i := range c.chans {
- if c.chans[i] == nil {
- ch := newClientChan(t, uint32(i))
- c.chans[i] = ch
- return ch
- }
- }
- i := len(c.chans)
- ch := newClientChan(t, uint32(i))
- c.chans = append(c.chans, ch)
- return ch
- }
- func (c *chanList) getChan(id uint32) (*clientChan, bool) {
- c.Lock()
- defer c.Unlock()
- if id >= uint32(len(c.chans)) {
- return nil, false
- }
- return c.chans[id], true
- }
- func (c *chanList) remove(id uint32) {
- c.Lock()
- defer c.Unlock()
- c.chans[id] = nil
- }
- func (c *chanList) closeAll() {
- c.Lock()
- defer c.Unlock()
- for _, ch := range c.chans {
- if ch == nil {
- continue
- }
- ch.theyClosed = true
- ch.stdout.eof()
- ch.stderr.eof()
- close(ch.msg)
- }
- }
- // A chanWriter represents the stdin of a remote process.
- type chanWriter struct {
- *channel
- }
- // Write writes data to the remote process's standard input.
- func (w *chanWriter) Write(data []byte) (written int, err error) {
- for len(data) > 0 {
- // n cannot be larger than 2^31 as len(data) cannot
- // be larger than 2^31
- n := int(w.remoteWin.reserve(uint32(len(data))))
- remoteId := w.remoteId
- packet := []byte{
- msgChannelData,
- byte(remoteId >> 24), byte(remoteId >> 16), byte(remoteId >> 8), byte(remoteId),
- byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n),
- }
- if err = w.writePacket(append(packet, data[:n]...)); err != nil {
- break
- }
- data = data[n:]
- written += n
- }
- return
- }
- func min(a, b int) int {
- if a < b {
- return a
- }
- return b
- }
- func (w *chanWriter) Close() error {
- return w.sendEOF()
- }
- // A chanReader represents stdout or stderr of a remote process.
- type chanReader struct {
- *channel // the channel backing this reader
- *buffer
- }
- // Read reads data from the remote process's stdout or stderr.
- func (r *chanReader) Read(buf []byte) (int, error) {
- n, err := r.buffer.Read(buf)
- if err != nil {
- if err == io.EOF {
- return n, err
- }
- return 0, err
- }
- return n, r.sendWindowAdj(n)
- }
|