Bläddra i källkod

ssh: cosmetic cleanups

These are the cosmetic cleanups from the bits of code that I
rereviewed.

1) stringLength now takes a int; the length of the string. Too many
   callers were allocating with stringLength([]byte(s)) and
   stringLength only needs to call len().

2) agent.go now has sendAndReceive to remove logic that was
   duplicated.

3) We now reject negative DH values

4) We now reject empty packets rather than crashing.

R=dave, jonathan.mark.pittman
CC=golang-dev
https://golang.org/cl/6061052
Adam Langley 13 år sedan
förälder
incheckning
63f855d724
13 ändrade filer med 103 tillägg och 99 borttagningar
  1. 42 45
      ssh/agent.go
  2. 7 9
      ssh/certs.go
  3. 8 9
      ssh/channel.go
  4. 5 5
      ssh/cipher.go
  5. 6 6
      ssh/client.go
  6. 1 1
      ssh/client_auth.go
  7. 1 1
      ssh/client_auth_test.go
  8. 17 9
      ssh/common.go
  9. 1 1
      ssh/keys.go
  10. 2 2
      ssh/messages.go
  11. 6 7
      ssh/server.go
  12. 2 2
      ssh/session.go
  13. 5 2
      ssh/transport.go

+ 42 - 45
ssh/agent.go

@@ -10,7 +10,6 @@ package ssh
 import (
 	"encoding/base64"
 	"errors"
-	"fmt"
 	"io"
 )
 
@@ -44,6 +43,10 @@ const (
 	agentConstrainConfirm  = 2
 )
 
+// maxAgentResponseBytes is the maximum agent reply size that is accepted. This
+// is a sanity check, not a limit in the spec.
+const maxAgentResponseBytes = 16 << 20
+
 // Agent messages:
 // These structures mirror the wire format of the corresponding ssh agent
 // messages found in PROTOCOL.agent.
