Browse Source

go.crypto/ssh: implement ECDH.

Implement elliptic-curve Diffie-Hellman, including host key signature
verification.

Moves host key cryptographic verification to ClientConn.handshake(), so
RSA host keys are also verified.

Fixes golang/go#6158.

R=dave, agl
CC=golang-dev
https://golang.org/cl/13021045
Han-Wen Nienhuys 12 years ago
parent
commit
d7d50b0a7c
8 changed files with 534 additions and 80 deletions
  1. 15 9
      ssh/certs.go
  2. 9 0
      ssh/cipher.go
  3. 174 28
      ssh/client.go
  4. 20 0
      ssh/client_auth_test.go
  5. 35 2
      ssh/common.go
  6. 90 0
      ssh/kex_test.go
  7. 13 0
      ssh/messages.go
  8. 178 41
      ssh/server.go

+ 15 - 9
ssh/certs.go

@@ -291,22 +291,28 @@ func marshalSignature(to []byte, sig *signature) []byte {
 	return to
 	return to
 }
 }
 
 
-func parseSignature(in []byte) (out *signature, rest []byte, ok bool) {
-	var sigBytes, format []byte
-	sig := new(signature)
-
-	if sigBytes, rest, ok = parseString(in); !ok {
+func parseSignatureBody(in []byte) (out *signature, rest []byte, ok bool) {
+	var format []byte
+	if format, in, ok = parseString(in); !ok {
 		return
 		return
 	}
 	}
 
 
-	if format, sigBytes, ok = parseString(sigBytes); !ok {
+	out = &signature{
+		Format: string(format),
+	}
+
+	if out.Blob, in, ok = parseString(in); !ok {
 		return
 		return
 	}
 	}
-	sig.Format = string(format)
 
 
-	if sig.Blob, sigBytes, ok = parseString(sigBytes); !ok {
+	return out, in, ok
+}
+
+func parseSignature(in []byte) (out *signature, rest []byte, ok bool) {
+	var sigBytes []byte
+	if sigBytes, rest, ok = parseString(in); !ok {
 		return
 		return
 	}
 	}
 
 
-	return sig, rest, ok
+	return parseSignatureBody(sigBytes)
 }
 }

+ 9 - 0
ssh/cipher.go

@@ -89,3 +89,12 @@ var cipherModes = map[string]*cipherMode{
 	"arcfour128": {16, 0, 1536, newRC4},
 	"arcfour128": {16, 0, 1536, newRC4},
 	"arcfour256": {32, 0, 1536, newRC4},
 	"arcfour256": {32, 0, 1536, newRC4},
 }
 }
+
+// defaultKeyExchangeOrder specifies a default set of key exchange algorithms
+// with preferences.
+var defaultKeyExchangeOrder = []string{
+	// P384 and P521 are not constant-time yet, but since we don't
+	// reuse ephemeral keys, using them for ECDH should be OK.
+	kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
+	kexAlgoDH14SHA1, kexAlgoDH1SHA1,
+}

+ 174 - 28
ssh/client.go

@@ -6,7 +6,10 @@ package ssh
 
 
 import (
 import (
 	"crypto"
 	"crypto"
+	"crypto/ecdsa"
+	"crypto/elliptic"
 	"crypto/rand"
 	"crypto/rand"
+	"crypto/rsa"
 	"encoding/binary"
 	"encoding/binary"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
@@ -86,7 +89,7 @@ func (c *ClientConn) handshake() error {
 	magics.serverVersion = version
 	magics.serverVersion = version
 	c.serverVersion = string(version)
 	c.serverVersion = string(version)
 	clientKexInit := kexInitMsg{
 	clientKexInit := kexInitMsg{
-		KexAlgos:                supportedKexAlgos,
+		KexAlgos:                c.config.Crypto.kexes(),
 		ServerHostKeyAlgos:      supportedHostKeyAlgos,
 		ServerHostKeyAlgos:      supportedHostKeyAlgos,
 		CiphersClientServer:     c.config.Crypto.ciphers(),
 		CiphersClientServer:     c.config.Crypto.ciphers(),
 		CiphersServerClient:     c.config.Crypto.ciphers(),
 		CiphersServerClient:     c.config.Crypto.ciphers(),
@@ -126,17 +129,20 @@ func (c *ClientConn) handshake() error {
 		}
 		}
 	}
 	}
 
 
-	var H, K []byte
-	var hashFunc crypto.Hash
+	var result *kexResult
 	switch kexAlgo {
 	switch kexAlgo {
+	case kexAlgoECDH256:
+		result, err = c.kexECDH(elliptic.P256(), &magics, hostKeyAlgo)
+	case kexAlgoECDH384:
+		result, err = c.kexECDH(elliptic.P384(), &magics, hostKeyAlgo)
+	case kexAlgoECDH521:
+		result, err = c.kexECDH(elliptic.P521(), &magics, hostKeyAlgo)
 	case kexAlgoDH14SHA1:
 	case kexAlgoDH14SHA1:
-		hashFunc = crypto.SHA1
 		dhGroup14Once.Do(initDHGroup14)
 		dhGroup14Once.Do(initDHGroup14)
-		H, K, err = c.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo)
-	case keyAlgoDH1SHA1:
-		hashFunc = crypto.SHA1
+		result, err = c.kexDH(crypto.SHA1, dhGroup14, &magics, hostKeyAlgo)
+	case kexAlgoDH1SHA1:
 		dhGroup1Once.Do(initDHGroup1)
 		dhGroup1Once.Do(initDHGroup1)
-		H, K, err = c.kexDH(dhGroup1, hashFunc, &magics, hostKeyAlgo)
+		result, err = c.kexDH(crypto.SHA1, dhGroup1, &magics, hostKeyAlgo)
 	default:
 	default:
 		err = fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
 		err = fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
 	}
 	}
@@ -144,10 +150,22 @@ func (c *ClientConn) handshake() error {
 		return err
 		return err
 	}
 	}
 
 
