Browse Source

ssh: cosmetic cleanups

These are the cosmetic cleanups from the bits of code that I
rereviewed.

1) stringLength now takes a int; the length of the string. Too many
   callers were allocating with stringLength([]byte(s)) and
   stringLength only needs to call len().

2) agent.go now has sendAndReceive to remove logic that was
   duplicated.

3) We now reject negative DH values

4) We now reject empty packets rather than crashing.

R=dave, jonathan.mark.pittman
CC=golang-dev
https://golang.org/cl/6061052
Adam Langley 13 years ago
parent
commit
63f855d724
13 changed files with 103 additions and 99 deletions
  1. 42 45
      ssh/agent.go
  2. 7 9
      ssh/certs.go
  3. 8 9
      ssh/channel.go
  4. 5 5
      ssh/cipher.go
  5. 6 6
      ssh/client.go
  6. 1 1
      ssh/client_auth.go
  7. 1 1
      ssh/client_auth_test.go
  8. 17 9
      ssh/common.go
  9. 1 1
      ssh/keys.go
  10. 2 2
      ssh/messages.go
  11. 6 7
      ssh/server.go
  12. 2 2
      ssh/session.go
  13. 5 2
      ssh/transport.go

+ 42 - 45
ssh/agent.go

@@ -10,7 +10,6 @@ package ssh
 import (
 import (
 	"encoding/base64"
 	"encoding/base64"
 	"errors"
 	"errors"
-	"fmt"
 	"io"
 	"io"
 )
 )
 
 
@@ -44,6 +43,10 @@ const (
 	agentConstrainConfirm  = 2
 	agentConstrainConfirm  = 2
 )
 )
 
 
+// maxAgentResponseBytes is the maximum agent reply size that is accepted. This
+// is a sanity check, not a limit in the spec.
+const maxAgentResponseBytes = 16 << 20
+
 // Agent messages:
 // Agent messages:
 // These structures mirror the wire format of the corresponding ssh agent
 // These structures mirror the wire format of the corresponding ssh agent
 // messages found in PROTOCOL.agent.
 // messages found in PROTOCOL.agent.
@@ -85,18 +88,16 @@ type AgentKey struct {
 func (ak *AgentKey) String() string {
 func (ak *AgentKey) String() string {
 	algo, _, ok := parseString(ak.blob)
 	algo, _, ok := parseString(ak.blob)
 	if !ok {
 	if !ok {
-		return "malformed key"
+		return "ssh: malformed key"
 	}
 	}
 
 
-	algoName := string(algo)
-	b64EncKey := base64.StdEncoding.EncodeToString(ak.blob)
-	comment := ""
+	s := string(algo) + " " + base64.StdEncoding.EncodeToString(ak.blob)
 
 
 	if ak.Comment != "" {
 	if ak.Comment != "" {
-		comment = " " + ak.Comment
+		s += " " + ak.Comment
 	}
 	}
 
 
-	return fmt.Sprintf("%s %s%s", algoName, b64EncKey, comment)
+	return s
 }
 }
 
 
 // Key returns an agent's public key as a *rsa.PublicKey, *dsa.PublicKey, or
 // Key returns an agent's public key as a *rsa.PublicKey, *dsa.PublicKey, or
@@ -131,50 +132,51 @@ type AgentClient struct {
 	io.ReadWriter
 	io.ReadWriter
 }
 }
 
 
