Przeglądaj źródła

x/crypto/ssh: support more keytypes in the agent.

This allows the golang ssh-agent to support the full suite of keys
the library accepts.

Currently constraints are ignored.

Change-Id: I7d48c78e9a355582eb54788571a483a736c3d3ef
Reviewed-on: https://go-review.googlesource.com/21536
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Peter Moody 9 lat temu
rodzic
commit
e84a34b1ee
6 zmienionych plików z 390 dodań i 48 usunięć
  1. 12 9
      ssh/agent/client.go
  2. 222 18
      ssh/agent/server.go
  3. 77 0
      ssh/agent/server_test.go
  4. 6 0
      ssh/keys.go
  5. 39 21
      ssh/messages.go
  6. 34 0
      ssh/messages_test.go

+ 12 - 9
ssh/agent/client.go

@@ -184,10 +184,13 @@ func (k *Key) Marshal() []byte {
 	return k.Blob
 }
 
-// Verify satisfies the ssh.PublicKey interface, but is not
-// implemented for agent keys.
+// Verify satisfies the ssh.PublicKey interface.
 func (k *Key) Verify(data []byte, sig *ssh.Signature) error {
-	return errors.New("agent: agent key does not know how to verify")
+	pubKey, err := ssh.ParsePublicKey(k.Blob)
+	if err != nil {
+		return fmt.Errorf("agent: bad public key")
+	}
+	return pubKey.Verify(data, sig)
 }
 
 type wireKey struct {
@@ -389,7 +392,7 @@ func unmarshal(packet []byte) (interface{}, error) {
 }
 
 type rsaKeyMsg struct {
-	Type        string `sshtype:"17"`
+	Type        string `sshtype:"17|25"`
 	N           *big.Int
 	E           *big.Int
 	D           *big.Int
@@ -401,7 +404,7 @@ type rsaKeyMsg struct {
 }
 
 type dsaKeyMsg struct {
-	Type        string `sshtype:"17"`
+	Type        string `sshtype:"17|25"`
 	P           *big.Int
 	Q           *big.Int
 	G           *big.Int
@@ -412,7 +415,7 @@ type dsaKeyMsg struct {
 }
 
 type ecdsaKeyMsg struct {
-	Type        string `sshtype:"17"`
+	Type        string `sshtype:"17|25"`
 	Curve       string
 	KeyBytes    []byte
 	D           *big.Int
@@ -481,7 +484,7 @@ func (c *client) insertKey(s interface{}, comment string, constraints []byte) er
 }
 
 type rsaCertMsg struct {
-	Type        string `sshtype:"17"`
+	Type        string `sshtype:"17|25"`
 	CertBytes   []byte
 	D           *big.Int
 	Iqmp        *big.Int // IQMP = Inverse Q Mod P
@@ -492,7 +495,7 @@ type rsaCertMsg struct {
 }
 
 type dsaCertMsg struct {
-	Type        string `sshtype:"17"`
+	Type        string `sshtype:"17|25"`
 	CertBytes   []byte
 	X           *big.Int
 	Comments    string
@@ -500,7 +503,7 @@ type dsaCertMsg struct {
 }
 
 type ecdsaCertMsg struct {
-	Type        string `sshtype:"17"`
+	Type        string `sshtype:"17|25"`
 	CertBytes   []byte
 	D           *big.Int
 	Comments    string

+ 222 - 18
ssh/agent/server.go

@@ -5,8 +5,12 @@
 package agent
 
 import (
+	"crypto/dsa"
+	"crypto/ecdsa"
+	"crypto/elliptic"
 	"crypto/rsa"
 	"encoding/binary"
+	"errors"
 	"fmt"
 	"io"
 	"log"
@@ -128,6 +132,7 @@ func (s *server) processRequest(data []byte) (interface{}, error) {
 			return nil, err
 		}
 		return &signResponseAgentMsg{SigBlob: ssh.Marshal(sig)}, nil
+
 	case agentRequestIdentities:
 		keys, err := s.agent.List()
 		if err != nil {
@@ -141,42 +146,241 @@ func (s *server) processRequest(data []byte) (interface{}, error) {
 			rep.Keys = append(rep.Keys, marshalKey(k)...)
 		}
 		return rep, nil
-	case agentAddIdentity:
+
+	case agentAddIdConstrained, agentAddIdentity:
 		return nil, s.insertIdentity(data)
 	}
 
 	return nil, fmt.Errorf("unknown opcode %d", data[0])
 }
 
+func parseRSAKey(req []byte) (*AddedKey, error) {
+	var k rsaKeyMsg
+	if err := ssh.Unmarshal(req, &k); err != nil {
+		return nil, err
+	}
+	if k.E.BitLen() > 30 {
+		return nil, errors.New("agent: RSA public exponent too large")
+	}
+	priv := &rsa.PrivateKey{
+		PublicKey: rsa.PublicKey{
+			E: int(k.E.Int64()),
+			N: k.N,
+		},
+		D:      k.D,
+		Primes: []*big.Int{k.P, k.Q},
+	}
+	priv.Precompute()
+
+	return &AddedKey{PrivateKey: priv, Comment: k.Comments}, nil
+}
+
+func parseDSAKey(req []byte) (*AddedKey, error) {
+	var k dsaKeyMsg
+	if err := ssh.Unmarshal(req, &k); err != nil {
+		return nil, err
+	}
+	priv := &dsa.PrivateKey{
+		PublicKey: dsa.PublicKey{
+			Parameters: dsa.Parameters{
+				P: k.P,
+				Q: k.Q,
+				G: k.G,
+			},
+			Y: k.Y,
+		},
+		X: k.X,
+	}
+
+	return &AddedKey{PrivateKey: priv, Comment: k.Comments}, nil
+}
+
+func unmarshalECDSA(curveName string, keyBytes []byte, privScalar *big.Int) (priv *ecdsa.PrivateKey, err error) {
+	priv = &ecdsa.PrivateKey{
+		D: privScalar,
+	}
+
+	switch curveName {
+	case "nistp256":
+		priv.Curve = elliptic.P256()
+	case "nistp384":
+		priv.Curve = elliptic.P384()
+	case "nistp521":
+		priv.Curve = elliptic.P521()
+	default:
+		return nil, fmt.Errorf("agent: unknown curve %q", curveName)
+	}
+
+	priv.X, priv.Y = elliptic.Unmarshal(priv.Curve, keyBytes)
+	if priv.X == nil || priv.Y == nil {
+		return nil, errors.New("agent: point not on curve")
+	}
+
+	return priv, nil
+}
+
+func parseECDSAKey(req []byte) (*AddedKey, error) {
+	var k ecdsaKeyMsg
+	if err := ssh.Unmarshal(req, &k); err != nil {
+		return nil, err
+	}
+
+	priv, err := unmarshalECDSA(k.Curve, k.KeyBytes, k.D)
+	if err != nil {
+		return nil, err
+	}
+
+	return &AddedKey{PrivateKey: &priv, Comment: k.Comments}, nil
+}
+
+func parseRSACert(req []byte) (*AddedKey, error) {
+	var k rsaCertMsg
+	if err := ssh.Unmarshal(req, &k); err != nil {
+		return nil, err
+	}
+
+	pubKey, err := ssh.ParsePublicKey(k.CertBytes)
+	if err != nil {
+		return nil, err
+	}
+
+	cert, ok := pubKey.(*ssh.Certificate)
+	if !ok {
+		return nil, errors.New("agent: bad RSA certificate")
+	}
+
+	// An RSA publickey as marshaled by rsaPublicKey.Marshal() in keys.go
+	var rsaPub struct {
+		Name string
+		E    *big.Int
+		N    *big.Int
+	}
+	if err := ssh.Unmarshal(cert.Key.Marshal(), &rsaPub); err != nil {
+		return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
+	}
+
+	if rsaPub.E.BitLen() > 30 {
+		return nil, errors.New("agent: RSA public exponent too large")
+	}
+
+	priv := rsa.PrivateKey{
+		PublicKey: rsa.PublicKey{
+			E: int(rsaPub.E.Int64()),
+			N: rsaPub.N,
+		},
+		D:      k.D,
+		Primes: []*big.Int{k.Q, k.P},
+	}
+	priv.Precompute()
+
+	return &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}, nil
+}
+
+func parseDSACert(req []byte) (*AddedKey, error) {
+	var k dsaCertMsg
+	if err := ssh.Unmarshal(req, &k); err != nil {
+		return nil, err
+	}
+	pubKey, err := ssh.ParsePublicKey(k.CertBytes)
+	if err != nil {
+		return nil, err
+	}
+	cert, ok := pubKey.(*ssh.Certificate)
+	if !ok {
+		return nil, errors.New("agent: bad DSA certificate")
+	}
+
+	// A DSA publickey as marshaled by dsaPublicKey.Marshal() in keys.go
+	var w struct {
+		Name       string
+		P, Q, G, Y *big.Int
+	}
+	if err := ssh.Unmarshal(cert.Key.Marshal(), &w); err != nil {
+		return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
+	}
+
+	priv := &dsa.PrivateKey{
+		PublicKey: dsa.PublicKey{
+			Parameters: dsa.Parameters{
+				P: w.P,
+				Q: w.Q,
+				G: w.G,
+			},
+			Y: w.Y,
+		},
+		X: k.X,
+	}
+
+	return &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}, nil
+}
+
+func parseECDSACert(req []byte) (*AddedKey, error) {
+	var k ecdsaCertMsg
+	if err := ssh.Unmarshal(req, &k); err != nil {
+		return nil, err
+	}
+
+	pubKey, err := ssh.ParsePublicKey(k.CertBytes)
+	if err != nil {
+		return nil, err
+	}
+	cert, ok := pubKey.(*ssh.Certificate)
+	if !ok {
+		return nil, errors.New("agent: bad ECDSA certificate")
+	}
+
+	// An ECDSA publickey as marshaled by ecdsaPublicKey.Marshal() in keys.go
+	var ecdsaPub struct {
+		Name string
+		ID   string
+		Key  []byte
+	}
+	if err := ssh.Unmarshal(cert.Key.Marshal(), &ecdsaPub); err != nil {
+		return nil, err
+	}
+
+	priv, err := unmarshalECDSA(ecdsaPub.ID, ecdsaPub.Key, k.D)
+	if err != nil {
+		return nil, err
+	}
+
+	return &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}, nil
+}
+
 func (s *server) insertIdentity(req []byte) error {
 	var record struct {
-		Type string `sshtype:"17"`
+		Type string `sshtype:"17|25"`
 		Rest []byte `ssh:"rest"`
 	}
+
 	if err := ssh.Unmarshal(req, &record); err != nil {
 		return err
 	}
 
+	var addedKey *AddedKey
+	var err error
+
 	switch record.Type {
 	case ssh.KeyAlgoRSA:
-		var k rsaKeyMsg
-		if err := ssh.Unmarshal(req, &k); err != nil {
-			return err
-		}
-
-		priv := rsa.PrivateKey{
-			PublicKey: rsa.PublicKey{
-				E: int(k.E.Int64()),
-				N: k.N,
-			},
-			D:      k.D,
-			Primes: []*big.Int{k.P, k.Q},
-		}
-		priv.Precompute()
+		addedKey, err = parseRSAKey(req)
+	case ssh.KeyAlgoDSA:
+		addedKey, err = parseDSAKey(req)
+	case ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521:
+		addedKey, err = parseECDSACert(req)
+	case ssh.CertAlgoRSAv01:
+		addedKey, err = parseRSACert(req)
+	case ssh.CertAlgoDSAv01:
+		addedKey, err = parseDSACert(req)
+	case ssh.CertAlgoECDSA256v01, ssh.CertAlgoECDSA384v01, ssh.CertAlgoECDSA521v01:
+		addedKey, err = parseECDSACert(req)
+	default:
+		return fmt.Errorf("agent: not implemented: %q", record.Type)
+	}
 
-		return s.agent.Add(AddedKey{PrivateKey: &priv, Comment: k.Comments})
+	if err != nil {
+		return err
 	}
-	return fmt.Errorf("not implemented: %s", record.Type)
+	return s.agent.Add(*addedKey)
 }
 
 // ServeAgent serves the agent protocol on the given connection. It

+ 77 - 0
ssh/agent/server_test.go

@@ -5,6 +5,9 @@
 package agent
 
 import (
+	"crypto"
+	"crypto/rand"
+	"fmt"
 	"testing"
 
 	"golang.org/x/crypto/ssh"
@@ -107,3 +110,77 @@ func testV1ProtocolMessages(t *testing.T, c *client) {
 		t.Fatalf("invalid remove all response: %#v", reply)
 	}
 }
+
+func verifyKey(sshAgent Agent) error {
+	keys, err := sshAgent.List()
+	if err != nil {
+		return fmt.Errorf("listing keys: %v", err)
+	}
+
+	if len(keys) != 1 {
+		return fmt.Errorf("bad number of keys found. expected 1, got %d", len(keys))
+	}
+
+	buf := make([]byte, 128)
+	if _, err := rand.Read(buf); err != nil {
+		return fmt.Errorf("rand: %v", err)
+	}
+
+	sig, err := sshAgent.Sign(keys[0], buf)
+	if err != nil {
+		return fmt.Errorf("sign: %v", err)
+	}
+
+	if err := keys[0].Verify(buf, sig); err != nil {
+		return fmt.Errorf("verify: %v", err)
+	}
+	return nil
+}
+
+func addKeyToAgent(key crypto.PrivateKey) error {
+	sshAgent := NewKeyring()
+	if err := sshAgent.Add(AddedKey{PrivateKey: key}); err != nil {
+		return fmt.Errorf("add: %v", err)
+	}
+	return verifyKey(sshAgent)
+}
+
+func TestKeyTypes(t *testing.T) {
+	for k, v := range testPrivateKeys {
+		if err := addKeyToAgent(v); err != nil {
+			t.Errorf("error adding key type %s, %v", k, err)
+		}
+	}
+}
+
+func addCertToAgent(key crypto.PrivateKey, cert *ssh.Certificate) error {
+	sshAgent := NewKeyring()
+	if err := sshAgent.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil {
+		return fmt.Errorf("add: %v", err)
+	}
+	return verifyKey(sshAgent)
+}
+
+func TestCertTypes(t *testing.T) {
+	for keyType, key := range testPublicKeys {
+		cert := &ssh.Certificate{
+			ValidPrincipals: []string{"gopher1"},
+			ValidAfter:      0,
+			ValidBefore:     ssh.CertTimeInfinity,
+			Key:             key,
+			Serial:          1,
+			CertType:        ssh.UserCert,
+			SignatureKey:    testPublicKeys["rsa"],
+			Permissions: ssh.Permissions{
+				CriticalOptions: map[string]string{},
+				Extensions:      map[string]string{},
+			},
+		}
+		if err := cert.SignCert(rand.Reader, testSigners["rsa"]); err != nil {
+			t.Fatalf("signcert: %v", err)
+		}
+		if err := addCertToAgent(testPrivateKeys[keyType], cert); err != nil {
+			t.Fatalf("%v", err)
+		}
+	}
+}

+ 6 - 0
ssh/keys.go

@@ -319,6 +319,8 @@ func parseRSA(in []byte) (out PublicKey, rest []byte, err error) {
 
 func (r *rsaPublicKey) Marshal() []byte {
 	e := new(big.Int).SetInt64(int64(r.E))
+	// RSA publickey struct layout should match the struct used by
+	// parseRSACert in the x/crypto/ssh/agent package.
 	wirekey := struct {
 		Name string
 		E    *big.Int
@@ -369,6 +371,8 @@ func parseDSA(in []byte) (out PublicKey, rest []byte, err error) {
 }
 
 func (k *dsaPublicKey) Marshal() []byte {
+	// DSA publickey struct layout should match the struct used by
+	// parseDSACert in the x/crypto/ssh/agent package.
 	w := struct {
 		Name       string
 		P, Q, G, Y *big.Int
@@ -507,6 +511,8 @@ func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) {
 func (key *ecdsaPublicKey) Marshal() []byte {
 	// See RFC 5656, section 3.1.
 	keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y)
+	// ECDSA publickey struct layout should match the struct used by
+	// parseECDSACert in the x/crypto/ssh/agent package.
 	w := struct {
 		Name string
 		ID   string

+ 39 - 21
ssh/messages.go

@@ -13,6 +13,7 @@ import (
 	"math/big"
 	"reflect"
 	"strconv"
+	"strings"
 )
 
 // These are SSH message type numbers. They are scattered around several
@@ -266,17 +267,19 @@ type userAuthPubKeyOkMsg struct {
 	PubKey []byte
 }
 
-// typeTag returns the type byte for the given type. The type should
-// be struct.
-func typeTag(structType reflect.Type) byte {
-	var tag byte
-	var tagStr string
-	tagStr = structType.Field(0).Tag.Get("sshtype")
-	i, err := strconv.Atoi(tagStr)
-	if err == nil {
-		tag = byte(i)
+// typeTags returns the possible type bytes for the given reflect.Type, which
+// should be a struct. The possible values are separated by a '|' character.
+func typeTags(structType reflect.Type) (tags []byte) {
+	tagStr := structType.Field(0).Tag.Get("sshtype")
+
+	for _, tag := range strings.Split(tagStr, "|") {
+		i, err := strconv.Atoi(tag)
+		if err == nil {
+			tags = append(tags, byte(i))
+		}
 	}
-	return tag
+
+	return tags
 }
 
 func fieldError(t reflect.Type, field int, problem string) error {
@@ -290,19 +293,34 @@ var errShortRead = errors.New("ssh: short read")
 
 // Unmarshal parses data in SSH wire format into a structure. The out
 // argument should be a pointer to struct. If the first member of the
-// struct has the "sshtype" tag set to a number in decimal, the packet
-// must start that number.  In case of error, Unmarshal returns a
-// ParseError or UnexpectedMessageError.
+// struct has the "sshtype" tag set to a '|'-separated set of numbers
+// in decimal, the packet must start with one of those numbers. In
+// case of error, Unmarshal returns a ParseError or
+// UnexpectedMessageError.
 func Unmarshal(data []byte, out interface{}) error {
 	v := reflect.ValueOf(out).Elem()
 	structType := v.Type()
-	expectedType := typeTag(structType)
+	expectedTypes := typeTags(structType)
+
+	var expectedType byte
+	if len(expectedTypes) > 0 {
+		expectedType = expectedTypes[0]
+	}
+
 	if len(data) == 0 {
 		return parseError(expectedType)
 	}
-	if expectedType > 0 {
-		if data[0] != expectedType {
-			return unexpectedMessageError(expectedType, data[0])
+
+	if len(expectedTypes) > 0 {
+		goodType := false
+		for _, e := range expectedTypes {
+			if e > 0 && data[0] == e {
+				goodType = true
+				break
+			}
+		}
+		if !goodType {
+			return fmt.Errorf("ssh: unexpected message type %d (expected one of %v)", data[0], expectedTypes)
 		}
 		data = data[1:]
 	}
@@ -386,7 +404,7 @@ func Unmarshal(data []byte, out interface{}) error {
 				return fieldError(structType, i, "pointer to unsupported type")
 			}
 		default:
-			return fieldError(structType, i, "unsupported type")
+			return fieldError(structType, i, fmt.Sprintf("unsupported type: %v", t))
 		}
 	}
 
@@ -409,9 +427,9 @@ func Marshal(msg interface{}) []byte {
 
 func marshalStruct(out []byte, msg interface{}) []byte {
 	v := reflect.Indirect(reflect.ValueOf(msg))
-	msgType := typeTag(v.Type())
-	if msgType > 0 {
-		out = append(out, msgType)
+	msgTypes := typeTags(v.Type())
+	if len(msgTypes) > 0 {
+		out = append(out, msgTypes[0])
 	}
 
 	for i, n := 0, v.NumField(); i < n; i++ {

+ 34 - 0
ssh/messages_test.go

@@ -172,6 +172,40 @@ func TestUnmarshalShortKexInitPacket(t *testing.T) {
 	}
 }
 
+func TestMarshalMultiTag(t *testing.T) {
+	var res struct {
+		A uint32 `sshtype:"1|2"`
+	}
+
+	good1 := struct {
+		A uint32 `sshtype:"1"`
+	}{
+		1,
+	}
+	good2 := struct {
+		A uint32 `sshtype:"2"`
+	}{
+		1,
+	}
+
+	if e := Unmarshal(Marshal(good1), &res); e != nil {
+		t.Errorf("error unmarshaling multipart tag: %v", e)
+	}
+
+	if e := Unmarshal(Marshal(good2), &res); e != nil {
+		t.Errorf("error unmarshaling multipart tag: %v", e)
+	}
+
+	bad1 := struct {
+		A uint32 `sshtype:"3"`
+	}{
+		1,
+	}
+	if e := Unmarshal(Marshal(bad1), &res); e == nil {
+		t.Errorf("bad struct unmarshaled without error")
+	}
+}
+
 func randomBytes(out []byte, rand *rand.Rand) {
 	for i := 0; i < len(out); i++ {
 		out[i] = byte(rand.Int31())