+	err = verifyHostKeySignature(hostKeyAlgo, result.HostKey, result.H, result.Signature)
+	if err != nil {
+		return err
+	}
+
+	if checker := c.config.HostKeyChecker; checker != nil {
+		err = checker.Check(c.dialAddress, c.RemoteAddr(), hostKeyAlgo, result.HostKey)
+		if err != nil {
+			return err
+		}
+	}
+
 	if err = c.writePacket([]byte{msgNewKeys}); err != nil {
 	if err = c.writePacket([]byte{msgNewKeys}); err != nil {
 		return err
 		return err
 	}
 	}
-	if err = c.transport.writer.setupKeys(clientKeys, K, H, H, hashFunc); err != nil {
+	if err = c.transport.writer.setupKeys(clientKeys, result.K, result.H, result.H, result.Hash); err != nil {
 		return err
 		return err
 	}
 	}
 	if packet, err = c.readPacket(); err != nil {
 	if packet, err = c.readPacket(); err != nil {
@@ -156,46 +174,170 @@ func (c *ClientConn) handshake() error {
 	if packet[0] != msgNewKeys {
 	if packet[0] != msgNewKeys {
 		return UnexpectedMessageError{msgNewKeys, packet[0]}
 		return UnexpectedMessageError{msgNewKeys, packet[0]}
 	}
 	}
-	if err := c.transport.reader.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
+	if err := c.transport.reader.setupKeys(serverKeys, result.K, result.H, result.H, result.Hash); err != nil {
 		return err
 		return err
 	}
 	}
-	return c.authenticate(H)
+	return c.authenticate(result.H)
+}
+
+// kexECDH performs Elliptic Curve Diffie-Hellman key exchange as
+// described in RFC 5656, section 4.
+func (c *ClientConn) kexECDH(curve elliptic.Curve, magics *handshakeMagics, hostKeyAlgo string) (*kexResult, error) {
+	ephKey, err := ecdsa.GenerateKey(curve, c.config.rand())
+	if err != nil {
+		return nil, err
+	}
+
+	kexInit := kexECDHInitMsg{
+		ClientPubKey: elliptic.Marshal(curve, ephKey.PublicKey.X, ephKey.PublicKey.Y),
+	}
+
+	serialized := marshal(msgKexECDHInit, kexInit)
+	if err := c.writePacket(serialized); err != nil {
+		return nil, err
+	}
+
+	packet, err := c.readPacket()
+	if err != nil {
+		return nil, err
+	}
+
+	var reply kexECDHReplyMsg
+	if err = unmarshal(&reply, packet, msgKexECDHReply); err != nil {
+		return nil, err
+	}
+
+	x, y := elliptic.Unmarshal(curve, reply.EphemeralPubKey)
+	if x == nil {
+		return nil, errors.New("ssh: elliptic.Unmarshal failure")
+	}
+	if !validateECPublicKey(curve, x, y) {
+		return nil, errors.New("ssh: ephemeral server key not on curve")
+	}
+
+	// generate shared secret
+	secret, _ := curve.ScalarMult(x, y, ephKey.D.Bytes())
+
+	hashFunc := ecHash(curve)
+	h := hashFunc.New()
+	writeString(h, magics.clientVersion)
+	writeString(h, magics.serverVersion)
+	writeString(h, magics.clientKexInit)
+	writeString(h, magics.serverKexInit)
+	writeString(h, reply.HostKey)
+	writeString(h, kexInit.ClientPubKey)
+	writeString(h, reply.EphemeralPubKey)
+	K := make([]byte, intLength(secret))
+	marshalInt(K, secret)
+	h.Write(K)
+
+	return &kexResult{
+		H:         h.Sum(nil),
+		K:         K,
+		HostKey:   reply.HostKey,
+		Signature: reply.Signature,
+		Hash:      hashFunc,
+	}, nil
+}
+
+// Verify the host key obtained in the key exchange.
+func verifyHostKeySignature(hostKeyAlgo string, hostKeyBytes []byte, data []byte, signature []byte) error {
+	hostKey, rest, ok := ParsePublicKey(hostKeyBytes)
+	if len(rest) > 0 || !ok {
+		return errors.New("ssh: could not parse hostkey")
+	}
+
+	// Select hash function to match the hostkey algorithm, as per
+	// RFC 4253, section 6.1 (for RSA/DSS) and RFC 5656, section
+	// 6.2.1 (for ECDSA).
+	var hashFunc crypto.Hash
+	switch hostKeyAlgo {
+	case KeyAlgoRSA:
+		hashFunc = crypto.SHA1
+	case KeyAlgoDSA:
+		hashFunc = crypto.SHA1
+	case KeyAlgoECDSA256:
+		hashFunc = crypto.SHA256
+	case KeyAlgoECDSA384:
+		hashFunc = crypto.SHA384
+	case KeyAlgoECDSA521:
+		hashFunc = crypto.SHA512
+	default:
+		return errors.New("ssh: unknown key algorithm: " + hostKeyAlgo)
+	}
+
+	signed := hashFunc.New()
+	signed.Write(data)
+	digest := signed.Sum(nil)
+
+	sig, rest, ok := parseSignatureBody(signature)
+	if len(rest) > 0 || !ok {
+		return errors.New("ssh: signature parse error")
+	}
+	if sig.Format != hostKeyAlgo {
+		return fmt.Errorf("ssh: unexpected signature type %q", sig.Format)
+	}
+
+	return verifySignature(digest, sig, hostKey)
 }
 }
 
 