-func (ac *AgentClient) sendRequest(req []byte) error {
-	msg := make([]byte, stringLength(req))
+// sendAndReceive sends req to the agent and waits for a reply. On success,
+// the reply is unmarshaled into reply and replyType is set to the first byte of
+// the reply, which contains the type of the message.
+func (ac *AgentClient) sendAndReceive(req []byte) (reply interface{}, replyType uint8, err error) {
+	msg := make([]byte, stringLength(len(req)))
 	marshalString(msg, req)
 	marshalString(msg, req)
-	if _, err := ac.Write(msg); err != nil {
-		return err
+	if _, err = ac.Write(msg); err != nil {
+		return
 	}
 	}
-	return nil
-}
 
 
-func (ac *AgentClient) readResponse() ([]byte, error) {
 	var respSizeBuf [4]byte
 	var respSizeBuf [4]byte
-	if _, err := io.ReadFull(ac, respSizeBuf[:]); err != nil {
-		return nil, err
+	if _, err = io.ReadFull(ac, respSizeBuf[:]); err != nil {
+		return
 	}
 	}
+	respSize, _, _ := parseUint32(respSizeBuf[:])
 
 
-	respSize, _, ok := parseUint32(respSizeBuf[:])
-	if !ok {
-		return nil, errors.New("ssh: failure to parse response size")
+	if respSize > maxAgentResponseBytes {
+		err = errors.New("ssh: agent reply too large")
+		return
 	}
 	}
 
 
 	buf := make([]byte, respSize)
 	buf := make([]byte, respSize)
-	if _, err := io.ReadFull(ac, buf); err != nil {
-		return nil, err
+	if _, err = io.ReadFull(ac, buf); err != nil {
+		return
 	}
 	}
-	return buf, nil
+	return unmarshalAgentMsg(buf)
 }
 }
 
 
 // RequestIdentities queries the agent for protocol 2 keys as defined in
 // RequestIdentities queries the agent for protocol 2 keys as defined in
 // PROTOCOL.agent section 2.5.2.
 // PROTOCOL.agent section 2.5.2.
 func (ac *AgentClient) RequestIdentities() ([]*AgentKey, error) {
 func (ac *AgentClient) RequestIdentities() ([]*AgentKey, error) {
 	req := marshal(agentRequestIdentities, requestIdentitiesAgentMsg{})
 	req := marshal(agentRequestIdentities, requestIdentitiesAgentMsg{})
-	if err := ac.sendRequest(req); err != nil {
-		return nil, err
-	}
 
 
-	resp, err := ac.readResponse()
+	msg, msgType, err := ac.sendAndReceive(req)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	switch msg := decodeAgentMsg(resp).(type) {
+	switch msg := msg.(type) {
 	case *identitiesAnswerAgentMsg:
 	case *identitiesAnswerAgentMsg:
+		if msg.NumKeys > maxAgentResponseBytes/8 {
+			return nil, errors.New("ssh: too many keys in agent reply")
+		}
 		keys := make([]*AgentKey, msg.NumKeys)
 		keys := make([]*AgentKey, msg.NumKeys)
-		data := msg.Keys[:]
+		data := msg.Keys
 		for i := uint32(0); i < msg.NumKeys; i++ {
 		for i := uint32(0); i < msg.NumKeys; i++ {
 			var key *AgentKey
 			var key *AgentKey
 			var ok bool
 			var ok bool
@@ -185,11 +187,9 @@ func (ac *AgentClient) RequestIdentities() ([]*AgentKey, error) {
 		}
 		}
 		return keys, nil
 		return keys, nil
 	case *failureAgentMsg:
 	case *failureAgentMsg:
-		return nil, errors.New("ssh: failed to list keys.")
-	case ParseError, UnexpectedMessageError:
-		return nil, msg.(error)
+		return nil, errors.New("ssh: failed to list keys")
 	}
 	}
-	return nil, UnexpectedMessageError{agentIdentitiesAnswer, resp[0]}
+	return nil, UnexpectedMessageError{agentIdentitiesAnswer, msgType}
 }
 }
 
 
 // SignRequest requests the signing of data by the agent using a protocol 2 key
 // SignRequest requests the signing of data by the agent using a protocol 2 key
@@ -200,29 +200,26 @@ func (ac *AgentClient) SignRequest(key interface{}, data []byte) ([]byte, error)
 		KeyBlob: serializePublickey(key),
 		KeyBlob: serializePublickey(key),
 		Data:    data,
 		Data:    data,
 	})
 	})
