common_test.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "reflect"
  7. "testing"
  8. )
  9. func TestFindAgreedAlgorithms(t *testing.T) {
  10. initKex := func(k *kexInitMsg) {
  11. if k.KexAlgos == nil {
  12. k.KexAlgos = []string{"kex1"}
  13. }
  14. if k.ServerHostKeyAlgos == nil {
  15. k.ServerHostKeyAlgos = []string{"hostkey1"}
  16. }
  17. if k.CiphersClientServer == nil {
  18. k.CiphersClientServer = []string{"cipher1"}
  19. }
  20. if k.CiphersServerClient == nil {
  21. k.CiphersServerClient = []string{"cipher1"}
  22. }
  23. if k.MACsClientServer == nil {
  24. k.MACsClientServer = []string{"mac1"}
  25. }
  26. if k.MACsServerClient == nil {
  27. k.MACsServerClient = []string{"mac1"}
  28. }
  29. if k.CompressionClientServer == nil {
  30. k.CompressionClientServer = []string{"compression1"}
  31. }
  32. if k.CompressionServerClient == nil {
  33. k.CompressionServerClient = []string{"compression1"}
  34. }
  35. if k.LanguagesClientServer == nil {
  36. k.LanguagesClientServer = []string{"language1"}
  37. }
  38. if k.LanguagesServerClient == nil {
  39. k.LanguagesServerClient = []string{"language1"}
  40. }
  41. }
  42. initDirAlgs := func(a *directionAlgorithms) {
  43. if a.Cipher == "" {
  44. a.Cipher = "cipher1"
  45. }
  46. if a.MAC == "" {
  47. a.MAC = "mac1"
  48. }
  49. if a.Compression == "" {
  50. a.Compression = "compression1"
  51. }
  52. }
  53. initAlgs := func(a *algorithms) {
  54. if a.kex == "" {
  55. a.kex = "kex1"
  56. }
  57. if a.hostKey == "" {
  58. a.hostKey = "hostkey1"
  59. }
  60. initDirAlgs(&a.r)
  61. initDirAlgs(&a.w)
  62. }
  63. type testcase struct {
  64. name string
  65. clientIn, serverIn kexInitMsg
  66. wantClient, wantServer algorithms
  67. wantErr bool
  68. }
  69. cases := []testcase{
  70. testcase{
  71. name: "standard",
  72. },
  73. testcase{
  74. name: "no common hostkey",
  75. serverIn: kexInitMsg{
  76. ServerHostKeyAlgos: []string{"hostkey2"},
  77. },
  78. wantErr: true,
  79. },
  80. testcase{
  81. name: "no common kex",
  82. serverIn: kexInitMsg{
  83. KexAlgos: []string{"kex2"},
  84. },
  85. wantErr: true,
  86. },
  87. testcase{
  88. name: "no common cipher",
  89. serverIn: kexInitMsg{
  90. CiphersClientServer: []string{"cipher2"},
  91. },
  92. wantErr: true,
  93. },
  94. testcase{
  95. name: "client decides cipher",
  96. serverIn: kexInitMsg{
  97. CiphersClientServer: []string{"cipher1", "cipher2"},
  98. CiphersServerClient: []string{"cipher2", "cipher3"},
  99. },
  100. clientIn: kexInitMsg{
  101. CiphersClientServer: []string{"cipher2", "cipher1"},
  102. CiphersServerClient: []string{"cipher3", "cipher2"},
  103. },
  104. wantClient: algorithms{
  105. r: directionAlgorithms{
  106. Cipher: "cipher3",
  107. },
  108. w: directionAlgorithms{
  109. Cipher: "cipher2",
  110. },
  111. },
  112. wantServer: algorithms{
  113. w: directionAlgorithms{
  114. Cipher: "cipher3",
  115. },
  116. r: directionAlgorithms{
  117. Cipher: "cipher2",
  118. },
  119. },
  120. },
  121. // TODO(hanwen): fix and add tests for AEAD ignoring
  122. // the MACs field
  123. }
  124. for i := range cases {
  125. initKex(&cases[i].clientIn)
  126. initKex(&cases[i].serverIn)
  127. initAlgs(&cases[i].wantClient)
  128. initAlgs(&cases[i].wantServer)
  129. }
  130. for _, c := range cases {
  131. t.Run(c.name, func(t *testing.T) {
  132. serverAlgs, serverErr := findAgreedAlgorithms(false, &c.clientIn, &c.serverIn)
  133. clientAlgs, clientErr := findAgreedAlgorithms(true, &c.clientIn, &c.serverIn)
  134. serverHasErr := serverErr != nil
  135. clientHasErr := clientErr != nil
  136. if c.wantErr != serverHasErr || c.wantErr != clientHasErr {
  137. t.Fatalf("got client/server error (%v, %v), want hasError %v",
  138. clientErr, serverErr, c.wantErr)
  139. }
  140. if c.wantErr {
  141. return
  142. }
  143. if !reflect.DeepEqual(serverAlgs, &c.wantServer) {
  144. t.Errorf("server: got algs %#v, want %#v", serverAlgs, &c.wantServer)
  145. }
  146. if !reflect.DeepEqual(clientAlgs, &c.wantClient) {
  147. t.Errorf("server: got algs %#v, want %#v", clientAlgs, &c.wantClient)
  148. }
  149. })
  150. }
  151. }