-// kexDH performs Diffie-Hellman key agreement on a ClientConn. The
-// returned values are given the same names as in RFC 4253, section 8.
-func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) ([]byte, []byte, error) {
+func verifySignature(hash []byte, sig *signature, key interface{}) error {
+	switch pubKey := key.(type) {
+	case *rsa.PublicKey:
+		return verifyRSASignature(hash, sig, pubKey)
+	}
+	return fmt.Errorf("ssh: unknown key type %T", key)
+}
+
+func verifyRSASignature(hash []byte, sig *signature, key *rsa.PublicKey) error {
+	return rsa.VerifyPKCS1v15(key, crypto.SHA1, hash, sig.Blob)
+}
+
+// kexResult captures the outcome of a key exchange.
+type kexResult struct {
+	// Session hash. See also RFC 4253, section 8.
+	H []byte
+
+	// Shared secret. See also RFC 4253, section 8.
+	K []byte
+
+	// Host key as hashed into H
+	HostKey []byte
+
+	// Signature of H
+	Signature []byte
+
+	// Hash function that was used.
+	Hash crypto.Hash
+}
+
+// kexDH performs Diffie-Hellman key agreement on a ClientConn.
+func (c *ClientConn) kexDH(hashFunc crypto.Hash, group *dhGroup, magics *handshakeMagics, hostKeyAlgo string) (*kexResult, error) {
 	x, err := rand.Int(c.config.rand(), group.p)
 	x, err := rand.Int(c.config.rand(), group.p)
 	if err != nil {
 	if err != nil {
-		return nil, nil, err
+		return nil, err
 	}
 	}
 	X := new(big.Int).Exp(group.g, x, group.p)
 	X := new(big.Int).Exp(group.g, x, group.p)
 	kexDHInit := kexDHInitMsg{
 	kexDHInit := kexDHInitMsg{
 		X: X,
 		X: X,
 	}
 	}
 	if err := c.writePacket(marshal(msgKexDHInit, kexDHInit)); err != nil {
 	if err := c.writePacket(marshal(msgKexDHInit, kexDHInit)); err != nil {
-		return nil, nil, err
+		return nil, err
 	}
 	}
 
 
 	packet, err := c.readPacket()
 	packet, err := c.readPacket()
 	if err != nil {
 	if err != nil {
-		return nil, nil, err
+		return nil, err
 	}
 	}
 
 
 	var kexDHReply kexDHReplyMsg
 	var kexDHReply kexDHReplyMsg
 	if err = unmarshal(&kexDHReply, packet, msgKexDHReply); err != nil {
 	if err = unmarshal(&kexDHReply, packet, msgKexDHReply); err != nil {
-		return nil, nil, err
-	}
-
-	if checker := c.config.HostKeyChecker; checker != nil {
-		if err = checker.Check(c.dialAddress, c.RemoteAddr(), hostKeyAlgo, kexDHReply.HostKey); err != nil {
-			return nil, nil, err
-		}
+		return nil, err
 	}
 	}
 
 
 	kInt, err := group.diffieHellman(kexDHReply.Y, x)
 	kInt, err := group.diffieHellman(kexDHReply.Y, x)
 	if err != nil {
 	if err != nil {
-		return nil, nil, err
+		return nil, err
 	}
 	}
 
 
 	h := hashFunc.New()
 	h := hashFunc.New()
@@ -210,9 +352,13 @@ func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
 	marshalInt(K, kInt)
 	marshalInt(K, kInt)
 	h.Write(K)
 	h.Write(K)
 
 
-	H := h.Sum(nil)
-
-	return H, K, nil
+	return &kexResult{
+		H:         h.Sum(nil),
+		K:         K,
+		HostKey:   kexDHReply.HostKey,
+		Signature: kexDHReply.Signature,
+		Hash:      hashFunc,
+	}, nil
 }
 }
 
 
 // mainLoop reads incoming messages and routes channel messages
 // mainLoop reads incoming messages and routes channel messages