-	if err := ac.sendRequest(req); err != nil {
-		return nil, err
-	}
 
 
-	resp, err := ac.readResponse()
+	msg, msgType, err := ac.sendAndReceive(req)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	switch msg := decodeAgentMsg(resp).(type) {
+	switch msg := msg.(type) {
 	case *signResponseAgentMsg:
 	case *signResponseAgentMsg:
 		return msg.SigBlob, nil
 		return msg.SigBlob, nil
 	case *failureAgentMsg:
 	case *failureAgentMsg:
 		return nil, errors.New("ssh: failed to sign challenge")
 		return nil, errors.New("ssh: failed to sign challenge")
-	case ParseError, UnexpectedMessageError:
-		return nil, msg.(error)
 	}
 	}
-	return nil, UnexpectedMessageError{agentSignResponse, resp[0]}
+	return nil, UnexpectedMessageError{agentSignResponse, msgType}
 }
 }
 
 
-func decodeAgentMsg(packet []byte) interface{} {
+// unmarshalAgentMsg parses an agent message in packet, returning the parsed
+// form and the message type of packet.
+func unmarshalAgentMsg(packet []byte) (interface{}, uint8, error) {
 	if len(packet) < 1 {
 	if len(packet) < 1 {
-		return ParseError{0}
+		return nil, 0, ParseError{0}
 	}
 	}
 	var msg interface{}
 	var msg interface{}
 	switch packet[0] {
 	switch packet[0] {
@@ -235,10 +232,10 @@ func decodeAgentMsg(packet []byte) interface{} {
 	case agentSignResponse:
 	case agentSignResponse:
 		msg = new(signResponseAgentMsg)
 		msg = new(signResponseAgentMsg)
 	default:
 	default:
-		return UnexpectedMessageError{0, packet[0]}
+		return nil, 0, UnexpectedMessageError{0, packet[0]}
 	}
 	}
 	if err := unmarshal(msg, packet, packet[0]); err != nil {
 	if err := unmarshal(msg, packet, packet[0]); err != nil {
-		return err
+		return nil, 0, err
 	}
 	}
-	return msg
+	return msg, packet[0], nil
 }
 }

+ 7 - 9
ssh/certs.go

@@ -154,18 +154,18 @@ func marshalOpenSSHCertV01(cert *OpenSSHCertV01) []byte {
 
 
 	sigKey := serializePublickey(cert.SignatureKey)
 	sigKey := serializePublickey(cert.SignatureKey)
 
 
-	length := stringLength(cert.Nonce)
+	length := stringLength(len(cert.Nonce))
 	length += len(pubKey)
 	length += len(pubKey)
 	length += 8 // Length of Serial
 	length += 8 // Length of Serial
 	length += 4 // Length of Type
 	length += 4 // Length of Type
-	length += stringLength([]byte(cert.KeyId))
+	length += stringLength(len(cert.KeyId))
 	length += lengthPrefixedNameListLength(cert.ValidPrincipals)
 	length += lengthPrefixedNameListLength(cert.ValidPrincipals)
 	length += 8 // Length of ValidAfter
 	length += 8 // Length of ValidAfter
 	length += 8 // Length of ValidBefore
 	length += 8 // Length of ValidBefore
 	length += tupleListLength(cert.CriticalOptions)
 	length += tupleListLength(cert.CriticalOptions)
 	length += tupleListLength(cert.Extensions)
 	length += tupleListLength(cert.Extensions)
-	length += stringLength(cert.Reserved)
-	length += stringLength(sigKey)
+	length += stringLength(len(cert.Reserved))
+	length += stringLength(len(sigKey))
 	length += signatureLength(cert.Signature)
 	length += signatureLength(cert.Signature)
 
 
 	ret := make([]byte, length)
 	ret := make([]byte, length)
@@ -215,9 +215,7 @@ func parseLengthPrefixedNameList(in []byte) (out []string, rest []byte, ok bool)
 
 
 	for len(list) > 0 {
 	for len(list) > 0 {
 		var next []byte
 		var next []byte
-		var ok bool
-		next, list, ok = parseString(list)
-		if !ok {
+		if next, list, ok = parseString(list); !ok {
 			return nil, nil, false
 			return nil, nil, false
 		}
 		}
 		out = append(out, string(next))
 		out = append(out, string(next))
@@ -272,8 +270,8 @@ func parseTupleList(in []byte) (out []tuple, rest []byte, ok bool) {
 
 
 func signatureLength(sig *signature) int {
 func signatureLength(sig *signature) int {
 	length := 4 // length prefix for signature
 	length := 4 // length prefix for signature
-	length += stringLength([]byte(sig.Format))
-	length += stringLength(sig.Blob)
+	length += stringLength(len(sig.Format))
+	length += stringLength(len(sig.Blob))
 	return length
 	return length
 }
 }
 
 

+ 8 - 9
ssh/channel.go

@@ -56,7 +56,7 @@ type ChannelRequest struct {
 }
 }
 
 
 func (c ChannelRequest) Error() string {
 func (c ChannelRequest) Error() string {
-	return "channel request received"
+	return "ssh: channel request received"
 }
 }
 
 
 // RejectionReason is an enumeration used when rejecting channel creation
 // RejectionReason is an enumeration used when rejecting channel creation
@@ -255,7 +255,7 @@ func (c *channel) Read(data []byte) (n int, err error) {
 		}
 		}
 
 
 		if c.length > 0 {
 		if c.length > 0 {
-			tail := min(c.head + c.length, len(c.pendingData))
+			tail := min(c.head+c.length, len(c.pendingData))
 			n = copy(data, c.pendingData[c.head:tail])
 			n = copy(data, c.pendingData[c.head:tail])
 			c.head += n
 			c.head += n
 			c.length -= n
 			c.length -= n
@@ -374,18 +374,17 @@ func (c *channel) AckRequest(ok bool) error {
 		return c.serverConn.err
 		return c.serverConn.err
 	}
 	}
 
 
-	if ok {
-		ack := channelRequestSuccessMsg{
-			PeersId: c.theirId,
-		}
-		return c.serverConn.writePacket(marshal(msgChannelSuccess, ack))
-	} else {
+	if !ok {
 		ack := channelRequestFailureMsg{
 		ack := channelRequestFailureMsg{
 			PeersId: c.theirId,
 			PeersId: c.theirId,
 		}
 		}
 		return c.serverConn.writePacket(marshal(msgChannelFailure, ack))
 		return c.serverConn.writePacket(marshal(msgChannelFailure, ack))
 	}
 	}
-	panic("unreachable")
+
+	ack := channelRequestSuccessMsg{
+		PeersId: c.theirId,
+	}
+	return c.serverConn.writePacket(marshal(msgChannelSuccess, ack))
 }
 }
 
 
 func (c *channel) ChannelType() string {
 func (c *channel) ChannelType() string {

+ 5 - 5
ssh/cipher.go

@@ -35,10 +35,10 @@ func newRC4(key, iv []byte) (cipher.Stream, error) {
 }
 }
 
 
 type cipherMode struct {
 type cipherMode struct {
-	keySize  int
-	ivSize   int
-	skip     int
-	createFn func(key, iv []byte) (cipher.Stream, error)
+	keySize    int
+	ivSize     int
+	skip       int
+	createFunc func(key, iv []byte) (cipher.Stream, error)
 }
 }
 
 
 func (c *cipherMode) createCipher(key, iv []byte) (cipher.Stream, error) {
 func (c *cipherMode) createCipher(key, iv []byte) (cipher.Stream, error) {
@@ -49,7 +49,7 @@ func (c *cipherMode) createCipher(key, iv []byte) (cipher.Stream, error) {
 		panic("ssh: iv too small for cipher")
 		panic("ssh: iv too small for cipher")
 	}
 	}
 
 
-	stream, err := c.createFn(key[:c.keySize], iv[:c.ivSize])
+	stream, err := c.createFunc(key[:c.keySize], iv[:c.ivSize])
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}

+ 6 - 6
ssh/client.go

@@ -154,16 +154,16 @@ func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
 		return nil, nil, err
 		return nil, nil, err
 	}
 	}
 
 
-	var kexDHReply = new(kexDHReplyMsg)
-	if err = unmarshal(kexDHReply, packet, msgKexDHReply); err != nil {
+	var kexDHReply kexDHReplyMsg
+	if err = unmarshal(&kexDHReply, packet, msgKexDHReply); err != nil {
 		return nil, nil, err
 		return nil, nil, err
 	}
 	}
 
 
-	if kexDHReply.Y.Sign() == 0 || kexDHReply.Y.Cmp(group.p) >= 0 {
-		return nil, nil, errors.New("server DH parameter out of bounds")
+	kInt, err := group.diffieHellman(kexDHReply.Y, x)
+	if err != nil {
+		return nil, nil, err
 	}
 	}
 
 
-	kInt := new(big.Int).Exp(kexDHReply.Y, x, group.p)
 	h := hashFunc.New()
 	h := hashFunc.New()
 	writeString(h, magics.clientVersion)
 	writeString(h, magics.clientVersion)
 	writeString(h, magics.serverVersion)
 	writeString(h, magics.serverVersion)
@@ -352,7 +352,7 @@ func (c *clientChan) waitForChannelOpenResponse() error {
 	case *channelOpenFailureMsg:
 	case *channelOpenFailureMsg:
 		return errors.New(safeString(msg.Message))
 		return errors.New(safeString(msg.Message))
 	}
 	}
