|
|
@@ -10,7 +10,6 @@ package ssh
|
|
|
import (
|
|
|
"encoding/base64"
|
|
|
"errors"
|
|
|
- "fmt"
|
|
|
"io"
|
|
|
)
|
|
|
|
|
|
@@ -44,6 +43,10 @@ const (
|
|
|
agentConstrainConfirm = 2
|
|
|
)
|
|
|
|
|
|
+// maxAgentResponseBytes is the maximum agent reply size that is accepted. This
|
|
|
+// is a sanity check, not a limit in the spec.
|
|
|
+const maxAgentResponseBytes = 16 << 20
|
|
|
+
|
|
|
// Agent messages:
|
|
|
// These structures mirror the wire format of the corresponding ssh agent
|
|
|
// messages found in PROTOCOL.agent.
|
|
|
@@ -85,18 +88,16 @@ type AgentKey struct {
|
|
|
func (ak *AgentKey) String() string {
|
|
|
algo, _, ok := parseString(ak.blob)
|
|
|
if !ok {
|
|
|
- return "malformed key"
|
|
|
+ return "ssh: malformed key"
|
|
|
}
|
|
|
|
|
|
- algoName := string(algo)
|
|
|
- b64EncKey := base64.StdEncoding.EncodeToString(ak.blob)
|
|
|
- comment := ""
|
|
|
+ s := string(algo) + " " + base64.StdEncoding.EncodeToString(ak.blob)
|
|
|
|
|
|
if ak.Comment != "" {
|
|
|
- comment = " " + ak.Comment
|
|
|
+ s += " " + ak.Comment
|
|
|
}
|
|
|
|
|
|
- return fmt.Sprintf("%s %s%s", algoName, b64EncKey, comment)
|
|
|
+ return s
|
|
|
}
|
|
|
|
|
|
// Key returns an agent's public key as a *rsa.PublicKey, *dsa.PublicKey, or
|
|
|
@@ -131,50 +132,51 @@ type AgentClient struct {
|
|
|
io.ReadWriter
|
|
|
}
|
|
|
|
|
|
-func (ac *AgentClient) sendRequest(req []byte) error {
|
|
|
- msg := make([]byte, stringLength(req))
|
|
|
+// sendAndReceive sends req to the agent and waits for a reply. On success,
|
|
|
+// the reply is unmarshaled into reply and replyType is set to the first byte of
|
|
|
+// the reply, which contains the type of the message.
|
|
|
+func (ac *AgentClient) sendAndReceive(req []byte) (reply interface{}, replyType uint8, err error) {
|
|
|
+ msg := make([]byte, stringLength(len(req)))
|
|
|
marshalString(msg, req)
|
|
|
- if _, err := ac.Write(msg); err != nil {
|
|
|
- return err
|
|
|
+ if _, err = ac.Write(msg); err != nil {
|
|
|
+ return
|
|
|
}
|
|
|
- return nil
|
|
|
-}
|
|
|
|
|
|
-func (ac *AgentClient) readResponse() ([]byte, error) {
|
|
|
var respSizeBuf [4]byte
|
|
|
- if _, err := io.ReadFull(ac, respSizeBuf[:]); err != nil {
|
|
|
- return nil, err
|
|
|
+ if _, err = io.ReadFull(ac, respSizeBuf[:]); err != nil {
|
|
|
+ return
|
|
|
}
|
|
|
+ respSize, _, _ := parseUint32(respSizeBuf[:])
|
|
|
|
|
|
- respSize, _, ok := parseUint32(respSizeBuf[:])
|
|
|
- if !ok {
|
|
|
- return nil, errors.New("ssh: failure to parse response size")
|
|
|
+ if respSize > maxAgentResponseBytes {
|
|
|
+ err = errors.New("ssh: agent reply too large")
|
|
|
+ return
|
|
|
}
|
|
|
|
|
|
buf := make([]byte, respSize)
|
|
|
- if _, err := io.ReadFull(ac, buf); err != nil {
|
|
|
- return nil, err
|
|
|
+ if _, err = io.ReadFull(ac, buf); err != nil {
|
|
|
+ return
|
|
|
}
|
|
|
- return buf, nil
|
|
|
+ return unmarshalAgentMsg(buf)
|
|
|
}
|
|
|
|
|
|
// RequestIdentities queries the agent for protocol 2 keys as defined in
|
|
|
// PROTOCOL.agent section 2.5.2.
|
|
|
func (ac *AgentClient) RequestIdentities() ([]*AgentKey, error) {
|
|
|
req := marshal(agentRequestIdentities, requestIdentitiesAgentMsg{})
|
|
|
- if err := ac.sendRequest(req); err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
|
|
|
- resp, err := ac.readResponse()
|
|
|
+ msg, msgType, err := ac.sendAndReceive(req)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
- switch msg := decodeAgentMsg(resp).(type) {
|
|
|
+ switch msg := msg.(type) {
|
|
|
case *identitiesAnswerAgentMsg:
|
|
|
+ if msg.NumKeys > maxAgentResponseBytes/8 {
|
|
|
+ return nil, errors.New("ssh: too many keys in agent reply")
|
|
|
+ }
|
|
|
keys := make([]*AgentKey, msg.NumKeys)
|
|
|
- data := msg.Keys[:]
|
|
|
+ data := msg.Keys
|
|
|
for i := uint32(0); i < msg.NumKeys; i++ {
|
|
|
var key *AgentKey
|
|
|
var ok bool
|
|
|
@@ -185,11 +187,9 @@ func (ac *AgentClient) RequestIdentities() ([]*AgentKey, error) {
|
|
|
}
|
|
|
return keys, nil
|
|
|
case *failureAgentMsg:
|
|
|
- return nil, errors.New("ssh: failed to list keys.")
|
|
|
- case ParseError, UnexpectedMessageError:
|
|
|
- return nil, msg.(error)
|
|
|
+ return nil, errors.New("ssh: failed to list keys")
|
|
|
}
|
|
|
- return nil, UnexpectedMessageError{agentIdentitiesAnswer, resp[0]}
|
|
|
+ return nil, UnexpectedMessageError{agentIdentitiesAnswer, msgType}
|
|
|
}
|
|
|
|
|
|
// SignRequest requests the signing of data by the agent using a protocol 2 key
|
|
|
@@ -200,29 +200,26 @@ func (ac *AgentClient) SignRequest(key interface{}, data []byte) ([]byte, error)
|
|
|
KeyBlob: serializePublickey(key),
|
|
|
Data: data,
|
|
|
})
|
|
|
- if err := ac.sendRequest(req); err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
|
|
|
- resp, err := ac.readResponse()
|
|
|
+ msg, msgType, err := ac.sendAndReceive(req)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
- switch msg := decodeAgentMsg(resp).(type) {
|
|
|
+ switch msg := msg.(type) {
|
|
|
case *signResponseAgentMsg:
|
|
|
return msg.SigBlob, nil
|
|
|
case *failureAgentMsg:
|
|
|
return nil, errors.New("ssh: failed to sign challenge")
|
|
|
- case ParseError, UnexpectedMessageError:
|
|
|
- return nil, msg.(error)
|
|
|
}
|
|
|
- return nil, UnexpectedMessageError{agentSignResponse, resp[0]}
|
|
|
+ return nil, UnexpectedMessageError{agentSignResponse, msgType}
|
|
|
}
|
|
|
|
|
|
-func decodeAgentMsg(packet []byte) interface{} {
|
|
|
+// unmarshalAgentMsg parses an agent message in packet, returning the parsed
|
|
|
+// form and the message type of packet.
|
|
|
+func unmarshalAgentMsg(packet []byte) (interface{}, uint8, error) {
|
|
|
if len(packet) < 1 {
|
|
|
- return ParseError{0}
|
|
|
+ return nil, 0, ParseError{0}
|
|
|
}
|
|
|
var msg interface{}
|
|
|
switch packet[0] {
|
|
|
@@ -235,10 +232,10 @@ func decodeAgentMsg(packet []byte) interface{} {
|
|
|
case agentSignResponse:
|
|
|
msg = new(signResponseAgentMsg)
|
|
|
default:
|
|
|
- return UnexpectedMessageError{0, packet[0]}
|
|
|
+ return nil, 0, UnexpectedMessageError{0, packet[0]}
|
|
|
}
|
|
|
if err := unmarshal(msg, packet, packet[0]); err != nil {
|
|
|
- return err
|
|
|
+ return nil, 0, err
|
|
|
}
|
|
|
- return msg
|
|
|
+ return msg, packet[0], nil
|
|
|
}
|