Ver código fonte

ssh: return specific error for invalid signature algorithm

Previously, this would return the default error "no auth passed yet".

Not only is the new error more specific, it makes it easier to verify
the control flow of server authentication code.

Change-Id: I6c8de4e3f91da74274acbe9d87ec4f6158b4a94f
Reviewed-on: https://go-review.googlesource.com/c/142897
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Han-Wen Nienhuys 7 anos atrás
pai
commit
e4dc69e5b2
2 arquivos alterados com 53 adições e 2 exclusões
  1. 52 2
      ssh/client_auth_test.go
  2. 1 0
      ssh/server.go

+ 52 - 2
ssh/client_auth_test.go

@@ -9,6 +9,7 @@ import (
 	"crypto/rand"
 	"crypto/rand"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
+	"io"
 	"os"
 	"os"
 	"strings"
 	"strings"
 	"testing"
 	"testing"
@@ -28,8 +29,14 @@ func (cr keyboardInteractive) Challenge(user string, instruction string, questio
 var clientPassword = "tiger"
 var clientPassword = "tiger"
 
 
 // tryAuth runs a handshake with a given config against an SSH server
 // tryAuth runs a handshake with a given config against an SSH server
-// with config serverConfig
+// with config serverConfig. Returns both client and server side errors.
 func tryAuth(t *testing.T, config *ClientConfig) error {
 func tryAuth(t *testing.T, config *ClientConfig) error {
+	err, _ := tryAuthBothSides(t, config)
+	return err
+}
+
+// tryAuthBothSides runs the handshake and returns the resulting errors from both sides of the connection.
+func tryAuthBothSides(t *testing.T, config *ClientConfig) (clientError error, serverAuthErrors []error) {
 	c1, c2, err := netPipe()
 	c1, c2, err := netPipe()
 	if err != nil {
 	if err != nil {
 		t.Fatalf("netPipe: %v", err)
 		t.Fatalf("netPipe: %v", err)
@@ -79,9 +86,13 @@ func tryAuth(t *testing.T, config *ClientConfig) error {
 	}
 	}
 	serverConfig.AddHostKey(testSigners["rsa"])
 	serverConfig.AddHostKey(testSigners["rsa"])
 
 
+	serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) {
+		serverAuthErrors = append(serverAuthErrors, err)
+	}
+
 	go newServer(c1, serverConfig)
 	go newServer(c1, serverConfig)
 	_, _, _, err = NewClientConn(c2, "", config)
 	_, _, _, err = NewClientConn(c2, "", config)
-	return err
+	return err, serverAuthErrors
 }
 }
 
 
 func TestClientAuthPublicKey(t *testing.T) {
 func TestClientAuthPublicKey(t *testing.T) {
@@ -213,6 +224,45 @@ func TestAuthMethodRSAandDSA(t *testing.T) {
 	}
 	}
 }
 }
 
 
+type invalidAlgSigner struct {
+	Signer
+}
+
+func (s *invalidAlgSigner) Sign(rand io.Reader, data []byte) (*Signature, error) {
+	sig, err := s.Signer.Sign(rand, data)
+	if sig != nil {
+		sig.Format = "invalid"
+	}
+	return sig, err
+}
+
+func TestMethodInvalidAlgorithm(t *testing.T) {
+	config := &ClientConfig{
+		User: "testuser",
+		Auth: []AuthMethod{
+			PublicKeys(&invalidAlgSigner{testSigners["rsa"]}),
+		},
+		HostKeyCallback: InsecureIgnoreHostKey(),
+	}
+
+	err, serverErrors := tryAuthBothSides(t, config)
+	if err == nil {
+		t.Fatalf("login succeeded")
+	}
+
+	found := false
+	want := "algorithm \"invalid\""
+
+	var errStrings []string
+	for _, err := range serverErrors {
+		found = found || (err != nil && strings.Contains(err.Error(), want))
+		errStrings = append(errStrings, err.Error())
+	}
+	if !found {
+		t.Errorf("server got error %q, want substring %q", errStrings, want)
+	}
+}
+
 func TestClientHMAC(t *testing.T) {
 func TestClientHMAC(t *testing.T) {
 	for _, mac := range supportedMACs {
 	for _, mac := range supportedMACs {
 		config := &ClientConfig{
 		config := &ClientConfig{

+ 1 - 0
ssh/server.go

@@ -484,6 +484,7 @@ userAuthLoop:
 				// sig.Format.  This is usually the same, but
 				// sig.Format.  This is usually the same, but
 				// for certs, the names differ.
 				// for certs, the names differ.
 				if !isAcceptableAlgo(sig.Format) {
 				if !isAcceptableAlgo(sig.Format) {
+					authErr = fmt.Errorf("ssh: algorithm %q not accepted", sig.Format)
 					break
 					break
 				}
 				}
 				signedData := buildDataSignedForAuth(sessionID, userAuthReq, algoBytes, pubKeyData)
 				signedData := buildDataSignedForAuth(sessionID, userAuthReq, algoBytes, pubKeyData)