-	return errors.New("unexpected packet")
+	return errors.New("ssh: unexpected packet")
 }
 }
 
 
 // sendEOF sends EOF to the server. RFC 4254 Section 5.3
 // sendEOF sends EOF to the server. RFC 4254 Section 5.3

+ 1 - 1
ssh/client_auth.go

@@ -213,7 +213,7 @@ func (p *publickeyAuth) auth(session []byte, user string, t *transport, rand io.
 		}
 		}
 		// manually wrap the serialized signature in a string
 		// manually wrap the serialized signature in a string
 		s := serializeSignature(algoname, sign)
 		s := serializeSignature(algoname, sign)
-		sig := make([]byte, stringLength(s))
+		sig := make([]byte, stringLength(len(s)))
 		marshalString(sig, s)
 		marshalString(sig, s)
 		msg := publickeyAuthMsg{
 		msg := publickeyAuthMsg{
 			User:     user,
 			User:     user,

+ 1 - 1
ssh/client_auth_test.go

@@ -85,7 +85,7 @@ func (k *keychain) Sign(i int, rand io.Reader, data []byte) (sig []byte, err err
 	case *rsa.PrivateKey:
 	case *rsa.PrivateKey:
 		return rsa.SignPKCS1v15(rand, key, hashFunc, digest)
 		return rsa.SignPKCS1v15(rand, key, hashFunc, digest)
 	}
 	}
-	return nil, errors.New("unknown key type")
+	return nil, errors.New("ssh: unknown key type")
 }
 }
 
 
 func (k *keychain) loadPEM(file string) error {
 func (k *keychain) loadPEM(file string) error {

+ 17 - 9
ssh/common.go

@@ -7,6 +7,7 @@ package ssh
 import (
 import (
 	"crypto/dsa"
 	"crypto/dsa"
 	"crypto/rsa"
 	"crypto/rsa"
+	"errors"
 	"math/big"
 	"math/big"
 	"strconv"
 	"strconv"
 	"sync"
 	"sync"
@@ -32,6 +33,13 @@ type dhGroup struct {
 	g, p *big.Int
 	g, p *big.Int
 }
 }
 
 
+func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) {
+	if theirPublic.Sign() <= 0 || theirPublic.Cmp(group.p) >= 0 {
+		return nil, errors.New("ssh: DH parameter out of bounds")
+	}
+	return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil
+}
+
 // dhGroup1 is the group called diffie-hellman-group1-sha1 in RFC 4253 and
 // dhGroup1 is the group called diffie-hellman-group1-sha1 in RFC 4253 and
 // Oakley Group 2 in RFC 2409.
 // Oakley Group 2 in RFC 2409.
 var dhGroup1 *dhGroup
 var dhGroup1 *dhGroup
@@ -178,8 +186,8 @@ func serializeSignature(algoname string, sig []byte) []byte {
 	case hostAlgoDSACertV01:
 	case hostAlgoDSACertV01:
 		algoname = "ssh-dss"
 		algoname = "ssh-dss"
 	}
 	}
-	length := stringLength([]byte(algoname))
-	length += stringLength(sig)
+	length := stringLength(len(algoname))
+	length += stringLength(len(sig))
 
 
 	ret := make([]byte, length)
 	ret := make([]byte, length)
 	r := marshalString(ret, []byte(algoname))
 	r := marshalString(ret, []byte(algoname))
@@ -203,7 +211,7 @@ func serializePublickey(key interface{}) []byte {
 		panic("unexpected key type")
 		panic("unexpected key type")
 	}
 	}
 
 