+ 20 - 0
ssh/client_auth_test.go

@@ -16,6 +16,7 @@ import (
 	"io"
 	"io"
 	"io/ioutil"
 	"io/ioutil"
 	"math/big"
 	"math/big"
+	"strings"
 	"testing"
 	"testing"
 )
 )
 
 
@@ -356,3 +357,22 @@ func TestClientUnsupportedCipher(t *testing.T) {
 		c.Close()
 		c.Close()
 	}
 	}
 }
 }
+
+func TestClientUnsupportedKex(t *testing.T) {
+	kc := new(keychain)
+	kc.keys = append(kc.keys, rsakey)
+	config := &ClientConfig{
+		User: "testuser",
+		Auth: []ClientAuth{
+			ClientAuthKeyring(kc),
+		},
+		Crypto: CryptoConfig{
+			KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported
+		},
+	}
+	c, err := Dial("tcp", newMockAuthServer(t), config)
+	if err == nil || !strings.Contains(err.Error(), "no common algorithms") {
+		t.Errorf("got %v, expected 'no common algorithms'", err)
+		c.Close()
+	}
+}

+ 35 - 2
ssh/common.go

@@ -5,8 +5,10 @@
 package ssh
 package ssh
 
 
 import (
 import (
+	"crypto"
 	"crypto/dsa"
 	"crypto/dsa"
 	"crypto/ecdsa"
 	"crypto/ecdsa"
+	"crypto/elliptic"
 	"crypto/rsa"
 	"crypto/rsa"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
@@ -16,8 +18,11 @@ import (
 
 
 // These are string constants in the SSH protocol.
 // These are string constants in the SSH protocol.
 const (
 const (
-	keyAlgoDH1SHA1  = "diffie-hellman-group1-sha1"
+	kexAlgoDH1SHA1  = "diffie-hellman-group1-sha1"
 	kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
 	kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
+	kexAlgoECDH256  = "ecdh-sha2-nistp256"
+	kexAlgoECDH384  = "ecdh-sha2-nistp384"
+	kexAlgoECDH521  = "ecdh-sha2-nistp521"
 	hostAlgoRSA     = "ssh-rsa"
 	hostAlgoRSA     = "ssh-rsa"
 	hostAlgoDSA     = "ssh-dss"
 	hostAlgoDSA     = "ssh-dss"
 	compressionNone = "none"
 	compressionNone = "none"
@@ -25,7 +30,11 @@ const (
 	serviceSSH      = "ssh-connection"
 	serviceSSH      = "ssh-connection"
 )
 )
 
 
-var supportedKexAlgos = []string{kexAlgoDH14SHA1, keyAlgoDH1SHA1}
+var supportedKexAlgos = []string{
+	kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
+	kexAlgoDH14SHA1, kexAlgoDH1SHA1,
+}
+
 var supportedHostKeyAlgos = []string{hostAlgoRSA}
 var supportedHostKeyAlgos = []string{hostAlgoRSA}
 var supportedCompressions = []string{compressionNone}
 var supportedCompressions = []string{compressionNone}
 
 
@@ -165,6 +174,10 @@ func findAgreedAlgorithms(transport *transport, clientKexInit, serverKexInit *ke
 
 
 // Cryptographic configuration common to both ServerConfig and ClientConfig.
 // Cryptographic configuration common to both ServerConfig and ClientConfig.
 type CryptoConfig struct {
 type CryptoConfig struct {
+	// The allowed key exchanges algorithms. If unspecified then a
+	// default set of algorithms is used.
+	KeyExchanges []string
+
 	// The allowed cipher algorithms. If unspecified then DefaultCipherOrder is
 	// The allowed cipher algorithms. If unspecified then DefaultCipherOrder is
 	// used.
 	// used.
 	Ciphers []string
 	Ciphers []string
@@ -180,6 +193,13 @@ func (c *CryptoConfig) ciphers() []string {
 	return c.Ciphers
 	return c.Ciphers
 }
 }
 
 
+func (c *CryptoConfig) kexes() []string {
+	if c.KeyExchanges == nil {
+		return defaultKeyExchangeOrder
+	}
+	return c.KeyExchanges
+}
+
 func (c *CryptoConfig) macs() []string {
 func (c *CryptoConfig) macs() []string {
 	if c.MACs == nil {
 	if c.MACs == nil {
 		return DefaultMACOrder
 		return DefaultMACOrder
@@ -187,6 +207,19 @@ func (c *CryptoConfig) macs() []string {
 	return c.MACs
 	return c.MACs
 }
 }
 
 
+// ecHash returns the hash to match the given elliptic curve, see RFC
+// 5656, section 6.2.1
+func ecHash(curve elliptic.Curve) crypto.Hash {
+	bitSize := curve.Params().BitSize
+	switch {
+	case bitSize <= 256:
+		return crypto.SHA256
+	case bitSize <= 384:
+		return crypto.SHA384
+	}
+	return crypto.SHA512
+}
+
 // serialize a signed slice according to RFC 4254 6.6.
 // serialize a signed slice according to RFC 4254 6.6.
 func serializeSignature(algoname string, sig []byte) []byte {
 func serializeSignature(algoname string, sig []byte) []byte {
 	switch algoname {
 	switch algoname {

+ 90 - 0
ssh/kex_test.go

@@ -0,0 +1,90 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+// Key exchange tests.
+
+import (
+	"fmt"
+	"net"
+	"testing"
+)
+
+func pipe() (net.Conn, net.Conn, error) {
+	l, err := net.Listen("tcp", ":0")
+	if err != nil {
+		return nil, nil, err
+	}
+	conn1, err := net.Dial("tcp", l.Addr().String())
+	if err != nil {
+		return nil, nil, err
+	}
+
+	conn2, err := l.Accept()
+	if err != nil {
+		conn1.Close()
+		return nil, nil, err
+	}
+	l.Close()
+	return conn1, conn2, nil
+}
+
+func testKexAlgorithm(algo string) error {
+	crypto := CryptoConfig{
+		KeyExchanges: []string{algo},
+	}
+	serverConfig := ServerConfig{
+		PasswordCallback: func(conn *ServerConn, user, password string) bool {
+			return password == "password"
+		},
+		Crypto: crypto,
+	}
+
+	if err := serverConfig.SetRSAPrivateKey([]byte(testServerPrivateKey)); err != nil {
+		return fmt.Errorf("SetRSAPrivateKey: %v", err)
+	}
+
+	clientConfig := ClientConfig{
+		User:   "user",
+		Auth:   []ClientAuth{ClientAuthPassword(password("password"))},
+		Crypto: crypto,
+	}
+
+	conn1, conn2, err := pipe()
+	if err != nil {
+		return err
+	}
+
+	defer conn1.Close()
+	defer conn2.Close()
+
+	server := Server(conn2, &serverConfig)
+	serverHS := make(chan error, 1)
+	go func() {
+		serverHS <- server.Handshake()
+	}()
+
+	// Client runs the handshake.
+	_, err = Client(conn1, &clientConfig)
+	if err != nil {
+		return fmt.Errorf("Client: %v", err)
+	}
+
+	if err := <-serverHS; err != nil {
+		return fmt.Errorf("server.Handshake: %v", err)
+	}
+
+	// Here we could check that we now can send data between client &
+	// server.
+	return nil
+}
+
+func TestKexAlgorithms(t *testing.T) {
+	for _, algo := range []string{kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521, kexAlgoDH1SHA1, kexAlgoDH14SHA1} {
+		if err := testKexAlgorithm(algo); err != nil {
+			t.Errorf("algorithm %s: %v", algo, err)
+		}
+	}
+}

+ 13 - 0
ssh/messages.go

@@ -29,6 +29,9 @@ const (
 	msgKexDHInit  = 30
 	msgKexDHInit  = 30
 	msgKexDHReply = 31
 	msgKexDHReply = 31
 
 
+	msgKexECDHInit  = 30
+	msgKexECDHReply = 31
+
 	// Standard authentication messages
 	// Standard authentication messages
 	msgUserAuthRequest  = 50
 	msgUserAuthRequest  = 50
 	msgUserAuthFailure  = 51
 	msgUserAuthFailure  = 51
@@ -94,6 +97,16 @@ type kexDHInitMsg struct {
 	X *big.Int
 	X *big.Int
 }
 }
 
 
+type kexECDHInitMsg struct {
+	ClientPubKey []byte
+}
+
+type kexECDHReplyMsg struct {
+	HostKey         []byte
+	EphemeralPubKey []byte
+	Signature       []byte
+}
+
 type kexDHReplyMsg struct {
 type kexDHReplyMsg struct {
 	HostKey   []byte
 	HostKey   []byte
 	Y         *big.Int
 	Y         *big.Int

+ 178 - 41
ssh/server.go

@@ -7,6 +7,8 @@ package ssh
 import (
 import (
 	"bytes"
 	"bytes"
 	"crypto"
 	"crypto"
+	"crypto/ecdsa"
+	"crypto/elliptic"
 	"crypto/rand"
 	"crypto/rand"
 	"crypto/rsa"
 	"crypto/rsa"
 	"crypto/x509"
 	"crypto/x509"
@@ -17,6 +19,9 @@ import (
 	"math/big"
 	"math/big"
 	"net"
 	"net"
 	"sync"
 	"sync"
+
+	_ "crypto/sha256"
+	_ "crypto/sha512"
 )
 )
 
 
 type ServerConfig struct {
 type ServerConfig struct {
@@ -145,9 +150,126 @@ func Server(c net.Conn, config *ServerConfig) *ServerConn {
 	}
 	}
 }
 }
 
 
-// kexDH performs Diffie-Hellman key agreement on a ServerConnection. The
-// returned values are given the same names as in RFC 4253, section 8.
-func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err error) {
+// kexECDH performs Elliptic Curve Diffie-Hellman key agreement on a
+// ServerConnection, as documented in RFC 5656, section 4.
+func (s *ServerConn) kexECDH(curve elliptic.Curve, magics *handshakeMagics, hostKeyAlgo string) (result *kexResult, err error) {
+	packet, err := s.readPacket()
+	if err != nil {
+		return
+	}
+
+	var kexECDHInit kexECDHInitMsg
+	if err = unmarshal(&kexECDHInit, packet, msgKexECDHInit); err != nil {
+		return
+	}
+
+	clientX, clientY := elliptic.Unmarshal(curve, kexECDHInit.ClientPubKey)
+	if clientX == nil {
+		return nil, errors.New("ssh: elliptic.Unmarshal failure")
+	}
+
+	if !validateECPublicKey(curve, clientX, clientY) {
+		return nil, errors.New("ssh: not a valid EC public key")
+	}
+
+	// We could cache this key across multiple users/multiple
+	// connection attempts, but the benefit is small. OpenSSH
+	// generates a new key for each incoming connection.
+	ephKey, err := ecdsa.GenerateKey(curve, s.config.rand())
+	if err != nil {
+		return nil, err
+	}
+
+	hostKey, err := s.serializedHostKey(hostKeyAlgo)
+	if err != nil {
+		return nil, err
+	}
+
+	serializedEphKey := elliptic.Marshal(curve, ephKey.PublicKey.X, ephKey.PublicKey.Y)
+
+	// generate shared secret
+	secret, _ := curve.ScalarMult(clientX, clientY, ephKey.D.Bytes())
+
+	hashFunc := ecHash(curve)
+	h := hashFunc.New()
+	writeString(h, magics.clientVersion)
+	writeString(h, magics.serverVersion)
+	writeString(h, magics.clientKexInit)
+	writeString(h, magics.serverKexInit)
+	writeString(h, hostKey)
+	writeString(h, kexECDHInit.ClientPubKey)
+	writeString(h, serializedEphKey)
+
+	K := make([]byte, intLength(secret))
+	marshalInt(K, secret)
+	h.Write(K)
+
+	H := h.Sum(nil)
+
+	serializedSig, err := s.serializedHostKeySignature(hostKeyAlgo, H)
+	if err != nil {
+		return nil, err
+	}
+
+	reply := kexECDHReplyMsg{
+		EphemeralPubKey: serializedEphKey,
+		HostKey:         hostKey,
+		Signature:       serializedSig,
+	}
+
+	serialized := marshal(msgKexECDHReply, reply)
+	if err := s.writePacket(serialized); err != nil {
+		return nil, err
+	}
+
+	return &kexResult{
+		H:       H,
+		K:       K,
+		HostKey: reply.HostKey,
+		Hash:    hashFunc,
+	}, nil
+}
+
+func (s *ServerConn) serializedHostKey(hostKeyAlgo string) ([]byte, error) {
+	switch hostKeyAlgo {
+	case hostAlgoRSA:
+		return s.config.rsaSerialized, nil
+	}
+	return nil, errors.New("ssh: internal error")
+}
+
+// validateECPublicKey checks that the point is a valid public key for
+// the given curve. See [SEC1], 3.2.2
+func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool {
+	if x.Sign() == 0 && y.Sign() == 0 {
+		return false
+	}
+
+	if x.Cmp(curve.Params().P) >= 0 {
+		return false
+	}
+
+	if y.Cmp(curve.Params().P) >= 0 {
+		return false
+	}
+
+	if !curve.IsOnCurve(x, y) {
+		return false
+	}
+
+	// We don't check if N * PubKey == 0, since
+	//
+	// - the NIST curves have cofactor = 1, so this is implicit.
+	// (We don't forsee an implementation that supports non NIST
+	// curves)
+	//
+	// - for ephemeral keys, we don't need to worry about small
+	// subgroup attacks.
+	return true
+}
+
+// kexDH performs Diffie-Hellman key agreement on a ServerConnection.
+func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (result *kexResult, err error) {
 	packet, err := s.readPacket()
 	packet, err := s.readPacket()
 	if err != nil {
 	if err != nil {
 		return
 		return
@@ -165,15 +287,12 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
 	Y := new(big.Int).Exp(group.g, y, group.p)
 	Y := new(big.Int).Exp(group.g, y, group.p)
 	kInt, err := group.diffieHellman(kexDHInit.X, y)
 	kInt, err := group.diffieHellman(kexDHInit.X, y)
 	if err != nil {
 	if err != nil {
-		return nil, nil, err
+		return nil, err
 	}
 	}
 
 
-	var serializedHostKey []byte
-	switch hostKeyAlgo {
-	case hostAlgoRSA:
-		serializedHostKey = s.config.rsaSerialized
-	default:
-		return nil, nil, errors.New("ssh: internal error")
+	hostKey, err := s.serializedHostKey(hostKeyAlgo)
+	if err != nil {
+		return nil, err
 	}
 	}
 
 
 	h := hashFunc.New()
 	h := hashFunc.New()
@@ -181,41 +300,56 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
 	writeString(h, magics.serverVersion)
 	writeString(h, magics.serverVersion)
 	writeString(h, magics.clientKexInit)
 	writeString(h, magics.clientKexInit)
 	writeString(h, magics.serverKexInit)
 	writeString(h, magics.serverKexInit)
-	writeString(h, serializedHostKey)
+	writeString(h, hostKey)
 	writeInt(h, kexDHInit.X)
 	writeInt(h, kexDHInit.X)
 	writeInt(h, Y)
 	writeInt(h, Y)
-	K = make([]byte, intLength(kInt))
+
+	K := make([]byte, intLength(kInt))
 	marshalInt(K, kInt)
 	marshalInt(K, kInt)
 	h.Write(K)
 	h.Write(K)
 
 
-	H = h.Sum(nil)
-
-	h.Reset()
-	h.Write(H)
-	hh := h.Sum(nil)
+	H := h.Sum(nil)
 
 
-	var sig []byte
-	switch hostKeyAlgo {
-	case hostAlgoRSA:
-		sig, err = rsa.SignPKCS1v15(s.config.rand(), s.config.rsa, hashFunc, hh)
-		if err != nil {
-			return
-		}
-	default:
-		return nil, nil, errors.New("ssh: internal error")
+	serializedSig, err := s.serializedHostKeySignature(hostKeyAlgo, H)
+	if err != nil {
+		return nil, err
 	}
 	}
 
 
-	serializedSig := serializeSignature(hostKeyAlgo, sig)
-
 	kexDHReply := kexDHReplyMsg{
 	kexDHReply := kexDHReplyMsg{
-		HostKey:   serializedHostKey,
+		HostKey:   hostKey,
 		Y:         Y,
 		Y:         Y,
 		Signature: serializedSig,
 		Signature: serializedSig,
 	}
 	}
 	packet = marshal(msgKexDHReply, kexDHReply)
 	packet = marshal(msgKexDHReply, kexDHReply)
 
 
 	err = s.writePacket(packet)
 	err = s.writePacket(packet)
-	return
+	return &kexResult{
+		H:       H,
+		K:       K,
+		HostKey: hostKey,
+		Hash:    hashFunc,
+	}, nil
+}
+
+// serializedHostKeySignature signs the hashed data, and serializes
+// the signature according to SSH conventions.
+func (s *ServerConn) serializedHostKeySignature(hostKeyAlgo string, hashed []byte) ([]byte, error) {
+	var sig []byte
+	switch hostKeyAlgo {
+	case hostAlgoRSA:
+		hashFunc := crypto.SHA1
+		hh := hashFunc.New()
+		hh.Write(hashed)
+		var err error
+		sig, err = rsa.SignPKCS1v15(s.config.rand(), s.config.rsa, hashFunc, hh.Sum(nil))
+		if err != nil {
+			return nil, err
+		}
+	default:
+		return nil, errors.New("ssh: internal error")
+	}
+
+	return serializeSignature(hostKeyAlgo, sig), nil
 }
 }
 
 
 // serverVersion is the fixed identification string that Server will use.
 // serverVersion is the fixed identification string that Server will use.
@@ -264,7 +398,7 @@ func (s *ServerConn) Handshake() (err error) {
 
 
 func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexInitPacket []byte) (err error) {
 func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexInitPacket []byte) (err error) {
 	serverKexInit := kexInitMsg{
 	serverKexInit := kexInitMsg{
-		KexAlgos:                supportedKexAlgos,
+		KexAlgos:                s.config.Crypto.kexes(),
 		ServerHostKeyAlgos:      supportedHostKeyAlgos,
 		ServerHostKeyAlgos:      supportedHostKeyAlgos,
 		CiphersClientServer:     s.config.Crypto.ciphers(),
 		CiphersClientServer:     s.config.Crypto.ciphers(),
 		CiphersServerClient:     s.config.Crypto.ciphers(),
 		CiphersServerClient:     s.config.Crypto.ciphers(),
@@ -308,17 +442,20 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
 	magics.serverKexInit = marshal(msgKexInit, serverKexInit)
 	magics.serverKexInit = marshal(msgKexInit, serverKexInit)
 	magics.clientKexInit = clientKexInitPacket
 	magics.clientKexInit = clientKexInitPacket
 
 
-	var H, K []byte
-	var hashFunc crypto.Hash
+	var result *kexResult
 	switch kexAlgo {
 	switch kexAlgo {
+	case kexAlgoECDH256:
+		result, err = s.kexECDH(elliptic.P256(), &magics, hostKeyAlgo)
+	case kexAlgoECDH384:
+		result, err = s.kexECDH(elliptic.P384(), &magics, hostKeyAlgo)
+	case kexAlgoECDH521:
+		result, err = s.kexECDH(elliptic.P521(), &magics, hostKeyAlgo)
 	case kexAlgoDH14SHA1:
 	case kexAlgoDH14SHA1:
-		hashFunc = crypto.SHA1
 		dhGroup14Once.Do(initDHGroup14)
 		dhGroup14Once.Do(initDHGroup14)
-		H, K, err = s.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo)
-	case keyAlgoDH1SHA1:
-		hashFunc = crypto.SHA1
+		result, err = s.kexDH(dhGroup14, crypto.SHA1, &magics, hostKeyAlgo)
+	case kexAlgoDH1SHA1:
 		dhGroup1Once.Do(initDHGroup1)
 		dhGroup1Once.Do(initDHGroup1)
-		H, K, err = s.kexDH(dhGroup1, hashFunc, &magics, hostKeyAlgo)
+		result, err = s.kexDH(dhGroup1, crypto.SHA1, &magics, hostKeyAlgo)
 	default:
 	default:
 		err = errors.New("ssh: unexpected key exchange algorithm " + kexAlgo)
 		err = errors.New("ssh: unexpected key exchange algorithm " + kexAlgo)
 	}
 	}
@@ -327,7 +464,7 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
 	}
 	}
 	// sessionId must only be assigned during initial handshake.
 	// sessionId must only be assigned during initial handshake.
 	if s.sessionId == nil {
 	if s.sessionId == nil {
-		s.sessionId = H
+		s.sessionId = result.H
 	}
 	}
 
 
 	var packet []byte
 	var packet []byte
@@ -335,7 +472,7 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
 	if err = s.writePacket([]byte{msgNewKeys}); err != nil {
 	if err = s.writePacket([]byte{msgNewKeys}); err != nil {
 		return
 		return
 	}
 	}
-	if err = s.transport.writer.setupKeys(serverKeys, K, H, s.sessionId, hashFunc); err != nil {
+	if err = s.transport.writer.setupKeys(serverKeys, result.K, result.H, s.sessionId, result.Hash); err != nil {
 		return
 		return
 	}
 	}
 
 
@@ -345,7 +482,7 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
 	if packet[0] != msgNewKeys {
 	if packet[0] != msgNewKeys {
 		return UnexpectedMessageError{msgNewKeys, packet[0]}
 		return UnexpectedMessageError{msgNewKeys, packet[0]}
 	}
 	}
-	if err = s.transport.reader.setupKeys(clientKeys, K, H, s.sessionId, hashFunc); err != nil {
+	if err = s.transport.reader.setupKeys(clientKeys, result.K, result.H, s.sessionId, result.Hash); err != nil {
 		return
 		return
 	}
 	}