瀏覽代碼

go.crypto/ssh: add hook for host key checking.

R=dave, agl
CC=gobot, golang-dev
https://golang.org/cl/9922043
Han-Wen Nienhuys 12 年之前
父節點
當前提交
afdc305bc8
共有 4 個文件被更改,包括 95 次插入5 次删除
  1. 21 1
      ssh/client.go
  2. 12 0
      ssh/client_auth.go
  3. 19 0
      ssh/test/session_test.go
  4. 43 4
      ssh/test/test_unix_test.go

+ 21 - 1
ssh/client.go

@@ -26,6 +26,9 @@ type ClientConn struct {
 	chanList    // channels associated with this connection
 	forwardList // forwarded tcpip connections from the remote side
 	globalRequest
+
+	// Address as passed to the Dial function.
+	dialAddress string
 }
 
 type globalRequest struct {
@@ -35,11 +38,17 @@ type globalRequest struct {
 
 // Client returns a new SSH client connection using c as the underlying transport.
 func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) {
+	return clientWithAddress(c, "", config)
+}
+
+func clientWithAddress(c net.Conn, addr string, config *ClientConfig) (*ClientConn, error) {
 	conn := &ClientConn{
 		transport:     newTransport(c, config.rand()),
 		config:        config,
 		globalRequest: globalRequest{response: make(chan interface{}, 1)},
+		dialAddress:   addr,
 	}
+
 	if err := conn.handshake(); err != nil {
 		conn.Close()
 		return nil, fmt.Errorf("handshake failed: %v", err)
@@ -168,6 +177,12 @@ func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
 		return nil, nil, err
 	}
 
+	if checker := c.config.HostKeyChecker; checker != nil {
+		if err = checker.Check(c.dialAddress, c.RemoteAddr(), hostKeyAlgo, kexDHReply.HostKey); err != nil {
+			return nil, nil, err
+		}
+	}
+
 	kInt, err := group.diffieHellman(kexDHReply.Y, x)
 	if err != nil {
 		return nil, nil, err
@@ -445,7 +460,7 @@ func Dial(network, addr string, config *ClientConfig) (*ClientConn, error) {
 	if err != nil {
 		return nil, err
 	}
-	return Client(conn, config)
+	return clientWithAddress(conn, addr, config)
 }
 
 // A ClientConfig structure is used to configure a ClientConn. After one has
@@ -463,6 +478,11 @@ type ClientConfig struct {
 	// of a particular RFC 4252 method will be used during authentication.
 	Auth []ClientAuth
 
+	// HostKeyChecker, if not nil, is called during the cryptographic
+	// handshake to validate the server's host key. A nil HostKeyChecker
+	// implies that all host keys are accepted.
+	HostKeyChecker HostKeyChecker
+
 	// Cryptographic-related configuration.
 	Crypto CryptoConfig
 }

+ 12 - 0
ssh/client_auth.go

@@ -8,6 +8,7 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"net"
 )
 
 // authenticate authenticates with the remote server. See RFC 4252.
@@ -63,6 +64,17 @@ func keys(m map[string]bool) (s []string) {
 	return
 }
 
+// HostKeyChecker represents a database of known server host keys.
+type HostKeyChecker interface {
+	// Check is called during the handshake to check server's
+	// public key for unexpected changes. The hostKey argument is
+	// in SSH wire format. It can be parsed using
+	// ssh.ParsePublicKey. The address before DNS resolution is
+	// passed in the addr argument, so the key can also be checked
+	// against the hostname.
+	Check(addr string, remote net.Addr, algorithm string, hostKey []byte) error
+}
+
 // A ClientAuth represents an instance of an RFC 4252 authentication method.
 type ClientAuth interface {
 	// auth authenticates user over transport t.

+ 19 - 0
ssh/test/session_test.go

@@ -33,6 +33,25 @@ func TestRunCommandSuccess(t *testing.T) {
 	}
 }
 
+func TestHostKeyCheck(t *testing.T) {
+	server := newServer(t)
+	defer server.Shutdown()
+
+	conf := clientConfig()
+	k := conf.HostKeyChecker.(*storedHostKey)
+
+	// change the key.
+	k.keys["ssh-rsa"][25]++
+
+	conn, err := server.TryDial(conf)
+	if err == nil {
+		conn.Close()
+		t.Fatalf("dial should have failed.")
+	} else if !strings.Contains(err.Error(), "host key mismatch") {
+		t.Fatalf("'host key mismatch' not found in %v", err)
+	}
+}
+
 func TestRunCommandFailed(t *testing.T) {
 	server := newServer(t)
 	defer server.Shutdown()

+ 43 - 4
ssh/test/test_unix_test.go

@@ -55,14 +55,25 @@ HostbasedAuthentication no
 `
 
 var (
-	configTmpl template.Template
-	rsakey     *rsa.PrivateKey
+	configTmpl        template.Template
+	rsakey            *rsa.PrivateKey
+	serializedHostKey []byte
 )
 
 func init() {
 	template.Must(configTmpl.Parse(sshd_config))
 	block, _ := pem.Decode([]byte(testClientPrivateKey))
 	rsakey, _ = x509.ParsePKCS1PrivateKey(block.Bytes)
+
+	block, _ = pem.Decode([]byte(keys["ssh_host_rsa_key"]))
+	if block == nil {
+		panic("pem.Decode ssh_host_rsa_key")
+	}
+	priv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
+	if err != nil {
+		panic("ParsePKCS1PrivateKey: " + err.Error())
+	}
+	serializedHostKey = ssh.MarshalPublicKey(&priv.PublicKey)
 }
 
 type server struct {
@@ -89,7 +100,29 @@ func username() string {
 	return username
 }
 
+type storedHostKey struct {
+	// keys map from an algorithm string to binary key data.
+	keys map[string][]byte
+}
+
+func (k *storedHostKey) Add(algo string, public []byte) {
+	if k.keys == nil {
+		k.keys = map[string][]byte{}
+	}
+	k.keys[algo] = append([]byte(nil), public...)
+}
+
+func (k *storedHostKey) Check(addr string, remote net.Addr, algo string, key []byte) error {
+	if k.keys == nil || bytes.Compare(key, k.keys[algo]) != 0 {
+		return errors.New("host key mismatch")
+	}
+	return nil
+}
+
 func clientConfig() *ssh.ClientConfig {
+	keyChecker := storedHostKey{}
+	keyChecker.Add("ssh-rsa", serializedHostKey)
+
 	kc := new(keychain)
 	kc.keys = append(kc.keys, rsakey)
 	config := &ssh.ClientConfig{
@@ -97,11 +130,12 @@ func clientConfig() *ssh.ClientConfig {
 		Auth: []ssh.ClientAuth{
 			ssh.ClientAuthKeyring(kc),
 		},
+		HostKeyChecker: &keyChecker,
 	}
 	return config
 }
 
-func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn {
+func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.ClientConn, error) {
 	sshd, err := exec.LookPath("sshd")
 	if err != nil {
 		s.t.Skipf("skipping test: %v", err)
@@ -123,7 +157,12 @@ func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn {
 		s.Shutdown()
 		s.t.Fatalf("s.cmd.Start: %v", err)
 	}
-	conn, err := ssh.Client(&client{wc: w2, r: r1}, config)
+
+	return ssh.Client(&client{wc: w2, r: r1}, config)
+}
+
+func (s *server) Dial(config *ssh.ClientConfig) *ssh.ClientConn {
+	conn, err := s.TryDial(config)
 	if err != nil {
 		s.t.Fail()
 		s.Shutdown()