-	length := stringLength([]byte(algoname))
+	length := stringLength(len(algoname))
 	length += len(pubKeyBytes)
 	length += len(pubKeyBytes)
 	ret := make([]byte, length)
 	ret := make([]byte, length)
 	r := marshalString(ret, []byte(algoname))
 	r := marshalString(ret, []byte(algoname))
@@ -230,14 +238,14 @@ func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubK
 	service := []byte(req.Service)
 	service := []byte(req.Service)
 	method := []byte(req.Method)
 	method := []byte(req.Method)
 
 
-	length := stringLength(sessionId)
+	length := stringLength(len(sessionId))
 	length += 1
 	length += 1
-	length += stringLength(user)
-	length += stringLength(service)
-	length += stringLength(method)
+	length += stringLength(len(user))
+	length += stringLength(len(service))
+	length += stringLength(len(method))
 	length += 1
 	length += 1
-	length += stringLength(algo)
-	length += stringLength(pubKey)
+	length += stringLength(len(algo))
+	length += stringLength(len(pubKey))
 
 
 	ret := make([]byte, length)
 	ret := make([]byte, length)
 	r := marshalString(ret, sessionId)
 	r := marshalString(ret, sessionId)

+ 1 - 1
ssh/keys.go

@@ -78,7 +78,7 @@ func parseDSA(in []byte) (out *dsa.PublicKey, rest []byte, ok bool) {
 // marshalPrivRSA serializes an RSA private key according to RFC 4253, section 6.6.
 // marshalPrivRSA serializes an RSA private key according to RFC 4253, section 6.6.
 func marshalPrivRSA(priv *rsa.PrivateKey) []byte {
 func marshalPrivRSA(priv *rsa.PrivateKey) []byte {
 	e := new(big.Int).SetInt64(int64(priv.E))
 	e := new(big.Int).SetInt64(int64(priv.E))
-	length := stringLength([]byte(hostAlgoRSA))
+	length := stringLength(len(hostAlgoRSA))
 	length += intLength(e)
 	length += intLength(e)
 	length += intLength(priv.N)
 	length += intLength(priv.N)
 
 

+ 2 - 2
ssh/messages.go

@@ -543,8 +543,8 @@ func writeString(w io.Writer, s []byte) {
 	w.Write(s)
 	w.Write(s)
 }
 }
 
 
-func stringLength(s []byte) int {
-	return 4 + len(s)
+func stringLength(n int) int {
+	return 4 + n
 }
 }
 
 
 func marshalString(to []byte, s []byte) []byte {
 func marshalString(to []byte, s []byte) []byte {

+ 6 - 7
ssh/server.go

@@ -141,24 +141,23 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
 		return
 		return
 	}
 	}
 
 
-	if kexDHInit.X.Sign() == 0 || kexDHInit.X.Cmp(group.p) >= 0 {
-		return nil, nil, errors.New("client DH parameter out of bounds")
-	}
-
 	y, err := rand.Int(s.config.rand(), group.p)
 	y, err := rand.Int(s.config.rand(), group.p)
 	if err != nil {
 	if err != nil {
 		return
 		return
 	}
 	}
 
 
 	Y := new(big.Int).Exp(group.g, y, group.p)
 	Y := new(big.Int).Exp(group.g, y, group.p)
-	kInt := new(big.Int).Exp(kexDHInit.X, y, group.p)
+	kInt, err := group.diffieHellman(kexDHInit.X, y)
+	if err != nil {
+		return nil, nil, err
+	}
 
 
 	var serializedHostKey []byte
 	var serializedHostKey []byte
 	switch hostKeyAlgo {
 	switch hostKeyAlgo {
 	case hostAlgoRSA:
 	case hostAlgoRSA:
 		serializedHostKey = s.config.rsaSerialized
 		serializedHostKey = s.config.rsaSerialized
 	default:
 	default:
-		return nil, nil, errors.New("internal error")
+		return nil, nil, errors.New("ssh: internal error")
 	}
 	}
 
 
 	h := hashFunc.New()
 	h := hashFunc.New()
@@ -187,7 +186,7 @@ func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handsha
 			return
 			return
 		}
 		}
 	default:
 	default:
