Explorar o código

go.crypto/ssh: introduce PublicKey interface type.

Public functions affected:
-AgentKey.Key
-AgentClient.SignRequest
-ClientKeyring.Key
-MarshalPublicKey
-ParsePublicKey

R=agl, jpsugar, jmpittman
CC=golang-dev
https://golang.org/cl/13642043
Han-Wen Nienhuys %!s(int64=12) %!d(string=hai) anos
pai
achega
e62b2aead4
Modificáronse 12 ficheiros con 448 adicións e 439 borrados
  1. 3 3
      ssh/agent.go
  2. 40 41
      ssh/certs.go
  3. 4 13
      ssh/client.go
  4. 14 15
      ssh/client_auth.go
  5. 7 6
      ssh/client_auth_test.go
  6. 15 119
      ssh/common.go
  7. 0 51
      ssh/common_test.go
  8. 282 146
      ssh/keys.go
  9. 60 0
      ssh/keys_test.go
  10. 16 38
      ssh/server.go
  11. 3 3
      ssh/test/keys_test.go
  12. 4 4
      ssh/test/test_unix_test.go

+ 3 - 3
ssh/agent.go

@@ -99,7 +99,7 @@ func (ak *AgentKey) String() string {
 }
 
 // Key returns an agent's public key as one of the supported key or certificate types.
