cipher_test.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "bytes"
  7. "crypto"
  8. "crypto/rand"
  9. "testing"
  10. )
  11. func TestDefaultCiphersExist(t *testing.T) {
  12. for _, cipherAlgo := range supportedCiphers {
  13. if _, ok := cipherModes[cipherAlgo]; !ok {
  14. t.Errorf("supported cipher %q is unknown", cipherAlgo)
  15. }
  16. }
  17. for _, cipherAlgo := range preferredCiphers {
  18. if _, ok := cipherModes[cipherAlgo]; !ok {
  19. t.Errorf("preferred cipher %q is unknown", cipherAlgo)
  20. }
  21. }
  22. }
  23. func TestPacketCiphers(t *testing.T) {
  24. defaultMac := "hmac-sha2-256"
  25. defaultCipher := "aes128-ctr"
  26. for cipher := range cipherModes {
  27. t.Run("cipher="+cipher,
  28. func(t *testing.T) { testPacketCipher(t, cipher, defaultMac) })
  29. }
  30. for mac := range macModes {
  31. t.Run("mac="+mac,
  32. func(t *testing.T) { testPacketCipher(t, defaultCipher, mac) })
  33. }
  34. }
  35. func testPacketCipher(t *testing.T, cipher, mac string) {
  36. kr := &kexResult{Hash: crypto.SHA1}
  37. algs := directionAlgorithms{
  38. Cipher: cipher,
  39. MAC: mac,
  40. Compression: "none",
  41. }
  42. client, err := newPacketCipher(clientKeys, algs, kr)
  43. if err != nil {
  44. t.Fatalf("newPacketCipher(client, %q, %q): %v", cipher, mac, err)
  45. }
  46. server, err := newPacketCipher(clientKeys, algs, kr)
  47. if err != nil {
  48. t.Fatalf("newPacketCipher(client, %q, %q): %v", cipher, mac, err)
  49. }
  50. want := "bla bla"
  51. input := []byte(want)
  52. buf := &bytes.Buffer{}
  53. if err := client.writeCipherPacket(0, buf, rand.Reader, input); err != nil {
  54. t.Fatalf("writeCipherPacket(%q, %q): %v", cipher, mac, err)
  55. }
  56. packet, err := server.readCipherPacket(0, buf)
  57. if err != nil {
  58. t.Fatalf("readCipherPacket(%q, %q): %v", cipher, mac, err)
  59. }
  60. if string(packet) != want {
  61. t.Errorf("roundtrip(%q, %q): got %q, want %q", cipher, mac, packet, want)
  62. }
  63. }
  64. func TestCBCOracleCounterMeasure(t *testing.T) {
  65. kr := &kexResult{Hash: crypto.SHA1}
  66. algs := directionAlgorithms{
  67. Cipher: aes128cbcID,
  68. MAC: "hmac-sha1",
  69. Compression: "none",
  70. }
  71. client, err := newPacketCipher(clientKeys, algs, kr)
  72. if err != nil {
  73. t.Fatalf("newPacketCipher(client): %v", err)
  74. }
  75. want := "bla bla"
  76. input := []byte(want)
  77. buf := &bytes.Buffer{}
  78. if err := client.writeCipherPacket(0, buf, rand.Reader, input); err != nil {
  79. t.Errorf("writeCipherPacket: %v", err)
  80. }
  81. packetSize := buf.Len()
  82. buf.Write(make([]byte, 2*maxPacket))
  83. // We corrupt each byte, but this usually will only test the
  84. // 'packet too large' or 'MAC failure' cases.
  85. lastRead := -1
  86. for i := 0; i < packetSize; i++ {
  87. server, err := newPacketCipher(clientKeys, algs, kr)
  88. if err != nil {
  89. t.Fatalf("newPacketCipher(client): %v", err)
  90. }
  91. fresh := &bytes.Buffer{}
  92. fresh.Write(buf.Bytes())
  93. fresh.Bytes()[i] ^= 0x01
  94. before := fresh.Len()
  95. _, err = server.readCipherPacket(0, fresh)
  96. if err == nil {
  97. t.Errorf("corrupt byte %d: readCipherPacket succeeded ", i)
  98. continue
  99. }
  100. if _, ok := err.(cbcError); !ok {
  101. t.Errorf("corrupt byte %d: got %v (%T), want cbcError", i, err, err)
  102. continue
  103. }
  104. after := fresh.Len()
  105. bytesRead := before - after
  106. if bytesRead < maxPacket {
  107. t.Errorf("corrupt byte %d: read %d bytes, want more than %d", i, bytesRead, maxPacket)
  108. continue
  109. }
  110. if i > 0 && bytesRead != lastRead {
  111. t.Errorf("corrupt byte %d: read %d bytes, want %d bytes read", i, bytesRead, lastRead)
  112. }
  113. lastRead = bytesRead
  114. }
  115. }