-		return nil, nil, errors.New("internal error")
+		return nil, nil, errors.New("ssh: internal error")
 	}
 	}
 
 
 	serializedSig := serializeSignature(hostAlgoRSA, sig)
 	serializedSig := serializeSignature(hostAlgoRSA, sig)

+ 2 - 2
ssh/session.go

@@ -231,9 +231,9 @@ func (s *Session) waitForResponse() error {
 	case *channelRequestSuccessMsg:
 	case *channelRequestSuccessMsg:
 		return nil
 		return nil
 	case *channelRequestFailureMsg:
 	case *channelRequestFailureMsg:
-		return errors.New("request failed")
+		return errors.New("ssh: request failed")
 	}
 	}
-	return fmt.Errorf("unknown packet %T received: %v", msg, msg)
+	return fmt.Errorf("ssh: unknown packet %T received: %v", msg, msg)
 }
 }
 
 
 func (s *Session) start() error {
 func (s *Session) start() error {

+ 5 - 2
ssh/transport.go

@@ -105,10 +105,10 @@ func (r *reader) readOnePacket() ([]byte, error) {
 	paddingLength := uint32(lengthBytes[4])
 	paddingLength := uint32(lengthBytes[4])
 
 
 	if length <= paddingLength+1 {
 	if length <= paddingLength+1 {
-		return nil, errors.New("invalid packet length")
+		return nil, errors.New("ssh: invalid packet length")
 	}
 	}
 	if length > maxPacketSize {
 	if length > maxPacketSize {
-		return nil, errors.New("packet too large")
+		return nil, errors.New("ssh: packet too large")
 	}
 	}
 
 
 	packet := make([]byte, length-1+macSize)
 	packet := make([]byte, length-1+macSize)
@@ -136,6 +136,9 @@ func (t *transport) readPacket() ([]byte, error) {
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
+		if len(packet) == 0 {
+			return nil, errors.New("ssh: zero length packet")
+		}
 		if packet[0] != msgIgnore && packet[0] != msgDebug {
 		if packet[0] != msgIgnore && packet[0] != msgDebug {
 			return packet, nil
 			return packet, nil
 		}
 		}