|
@@ -8,11 +8,13 @@ import (
|
|
|
"bytes"
|
|
"bytes"
|
|
|
"crypto/rand"
|
|
"crypto/rand"
|
|
|
"errors"
|
|
"errors"
|
|
|
|
|
+ "io"
|
|
|
"net"
|
|
"net"
|
|
|
"os"
|
|
"os"
|
|
|
"os/exec"
|
|
"os/exec"
|
|
|
"path/filepath"
|
|
"path/filepath"
|
|
|
"strconv"
|
|
"strconv"
|
|
|
|
|
+ "sync"
|
|
|
"testing"
|
|
"testing"
|
|
|
"time"
|
|
"time"
|
|
|
|
|
|
|
@@ -196,6 +198,63 @@ func testAgentInterface(t *testing.T, agent ExtendedAgent, key interface{}, cert
|
|
|
|
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+func TestMalformedRequests(t *testing.T) {
|
|
|
|
|
+ keyringAgent := NewKeyring()
|
|
|
|
|
+ listener, err := netListener()
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ t.Fatalf("netListener: %v", err)
|
|
|
|
|
+ }
|
|
|
|
|
+ defer listener.Close()
|
|
|
|
|
+
|
|
|
|
|
+ testCase := func(t *testing.T, requestBytes []byte, wantServerErr bool) {
|
|
|
|
|
+ var wg sync.WaitGroup
|
|
|
|
|
+ wg.Add(1)
|
|
|
|
|
+ go func() {
|
|
|
|
|
+ defer wg.Done()
|
|
|
|
|
+ c, err := listener.Accept()
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ t.Errorf("listener.Accept: %v", err)
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+ defer c.Close()
|
|
|
|
|
+
|
|
|
|
|
+ err = ServeAgent(keyringAgent, c)
|
|
|
|
|
+ if err == nil {
|
|
|
|
|
+ t.Error("ServeAgent should have returned an error to malformed input")
|
|
|
|
|
+ } else {
|
|
|
|
|
+ if (err != io.EOF) != wantServerErr {
|
|
|
|
|
+ t.Errorf("ServeAgent returned expected error: %v", err)
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }()
|
|
|
|
|
+
|
|
|
|
|
+ c, err := net.Dial("tcp", listener.Addr().String())
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ t.Fatalf("net.Dial: %v", err)
|
|
|
|
|
+ }
|
|
|
|
|
+ _, err = c.Write(requestBytes)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ t.Errorf("Unexpected error writing raw bytes on connection: %v", err)
|
|
|
|
|
+ }
|
|
|
|
|
+ c.Close()
|
|
|
|
|
+ wg.Wait()
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ var testCases = []struct {
|
|
|
|
|
+ name string
|
|
|
|
|
+ requestBytes []byte
|
|
|
|
|
+ wantServerErr bool
|
|
|
|
|
+ }{
|
|
|
|
|
+ {"Empty request", []byte{}, false},
|
|
|
|
|
+ {"Short header", []byte{0x00}, true},
|
|
|
|
|
+ {"Empty body", []byte{0x00, 0x00, 0x00, 0x00}, true},
|
|
|
|
|
+ {"Short body", []byte{0x00, 0x00, 0x00, 0x01}, false},
|
|
|
|
|
+ }
|
|
|
|
|
+ for _, tc := range testCases {
|
|
|
|
|
+ t.Run(tc.name, func(t *testing.T) { testCase(t, tc.requestBytes, tc.wantServerErr) })
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
func TestAgent(t *testing.T) {
|
|
func TestAgent(t *testing.T) {
|
|
|
for _, keyType := range []string{"rsa", "dsa", "ecdsa", "ed25519"} {
|
|
for _, keyType := range []string{"rsa", "dsa", "ecdsa", "ed25519"} {
|
|
|
testOpenSSHAgent(t, testPrivateKeys[keyType], nil, 0)
|
|
testOpenSSHAgent(t, testPrivateKeys[keyType], nil, 0)
|
|
@@ -215,17 +274,26 @@ func TestCert(t *testing.T) {
|
|
|
testKeyringAgent(t, testPrivateKeys["rsa"], cert, 0)
|
|
testKeyringAgent(t, testPrivateKeys["rsa"], cert, 0)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
|
|
|
|
|
-// therefore is buffered (net.Pipe deadlocks if both sides start with
|
|
|
|
|
-// a write.)
|
|
|
|
|
-func netPipe() (net.Conn, net.Conn, error) {
|
|
|
|
|
|
|
+// netListener creates a localhost network listener.
|
|
|
|
|
+func netListener() (net.Listener, error) {
|
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
listener, err = net.Listen("tcp", "[::1]:0")
|
|
listener, err = net.Listen("tcp", "[::1]:0")
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
- return nil, nil, err
|
|
|
|
|
|
|
+ return nil, err
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
+ return listener, nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
|
|
|
|
|
+// therefore is buffered (net.Pipe deadlocks if both sides start with
|
|
|
|
|
+// a write.)
|
|
|
|
|
+func netPipe() (net.Conn, net.Conn, error) {
|
|
|
|
|
+ listener, err := netListener()
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, nil, err
|
|
|
|
|
+ }
|
|
|
defer listener.Close()
|
|
defer listener.Close()
|
|
|
c1, err := net.Dial("tcp", listener.Addr().String())
|
|
c1, err := net.Dial("tcp", listener.Addr().String())
|
|
|
if err != nil {
|
|
if err != nil {
|