@@ -85,18 +88,16 @@ type AgentKey struct {
 func (ak *AgentKey) String() string {
 	algo, _, ok := parseString(ak.blob)
 	if !ok {
-		return "malformed key"
+		return "ssh: malformed key"
 	}
 
-	algoName := string(algo)
-	b64EncKey := base64.StdEncoding.EncodeToString(ak.blob)
-	comment := ""
+	s := string(algo) + " " + base64.StdEncoding.EncodeToString(ak.blob)
 
 	if ak.Comment != "" {
-		comment = " " + ak.Comment
+		s += " " + ak.Comment
 	}
 
-	return fmt.Sprintf("%s %s%s", algoName, b64EncKey, comment)
+	return s
 }
 
 // Key returns an agent's public key as a *rsa.PublicKey, *dsa.PublicKey, or
@@ -131,50 +132,51 @@ type AgentClient struct {
 	io.ReadWriter
 }
 
-func (ac *AgentClient) sendRequest(req []byte) error {
-	msg := make([]byte, stringLength(req))
+// sendAndReceive sends req to the agent and waits for a reply. On success,
+// the reply is unmarshaled into reply and replyType is set to the first byte of
+// the reply, which contains the type of the message.
+func (ac *AgentClient) sendAndReceive(req []byte) (reply interface{}, replyType uint8, err error) {
+	msg := make([]byte, stringLength(len(req)))
 	marshalString(msg, req)
-	if _, err := ac.Write(msg); err != nil {
-		return err
+	if _, err = ac.Write(msg); err != nil {
+		return
 	}
-	return nil
-}
 
-func (ac *AgentClient) readResponse() ([]byte, error) {
 	var respSizeBuf [4]byte
-	if _, err := io.ReadFull(ac, respSizeBuf[:]); err != nil {
-		return nil, err
+	if _, err = io.ReadFull(ac, respSizeBuf[:]); err != nil {
+		return
 	}
+	respSize, _, _ := parseUint32(respSizeBuf[:])
 
-	respSize, _, ok := parseUint32(respSizeBuf[:])
-	if !ok {
-		return nil, errors.New("ssh: failure to parse response size")
+	if respSize > maxAgentResponseBytes {
+		err = errors.New("ssh: agent reply too large")
+		return
 	}
 
 	buf := make([]byte, respSize)
-	if _, err := io.ReadFull(ac, buf); err != nil {
-		return nil, err
+	if _, err = io.ReadFull(ac, buf); err != nil {
+		return
 	}
-	return buf, nil
+	return unmarshalAgentMsg(buf)
 }
 
 // RequestIdentities queries the agent for protocol 2 keys as defined in
 // PROTOCOL.agent section 2.5.2.
 func (ac *AgentClient) RequestIdentities() ([]*AgentKey, error) {
 	req := marshal(agentRequestIdentities, requestIdentitiesAgentMsg{})
-	if err := ac.sendRequest(req); err != nil {
-		return nil, err
-	}
 
-	resp, err := ac.readResponse()
+	msg, msgType, err := ac.sendAndReceive(req)
 	if err != nil {
 		return nil, err
 	}
 
-	switch msg := decodeAgentMsg(resp).(type) {
+	switch msg := msg.(type) {
 	case *identitiesAnswerAgentMsg:
+		if msg.NumKeys > maxAgentResponseBytes/8 {
+			return nil, errors.New("ssh: too many keys in agent reply")
+		}
 		keys := make([]*AgentKey, msg.NumKeys)
-		data := msg.Keys[:]
+		data := msg.Keys
 		for i := uint32(0); i < msg.NumKeys; i++ {
 			var key *AgentKey
 			var ok bool
@@ -185,11 +187,9 @@ func (ac *AgentClient) RequestIdentities() ([]*AgentKey, error) {
 		}
 		return keys, nil
 	case *failureAgentMsg:
-		return nil, errors.New("ssh: failed to list keys.")
-	case ParseError, UnexpectedMessageError:
-		return nil, msg.(error)
+		return nil, errors.New("ssh: failed to list keys")
 	}
-	return nil, UnexpectedMessageError{agentIdentitiesAnswer, resp[0]}
+	return nil, UnexpectedMessageError{agentIdentitiesAnswer, msgType}
 }
 
 // SignRequest requests the signing of data by the agent using a protocol 2 key
@@ -200,29 +200,26 @@ func (ac *AgentClient) SignRequest(key interface{}, data []byte) ([]byte, error)
 		KeyBlob: serializePublickey(key),
 		Data:    data,
 	})
-	if err := ac.sendRequest(req); err != nil {
-		return nil, err
-	}
 
-	resp, err := ac.readResponse()
+	msg, msgType, err := ac.sendAndReceive(req)
 	if err != nil {
 		return nil, err
 	}
 
-	switch msg := decodeAgentMsg(resp).(type) {
+	switch msg := msg.(type) {
 	case *signResponseAgentMsg:
 		return msg.SigBlob, nil
 	case *failureAgentMsg:
 		return nil, errors.New("ssh: failed to sign challenge")
-	case ParseError, UnexpectedMessageError:
-		return nil, msg.(error)
 	}
-	return nil, UnexpectedMessageError{agentSignResponse, resp[0]}
+	return nil, UnexpectedMessageError{agentSignResponse, msgType}
 }
 
-func decodeAgentMsg(packet []byte) interface{} {
+// unmarshalAgentMsg parses an agent message in packet, returning the parsed
+// form and the message type of packet.
+func unmarshalAgentMsg(packet []byte) (interface{}, uint8, error) {
 	if len(packet) < 1 {
-		return ParseError{0}
+		return nil, 0, ParseError{0}
 	}
 	var msg interface{}
 	switch packet[0] {
@@ -235,10 +232,10 @@ func decodeAgentMsg(packet []byte) interface{} {
 	case agentSignResponse:
 		msg = new(signResponseAgentMsg)
 	default:
-		return UnexpectedMessageError{0, packet[0]}
+		return nil, 0, UnexpectedMessageError{0, packet[0]}
 	}
 	if err := unmarshal(msg, packet, packet[0]); err != nil {
-		return err
+		return nil, 0, err
 	}
-	return msg
+	return msg, packet[0], nil
 }

+ 7 - 9
ssh/certs.go

@@ -154,18 +154,18 @@ func marshalOpenSSHCertV01(cert *OpenSSHCertV01) []byte {
 
 	sigKey := serializePublickey(cert.SignatureKey)
 
-	length := stringLength(cert.Nonce)
+	length := stringLength(len(cert.Nonce))
 	length += len(pubKey)
 	length += 8 // Length of Serial
 	length += 4 // Length of Type
-	length += stringLength([]byte(cert.KeyId))
+	length += stringLength(len(cert.KeyId))
 	length += lengthPrefixedNameListLength(cert.ValidPrincipals)
 	length += 8 // Length of ValidAfter
 	length += 8 // Length of ValidBefore
 	length += tupleListLength(cert.CriticalOptions)
 	length += tupleListLength(cert.Extensions)
-	length += stringLength(cert.Reserved)
-	length += stringLength(sigKey)
+	length += stringLength(len(cert.Reserved))
+	length += stringLength(len(sigKey))
 	length += signatureLength(cert.Signature)
 
 	ret := make([]byte, length)
@@ -215,9 +215,7 @@ func parseLengthPrefixedNameList(in []byte) (out []string, rest []byte, ok bool)
 
 	for len(list) > 0 {
 		var next []byte
-		var ok bool
-		next, list, ok = parseString(list)
-		if !ok {
+		if next, list, ok = parseString(list); !ok {
 			return nil, nil, false
 		}
 		out = append(out, string(next))
@@ -272,8 +270,8 @@ func parseTupleList(in []byte) (out []tuple, rest []byte, ok bool) {
 
 func signatureLength(sig *signature) int {
 	length := 4 // length prefix for signature
-	length += stringLength([]byte(sig.Format))
-	length += stringLength(sig.Blob)
+	length += stringLength(len(sig.Format))
+	length += stringLength(len(sig.Blob))
 	return length
 }
 

+ 8 - 9
ssh/channel.go

@@ -56,7 +56,7 @@ type ChannelRequest struct {
 }
 
 func (c ChannelRequest) Error() string {
-	return "channel request received"
+	return "ssh: channel request received"
 }
 
 // RejectionReason is an enumeration used when rejecting channel creation
@@ -255,7 +255,7 @@ func (c *channel) Read(data []byte) (n int, err error) {
 		}
 
 		if c.length > 0 {
-			tail := min(c.head + c.length, len(c.pendingData))
+			tail := min(c.head+c.length, len(c.pendingData))
 			n = copy(data, c.pendingData[c.head:tail])
 			c.head += n
 			c.length -= n
@@ -374,18 +374,17 @@ func (c *channel) AckRequest(ok bool) error {
 		return c.serverConn.err
 	}
 
-	if ok {
-		ack := channelRequestSuccessMsg{
-			PeersId: c.theirId,
-		}
-		return c.serverConn.writePacket(marshal(msgChannelSuccess, ack))
-	} else {
+	if !ok {
 		ack := channelRequestFailureMsg{
 			PeersId: c.theirId,
 		}
 		return c.serverConn.writePacket(marshal(msgChannelFailure, ack))
 	}
-	panic("unreachable")
+
+	ack := channelRequestSuccessMsg{
+		PeersId: c.theirId,
+	}
+	return c.serverConn.writePacket(marshal(msgChannelSuccess, ack))
 }
 
 func (c *channel) ChannelType() string {

+ 5 - 5
ssh/cipher.go

@@ -35,10 +35,10 @@ func newRC4(key, iv []byte) (cipher.Stream, error) {
 }
 
 type cipherMode struct {
-	keySize  int
-	ivSize   int
-	skip     int
-	createFn func(key, iv []byte) (cipher.Stream, error)
+	keySize    int
+	ivSize     int
+	skip       int
+	createFunc func(key, iv []byte) (cipher.Stream, error)
 }
 
 func (c *cipherMode) createCipher(key, iv []byte) (cipher.Stream, error) {
@@ -49,7 +49,7 @@ func (c *cipherMode) createCipher(key, iv []byte) (cipher.Stream, error) {
 		panic("ssh: iv too small for cipher")
 	}
 
-	stream, err := c.createFn(key[:c.keySize], iv[:c.ivSize])
+	stream, err := c.createFunc(key[:c.keySize], iv[:c.ivSize])
 	if err != nil {
 		return nil, err
 	}

+ 6 - 6
ssh/client.go

@@ -154,16 +154,16 @@ func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
 		return nil, nil, err
 	}
 
-	var kexDHReply = new(kexDHReplyMsg)
-	if err = unmarshal(kexDHReply, packet, msgKexDHReply); err != nil {
+	var kexDHReply kexDHReplyMsg
+	if err = unmarshal(&kexDHReply, packet, msgKexDHReply); err != nil {
 		return nil, nil, err
 	}
 
-	if kexDHReply.Y.Sign() == 0 || kexDHReply.Y.Cmp(group.p) >= 0 {
-		return nil, nil, errors.New("server DH parameter out of bounds")
+	kInt, err := group.diffieHellman(kexDHReply.Y, x)
+	if err != nil {
+		return nil, nil, err
 	}
 
-	kInt := new(big.Int).Exp(kexDHReply.Y, x, group.p)
 	h := hashFunc.New()
 	writeString(h, magics.clientVersion)
 	writeString(h, magics.serverVersion)
@@ -352,7 +352,7 @@ func (c *clientChan) waitForChannelOpenResponse() error {
 	case *channelOpenFailureMsg:
 		return errors.New(safeString(msg.Message))
 	}
-	return errors.New("unexpected packet")
+	return errors.New("ssh: unexpected packet")
 }
 
 // sendEOF sends EOF to the server. RFC 4254 Section 5.3

+ 1 - 1
ssh/client_auth.go

@@ -213,7 +213,7 @@ func (p *publickeyAuth) auth(session []byte, user string, t *transport, rand io.
 		}
 		// manually wrap the serialized signature in a string
 		s := serializeSignature(algoname, sign)
-		sig := make([]byte, stringLength(s))
+		sig := make([]byte, stringLength(len(s)))
 		marshalString(sig, s)
 		msg := publickeyAuthMsg{
 			User:     user,

+ 1 - 1
ssh/client_auth_test.go

@@ -85,7 +85,7 @@ func (k *keychain) Sign(i int, rand io.Reader, data []byte) (sig []byte, err err
 	case *rsa.PrivateKey:
 		return rsa.SignPKCS1v15(rand, key, hashFunc, digest)
 	}
-	return nil, errors.New("unknown key type")
+	return nil, errors.New("ssh: unknown key type")
 }
 
 func (k *keychain) loadPEM(file string) error {

+ 17 - 9
ssh/common.go

@@ -7,6 +7,7 @@ package ssh
 import (
 	"crypto/dsa"
 	"crypto/rsa"
+	"errors"
 	"math/big"
 	"strconv"
 	"sync"
@@ -32,6 +33,13 @@ type dhGroup struct {
 	g, p *big.Int
 }
 
+func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) {
+	if theirPublic.Sign() <= 0 || theirPublic.Cmp(group.p) >= 0 {
+		return nil, errors.New("ssh: DH parameter out of bounds")
+	}
+	return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil
+}
+
 // dhGroup1 is the group called diffie-hellman-group1-sha1 in RFC 4253 and
 // Oakley Group 2 in RFC 2409.
 var dhGroup1 *dhGroup
@@ -178,8 +186,8 @@ func serializeSignature(algoname string, sig []byte) []byte {
 	case hostAlgoDSACertV01:
 		algoname = "ssh-dss"
 	}
-	length := stringLength([]byte(algoname))
-	length += stringLength(sig)
+	length := stringLength(len(algoname))
+	length += stringLength(len(sig))
 
 	ret := make([]byte, length)
 	r := marshalString(ret, []byte(algoname))
@@ -203,7 +211,7 @@ func serializePublickey(key interface{}) []byte {
 		panic("unexpected key type")
 	}
 
-	length := stringLength([]byte(algoname))
+	length := stringLength(len(algoname))
 	length += len(pubKeyBytes)
 	ret := make([]byte, length)
 	r := marshalString(ret, []byte(algoname))
@@ -230,14 +238,14 @@ func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubK
 	service := []byte(req.Service)
 	method := []byte(req.Method)
 
-	length := stringLength(sessionId)
+	length := stringLength(len(sessionId))
 	length += 1
-	length += stringLength(user)
-	length += stringLength(service)
-	length += stringLength(method)
+	length += stringLength(len(user))
+	length += stringLength(len(service))
+	length += stringLength(len(method))
 	length += 1
-	length += stringLength(algo)
-	length += stringLength(pubKey)
+	length += stringLength(len(algo))
+	length += stringLength(len(pubKey))
 
 	ret := make([]byte, length)
 	r := marshalString(ret, sessionId)

+ 1 - 1
ssh/keys.go

@@ -78,7 +78,7 @@ func parseDSA(in []byte) (out *dsa.PublicKey, rest []byte, ok bool) {
 // marshalPrivRSA serializes an RSA private key according to RFC 4253, section 6.6.
 func marshalPrivRSA(priv *rsa.PrivateKey) []byte {
 	e := new(big.Int).SetInt64(int64(priv.E))
-	length := stringLength([]byte(hostAlgoRSA))
+	length := stringLength(len(hostAlgoRSA))
 	length += intLength(e)
 	length += intLength(priv.N)
 

+ 2 - 2
ssh/messages.go

@@ -543,8 +543,8 @@ func writeString(w io.Writer, s []byte) {
 	w.Write(s)
 }
 
-func stringLength(s []byte) int {
-	return 4 + len(s)
+func stringLength(n int) int {
+	return 4 + n
 }
 
 func marshalString(to []byte, s []byte) []byte {

+ 6 - 7
ssh/server.go

@@ -141,24 +141,23 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
 		return
 	}
 
-	if kexDHInit.X.Sign() == 0 || kexDHInit.X.Cmp(group.p) >= 0 {
-		return nil, nil, errors.New("client DH parameter out of bounds")
-	}
-
 	y, err := rand.Int(s.config.rand(), group.p)
 	if err != nil {
 		return
 	}
 
 	Y := new(big.Int).Exp(group.g, y, group.p)
-	kInt := new(big.Int).Exp(kexDHInit.X, y, group.p)
+	kInt, err := group.diffieHellman(kexDHInit.X, y)
+	if err != nil {
+		return nil, nil, err
+	}
 
 	var serializedHostKey []byte
 	switch hostKeyAlgo {
 	case hostAlgoRSA:
 		serializedHostKey = s.config.rsaSerialized
 	default:
-		return nil, nil, errors.New("internal error")
+		return nil, nil, errors.New("ssh: internal error")
 	}
 
 	h := hashFunc.New()
@@ -187,7 +186,7 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
 			return
 		}
 	default:
-		return nil, nil, errors.New("internal error")
+		return nil, nil, errors.New("ssh: internal error")
 	}
 
 	serializedSig := serializeSignature(hostAlgoRSA, sig)

+ 2 - 2
ssh/session.go

@@ -231,9 +231,9 @@ func (s *Session) waitForResponse() error {
 	case *channelRequestSuccessMsg:
 		return nil
 	case *channelRequestFailureMsg:
-		return errors.New("request failed")
+		return errors.New("ssh: request failed")
 	}
-	return fmt.Errorf("unknown packet %T received: %v", msg, msg)
+	return fmt.Errorf("ssh: unknown packet %T received: %v", msg, msg)
 }
 
 func (s *Session) start() error {

+ 5 - 2
ssh/transport.go

@@ -105,10 +105,10 @@ func (r *reader) readOnePacket() ([]byte, error) {
 	paddingLength := uint32(lengthBytes[4])
 
 	if length <= paddingLength+1 {
-		return nil, errors.New("invalid packet length")
+		return nil, errors.New("ssh: invalid packet length")
 	}
 	if length > maxPacketSize {
-		return nil, errors.New("packet too large")
+		return nil, errors.New("ssh: packet too large")
 	}
 
 	packet := make([]byte, length-1+macSize)
@@ -136,6 +136,9 @@ func (t *transport) readPacket() ([]byte, error) {
 		if err != nil {
 			return nil, err
 		}
+		if len(packet) == 0 {
+			return nil, errors.New("ssh: zero length packet")
+		}
 		if packet[0] != msgIgnore && packet[0] != msgDebug {
 			return packet, nil
 		}