-func (ak *AgentKey) Key() (interface{}, error) {
+func (ak *AgentKey) Key() (PublicKey, error) {
 	if key, _, ok := parsePubKey(ak.blob); ok {
 		return key, nil
 	}
@@ -204,9 +204,9 @@ func (ac *AgentClient) RequestIdentities() ([]*AgentKey, error) {
 
 // SignRequest requests the signing of data by the agent using a protocol 2 key
 // as defined in [PROTOCOL.agent] section 2.6.2.
-func (ac *AgentClient) SignRequest(key interface{}, data []byte) ([]byte, error) {
+func (ac *AgentClient) SignRequest(key PublicKey, data []byte) ([]byte, error) {
 	req := marshal(agentSignRequest, signRequestAgentMsg{
-		KeyBlob: serializePublicKey(key),
+		KeyBlob: MarshalPublicKey(key),
 		Data:    data,
 	})
 

+ 40 - 41
ssh/certs.go

@@ -5,9 +5,6 @@
 package ssh
 
 import (
-	"crypto/dsa"
-	"crypto/ecdsa"
-	"crypto/rsa"
 	"time"
 )
 
@@ -42,7 +39,7 @@ type tuple struct {
 // [PROTOCOL.certkeys]?rev=1.8.
 type OpenSSHCertV01 struct {
 	Nonce                   []byte
-	Key                     interface{} // rsa, dsa, or ecdsa *PublicKey
+	Key                     PublicKey
 	Serial                  uint64
 	Type                    uint32
 	KeyId                   string
@@ -51,10 +48,38 @@ type OpenSSHCertV01 struct {
 	CriticalOptions         []tuple
 	Extensions              []tuple
 	Reserved                []byte
-	SignatureKey            interface{} // rsa, dsa, or ecdsa *PublicKey
+	SignatureKey            PublicKey
 	Signature               *signature
 }
 
+var certAlgoNames = map[string]string{
+	KeyAlgoRSA:      CertAlgoRSAv01,
+	KeyAlgoDSA:      CertAlgoDSAv01,
+	KeyAlgoECDSA256: CertAlgoECDSA256v01,
+	KeyAlgoECDSA384: CertAlgoECDSA384v01,
+	KeyAlgoECDSA521: CertAlgoECDSA521v01,
+}
+
+func (c *OpenSSHCertV01) PublicKeyAlgo() string {
+	algo, ok := certAlgoNames[c.Key.PublicKeyAlgo()]
+	if !ok {
+		panic("unknown cert key type")
+	}
+	return algo
+}
+
+func (c *OpenSSHCertV01) RawKey() interface{} {
+	return c.Key.RawKey()
+}
+
+func (c *OpenSSHCertV01) PrivateKeyAlgo() string {
+	return c.Key.PrivateKeyAlgo()
+}
+
+func (c *OpenSSHCertV01) Verify(data []byte, sig []byte) bool {
+	return c.Key.Verify(data, sig)
+}
+
 func parseOpenSSHCertV01(in []byte, algo string) (out *OpenSSHCertV01, rest []byte, ok bool) {
 	cert := new(OpenSSHCertV01)
 
@@ -62,26 +87,12 @@ func parseOpenSSHCertV01(in []byte, algo string) (out *OpenSSHCertV01, rest []by
 		return
 	}
 
-	switch algo {
-	case CertAlgoRSAv01:
-		var rsaPubKey *rsa.PublicKey
-		if rsaPubKey, in, ok = parseRSA(in); !ok {
-			return
-		}
-		cert.Key = rsaPubKey
-	case CertAlgoDSAv01:
-		var dsaPubKey *dsa.PublicKey
-		if dsaPubKey, in, ok = parseDSA(in); !ok {
-			return
-		}
-		cert.Key = dsaPubKey
-	case CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01:
-		var ecdsaPubKey *ecdsa.PublicKey
-		if ecdsaPubKey, in, ok = parseECDSA(in); !ok {
-			return
-		}
-		cert.Key = ecdsaPubKey
-	default:
+	cert.Key, in, ok = ParsePublicKey(in)
+	if !ok {
+		return
+	}
+
+	if cert.Key.PrivateKeyAlgo() != algo {
 		ok = false
 		return
 	}
@@ -144,23 +155,10 @@ func parseOpenSSHCertV01(in []byte, algo string) (out *OpenSSHCertV01, rest []by
 	return cert, in, ok
 }
 
-func marshalOpenSSHCertV01(cert *OpenSSHCertV01) []byte {
-	var pubKey []byte
-	switch cert.Key.(type) {
-	case *rsa.PublicKey:
-		k := cert.Key.(*rsa.PublicKey)
-		pubKey = marshalPubRSA(k)
-	case *dsa.PublicKey:
-		k := cert.Key.(*dsa.PublicKey)
-		pubKey = marshalPubDSA(k)
-	case *ecdsa.PublicKey:
-		k := cert.Key.(*ecdsa.PublicKey)
-		pubKey = marshalPubECDSA(k)
-	default:
-		panic("ssh: unknown public key type in cert")
-	}
+func (cert *OpenSSHCertV01) Marshal() []byte {
+	pubKey := MarshalPublicKey(cert.Key)
 
-	sigKey := serializePublicKey(cert.SignatureKey)
+	sigKey := MarshalPublicKey(cert.SignatureKey)
 
 	length := stringLength(len(cert.Nonce))
 	length += len(pubKey)
@@ -314,5 +312,6 @@ func parseSignature(in []byte) (out *signature, rest []byte, ok bool) {
 		return
 	}
 
+	// TODO(hanwen): this is a bug; 'rest' gets swallowed.
 	return parseSignatureBody(sigBytes)
 }

+ 4 - 13
ssh/client.go

@@ -246,18 +246,6 @@ func verifyHostKeySignature(hostKeyAlgo string, hostKeyBytes []byte, data []byte
 		return errors.New("ssh: could not parse hostkey")
 	}
 
-	// Select hash function to match the hostkey algorithm, as per
-	// RFC 4253, section 6.6 (for RSA/DSS) and RFC 5656, section
-	// 6.2.1 (for ECDSA).
-	hashFunc, ok := hashFuncs[hostKeyAlgo]
-	if !ok {
-		return errors.New("ssh: unknown key algorithm: " + hostKeyAlgo)
-	}
-
-	signed := hashFunc.New()
-	signed.Write(data)
-	digest := signed.Sum(nil)
-
 	sig, rest, ok := parseSignatureBody(signature)
 	if len(rest) > 0 || !ok {
 		return errors.New("ssh: signature parse error")
@@ -266,7 +254,10 @@ func verifyHostKeySignature(hostKeyAlgo string, hostKeyBytes []byte, data []byte
 		return fmt.Errorf("ssh: unexpected signature type %q", sig.Format)
 	}
 
-	return verifySignature(digest, sig, hostKey)
+	if !hostKey.Verify(data, sig.Blob) {
+		return errors.New("ssh: host key signature error")
+	}
+	return nil
 }
 
 // kexResult captures the outcome of a key exchange.

+ 14 - 15
ssh/client_auth.go

@@ -155,9 +155,8 @@ func ClientAuthPassword(impl ClientPassword) ClientAuth {
 
 // ClientKeyring implements access to a client key ring.
 type ClientKeyring interface {
-	// Key returns the i'th *rsa.Publickey or *dsa.Publickey, or nil if
-	// no key exists at i.
-	Key(i int) (key interface{}, err error)
+	// Key returns the i'th Publickey, or nil if no key exists at i.
+	Key(i int) (key PublicKey, err error)
 
 	// Sign returns a signature of the given data using the i'th key
 	// and the supplied random source.
@@ -190,7 +189,7 @@ func (p *publickeyAuth) auth(session []byte, user string, t *transport, rand io.
 
 	var index int
 	// a map of public keys to their index in the keyring
-	validKeys := make(map[int]interface{})
+	validKeys := make(map[int]PublicKey)
 	for {
 		key, err := p.Key(index)
 		if err != nil {
@@ -214,8 +213,8 @@ func (p *publickeyAuth) auth(session []byte, user string, t *transport, rand io.
 	// methods that may continue if this auth is not successful.
 	var methods []string
 	for i, key := range validKeys {
-		pubkey := serializePublicKey(key)
-		algoname := algoName(key)
+		pubkey := MarshalPublicKey(key)
+		algoname := key.PublicKeyAlgo()
 		sign, err := p.Sign(i, rand, buildDataSignedForAuth(session, userAuthRequestMsg{
 			User:    user,
 			Service: serviceSSH,
@@ -225,7 +224,7 @@ func (p *publickeyAuth) auth(session []byte, user string, t *transport, rand io.
 			return false, nil, err
 		}
 		// manually wrap the serialized signature in a string
-		s := serializeSignature(algoname, sign)
+		s := serializeSignature(key.PublicKeyAlgo(), sign)
 		sig := make([]byte, stringLength(len(s)))
 		marshalString(sig, s)
 		msg := publickeyAuthMsg{
@@ -253,9 +252,9 @@ func (p *publickeyAuth) auth(session []byte, user string, t *transport, rand io.
 }
 
 // validateKey validates the key provided it is acceptable to the server.
-func (p *publickeyAuth) validateKey(key interface{}, user string, t *transport) (bool, error) {
-	pubkey := serializePublicKey(key)
-	algoname := algoName(key)
+func (p *publickeyAuth) validateKey(key PublicKey, user string, t *transport) (bool, error) {
+	pubkey := MarshalPublicKey(key)
+	algoname := key.PublicKeyAlgo()
 	msg := publickeyAuthMsg{
 		User:     user,
 		Service:  serviceSSH,
@@ -271,9 +270,9 @@ func (p *publickeyAuth) validateKey(key interface{}, user string, t *transport)
 	return p.confirmKeyAck(key, t)
 }
 
-func (p *publickeyAuth) confirmKeyAck(key interface{}, t *transport) (bool, error) {
-	pubkey := serializePublicKey(key)
-	algoname := algoName(key)
+func (p *publickeyAuth) confirmKeyAck(key PublicKey, t *transport) (bool, error) {
+	pubkey := MarshalPublicKey(key)
+	algoname := key.PublicKeyAlgo()
 
 	for {
 		packet, err := t.readPacket()
@@ -352,7 +351,7 @@ type agentKeyring struct {
 	keys  []*AgentKey
 }
 
-func (kr *agentKeyring) Key(i int) (key interface{}, err error) {
+func (kr *agentKeyring) Key(i int) (key PublicKey, err error) {
 	if kr.keys == nil {
 		if kr.keys, err = kr.agent.RequestIdentities(); err != nil {
 			return
@@ -365,7 +364,7 @@ func (kr *agentKeyring) Key(i int) (key interface{}, err error) {
 }
 
 func (kr *agentKeyring) Sign(i int, rand io.Reader, data []byte) (sig []byte, err error) {
-	var key interface{}
+	var key PublicKey
 	if key, err = kr.Key(i); err != nil {
 		return
 	}

+ 7 - 6
ssh/client_auth_test.go

@@ -65,15 +65,15 @@ type keychain struct {
 	keys []interface{}
 }
 
-func (k *keychain) Key(i int) (interface{}, error) {
+func (k *keychain) Key(i int) (PublicKey, error) {
 	if i < 0 || i >= len(k.keys) {
 		return nil, nil
 	}
 	switch key := k.keys[i].(type) {
 	case *rsa.PrivateKey:
-		return &key.PublicKey, nil
+		return NewRSAPublicKey(&key.PublicKey), nil
 	case *dsa.PrivateKey:
-		return &key.PublicKey, nil
+		return NewDSAPublicKey(&key.PublicKey), nil
 	}
 	panic("unknown key type")
 }
@@ -135,9 +135,10 @@ var (
 			return user == "testuser" && pass == string(clientPassword)
 		},
 		PublicKeyCallback: func(conn *ServerConn, user, algo string, pubkey []byte) bool {
-			key := &clientKeychain.keys[0].(*rsa.PrivateKey).PublicKey
-			expected := []byte(serializePublicKey(key))
-			algoname := algoName(key)
+			rsaKey := &clientKeychain.keys[0].(*rsa.PrivateKey).PublicKey
+			key := NewRSAPublicKey(rsaKey)
+			expected := MarshalPublicKey(key)
+			algoname := key.PublicKeyAlgo()
 			return user == "testuser" && algo == algoname && bytes.Equal(pubkey, expected)
 		},
 		KeyboardInteractiveCallback: func(conn *ServerConn, user string, client ClientKeyboardInteractive) bool {

+ 15 - 119
ssh/common.go

@@ -6,10 +6,7 @@ package ssh
 
 import (
 	"crypto"
-	"crypto/dsa"
-	"crypto/ecdsa"
 	"crypto/elliptic"
-	"crypto/rsa"
 	"errors"
 	"fmt"
 	"math/big"
@@ -239,137 +236,36 @@ func ecHash(curve elliptic.Curve) crypto.Hash {
 	return crypto.SHA512
 }
 
-// serialize a signed slice according to RFC 4254 6.6.
-func serializeSignature(algoname string, sig []byte) []byte {
-	// The corresponding private key to a public certificate is always a normal
-	// private key.  For signature serialization purposes, ensure we use the
-	// proper key algorithm name in case the public cert algorithm name is passed.
-	algoname = pubAlgoToPrivAlgo(algoname)
-
-	length := stringLength(len(algoname))
+// serialize a signed slice according to RFC 4254 6.6. The name should
+// be a key type name, rather than a cert type name.
+func serializeSignature(name string, sig []byte) []byte {
+	length := stringLength(len(name))
 	length += stringLength(len(sig))
 
 	ret := make([]byte, length)
-	r := marshalString(ret, []byte(algoname))
+	r := marshalString(ret, []byte(name))
 	r = marshalString(r, sig)
 
 	return ret
 }
 
-func verifySignature(hash []byte, sig *signature, key interface{}) error {
-	switch pubKey := key.(type) {
-	case *rsa.PublicKey:
-		return verifyRSASignature(hash, sig, pubKey)
-	case *dsa.PublicKey:
-		return verifyDSASignature(hash, sig, pubKey)
-	case *ecdsa.PublicKey:
-		return verifyECDSASignature(hash, sig, pubKey)
-	case *OpenSSHCertV01:
-		return verifySignature(hash, sig, pubKey.Key)
-	}
-	return fmt.Errorf("ssh: unknown key type %T", key)
-}
-
-func verifyRSASignature(hash []byte, sig *signature, key *rsa.PublicKey) error {
-	return rsa.VerifyPKCS1v15(key, crypto.SHA1, hash, sig.Blob)
-}
-
-func verifyDSASignature(hash []byte, sig *signature, key *dsa.PublicKey) error {
-	// Per RFC 4253, section 6.6,
-	// The value for 'dss_signature_blob' is encoded as a string containing
-	// r, followed by s (which are 160-bit integers, without lengths or
-	// padding, unsigned, and in network byte order).
-	// For DSS purposes, sig.Blob should be exactly 40 bytes in length.
-	if len(sig.Blob) != 40 {
-		return fmt.Errorf("ssh: improper dss signature length of %d", len(sig.Blob))
-	}
-	r := new(big.Int).SetBytes(sig.Blob[:20])
-	s := new(big.Int).SetBytes(sig.Blob[20:])
-	if !dsa.Verify(key, hash, r, s) {
-		return errors.New("ssh: unable to verify dsa signature")
-	}
-	return nil
-}
-
-func verifyECDSASignature(hash []byte, sig *signature, key *ecdsa.PublicKey) error {
-	// Per RFC 5656, section 3.1.2,
-	// The ecdsa_signature_blob value has the following specific encoding:
-	//    mpint    r
-	//    mpint    s
-	r, rest, ok := parseInt(sig.Blob)
-	if !ok {
-		return errors.New("ssh: ecdsa signature blob parse failed")
-	}
-	s, rest, ok := parseInt(rest)
-	if !ok || len(rest) > 0 {
-		return errors.New("ssh: ecdsa signature blob parse failed")
-	}
-	if !ecdsa.Verify(key, hash, r, s) {
-		return errors.New("ssh: unable to verify ecdsa signature")
-	}
-	return nil
-}
-
-// serialize a *rsa.PublicKey or *dsa.PublicKey according to RFC 4253 6.6.
-func serializePublicKey(key interface{}) []byte {
-	var pubKeyBytes []byte
-	algoname := algoName(key)
-	switch key := key.(type) {
-	case *rsa.PublicKey:
-		pubKeyBytes = marshalPubRSA(key)
-	case *dsa.PublicKey:
-		pubKeyBytes = marshalPubDSA(key)
-	case *ecdsa.PublicKey:
-		pubKeyBytes = marshalPubECDSA(key)
-	case *OpenSSHCertV01:
-		pubKeyBytes = marshalOpenSSHCertV01(key)
-	default:
-		panic("unexpected key type")
-	}
+// MarshalPublicKey serializes a supported key or certificate for use
+// by the SSH wire protocol. It can be used for comparison with the
+// pubkey argument of ServerConfig's PublicKeyCallback as well as for
+// generating an authorized_keys or host_keys file.
+func MarshalPublicKey(key PublicKey) []byte {
+	// See also RFC 4253 6.6.
+	algoname := key.PrivateKeyAlgo()
+	blob := key.Marshal()
 
 	length := stringLength(len(algoname))
-	length += len(pubKeyBytes)
+	length += len(blob)
 	ret := make([]byte, length)
 	r := marshalString(ret, []byte(algoname))
-	copy(r, pubKeyBytes)
+	copy(r, blob)
 	return ret
 }
 
-func algoName(key interface{}) string {
-	switch key.(type) {
-	case *rsa.PublicKey:
-		return KeyAlgoRSA
-	case *dsa.PublicKey:
-		return KeyAlgoDSA
-	case *ecdsa.PublicKey:
-		switch key.(*ecdsa.PublicKey).Params().BitSize {
-		case 256:
-			return KeyAlgoECDSA256
-		case 384:
-			return KeyAlgoECDSA384
-		case 521:
-			return KeyAlgoECDSA521
-		}
-	case *OpenSSHCertV01:
-		switch key.(*OpenSSHCertV01).Key.(type) {
-		case *rsa.PublicKey:
-			return CertAlgoRSAv01
-		case *dsa.PublicKey:
-			return CertAlgoDSAv01
-		case *ecdsa.PublicKey:
-			switch key.(*OpenSSHCertV01).Key.(*ecdsa.PublicKey).Params().BitSize {
-			case 256:
-				return CertAlgoECDSA256v01
-			case 384:
-				return CertAlgoECDSA384v01
-			case 521:
-				return CertAlgoECDSA521v01
-			}
-		}
-	}
-	panic(fmt.Sprintf("unexpected key type %T", key))
-}
-
 // pubAlgoToPrivAlgo returns the private key algorithm format name that
 // corresponds to a given public key algorithm format name.  For most
 // public keys, the private key algorithm name is the same.  For some

+ 0 - 51
ssh/common_test.go

@@ -5,11 +5,6 @@
 package ssh
 
 import (
-	"crypto/dsa"
-	"crypto/ecdsa"
-	"crypto/elliptic"
-	"crypto/rsa"
-	"errors"
 	"testing"
 )
 
@@ -29,49 +24,3 @@ func TestSafeString(t *testing.T) {
 		}
 	}
 }
-
-func TestAlgoNameSupported(t *testing.T) {
-	supported := map[string]interface{}{
-		KeyAlgoRSA:          new(rsa.PublicKey),
-		KeyAlgoDSA:          new(dsa.PublicKey),
-		KeyAlgoECDSA256:     &ecdsa.PublicKey{Curve: elliptic.P256()},
-		KeyAlgoECDSA384:     &ecdsa.PublicKey{Curve: elliptic.P384()},
-		KeyAlgoECDSA521:     &ecdsa.PublicKey{Curve: elliptic.P521()},
-		CertAlgoRSAv01:      &OpenSSHCertV01{Key: new(rsa.PublicKey)},
-		CertAlgoDSAv01:      &OpenSSHCertV01{Key: new(dsa.PublicKey)},
-		CertAlgoECDSA256v01: &OpenSSHCertV01{Key: &ecdsa.PublicKey{Curve: elliptic.P256()}},
-		CertAlgoECDSA384v01: &OpenSSHCertV01{Key: &ecdsa.PublicKey{Curve: elliptic.P384()}},
-		CertAlgoECDSA521v01: &OpenSSHCertV01{Key: &ecdsa.PublicKey{Curve: elliptic.P521()}},
-	}
-
-	for expected, key := range supported {
-		actual := algoName(key)
-		if expected != actual {
-			t.Errorf("expected: %s, actual: %s", expected, actual)
-		}
-	}
-
-}
-
-func TestAlgoNameNotSupported(t *testing.T) {
-	notSupported := []interface{}{
-		&ecdsa.PublicKey{Curve: elliptic.P224()},
-		&OpenSSHCertV01{Key: &ecdsa.PublicKey{Curve: elliptic.P224()}},
-	}
-
-	panicTest := func(key interface{}) (algo string, err error) {
-		defer func() {
-			if r := recover(); r != nil {
-				err = errors.New(r.(string))
-			}
-		}()
-		algo = algoName(key)
-		return
-	}
-
-	for _, unsupportedKey := range notSupported {
-		if algo, err := panicTest(unsupportedKey); err == nil {
-			t.Errorf("Expected a panic, Got: %s (for type %T)", algo, unsupportedKey)
-		}
-	}
-}

+ 282 - 146
ssh/keys.go

@@ -6,6 +6,7 @@ package ssh
 
 import (
 	"bytes"
+	"crypto"
 	"crypto/dsa"
 	"crypto/ecdsa"
 	"crypto/elliptic"
@@ -25,7 +26,7 @@ const (
 )
 
 // parsePubKey parses a public key according to RFC 4253, section 6.6.
-func parsePubKey(in []byte) (out interface{}, rest []byte, ok bool) {
+func parsePubKey(in []byte) (pubKey PublicKey, rest []byte, ok bool) {
 	algo, in, ok := parseString(in)
 	if !ok {
 		return
@@ -41,141 +42,7 @@ func parsePubKey(in []byte) (out interface{}, rest []byte, ok bool) {
 	case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01:
 		return parseOpenSSHCertV01(in, string(algo))
 	}
-	panic("ssh: unknown public key type")
-}
-
-// parseRSA parses an RSA key according to RFC 4253, section 6.6.
-func parseRSA(in []byte) (out *rsa.PublicKey, rest []byte, ok bool) {
-	key := new(rsa.PublicKey)
-
-	bigE, in, ok := parseInt(in)
-	if !ok || bigE.BitLen() > 24 {
-		return
-	}
-	e := bigE.Int64()
-	if e < 3 || e&1 == 0 {
-		ok = false
-		return
-	}
-	key.E = int(e)
-
-	if key.N, in, ok = parseInt(in); !ok {
-		return
-	}
-
-	ok = true
-	return key, in, ok
-}
-
-// parseDSA parses an DSA key according to RFC 4253, section 6.6.
-func parseDSA(in []byte) (out *dsa.PublicKey, rest []byte, ok bool) {
-	key := new(dsa.PublicKey)
-
-	if key.P, in, ok = parseInt(in); !ok {
-		return
-	}
-
-	if key.Q, in, ok = parseInt(in); !ok {
-		return
-	}
-
-	if key.G, in, ok = parseInt(in); !ok {
-		return
-	}
-
-	if key.Y, in, ok = parseInt(in); !ok {
-		return
-	}
-
-	ok = true
-	return key, in, ok
-}
-
-// parseECDSA parses an ECDSA key according to RFC 5656, section 3.1.
-func parseECDSA(in []byte) (out *ecdsa.PublicKey, rest []byte, ok bool) {
-	var identifier []byte
-	if identifier, in, ok = parseString(in); !ok {
-		return
-	}
-
-	key := new(ecdsa.PublicKey)
-
-	switch string(identifier) {
-	case "nistp256":
-		key.Curve = elliptic.P256()
-	case "nistp384":
-		key.Curve = elliptic.P384()
-	case "nistp521":
-		key.Curve = elliptic.P521()
-	default:
-		ok = false
-		return
-	}
-
-	var keyBytes []byte
-	if keyBytes, in, ok = parseString(in); !ok {
-		return
-	}
-
-	key.X, key.Y = elliptic.Unmarshal(key.Curve, keyBytes)
-	if key.X == nil || key.Y == nil {
-		ok = false
-		return
-	}
-	return key, in, ok
-}
-
-// marshalPubRSA serializes an RSA public key according to RFC 4253, section 6.6.
-func marshalPubRSA(key *rsa.PublicKey) []byte {
-	e := new(big.Int).SetInt64(int64(key.E))
-	length := intLength(e)
-	length += intLength(key.N)
-
-	ret := make([]byte, length)
-	r := marshalInt(ret, e)
-	r = marshalInt(r, key.N)
-
-	return ret
-}
-
-// marshalPubDSA serializes an DSA public key according to RFC 4253, section 6.6.
-func marshalPubDSA(key *dsa.PublicKey) []byte {
-	length := intLength(key.P)
-	length += intLength(key.Q)
-	length += intLength(key.G)
-	length += intLength(key.Y)
-
-	ret := make([]byte, length)
-	r := marshalInt(ret, key.P)
-	r = marshalInt(r, key.Q)
-	r = marshalInt(r, key.G)
-	r = marshalInt(r, key.Y)
-
-	return ret
-}
-
-// marshalPubECDSA serializes an ECDSA public key according to RFC 5656, section 3.1.
-func marshalPubECDSA(key *ecdsa.PublicKey) []byte {
-	var identifier []byte
-	switch key.Params().BitSize {
-	case 256:
-		identifier = []byte("nistp256")
-	case 384:
-		identifier = []byte("nistp384")
-	case 521:
-		identifier = []byte("nistp521")
-	default:
-		panic("ssh: unsupported ecdsa key size")
-	}
-	keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y)
-
-	length := stringLength(len(identifier))
-	length += stringLength(len(keyBytes))
-
-	ret := make([]byte, length)
-	r := marshalString(ret, identifier)
-	r = marshalString(r, keyBytes)
-	return ret
+	return nil, nil, false
 }
 
 // parseAuthorizedKey parses a public key in OpenSSH authorized_keys format
@@ -307,28 +174,297 @@ func ParseAuthorizedKey(in []byte) (out interface{}, comment string, options []s
 
 // ParsePublicKey parses an SSH public key formatted for use in
 // the SSH wire protocol.
-func ParsePublicKey(in []byte) (out interface{}, rest []byte, ok bool) {
+func ParsePublicKey(in []byte) (out PublicKey, rest []byte, ok bool) {
 	return parsePubKey(in)
 }
 
 // MarshalAuthorizedKey returns a byte stream suitable for inclusion
 // in an OpenSSH authorized_keys file following the format specified
 // in the sshd(8) manual page.
-func MarshalAuthorizedKey(key interface{}) []byte {
+func MarshalAuthorizedKey(key PublicKey) []byte {
 	b := &bytes.Buffer{}
-	b.WriteString(algoName(key))
+	b.WriteString(key.PublicKeyAlgo())
 	b.WriteByte(' ')
 	e := base64.NewEncoder(base64.StdEncoding, b)
-	e.Write(serializePublicKey(key))
+	e.Write(MarshalPublicKey(key))
 	e.Close()
 	b.WriteByte('\n')
 	return b.Bytes()
 }
 
-// MarshalPublicKey serializes a supported key or certificate for use by the
-// SSH wire protocol. It can be used for comparison with the pubkey argument
-// of ServerConfig's PublicKeyCallback as well as for generating an
-// authorized_keys or host_keys file.
-func MarshalPublicKey(key interface{}) []byte {
-	return serializePublicKey(key)
+// PublicKey is an abstraction of different types of public keys.
+type PublicKey interface {
+	// PrivateKeyAlgo returns the name of the encryption system.
+	PrivateKeyAlgo() string
+
+	// PublicKeyAlgo returns the algorithm for the public key,
+	// which may be different from PrivateKeyAlgo for certificates.
+	PublicKeyAlgo() string
+
+	// Marshal returns the serialized key data in SSH wire format,
+	// without the name prefix.  Callers should typically use
+	// MarshalPublicKey().
+	Marshal() []byte
+
+	// Verify that sig is a signature on the given data using this
+	// key. This function will hash the data appropriately first.
+	Verify(data []byte, sigBlob []byte) bool
+
+	// RawKey returns the underlying object, eg. *rsa.PublicKey.
+	RawKey() interface{}
+}
+
+// TODO(hanwen): define PrivateKey too.
+
+type rsaPublicKey rsa.PublicKey
+
+func (r *rsaPublicKey) PrivateKeyAlgo() string {
+	return "ssh-rsa"
+}
+
+func (r *rsaPublicKey) PublicKeyAlgo() string {
+	return "ssh-rsa"
+}
+
+func (r *rsaPublicKey) RawKey() interface{} {
+	return (*rsa.PublicKey)(r)
+}
+
+// parseRSA parses an RSA key according to RFC 4253, section 6.6.
+func parseRSA(in []byte) (out PublicKey, rest []byte, ok bool) {
+	key := new(rsa.PublicKey)
+
+	bigE, in, ok := parseInt(in)
+	if !ok || bigE.BitLen() > 24 {
+		return
+	}
+	e := bigE.Int64()
+	if e < 3 || e&1 == 0 {
+		ok = false
+		return
+	}
+	key.E = int(e)
+
+	if key.N, in, ok = parseInt(in); !ok {
+		return
+	}
+
+	ok = true
+	return NewRSAPublicKey(key), in, ok
+}
+
+func (r *rsaPublicKey) Marshal() []byte {
+	// See RFC 4253, section 6.6.
+	e := new(big.Int).SetInt64(int64(r.E))
+	length := intLength(e)
+	length += intLength(r.N)
+
+	ret := make([]byte, length)
+	rest := marshalInt(ret, e)
+	marshalInt(rest, r.N)
+
+	return ret
+}
+
+func (r *rsaPublicKey) Verify(data []byte, sig []byte) bool {
+	h := crypto.SHA1.New()
+	h.Write(data)
+	digest := h.Sum(nil)
+	return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), crypto.SHA1, digest, sig) == nil
+}
+
+func NewRSAPublicKey(k *rsa.PublicKey) PublicKey {
+	return (*rsaPublicKey)(k)
+}
+
+type dsaPublicKey dsa.PublicKey
+
+func (r *dsaPublicKey) PrivateKeyAlgo() string {
+	return "ssh-dss"
+}
+func (r *dsaPublicKey) PublicKeyAlgo() string {
+	return "ssh-dss"
+}
+func (r *dsaPublicKey) RawKey() interface{} {
+	return (*dsa.PublicKey)(r)
+}
+
+// parseDSA parses an DSA key according to RFC 4253, section 6.6.
+func parseDSA(in []byte) (out PublicKey, rest []byte, ok bool) {
+	key := new(dsa.PublicKey)
+
+	if key.P, in, ok = parseInt(in); !ok {
+		return
+	}
+
+	if key.Q, in, ok = parseInt(in); !ok {
+		return
+	}
+
+	if key.G, in, ok = parseInt(in); !ok {
+		return
+	}
+
+	if key.Y, in, ok = parseInt(in); !ok {
+		return
+	}
+
+	ok = true
+	return NewDSAPublicKey(key), in, ok
+}
+
+func (r *dsaPublicKey) Marshal() []byte {
+	// See RFC 4253, section 6.6.
+	length := intLength(r.P)
+	length += intLength(r.Q)
+	length += intLength(r.G)
+	length += intLength(r.Y)
+
+	ret := make([]byte, length)
+	rest := marshalInt(ret, r.P)
+	rest = marshalInt(rest, r.Q)
+	rest = marshalInt(rest, r.G)
+	marshalInt(rest, r.Y)
+
+	return ret
+}
+
+func (k *dsaPublicKey) Verify(data []byte, sigBlob []byte) bool {
+	h := crypto.SHA1.New()
+	h.Write(data)
+	digest := h.Sum(nil)
+
+	// Per RFC 4253, section 6.6,
+	// The value for 'dss_signature_blob' is encoded as a string containing
+	// r, followed by s (which are 160-bit integers, without lengths or
+	// padding, unsigned, and in network byte order).
+	// For DSS purposes, sig.Blob should be exactly 40 bytes in length.
+	if len(sigBlob) != 40 {
+		return false
+	}
+	r := new(big.Int).SetBytes(sigBlob[:20])
+	s := new(big.Int).SetBytes(sigBlob[20:])
+	return dsa.Verify((*dsa.PublicKey)(k), digest, r, s)
+}
+
+func NewDSAPublicKey(k *dsa.PublicKey) PublicKey {
+	return (*dsaPublicKey)(k)
+}
+
+type ecdsaPublicKey ecdsa.PublicKey
+
+func NewECDSAPublicKey(k *ecdsa.PublicKey) PublicKey {
+	return (*ecdsaPublicKey)(k)
+}
+func (r *ecdsaPublicKey) RawKey() interface{} {
+	return (*ecdsa.PublicKey)(r)
+}
+
+func (key *ecdsaPublicKey) PrivateKeyAlgo() string {
+	return "ecdh-sha2-" + key.nistID()
+}
+
+func (key *ecdsaPublicKey) nistID() string {
+	switch key.Params().BitSize {
+	case 256:
+		return "nistp256"
+	case 384:
+		return "nistp384"
+	case 521:
+		return "nistp521"
+	}
+	panic("ssh: unsupported ecdsa key size")
+}
+
+// RFC 5656, section 6.2.1 (for ECDSA).
+func (key *ecdsaPublicKey) hash() crypto.Hash {
+	switch key.Params().BitSize {
+	case 256:
+		return crypto.SHA256
+	case 384:
+		return crypto.SHA384
+	case 521:
+		return crypto.SHA512
+	}
+	panic("ssh: unsupported ecdsa key size")
+}
+
+func (key *ecdsaPublicKey) PublicKeyAlgo() string {
+	switch key.Params().BitSize {
+	case 256:
+		return KeyAlgoECDSA256
+	case 384:
+		return KeyAlgoECDSA384
+	case 521:
+		return KeyAlgoECDSA521
+	}
+	panic("ssh: unsupported ecdsa key size")
+}
+
+// parseECDSA parses an ECDSA key according to RFC 5656, section 3.1.
+func parseECDSA(in []byte) (out PublicKey, rest []byte, ok bool) {
+	var identifier []byte
+	if identifier, in, ok = parseString(in); !ok {
+		return
+	}
+
+	key := new(ecdsa.PublicKey)
+
+	switch string(identifier) {
+	case "nistp256":
+		key.Curve = elliptic.P256()
+	case "nistp384":
+		key.Curve = elliptic.P384()
+	case "nistp521":
+		key.Curve = elliptic.P521()
+	default:
+		ok = false
+		return
+	}
+
+	var keyBytes []byte
+	if keyBytes, in, ok = parseString(in); !ok {
+		return
+	}
+
+	key.X, key.Y = elliptic.Unmarshal(key.Curve, keyBytes)
+	if key.X == nil || key.Y == nil {
+		ok = false
+		return
+	}
+	return NewECDSAPublicKey(key), in, ok
+}
+
+func (key *ecdsaPublicKey) Marshal() []byte {
+	// See RFC 5656, section 3.1.
+	keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y)
+
+	ID := key.nistID()
+	length := stringLength(len(ID))
+	length += stringLength(len(keyBytes))
+
+	ret := make([]byte, length)
+	r := marshalString(ret, []byte(ID))
+	r = marshalString(r, keyBytes)
+	return ret
+}
+
+func (key *ecdsaPublicKey) Verify(data []byte, sigBlob []byte) bool {
+	h := key.hash().New()
+	h.Write(data)
+	digest := h.Sum(nil)
+
+	// Per RFC 5656, section 3.1.2,
+	// The ecdsa_signature_blob value has the following specific encoding:
+	//    mpint    r
+	//    mpint    s
+	r, rest, ok := parseInt(sigBlob)
+	if !ok {
+		return false
+	}
+	s, rest, ok := parseInt(rest)
+	if !ok || len(rest) > 0 {
+		return false
+	}
+	return ecdsa.Verify((*ecdsa.PublicKey)(key), digest, r, s)
 }

+ 60 - 0
ssh/keys_test.go

@@ -0,0 +1,60 @@
+package ssh
+
+import (
+	"crypto"
+	"crypto/dsa"
+	"crypto/rand"
+	"crypto/rsa"
+	"reflect"
+	"testing"
+)
+
+func TestRSAMarshal(t *testing.T) {
+	k0 := &rsakey.PublicKey
+	k1 := NewRSAPublicKey(k0)
+	k2, rest, ok := ParsePublicKey(MarshalPublicKey(k1))
+	if !ok {
+		t.Errorf("could not parse back Blob output")
+	}
+	if len(rest) > 0 {
+		t.Errorf("trailing junk in RSA Blob() output")
+	}
+	if !reflect.DeepEqual(k0, k2.RawKey().(*rsa.PublicKey)) {
+		t.Errorf("got %#v in roundtrip, want %#v", k2.RawKey(), k0)
+	}
+}
+
+func TestRSAKeyVerify(t *testing.T) {
+	pub := NewRSAPublicKey(&rsakey.PublicKey)
+
+	data := []byte("sign me")
+	h := crypto.SHA1.New()
+	h.Write(data)
+	digest := h.Sum(nil)
+
+	sig, err := rsa.SignPKCS1v15(rand.Reader, rsakey, crypto.SHA1, digest)
+	if err != nil {
+		t.Fatalf("SignPKCS1v15: %v", err)
+	}
+
+	if !pub.Verify(data, sig) {
+		t.Errorf("publicKey.Verify failed")
+	}
+}
+
+func TestDSAMarshal(t *testing.T) {
+	k0 := &dsakey.PublicKey
+	k1 := NewDSAPublicKey(k0)
+	k2, rest, ok := ParsePublicKey(MarshalPublicKey(k1))
+	if !ok {
+		t.Errorf("could not parse back Blob output")
+	}
+	if len(rest) > 0 {
+		t.Errorf("trailing junk in DSA Blob() output")
+	}
+	if !reflect.DeepEqual(k0, k2.RawKey().(*dsa.PublicKey)) {
+		t.Errorf("got %#v in roundtrip, want %#v", k2.RawKey(), k0)
+	}
+}
+
+// TODO(hanwen): test for ECDSA marshal.

+ 16 - 38
ssh/server.go

@@ -78,13 +78,12 @@ func (s *ServerConfig) SetRSAPrivateKey(pemBytes []byte) error {
 	if block == nil {
 		return errors.New("ssh: no key found")
 	}
-	var err error
-	s.rsa, err = x509.ParsePKCS1PrivateKey(block.Bytes)
+	rsa, err := x509.ParsePKCS1PrivateKey(block.Bytes)
 	if err != nil {
 		return err
 	}
-
-	s.rsaSerialized = serializePublicKey(&s.rsa.PublicKey)
+	s.rsa = rsa
+	s.rsaSerialized = MarshalPublicKey(NewRSAPublicKey(&rsa.PublicKey))
 	return nil
 }
 
@@ -170,10 +169,7 @@ func (s *ServerConn) kexECDH(curve elliptic.Curve, magics *handshakeMagics, host
 		return nil, err
 	}
 
-	hostKey, err := s.serializedHostKey(hostKeyAlgo)
-	if err != nil {
-		return nil, err
-	}
+	hostKeyBytes := s.config.rsaSerialized
 
 	serializedEphKey := elliptic.Marshal(curve, ephKey.PublicKey.X, ephKey.PublicKey.Y)
 
@@ -186,7 +182,7 @@ func (s *ServerConn) kexECDH(curve elliptic.Curve, magics *handshakeMagics, host
 	writeString(h, magics.serverVersion)
 	writeString(h, magics.clientKexInit)
 	writeString(h, magics.serverKexInit)
-	writeString(h, hostKey)
+	writeString(h, hostKeyBytes)
 	writeString(h, kexECDHInit.ClientPubKey)
 	writeString(h, serializedEphKey)
 
@@ -196,15 +192,15 @@ func (s *ServerConn) kexECDH(curve elliptic.Curve, magics *handshakeMagics, host
 
 	H := h.Sum(nil)
 
-	serializedSig, err := s.serializedHostKeySignature(hostKeyAlgo, H)
+	sig, err := s.serializedHostKeySignature(hostKeyAlgo, H)
 	if err != nil {
 		return nil, err
 	}
 
 	reply := kexECDHReplyMsg{
 		EphemeralPubKey: serializedEphKey,
-		HostKey:         hostKey,
-		Signature:       serializedSig,
+		HostKey:         hostKeyBytes,
+		Signature:       sig,
 	}
 
 	serialized := marshal(msgKexECDHReply, reply)
@@ -220,14 +216,6 @@ func (s *ServerConn) kexECDH(curve elliptic.Curve, magics *handshakeMagics, host
 	}, nil
 }
 
-func (s *ServerConn) serializedHostKey(hostKeyAlgo string) ([]byte, error) {
-	switch hostKeyAlgo {
-	case hostAlgoRSA:
-		return s.config.rsaSerialized, nil
-	}
-	return nil, errors.New("ssh: internal error")
-}
-
 // validateECPublicKey checks that the point is a valid public key for
 // the given curve. See [SEC1], 3.2.2
 func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool {
@@ -280,17 +268,14 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
 		return nil, err
 	}
 
-	hostKey, err := s.serializedHostKey(hostKeyAlgo)
-	if err != nil {
-		return nil, err
-	}
+	hostKeyBytes := s.config.rsaSerialized
 
 	h := hashFunc.New()
 	writeString(h, magics.clientVersion)
 	writeString(h, magics.serverVersion)
 	writeString(h, magics.clientKexInit)
 	writeString(h, magics.serverKexInit)
-	writeString(h, hostKey)
+	writeString(h, hostKeyBytes)
 	writeInt(h, kexDHInit.X)
 	writeInt(h, Y)
 
@@ -300,15 +285,15 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
 
 	H := h.Sum(nil)
 
-	serializedSig, err := s.serializedHostKeySignature(hostKeyAlgo, H)
+	sig, err := s.serializedHostKeySignature(hostKeyAlgo, H)
 	if err != nil {
 		return nil, err
 	}
 
 	kexDHReply := kexDHReplyMsg{
-		HostKey:   hostKey,
+		HostKey:   hostKeyBytes,
 		Y:         Y,
-		Signature: serializedSig,
+		Signature: sig,
 	}
 	packet = marshal(msgKexDHReply, kexDHReply)
 
@@ -316,7 +301,7 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
 	return &kexResult{
 		H:       H,
 		K:       K,
-		HostKey: hostKey,
+		HostKey: hostKeyBytes,
 		Hash:    hashFunc,
 	}, nil
 }
@@ -417,7 +402,6 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
 	if !ok {
 		return errors.New("ssh: no common algorithms")
 	}
-
 	if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] {
 		// The client sent a Kex message for the wrong algorithm,
 		// which we have to ignore.
@@ -619,14 +603,8 @@ userAuthLoop:
 				if !ok {
 					return ParseError{msgUserAuthRequest}
 				}
-				hashFunc, ok := hashFuncs[algo]
-				if !ok {
-					return errors.New("ssh: isAcceptableAlgo incorrect")
-				}
-				h := hashFunc.New()
-				h.Write(signedData)
-				digest := h.Sum(nil)
-				if verifySignature(digest, sig, key) != nil {
+
+				if !key.Verify(signedData, sig.Blob) {
 					return ParseError{msgUserAuthRequest}
 				}
 				// TODO(jmpittman): Implement full validation for certificates.

+ 3 - 3
ssh/test/keys_test.go

@@ -1,7 +1,6 @@
 package test
 
 import (
-	"crypto/rsa"
 	"crypto/x509"
 	"encoding/pem"
 	"reflect"
@@ -124,6 +123,7 @@ AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+
 
 func TestMarshalParsePublicKey(t *testing.T) {
 	pub := getTestPublicKey(t)
+
 	authKeys := ssh.MarshalAuthorizedKey(pub)
 	actualFields := strings.Fields(string(authKeys))
 	if len(actualFields) == 0 {
@@ -170,7 +170,7 @@ func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []authResult) {
 
 }
 
-func getTestPublicKey(t *testing.T) *rsa.PublicKey {
+func getTestPublicKey(t *testing.T) ssh.PublicKey {
 	block, _ := pem.Decode([]byte(testClientPrivateKey))
 	if block == nil {
 		t.Fatalf("pem.Decode: %v", testClientPrivateKey)
@@ -180,7 +180,7 @@ func getTestPublicKey(t *testing.T) *rsa.PublicKey {
 		t.Fatalf("x509.ParsePKCS1PrivateKey: %v", err)
 	}
 
-	return &priv.PublicKey
+	return ssh.NewRSAPublicKey(&priv.PublicKey)
 }
 
 func TestAuth(t *testing.T) {

+ 4 - 4
ssh/test/test_unix_test.go

@@ -71,7 +71,7 @@ func init() {
 	if err != nil {
 		panic("ParsePKCS1PrivateKey: " + err.Error())
 	}
-	serializedHostKey = ssh.MarshalPublicKey(&priv.PublicKey)
+	serializedHostKey = ssh.MarshalPublicKey(ssh.NewRSAPublicKey(&priv.PublicKey))
 }
 
 type server struct {
@@ -266,15 +266,15 @@ type keychain struct {
 	keys []interface{}
 }
 
-func (k *keychain) Key(i int) (interface{}, error) {
+func (k *keychain) Key(i int) (ssh.PublicKey, error) {
 	if i < 0 || i >= len(k.keys) {
 		return nil, nil
 	}
 	switch key := k.keys[i].(type) {
 	case *rsa.PrivateKey:
-		return &key.PublicKey, nil
+		return ssh.NewRSAPublicKey(&key.PublicKey), nil
 	case *dsa.PrivateKey:
-		return &key.PublicKey, nil
+		return ssh.NewDSAPublicKey(&key.PublicKey), nil
 	}
 	panic("unknown key type")
 }