handshake_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562
  1. // Copyright 2013 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "bytes"
  7. "crypto/rand"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "net"
  12. "reflect"
  13. "runtime"
  14. "strings"
  15. "sync"
  16. "testing"
  17. )
  18. type testChecker struct {
  19. calls []string
  20. }
  21. func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
  22. if dialAddr == "bad" {
  23. return fmt.Errorf("dialAddr is bad")
  24. }
  25. if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {
  26. return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)
  27. }
  28. t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))
  29. return nil
  30. }
  31. // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
  32. // therefore is buffered (net.Pipe deadlocks if both sides start with
  33. // a write.)
  34. func netPipe() (net.Conn, net.Conn, error) {
  35. listener, err := net.Listen("tcp", "127.0.0.1:0")
  36. if err != nil {
  37. listener, err = net.Listen("tcp", "[::1]:0")
  38. if err != nil {
  39. return nil, nil, err
  40. }
  41. }
  42. defer listener.Close()
  43. c1, err := net.Dial("tcp", listener.Addr().String())
  44. if err != nil {
  45. return nil, nil, err
  46. }
  47. c2, err := listener.Accept()
  48. if err != nil {
  49. c1.Close()
  50. return nil, nil, err
  51. }
  52. return c1, c2, nil
  53. }
  54. // noiseTransport inserts ignore messages to check that the read loop
  55. // and the key exchange filters out these messages.
  56. type noiseTransport struct {
  57. keyingTransport
  58. }
  59. func (t *noiseTransport) writePacket(p []byte) error {
  60. ignore := []byte{msgIgnore}
  61. if err := t.keyingTransport.writePacket(ignore); err != nil {
  62. return err
  63. }
  64. debug := []byte{msgDebug, 1, 2, 3}
  65. if err := t.keyingTransport.writePacket(debug); err != nil {
  66. return err
  67. }
  68. return t.keyingTransport.writePacket(p)
  69. }
  70. func addNoiseTransport(t keyingTransport) keyingTransport {
  71. return &noiseTransport{t}
  72. }
  73. // handshakePair creates two handshakeTransports connected with each
  74. // other. If the noise argument is true, both transports will try to
  75. // confuse the other side by sending ignore and debug messages.
  76. func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) {
  77. a, b, err := netPipe()
  78. if err != nil {
  79. return nil, nil, err
  80. }
  81. var trC, trS keyingTransport
  82. trC = newTransport(a, rand.Reader, true)
  83. trS = newTransport(b, rand.Reader, false)
  84. if noise {
  85. trC = addNoiseTransport(trC)
  86. trS = addNoiseTransport(trS)
  87. }
  88. clientConf.SetDefaults()
  89. v := []byte("version")
  90. client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr())
  91. serverConf := &ServerConfig{}
  92. serverConf.AddHostKey(testSigners["ecdsa"])
  93. serverConf.AddHostKey(testSigners["rsa"])
  94. serverConf.SetDefaults()
  95. server = newServerTransport(trS, v, v, serverConf)
  96. if err := server.waitSession(); err != nil {
  97. return nil, nil, fmt.Errorf("server.waitSession: %v", err)
  98. }
  99. if err := client.waitSession(); err != nil {
  100. return nil, nil, fmt.Errorf("client.waitSession: %v", err)
  101. }
  102. return client, server, nil
  103. }
  104. func TestHandshakeBasic(t *testing.T) {
  105. if runtime.GOOS == "plan9" {
  106. t.Skip("see golang.org/issue/7237")
  107. }
  108. checker := &syncChecker{
  109. waitCall: make(chan int, 10),
  110. called: make(chan int, 10),
  111. }
  112. checker.waitCall <- 1
  113. trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
  114. if err != nil {
  115. t.Fatalf("handshakePair: %v", err)
  116. }
  117. defer trC.Close()
  118. defer trS.Close()
  119. // Let first kex complete normally.
  120. <-checker.called
  121. clientDone := make(chan int, 0)
  122. gotHalf := make(chan int, 0)
  123. const N = 20
  124. go func() {
  125. defer close(clientDone)
  126. // Client writes a bunch of stuff, and does a key
  127. // change in the middle. This should not confuse the
  128. // handshake in progress. We do this twice, so we test
  129. // that the packet buffer is reset correctly.
  130. for i := 0; i < N; i++ {
  131. p := []byte{msgRequestSuccess, byte(i)}
  132. if err := trC.writePacket(p); err != nil {
  133. t.Fatalf("sendPacket: %v", err)
  134. }
  135. if (i % 10) == 5 {
  136. <-gotHalf
  137. // halfway through, we request a key change.
  138. trC.requestKeyExchange()
  139. // Wait until we can be sure the key
  140. // change has really started before we
  141. // write more.
  142. <-checker.called
  143. }
  144. if (i % 10) == 7 {
  145. // write some packets until the kex
  146. // completes, to test buffering of
  147. // packets.
  148. checker.waitCall <- 1
  149. }
  150. }
  151. }()
  152. // Server checks that client messages come in cleanly
  153. i := 0
  154. err = nil
  155. for ; i < N; i++ {
  156. var p []byte
  157. p, err = trS.readPacket()
  158. if err != nil {
  159. break
  160. }
  161. if (i % 10) == 5 {
  162. gotHalf <- 1
  163. }
  164. want := []byte{msgRequestSuccess, byte(i)}
  165. if bytes.Compare(p, want) != 0 {
  166. t.Errorf("message %d: got %v, want %v", i, p, want)
  167. }
  168. }
  169. <-clientDone
  170. if err != nil && err != io.EOF {
  171. t.Fatalf("server error: %v", err)
  172. }
  173. if i != N {
  174. t.Errorf("received %d messages, want 10.", i)
  175. }
  176. close(checker.called)
  177. if _, ok := <-checker.called; ok {
  178. // If all went well, we registered exactly 2 key changes: one
  179. // that establishes the session, and one that we requested
  180. // additionally.
  181. t.Fatalf("got another host key checks after 2 handshakes")
  182. }
  183. }
  184. func TestForceFirstKex(t *testing.T) {
  185. // like handshakePair, but must access the keyingTransport.
  186. checker := &testChecker{}
  187. clientConf := &ClientConfig{HostKeyCallback: checker.Check}
  188. a, b, err := netPipe()
  189. if err != nil {
  190. t.Fatalf("netPipe: %v", err)
  191. }
  192. var trC, trS keyingTransport
  193. trC = newTransport(a, rand.Reader, true)
  194. // This is the disallowed packet:
  195. trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth}))
  196. // Rest of the setup.
  197. trS = newTransport(b, rand.Reader, false)
  198. clientConf.SetDefaults()
  199. v := []byte("version")
  200. client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr())
  201. serverConf := &ServerConfig{}
  202. serverConf.AddHostKey(testSigners["ecdsa"])
  203. serverConf.AddHostKey(testSigners["rsa"])
  204. serverConf.SetDefaults()
  205. server := newServerTransport(trS, v, v, serverConf)
  206. defer client.Close()
  207. defer server.Close()
  208. // We setup the initial key exchange, but the remote side
  209. // tries to send serviceRequestMsg in cleartext, which is
  210. // disallowed.
  211. if err := server.waitSession(); err == nil {
  212. t.Errorf("server first kex init should reject unexpected packet")
  213. }
  214. }
  215. func TestHandshakeAutoRekeyWrite(t *testing.T) {
  216. checker := &syncChecker{
  217. called: make(chan int, 10),
  218. waitCall: nil,
  219. }
  220. clientConf := &ClientConfig{HostKeyCallback: checker.Check}
  221. clientConf.RekeyThreshold = 500
  222. trC, trS, err := handshakePair(clientConf, "addr", false)
  223. if err != nil {
  224. t.Fatalf("handshakePair: %v", err)
  225. }
  226. defer trC.Close()
  227. defer trS.Close()
  228. input := make([]byte, 251)
  229. input[0] = msgRequestSuccess
  230. done := make(chan int, 1)
  231. const numPacket = 5
  232. go func() {
  233. defer close(done)
  234. j := 0
  235. for ; j < numPacket; j++ {
  236. if p, err := trS.readPacket(); err != nil {
  237. break
  238. } else if !bytes.Equal(input, p) {
  239. t.Errorf("got packet type %d, want %d", p[0], input[0])
  240. }
  241. }
  242. if j != numPacket {
  243. t.Errorf("got %d, want 5 messages", j)
  244. }
  245. }()
  246. <-checker.called
  247. for i := 0; i < numPacket; i++ {
  248. p := make([]byte, len(input))
  249. copy(p, input)
  250. if err := trC.writePacket(p); err != nil {
  251. t.Errorf("writePacket: %v", err)
  252. }
  253. if i == 2 {
  254. // Make sure the kex is in progress.
  255. <-checker.called
  256. }
  257. }
  258. <-done
  259. }
  260. type syncChecker struct {
  261. waitCall chan int
  262. called chan int
  263. }
  264. func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
  265. c.called <- 1
  266. if c.waitCall != nil {
  267. <-c.waitCall
  268. }
  269. return nil
  270. }
  271. func TestHandshakeAutoRekeyRead(t *testing.T) {
  272. sync := &syncChecker{
  273. called: make(chan int, 2),
  274. waitCall: nil,
  275. }
  276. clientConf := &ClientConfig{
  277. HostKeyCallback: sync.Check,
  278. }
  279. clientConf.RekeyThreshold = 500
  280. trC, trS, err := handshakePair(clientConf, "addr", false)
  281. if err != nil {
  282. t.Fatalf("handshakePair: %v", err)
  283. }
  284. defer trC.Close()
  285. defer trS.Close()
  286. packet := make([]byte, 501)
  287. packet[0] = msgRequestSuccess
  288. if err := trS.writePacket(packet); err != nil {
  289. t.Fatalf("writePacket: %v", err)
  290. }
  291. // While we read out the packet, a key change will be
  292. // initiated.
  293. done := make(chan int, 1)
  294. go func() {
  295. defer close(done)
  296. if _, err := trC.readPacket(); err != nil {
  297. t.Fatalf("readPacket(client): %v", err)
  298. }
  299. }()
  300. <-done
  301. <-sync.called
  302. }
  303. // errorKeyingTransport generates errors after a given number of
  304. // read/write operations.
  305. type errorKeyingTransport struct {
  306. packetConn
  307. readLeft, writeLeft int
  308. }
  309. func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
  310. return nil
  311. }
  312. func (n *errorKeyingTransport) getSessionID() []byte {
  313. return nil
  314. }
  315. func (n *errorKeyingTransport) writePacket(packet []byte) error {
  316. if n.writeLeft == 0 {
  317. n.Close()
  318. return errors.New("barf")
  319. }
  320. n.writeLeft--
  321. return n.packetConn.writePacket(packet)
  322. }
  323. func (n *errorKeyingTransport) readPacket() ([]byte, error) {
  324. if n.readLeft == 0 {
  325. n.Close()
  326. return nil, errors.New("barf")
  327. }
  328. n.readLeft--
  329. return n.packetConn.readPacket()
  330. }
  331. func TestHandshakeErrorHandlingRead(t *testing.T) {
  332. for i := 0; i < 20; i++ {
  333. testHandshakeErrorHandlingN(t, i, -1, false)
  334. }
  335. }
  336. func TestHandshakeErrorHandlingWrite(t *testing.T) {
  337. for i := 0; i < 20; i++ {
  338. testHandshakeErrorHandlingN(t, -1, i, false)
  339. }
  340. }
  341. func TestHandshakeErrorHandlingReadCoupled(t *testing.T) {
  342. for i := 0; i < 20; i++ {
  343. testHandshakeErrorHandlingN(t, i, -1, true)
  344. }
  345. }
  346. func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) {
  347. for i := 0; i < 20; i++ {
  348. testHandshakeErrorHandlingN(t, -1, i, true)
  349. }
  350. }
  351. // testHandshakeErrorHandlingN runs handshakes, injecting errors. If
  352. // handshakeTransport deadlocks, the go runtime will detect it and
  353. // panic.
  354. func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) {
  355. if runtime.GOOS == "js" && runtime.GOARCH == "wasm" {
  356. t.Skip("skipping on js/wasm; see golang.org/issue/32840")
  357. }
  358. msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
  359. a, b := memPipe()
  360. defer a.Close()
  361. defer b.Close()
  362. key := testSigners["ecdsa"]
  363. serverConf := Config{RekeyThreshold: minRekeyThreshold}
  364. serverConf.SetDefaults()
  365. serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
  366. serverConn.hostKeys = []Signer{key}
  367. go serverConn.readLoop()
  368. go serverConn.kexLoop()
  369. clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
  370. clientConf.SetDefaults()
  371. clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
  372. clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
  373. clientConn.hostKeyCallback = InsecureIgnoreHostKey()
  374. go clientConn.readLoop()
  375. go clientConn.kexLoop()
  376. var wg sync.WaitGroup
  377. for _, hs := range []packetConn{serverConn, clientConn} {
  378. if !coupled {
  379. wg.Add(2)
  380. go func(c packetConn) {
  381. for i := 0; ; i++ {
  382. str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8)
  383. err := c.writePacket(Marshal(&serviceRequestMsg{str}))
  384. if err != nil {
  385. break
  386. }
  387. }
  388. wg.Done()
  389. c.Close()
  390. }(hs)
  391. go func(c packetConn) {
  392. for {
  393. _, err := c.readPacket()
  394. if err != nil {
  395. break
  396. }
  397. }
  398. wg.Done()
  399. }(hs)
  400. } else {
  401. wg.Add(1)
  402. go func(c packetConn) {
  403. for {
  404. _, err := c.readPacket()
  405. if err != nil {
  406. break
  407. }
  408. if err := c.writePacket(msg); err != nil {
  409. break
  410. }
  411. }
  412. wg.Done()
  413. }(hs)
  414. }
  415. }
  416. wg.Wait()
  417. }
  418. func TestDisconnect(t *testing.T) {
  419. if runtime.GOOS == "plan9" {
  420. t.Skip("see golang.org/issue/7237")
  421. }
  422. checker := &testChecker{}
  423. trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
  424. if err != nil {
  425. t.Fatalf("handshakePair: %v", err)
  426. }
  427. defer trC.Close()
  428. defer trS.Close()
  429. trC.writePacket([]byte{msgRequestSuccess, 0, 0})
  430. errMsg := &disconnectMsg{
  431. Reason: 42,
  432. Message: "such is life",
  433. }
  434. trC.writePacket(Marshal(errMsg))
  435. trC.writePacket([]byte{msgRequestSuccess, 0, 0})
  436. packet, err := trS.readPacket()
  437. if err != nil {
  438. t.Fatalf("readPacket 1: %v", err)
  439. }
  440. if packet[0] != msgRequestSuccess {
  441. t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess)
  442. }
  443. _, err = trS.readPacket()
  444. if err == nil {
  445. t.Errorf("readPacket 2 succeeded")
  446. } else if !reflect.DeepEqual(err, errMsg) {
  447. t.Errorf("got error %#v, want %#v", err, errMsg)
  448. }
  449. _, err = trS.readPacket()
  450. if err == nil {
  451. t.Errorf("readPacket 3 succeeded")
  452. }
  453. }
  454. func TestHandshakeRekeyDefault(t *testing.T) {
  455. clientConf := &ClientConfig{
  456. Config: Config{
  457. Ciphers: []string{"aes128-ctr"},
  458. },
  459. HostKeyCallback: InsecureIgnoreHostKey(),
  460. }
  461. trC, trS, err := handshakePair(clientConf, "addr", false)
  462. if err != nil {
  463. t.Fatalf("handshakePair: %v", err)
  464. }
  465. defer trC.Close()
  466. defer trS.Close()
  467. trC.writePacket([]byte{msgRequestSuccess, 0, 0})
  468. trC.Close()
  469. rgb := (1024 + trC.readBytesLeft) >> 30
  470. wgb := (1024 + trC.writeBytesLeft) >> 30
  471. if rgb != 64 {
  472. t.Errorf("got rekey after %dG read, want 64G", rgb)
  473. }
  474. if wgb != 64 {
  475. t.Errorf("got rekey after %dG write, want 64G", wgb)
  476. }
  477. }