|
|
@@ -55,14 +55,25 @@ HostbasedAuthentication no
|
|
|
`
|
|
|
|
|
|
var (
|
|
|
- configTmpl template.Template
|
|
|
- rsakey *rsa.PrivateKey
|
|
|
+ configTmpl template.Template
|
|
|
+ rsakey *rsa.PrivateKey
|
|
|
+ serializedHostKey []byte
|
|
|
)
|
|
|
|
|
|
func init() {
|
|
|
template.Must(configTmpl.Parse(sshd_config))
|
|
|
block, _ := pem.Decode([]byte(testClientPrivateKey))
|
|
|
rsakey, _ = x509.ParsePKCS1PrivateKey(block.Bytes)
|
|
|
+
|
|
|
+ block, _ = pem.Decode([]byte(keys["ssh_host_rsa_key"]))
|
|
|
+ if block == nil {
|
|
|
+ panic("pem.Decode ssh_host_rsa_key")
|
|
|
+ }
|
|
|
+ priv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
|
|
+ if err != nil {
|
|
|
+ panic("ParsePKCS1PrivateKey: " + err.Error())
|
|
|
+ }
|
|
|
+ serializedHostKey = ssh.MarshalPublicKey(&priv.PublicKey)
|
|
|
}
|
|
|
|
|
|
type server struct {
|
|
|
@@ -89,7 +100,29 @@ func username() string {
|
|
|
return username
|
|
|
}
|
|
|
|
|
|
+type storedHostKey struct {
|
|
|
+ // keys map from an algorithm string to binary key data.
|
|
|
+ keys map[string][]byte
|
|
|
+}
|
|
|
+
|
|
|
+func (k *storedHostKey) Add(algo string, public []byte) {
|
|
|
+ if k.keys == nil {
|
|
|
+ k.keys = map[string][]byte{}
|
|
|
+ }
|
|
|
+ k.keys[algo] = append([]byte(nil), public...)
|
|
|
+}
|
|
|
+
|
|
|
+func (k *storedHostKey) Check(addr string, remote net.Addr, algo string, key []byte) error {
|
|
|
+ if k.keys == nil || bytes.Compare(key, k.keys[algo]) != 0 {
|
|
|
+ return errors.New("host key mismatch")
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
func clientConfig() *ssh.ClientConfig {
|
|
|
+ keyChecker := storedHostKey{}
|
|
|
+ keyChecker.Add("ssh-rsa", serializedHostKey)
|
|
|
+
|
|
|
kc := new(keychain)
|
|
|
kc.keys = append(kc.keys, rsakey)
|
|
|
config := &ssh.ClientConfig{
|
|
|
@@ -97,11 +130,12 @@ func clientConfig() *ssh.ClientConfig {
|
|
|
Auth: []ssh.ClientAuth{
|
|
|
ssh.ClientAuthKeyring(kc),
|
|
|
},
|
|
|
+ HostKeyChecker: &keyChecker,
|
|
|
}
|
|
|
return config
|
|
|
}
|
|
|
|
|
|
-func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn {
|
|
|
+func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.ClientConn, error) {
|
|
|
sshd, err := exec.LookPath("sshd")
|
|
|
if err != nil {
|
|
|
s.t.Skipf("skipping test: %v", err)
|
|
|
@@ -123,7 +157,12 @@ func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn {
|
|
|
s.Shutdown()
|
|
|
s.t.Fatalf("s.cmd.Start: %v", err)
|
|
|
}
|
|
|
- conn, err := ssh.Client(&client{wc: w2, r: r1}, config)
|
|
|
+
|
|
|
+ return ssh.Client(&client{wc: w2, r: r1}, config)
|
|
|
+}
|
|
|
+
|
|
|
+func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn {
|
|
|
+ conn, err := s.TryDial(config)
|
|
|
if err != nil {
|
|
|
s.t.Fail()
|
|
|
s.Shutdown()
|