Kaynağa Gözat

ssh/agent: add checking for empty SSH requests

Previously empty SSH requests would cause a panic.

Change-Id: I8443fee50891b3d2b3b62ac01fb0b9e96244241f
GitHub-Last-Rev: 64f00d2bf2ee722f53e68b6bd4f70c722d7694bd
GitHub-Pull-Request: golang/crypto#58
Reviewed-on: https://go-review.googlesource.com/c/140237
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Ian Haken 6 yıl önce
ebeveyn
işleme
7f87c0fbb8
2 değiştirilmiş dosya ile 76 ekleme ve 5 silme
  1. 73 5
      ssh/agent/client_test.go
  2. 3 0
      ssh/agent/server.go

+ 73 - 5
ssh/agent/client_test.go

@@ -8,11 +8,13 @@ import (
 	"bytes"
 	"bytes"
 	"crypto/rand"
 	"crypto/rand"
 	"errors"
 	"errors"
+	"io"
 	"net"
 	"net"
 	"os"
 	"os"
 	"os/exec"
 	"os/exec"
 	"path/filepath"
 	"path/filepath"
 	"strconv"
 	"strconv"
+	"sync"
 	"testing"
 	"testing"
 	"time"
 	"time"
 
 
@@ -196,6 +198,63 @@ func testAgentInterface(t *testing.T, agent ExtendedAgent, key interface{}, cert
 
 
 }
 }
 
 
+func TestMalformedRequests(t *testing.T) {
+	keyringAgent := NewKeyring()
+	listener, err := netListener()
+	if err != nil {
+		t.Fatalf("netListener: %v", err)
+	}
+	defer listener.Close()
+
+	testCase := func(t *testing.T, requestBytes []byte, wantServerErr bool) {
+		var wg sync.WaitGroup
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			c, err := listener.Accept()
+			if err != nil {
+				t.Errorf("listener.Accept: %v", err)
+				return
+			}
+			defer c.Close()
+
+			err = ServeAgent(keyringAgent, c)
+			if err == nil {
+				t.Error("ServeAgent should have returned an error to malformed input")
+			} else {
+				if (err != io.EOF) != wantServerErr {
+					t.Errorf("ServeAgent returned expected error: %v", err)
+				}
+			}
+		}()
+
+		c, err := net.Dial("tcp", listener.Addr().String())
+		if err != nil {
+			t.Fatalf("net.Dial: %v", err)
+		}
+		_, err = c.Write(requestBytes)
+		if err != nil {
+			t.Errorf("Unexpected error writing raw bytes on connection: %v", err)
+		}
+		c.Close()
+		wg.Wait()
+	}
+
+	var testCases = []struct {
+		name          string
+		requestBytes  []byte
+		wantServerErr bool
+	}{
+		{"Empty request", []byte{}, false},
+		{"Short header", []byte{0x00}, true},
+		{"Empty body", []byte{0x00, 0x00, 0x00, 0x00}, true},
+		{"Short body", []byte{0x00, 0x00, 0x00, 0x01}, false},
+	}
+	for _, tc := range testCases {
+		t.Run(tc.name, func(t *testing.T) { testCase(t, tc.requestBytes, tc.wantServerErr) })
+	}
+}
+
 func TestAgent(t *testing.T) {
 func TestAgent(t *testing.T) {
 	for _, keyType := range []string{"rsa", "dsa", "ecdsa", "ed25519"} {
 	for _, keyType := range []string{"rsa", "dsa", "ecdsa", "ed25519"} {
 		testOpenSSHAgent(t, testPrivateKeys[keyType], nil, 0)
 		testOpenSSHAgent(t, testPrivateKeys[keyType], nil, 0)
@@ -215,17 +274,26 @@ func TestCert(t *testing.T) {
 	testKeyringAgent(t, testPrivateKeys["rsa"], cert, 0)
 	testKeyringAgent(t, testPrivateKeys["rsa"], cert, 0)
 }
 }
 
 
-// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
-// therefore is buffered (net.Pipe deadlocks if both sides start with
-// a write.)
-func netPipe() (net.Conn, net.Conn, error) {
+// netListener creates a localhost network listener.
+func netListener() (net.Listener, error) {
 	listener, err := net.Listen("tcp", "127.0.0.1:0")
 	listener, err := net.Listen("tcp", "127.0.0.1:0")
 	if err != nil {
 	if err != nil {
 		listener, err = net.Listen("tcp", "[::1]:0")
 		listener, err = net.Listen("tcp", "[::1]:0")
 		if err != nil {
 		if err != nil {
-			return nil, nil, err
+			return nil, err
 		}
 		}
 	}
 	}
+	return listener, nil
+}
+
+// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
+// therefore is buffered (net.Pipe deadlocks if both sides start with
+// a write.)
+func netPipe() (net.Conn, net.Conn, error) {
+	listener, err := netListener()
+	if err != nil {
+		return nil, nil, err
+	}
 	defer listener.Close()
 	defer listener.Close()
 	c1, err := net.Dial("tcp", listener.Addr().String())
 	c1, err := net.Dial("tcp", listener.Addr().String())
 	if err != nil {
 	if err != nil {

+ 3 - 0
ssh/agent/server.go

@@ -541,6 +541,9 @@ func ServeAgent(agent Agent, c io.ReadWriter) error {
 			return err
 			return err
 		}
 		}
 		l := binary.BigEndian.Uint32(length[:])
 		l := binary.BigEndian.Uint32(length[:])
+		if l == 0 {
+			return fmt.Errorf("agent: request size is 0")
+		}
 		if l > maxAgentResponseBytes {
 		if l > maxAgentResponseBytes {
 			// We also cap requests.
 			// We also cap requests.
 			return fmt.Errorf("agent: request too large: %d", l)
 			return fmt.Errorf("agent: request too large: %d", l)