handshake.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. // Copyright 2013 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "crypto/rand"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log"
  11. "net"
  12. "sync"
  13. )
  14. // debugHandshake, if set, prints messages sent and received. Key
  15. // exchange messages are printed as if DH were used, so the debug
  16. // messages are wrong when using ECDH.
  17. const debugHandshake = false
  18. // keyingTransport is a packet based transport that supports key
  19. // changes. It need not be thread-safe. It should pass through
  20. // msgNewKeys in both directions.
  21. type keyingTransport interface {
  22. packetConn
  23. // prepareKeyChange sets up a key change. The key change for a
  24. // direction will be effected if a msgNewKeys message is sent
  25. // or received.
  26. prepareKeyChange(*algorithms, *kexResult) error
  27. }
  28. // handshakeTransport implements rekeying on top of a keyingTransport
  29. // and offers a thread-safe writePacket() interface.
  30. type handshakeTransport struct {
  31. conn keyingTransport
  32. config *Config
  33. serverVersion []byte
  34. clientVersion []byte
  35. // hostKeys is non-empty if we are the server. In that case,
  36. // it contains all host keys that can be used to sign the
  37. // connection.
  38. hostKeys []Signer
  39. // hostKeyAlgorithms is non-empty if we are the client. In that case,
  40. // we accept these key types from the server as host key.
  41. hostKeyAlgorithms []string
  42. // On read error, incoming is closed, and readError is set.
  43. incoming chan []byte
  44. readError error
  45. mu sync.Mutex
  46. writeError error
  47. sentInitPacket []byte
  48. sentInitMsg *kexInitMsg
  49. pendingPackets [][]byte // Used when a key exchange is in progress.
  50. // If the read loop wants to schedule a kex, it pings this
  51. // channel, and the write loop will send out a kex message.
  52. requestKex chan struct{}
  53. // If the other side requests or confirms a kex, its kexInit
  54. // packet is sent here for the write loop to find it.
  55. startKex chan *pendingKex
  56. // data for host key checking
  57. hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
  58. dialAddress string
  59. remoteAddr net.Addr
  60. readSinceKex uint64
  61. writtenSinceKex uint64
  62. // The session ID or nil if first kex did not complete yet.
  63. sessionID []byte
  64. }
  65. type pendingKex struct {
  66. otherInit []byte
  67. done chan error
  68. }
  69. func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
  70. t := &handshakeTransport{
  71. conn: conn,
  72. serverVersion: serverVersion,
  73. clientVersion: clientVersion,
  74. incoming: make(chan []byte, 16),
  75. requestKex: make(chan struct{}, 1),
  76. startKex: make(chan *pendingKex, 1),
  77. config: config,
  78. }
  79. return t
  80. }
  81. func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport {
  82. t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
  83. t.dialAddress = dialAddr
  84. t.remoteAddr = addr
  85. t.hostKeyCallback = config.HostKeyCallback
  86. if config.HostKeyAlgorithms != nil {
  87. t.hostKeyAlgorithms = config.HostKeyAlgorithms
  88. } else {
  89. t.hostKeyAlgorithms = supportedHostKeyAlgos
  90. }
  91. go t.readLoop()
  92. go t.kexLoop()
  93. return t
  94. }
  95. func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport {
  96. t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
  97. t.hostKeys = config.hostKeys
  98. go t.readLoop()
  99. go t.kexLoop()
  100. return t
  101. }
  102. func (t *handshakeTransport) getSessionID() []byte {
  103. return t.sessionID
  104. }
  105. // waitSession waits for the session to be established. This should be
  106. // the first thing to call after instantiating handshakeTransport.
  107. func (t *handshakeTransport) waitSession() error {
  108. p, err := t.readPacket()
  109. if err != nil {
  110. return err
  111. }
  112. if p[0] != msgNewKeys {
  113. return fmt.Errorf("ssh: first packet should be msgNewKeys")
  114. }
  115. return nil
  116. }
  117. func (t *handshakeTransport) id() string {
  118. if len(t.hostKeys) > 0 {
  119. return "server"
  120. }
  121. return "client"
  122. }
  123. func (t *handshakeTransport) printPacket(p []byte, write bool) {
  124. action := "got"
  125. if write {
  126. action = "sent"
  127. }
  128. if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
  129. log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p))
  130. } else {
  131. msg, err := decode(p)
  132. log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err)
  133. }
  134. }
  135. func (t *handshakeTransport) readPacket() ([]byte, error) {
  136. p, ok := <-t.incoming
  137. if !ok {
  138. return nil, t.readError
  139. }
  140. return p, nil
  141. }
  142. func (t *handshakeTransport) readLoop() {
  143. // We always start with the mandatory key exchange. We use
  144. // the channel for simplicity, and this works if we can rely
  145. // on the SSH package itself not doing anything else before
  146. // waitSession has completed.
  147. t.requestKeyExchange()
  148. first := true
  149. for {
  150. p, err := t.readOnePacket(first)
  151. first = false
  152. if err != nil {
  153. t.readError = err
  154. close(t.incoming)
  155. break
  156. }
  157. if p[0] == msgIgnore || p[0] == msgDebug {
  158. continue
  159. }
  160. t.incoming <- p
  161. }
  162. // Stop writers too.
  163. t.recordWriteError(t.readError)
  164. // Unblock the writer should it wait for this.
  165. close(t.startKex)
  166. // Don't close t.requestKex; it's also written to from writePacket.
  167. }
  168. func (t *handshakeTransport) pushPacket(p []byte) error {
  169. if debugHandshake {
  170. t.printPacket(p, true)
  171. }
  172. return t.conn.writePacket(p)
  173. }
  174. func (t *handshakeTransport) getWriteError() error {
  175. t.mu.Lock()
  176. defer t.mu.Unlock()
  177. return t.writeError
  178. }
  179. func (t *handshakeTransport) recordWriteError(err error) {
  180. t.mu.Lock()
  181. defer t.mu.Unlock()
  182. if t.writeError == nil && err != nil {
  183. t.writeError = err
  184. }
  185. }
  186. func (t *handshakeTransport) requestKeyExchange() {
  187. select {
  188. case t.requestKex <- struct{}{}:
  189. default:
  190. // something already requested a kex, so do nothing.
  191. }
  192. }
  193. func (t *handshakeTransport) kexLoop() {
  194. write:
  195. for t.getWriteError() == nil {
  196. var request *pendingKex
  197. var sent bool
  198. for request == nil || !sent {
  199. var ok bool
  200. select {
  201. case request, ok = <-t.startKex:
  202. if !ok {
  203. break write
  204. }
  205. case <-t.requestKex:
  206. }
  207. if !sent {
  208. if err := t.sendKexInit(); err != nil {
  209. t.recordWriteError(err)
  210. break
  211. }
  212. sent = true
  213. }
  214. }
  215. if err := t.getWriteError(); err != nil {
  216. if request != nil {
  217. request.done <- err
  218. }
  219. break
  220. }
  221. // We're not servicing t.requestKex, but that is OK:
  222. // we never block on sending to t.requestKex.
  223. // We're not servicing t.startKex, but the remote end
  224. // has just sent us a kexInitMsg, so it can't send
  225. // another key change request.
  226. err := t.enterKeyExchange(request.otherInit)
  227. t.mu.Lock()
  228. t.writeError = err
  229. t.sentInitPacket = nil
  230. t.sentInitMsg = nil
  231. t.writtenSinceKex = 0
  232. request.done <- t.writeError
  233. // kex finished. Push packets that we received while
  234. // the kex was in progress. Don't look at t.startKex
  235. // and don't increment writtenSinceKex: if we trigger
  236. // another kex while we are still busy with the last
  237. // one, things will become very confusing.
  238. for _, p := range t.pendingPackets {
  239. t.writeError = t.pushPacket(p)
  240. if t.writeError != nil {
  241. break
  242. }
  243. }
  244. t.pendingPackets = t.pendingPackets[0:]
  245. t.mu.Unlock()
  246. }
  247. // drain startKex channel. We don't service t.requestKex
  248. // because nobody does blocking sends there.
  249. go func() {
  250. for init := range t.startKex {
  251. init.done <- t.writeError
  252. }
  253. }()
  254. // Unblock reader.
  255. t.conn.Close()
  256. }
  257. func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
  258. if t.readSinceKex > t.config.RekeyThreshold {
  259. t.requestKeyExchange()
  260. }
  261. p, err := t.conn.readPacket()
  262. if err != nil {
  263. return nil, err
  264. }
  265. t.readSinceKex += uint64(len(p))
  266. if debugHandshake {
  267. t.printPacket(p, false)
  268. }
  269. if first && p[0] != msgKexInit {
  270. return nil, fmt.Errorf("ssh: first packet should be msgKexInit")
  271. }
  272. if p[0] != msgKexInit {
  273. return p, nil
  274. }
  275. firstKex := t.sessionID == nil
  276. kex := pendingKex{
  277. done: make(chan error, 1),
  278. otherInit: p,
  279. }
  280. t.startKex <- &kex
  281. err = <-kex.done
  282. if debugHandshake {
  283. log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err)
  284. }
  285. if err != nil {
  286. return nil, err
  287. }
  288. t.readSinceKex = 0
  289. // By default, a key exchange is hidden from higher layers by
  290. // translating it into msgIgnore.
  291. successPacket := []byte{msgIgnore}
  292. if firstKex {
  293. // sendKexInit() for the first kex waits for
  294. // msgNewKeys so the authentication process is
  295. // guaranteed to happen over an encrypted transport.
  296. successPacket = []byte{msgNewKeys}
  297. }
  298. return successPacket, nil
  299. }
  300. // sendKexInit sends a key change message.
  301. func (t *handshakeTransport) sendKexInit() error {
  302. t.mu.Lock()
  303. defer t.mu.Unlock()
  304. if t.sentInitMsg != nil {
  305. // kexInits may be sent either in response to the other side,
  306. // or because our side wants to initiate a key change, so we
  307. // may have already sent a kexInit. In that case, don't send a
  308. // second kexInit.
  309. return nil
  310. }
  311. msg := &kexInitMsg{
  312. KexAlgos: t.config.KeyExchanges,
  313. CiphersClientServer: t.config.Ciphers,
  314. CiphersServerClient: t.config.Ciphers,
  315. MACsClientServer: t.config.MACs,
  316. MACsServerClient: t.config.MACs,
  317. CompressionClientServer: supportedCompressions,
  318. CompressionServerClient: supportedCompressions,
  319. }
  320. io.ReadFull(rand.Reader, msg.Cookie[:])
  321. if len(t.hostKeys) > 0 {
  322. for _, k := range t.hostKeys {
  323. msg.ServerHostKeyAlgos = append(
  324. msg.ServerHostKeyAlgos, k.PublicKey().Type())
  325. }
  326. } else {
  327. msg.ServerHostKeyAlgos = t.hostKeyAlgorithms
  328. }
  329. packet := Marshal(msg)
  330. // writePacket destroys the contents, so save a copy.
  331. packetCopy := make([]byte, len(packet))
  332. copy(packetCopy, packet)
  333. if err := t.pushPacket(packetCopy); err != nil {
  334. return err
  335. }
  336. t.sentInitMsg = msg
  337. t.sentInitPacket = packet
  338. return nil
  339. }
  340. func (t *handshakeTransport) writePacket(p []byte) error {
  341. switch p[0] {
  342. case msgKexInit:
  343. return errors.New("ssh: only handshakeTransport can send kexInit")
  344. case msgNewKeys:
  345. return errors.New("ssh: only handshakeTransport can send newKeys")
  346. }
  347. t.mu.Lock()
  348. defer t.mu.Unlock()
  349. if t.writeError != nil {
  350. return t.writeError
  351. }
  352. if t.sentInitMsg != nil {
  353. // Copy the packet so the writer can reuse the buffer.
  354. cp := make([]byte, len(p))
  355. copy(cp, p)
  356. t.pendingPackets = append(t.pendingPackets, cp)
  357. return nil
  358. }
  359. t.writtenSinceKex += uint64(len(p))
  360. if t.writtenSinceKex > t.config.RekeyThreshold {
  361. t.requestKeyExchange()
  362. }
  363. if err := t.pushPacket(p); err != nil {
  364. t.writeError = err
  365. }
  366. return nil
  367. }
  368. func (t *handshakeTransport) Close() error {
  369. return t.conn.Close()
  370. }
  371. func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
  372. if debugHandshake {
  373. log.Printf("%s entered key exchange", t.id())
  374. }
  375. otherInit := &kexInitMsg{}
  376. if err := Unmarshal(otherInitPacket, otherInit); err != nil {
  377. return err
  378. }
  379. magics := handshakeMagics{
  380. clientVersion: t.clientVersion,
  381. serverVersion: t.serverVersion,
  382. clientKexInit: otherInitPacket,
  383. serverKexInit: t.sentInitPacket,
  384. }
  385. clientInit := otherInit
  386. serverInit := t.sentInitMsg
  387. if len(t.hostKeys) == 0 {
  388. clientInit, serverInit = serverInit, clientInit
  389. magics.clientKexInit = t.sentInitPacket
  390. magics.serverKexInit = otherInitPacket
  391. }
  392. algs, err := findAgreedAlgorithms(clientInit, serverInit)
  393. if err != nil {
  394. return err
  395. }
  396. // We don't send FirstKexFollows, but we handle receiving it.
  397. //
  398. // RFC 4253 section 7 defines the kex and the agreement method for
  399. // first_kex_packet_follows. It states that the guessed packet
  400. // should be ignored if the "kex algorithm and/or the host
  401. // key algorithm is guessed wrong (server and client have
  402. // different preferred algorithm), or if any of the other
  403. // algorithms cannot be agreed upon". The other algorithms have
  404. // already been checked above so the kex algorithm and host key
  405. // algorithm are checked here.
  406. if otherInit.FirstKexFollows && (clientInit.KexAlgos[0] != serverInit.KexAlgos[0] || clientInit.ServerHostKeyAlgos[0] != serverInit.ServerHostKeyAlgos[0]) {
  407. // other side sent a kex message for the wrong algorithm,
  408. // which we have to ignore.
  409. if _, err := t.conn.readPacket(); err != nil {
  410. return err
  411. }
  412. }
  413. kex, ok := kexAlgoMap[algs.kex]
  414. if !ok {
  415. return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex)
  416. }
  417. var result *kexResult
  418. if len(t.hostKeys) > 0 {
  419. result, err = t.server(kex, algs, &magics)
  420. } else {
  421. result, err = t.client(kex, algs, &magics)
  422. }
  423. if err != nil {
  424. return err
  425. }
  426. if t.sessionID == nil {
  427. t.sessionID = result.H
  428. }
  429. result.SessionID = t.sessionID
  430. t.conn.prepareKeyChange(algs, result)
  431. if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
  432. return err
  433. }
  434. if packet, err := t.conn.readPacket(); err != nil {
  435. return err
  436. } else if packet[0] != msgNewKeys {
  437. return unexpectedMessageError(msgNewKeys, packet[0])
  438. }
  439. return nil
  440. }
  441. func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
  442. var hostKey Signer
  443. for _, k := range t.hostKeys {
  444. if algs.hostKey == k.PublicKey().Type() {
  445. hostKey = k
  446. }
  447. }
  448. r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey)
  449. return r, err
  450. }
  451. func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
  452. result, err := kex.Client(t.conn, t.config.Rand, magics)
  453. if err != nil {
  454. return nil, err
  455. }
  456. hostKey, err := ParsePublicKey(result.HostKey)
  457. if err != nil {
  458. return nil, err
  459. }
  460. if err := verifyHostKeySignature(hostKey, result); err != nil {
  461. return nil, err
  462. }
  463. if t.hostKeyCallback != nil {
  464. err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
  465. if err != nil {
  466. return nil, err
  467. }
  468. }
  469. return result, nil
  470. }