Browse Source

ssh: fix support for partial success authentication responses in client

The existing client side authentication does not handle correctly
the partial success flag in SSH_MSG_USERAUTH_FAILURE authentication
responses.

This commit fixes two problems in ssh library:
1) RetryableAuthMethod() now breaks out from the retry loop and
   returns  when underlying auth method fails with partial success
   set to true.
2) Book keeping of tried (and failed) auth methods in
   clientAuthenticate() does not mark an auth method failed if it
   fails with partial success set to true.

Fixes golang/go#23461

Change-Id: Ib2e1a1d54bfe2549496199bb2f66ebbce58d130d
Reviewed-on: https://go-review.googlesource.com/88035
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
Sami Pönkänen 7 years ago
parent
commit
9334d73e5f
5 changed files with 450 additions and 57 deletions
  1. 55 40
      ssh/client_auth.go
  2. 12 12
      ssh/keys_test.go
  3. 142 0
      ssh/test/multi_auth_test.go
  4. 173 0
      ssh/test/sshd_test_pw.c
  5. 68 5
      ssh/test/test_unix_test.go

+ 55 - 40
ssh/client_auth.go

@@ -11,6 +11,14 @@ import (
 	"io"
 	"io"
 )
 )
 
 
+type authResult int
+
+const (
+	authFailure authResult = iota
+	authPartialSuccess
+	authSuccess
+)
+
 // clientAuthenticate authenticates with the remote server. See RFC 4252.
 // clientAuthenticate authenticates with the remote server. See RFC 4252.
 func (c *connection) clientAuthenticate(config *ClientConfig) error {
 func (c *connection) clientAuthenticate(config *ClientConfig) error {
 	// initiate user auth session
 	// initiate user auth session
@@ -37,11 +45,12 @@ func (c *connection) clientAuthenticate(config *ClientConfig) error {
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
-		if ok {
+		if ok == authSuccess {
 			// success
 			// success
 			return nil
 			return nil
+		} else if ok == authFailure {
+			tried[auth.method()] = true
 		}
 		}
-		tried[auth.method()] = true
 		if methods == nil {
 		if methods == nil {
 			methods = lastMethods
 			methods = lastMethods
 		}
 		}
@@ -82,7 +91,7 @@ type AuthMethod interface {
 	// If authentication is not successful, a []string of alternative
 	// If authentication is not successful, a []string of alternative
 	// method names is returned. If the slice is nil, it will be ignored
 	// method names is returned. If the slice is nil, it will be ignored
 	// and the previous set of possible methods will be reused.
 	// and the previous set of possible methods will be reused.
-	auth(session []byte, user string, p packetConn, rand io.Reader) (bool, []string, error)
+	auth(session []byte, user string, p packetConn, rand io.Reader) (authResult, []string, error)
 
 
 	// method returns the RFC 4252 method name.
 	// method returns the RFC 4252 method name.
 	method() string
 	method() string
@@ -91,13 +100,13 @@ type AuthMethod interface {
 // "none" authentication, RFC 4252 section 5.2.
 // "none" authentication, RFC 4252 section 5.2.
 type noneAuth int
 type noneAuth int
 
 
-func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
+func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
 	if err := c.writePacket(Marshal(&userAuthRequestMsg{
 	if err := c.writePacket(Marshal(&userAuthRequestMsg{
 		User:    user,
 		User:    user,
 		Service: serviceSSH,
 		Service: serviceSSH,
 		Method:  "none",
 		Method:  "none",
 	})); err != nil {
 	})); err != nil {
-		return false, nil, err
+		return authFailure, nil, err
 	}
 	}
 
 
 	return handleAuthResponse(c)
 	return handleAuthResponse(c)
@@ -111,7 +120,7 @@ func (n *noneAuth) method() string {
 // a function call, e.g. by prompting the user.
 // a function call, e.g. by prompting the user.
 type passwordCallback func() (password string, err error)
 type passwordCallback func() (password string, err error)
 
 
-func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
+func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
 	type passwordAuthMsg struct {
 	type passwordAuthMsg struct {
 		User     string `sshtype:"50"`
 		User     string `sshtype:"50"`
 		Service  string
 		Service  string
@@ -125,7 +134,7 @@ func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand
 	// The program may only find out that the user doesn't have a password
 	// The program may only find out that the user doesn't have a password
 	// when prompting.
 	// when prompting.
 	if err != nil {
 	if err != nil {
-		return false, nil, err
+		return authFailure, nil, err
 	}
 	}
 
 
 	if err := c.writePacket(Marshal(&passwordAuthMsg{
 	if err := c.writePacket(Marshal(&passwordAuthMsg{
@@ -135,7 +144,7 @@ func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand
 		Reply:    false,
 		Reply:    false,
 		Password: pw,
 		Password: pw,
 	})); err != nil {
 	})); err != nil {
-		return false, nil, err
+		return authFailure, nil, err
 	}
 	}
 
 
 	return handleAuthResponse(c)
 	return handleAuthResponse(c)
@@ -178,7 +187,7 @@ func (cb publicKeyCallback) method() string {
 	return "publickey"
 	return "publickey"
 }
 }
 
 
-func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
+func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
 	// Authentication is performed by sending an enquiry to test if a key is
 	// Authentication is performed by sending an enquiry to test if a key is
 	// acceptable to the remote. If the key is acceptable, the client will
 	// acceptable to the remote. If the key is acceptable, the client will
 	// attempt to authenticate with the valid key.  If not the client will repeat
 	// attempt to authenticate with the valid key.  If not the client will repeat
@@ -186,13 +195,13 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
 
 
 	signers, err := cb()
 	signers, err := cb()
 	if err != nil {
 	if err != nil {
-		return false, nil, err
+		return authFailure, nil, err
 	}
 	}
 	var methods []string
 	var methods []string
 	for _, signer := range signers {
 	for _, signer := range signers {
 		ok, err := validateKey(signer.PublicKey(), user, c)
 		ok, err := validateKey(signer.PublicKey(), user, c)
 		if err != nil {
 		if err != nil {
-			return false, nil, err
+			return authFailure, nil, err
 		}
 		}
 		if !ok {
 		if !ok {
 			continue
 			continue
@@ -206,7 +215,7 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
 			Method:  cb.method(),
 			Method:  cb.method(),
 		}, []byte(pub.Type()), pubKey))
 		}, []byte(pub.Type()), pubKey))
 		if err != nil {
 		if err != nil {
-			return false, nil, err
+			return authFailure, nil, err
 		}
 		}
 
 
 		// manually wrap the serialized signature in a string
 		// manually wrap the serialized signature in a string
@@ -224,24 +233,24 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
 		}
 		}
 		p := Marshal(&msg)
 		p := Marshal(&msg)
 		if err := c.writePacket(p); err != nil {
 		if err := c.writePacket(p); err != nil {
-			return false, nil, err
+			return authFailure, nil, err
 		}
 		}
-		var success bool
+		var success authResult
 		success, methods, err = handleAuthResponse(c)
 		success, methods, err = handleAuthResponse(c)
 		if err != nil {
 		if err != nil {
-			return false, nil, err
+			return authFailure, nil, err
 		}
 		}
 
 
 		// If authentication succeeds or the list of available methods does not
 		// If authentication succeeds or the list of available methods does not
 		// contain the "publickey" method, do not attempt to authenticate with any
 		// contain the "publickey" method, do not attempt to authenticate with any
 		// other keys.  According to RFC 4252 Section 7, the latter can occur when
 		// other keys.  According to RFC 4252 Section 7, the latter can occur when
 		// additional authentication methods are required.
 		// additional authentication methods are required.
-		if success || !containsMethod(methods, cb.method()) {
+		if success == authSuccess || !containsMethod(methods, cb.method()) {
 			return success, methods, err
 			return success, methods, err
 		}
 		}
 	}
 	}
 
 
-	return false, methods, nil
+	return authFailure, methods, nil
 }
 }
 
 
 func containsMethod(methods []string, method string) bool {
 func containsMethod(methods []string, method string) bool {
@@ -318,28 +327,31 @@ func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMet
 // handleAuthResponse returns whether the preceding authentication request succeeded
 // handleAuthResponse returns whether the preceding authentication request succeeded
 // along with a list of remaining authentication methods to try next and
 // along with a list of remaining authentication methods to try next and
 // an error if an unexpected response was received.
 // an error if an unexpected response was received.
-func handleAuthResponse(c packetConn) (bool, []string, error) {
+func handleAuthResponse(c packetConn) (authResult, []string, error) {
 	for {
 	for {
 		packet, err := c.readPacket()
 		packet, err := c.readPacket()
 		if err != nil {
 		if err != nil {
-			return false, nil, err
+			return authFailure, nil, err
 		}
 		}
 
 
 		switch packet[0] {
 		switch packet[0] {
 		case msgUserAuthBanner:
 		case msgUserAuthBanner:
 			if err := handleBannerResponse(c, packet); err != nil {
 			if err := handleBannerResponse(c, packet); err != nil {
-				return false, nil, err
+				return authFailure, nil, err
 			}
 			}
 		case msgUserAuthFailure:
 		case msgUserAuthFailure:
 			var msg userAuthFailureMsg
 			var msg userAuthFailureMsg
 			if err := Unmarshal(packet, &msg); err != nil {
 			if err := Unmarshal(packet, &msg); err != nil {
-				return false, nil, err
+				return authFailure, nil, err
 			}
 			}
-			return false, msg.Methods, nil
+			if msg.PartialSuccess {
+				return authPartialSuccess, msg.Methods, nil
+			}
+			return authFailure, msg.Methods, nil
 		case msgUserAuthSuccess:
 		case msgUserAuthSuccess:
-			return true, nil, nil
+			return authSuccess, nil, nil
 		default:
 		default:
-			return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
+			return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
 		}
 		}
 	}
 	}
 }
 }
@@ -381,7 +393,7 @@ func (cb KeyboardInteractiveChallenge) method() string {
 	return "keyboard-interactive"
 	return "keyboard-interactive"
 }
 }
 
 
-func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
+func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
 	type initiateMsg struct {
 	type initiateMsg struct {
 		User       string `sshtype:"50"`
 		User       string `sshtype:"50"`
 		Service    string
 		Service    string
@@ -395,20 +407,20 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
 		Service: serviceSSH,
 		Service: serviceSSH,
 		Method:  "keyboard-interactive",
 		Method:  "keyboard-interactive",
 	})); err != nil {
 	})); err != nil {
-		return false, nil, err
+		return authFailure, nil, err
 	}
 	}
 
 
 	for {
 	for {
 		packet, err := c.readPacket()
 		packet, err := c.readPacket()
 		if err != nil {
 		if err != nil {
-			return false, nil, err
+			return authFailure, nil, err
 		}
 		}
 
 
 		// like handleAuthResponse, but with less options.
 		// like handleAuthResponse, but with less options.
 		switch packet[0] {
 		switch packet[0] {
 		case msgUserAuthBanner:
 		case msgUserAuthBanner:
 			if err := handleBannerResponse(c, packet); err != nil {
 			if err := handleBannerResponse(c, packet); err != nil {
-				return false, nil, err
+				return authFailure, nil, err
 			}
 			}
 			continue
 			continue
 		case msgUserAuthInfoRequest:
 		case msgUserAuthInfoRequest:
@@ -416,18 +428,21 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
 		case msgUserAuthFailure:
 		case msgUserAuthFailure:
 			var msg userAuthFailureMsg
 			var msg userAuthFailureMsg
 			if err := Unmarshal(packet, &msg); err != nil {
 			if err := Unmarshal(packet, &msg); err != nil {
-				return false, nil, err
+				return authFailure, nil, err
+			}
+			if msg.PartialSuccess {
+				return authPartialSuccess, msg.Methods, nil
 			}
 			}
-			return false, msg.Methods, nil
+			return authFailure, msg.Methods, nil
 		case msgUserAuthSuccess:
 		case msgUserAuthSuccess:
-			return true, nil, nil
+			return authSuccess, nil, nil
 		default:
 		default:
-			return false, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
+			return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
 		}
 		}
 
 
 		var msg userAuthInfoRequestMsg
 		var msg userAuthInfoRequestMsg
 		if err := Unmarshal(packet, &msg); err != nil {
 		if err := Unmarshal(packet, &msg); err != nil {
-			return false, nil, err
+			return authFailure, nil, err
 		}
 		}
 
 
 		// Manually unpack the prompt/echo pairs.
 		// Manually unpack the prompt/echo pairs.
@@ -437,7 +452,7 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
 		for i := 0; i < int(msg.NumPrompts); i++ {
 		for i := 0; i < int(msg.NumPrompts); i++ {
 			prompt, r, ok := parseString(rest)
 			prompt, r, ok := parseString(rest)
 			if !ok || len(r) == 0 {
 			if !ok || len(r) == 0 {
-				return false, nil, errors.New("ssh: prompt format error")
+				return authFailure, nil, errors.New("ssh: prompt format error")
 			}
 			}
 			prompts = append(prompts, string(prompt))
 			prompts = append(prompts, string(prompt))
 			echos = append(echos, r[0] != 0)
 			echos = append(echos, r[0] != 0)
@@ -445,16 +460,16 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
 		}
 		}
 
 
 		if len(rest) != 0 {
 		if len(rest) != 0 {
-			return false, nil, errors.New("ssh: extra data following keyboard-interactive pairs")
+			return authFailure, nil, errors.New("ssh: extra data following keyboard-interactive pairs")
 		}
 		}
 
 
 		answers, err := cb(msg.User, msg.Instruction, prompts, echos)
 		answers, err := cb(msg.User, msg.Instruction, prompts, echos)
 		if err != nil {
 		if err != nil {
-			return false, nil, err
+			return authFailure, nil, err
 		}
 		}
 
 
 		if len(answers) != len(prompts) {
 		if len(answers) != len(prompts) {
-			return false, nil, errors.New("ssh: not enough answers from keyboard-interactive callback")
+			return authFailure, nil, errors.New("ssh: not enough answers from keyboard-interactive callback")
 		}
 		}
 		responseLength := 1 + 4
 		responseLength := 1 + 4
 		for _, a := range answers {
 		for _, a := range answers {
@@ -470,7 +485,7 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
 		}
 		}
 
 
 		if err := c.writePacket(serialized); err != nil {
 		if err := c.writePacket(serialized); err != nil {
-			return false, nil, err
+			return authFailure, nil, err
 		}
 		}
 	}
 	}
 }
 }
@@ -480,10 +495,10 @@ type retryableAuthMethod struct {
 	maxTries   int
 	maxTries   int
 }
 }
 
 
-func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader) (ok bool, methods []string, err error) {
+func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader) (ok authResult, methods []string, err error) {
 	for i := 0; r.maxTries <= 0 || i < r.maxTries; i++ {
 	for i := 0; r.maxTries <= 0 || i < r.maxTries; i++ {
 		ok, methods, err = r.authMethod.auth(session, user, c, rand)
 		ok, methods, err = r.authMethod.auth(session, user, c, rand)
-		if ok || err != nil { // either success or error terminate
+		if ok != authFailure || err != nil { // either success, partial success or error terminate
 			return ok, methods, err
 			return ok, methods, err
 		}
 		}
 	}
 	}

+ 12 - 12
ssh/keys_test.go

@@ -234,7 +234,7 @@ func TestMarshalParsePublicKey(t *testing.T) {
 	}
 	}
 }
 }
 
 
-type authResult struct {
+type testAuthResult struct {
 	pubKey   PublicKey
 	pubKey   PublicKey
 	options  []string
 	options  []string
 	comments string
 	comments string
@@ -242,11 +242,11 @@ type authResult struct {
 	ok       bool
 	ok       bool
 }
 }
 
 
-func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []authResult) {
+func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []testAuthResult) {
 	rest := authKeys
 	rest := authKeys
-	var values []authResult
+	var values []testAuthResult
 	for len(rest) > 0 {
 	for len(rest) > 0 {
-		var r authResult
+		var r testAuthResult
 		var err error
 		var err error
 		r.pubKey, r.comments, r.options, rest, err = ParseAuthorizedKey(rest)
 		r.pubKey, r.comments, r.options, rest, err = ParseAuthorizedKey(rest)
 		r.ok = (err == nil)
 		r.ok = (err == nil)
@@ -264,7 +264,7 @@ func TestAuthorizedKeyBasic(t *testing.T) {
 	pub, pubSerialized := getTestKey()
 	pub, pubSerialized := getTestKey()
 	line := "ssh-rsa " + pubSerialized + " user@host"
 	line := "ssh-rsa " + pubSerialized + " user@host"
 	testAuthorizedKeys(t, []byte(line),
 	testAuthorizedKeys(t, []byte(line),
-		[]authResult{
+		[]testAuthResult{
 			{pub, nil, "user@host", "", true},
 			{pub, nil, "user@host", "", true},
 		})
 		})
 }
 }
@@ -286,7 +286,7 @@ func TestAuth(t *testing.T) {
 		authOptions := strings.Join(authWithOptions, eol)
 		authOptions := strings.Join(authWithOptions, eol)
 		rest2 := strings.Join(authWithOptions[3:], eol)
 		rest2 := strings.Join(authWithOptions[3:], eol)
 		rest3 := strings.Join(authWithOptions[6:], eol)
 		rest3 := strings.Join(authWithOptions[6:], eol)
-		testAuthorizedKeys(t, []byte(authOptions), []authResult{
+		testAuthorizedKeys(t, []byte(authOptions), []testAuthResult{
 			{pub, []string{`env="HOME=/home/root"`, "no-port-forwarding"}, "user@host", rest2, true},
 			{pub, []string{`env="HOME=/home/root"`, "no-port-forwarding"}, "user@host", rest2, true},
 			{pub, []string{`env="HOME=/home/root2"`}, "user2@host2", rest3, true},
 			{pub, []string{`env="HOME=/home/root2"`}, "user2@host2", rest3, true},
 			{nil, nil, "", "", false},
 			{nil, nil, "", "", false},
@@ -297,7 +297,7 @@ func TestAuth(t *testing.T) {
 func TestAuthWithQuotedSpaceInEnv(t *testing.T) {
 func TestAuthWithQuotedSpaceInEnv(t *testing.T) {
 	pub, pubSerialized := getTestKey()
 	pub, pubSerialized := getTestKey()
 	authWithQuotedSpaceInEnv := []byte(`env="HOME=/home/root dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`)
 	authWithQuotedSpaceInEnv := []byte(`env="HOME=/home/root dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`)
-	testAuthorizedKeys(t, []byte(authWithQuotedSpaceInEnv), []authResult{
+	testAuthorizedKeys(t, []byte(authWithQuotedSpaceInEnv), []testAuthResult{
 		{pub, []string{`env="HOME=/home/root dir"`, "no-port-forwarding"}, "user@host", "", true},
 		{pub, []string{`env="HOME=/home/root dir"`, "no-port-forwarding"}, "user@host", "", true},
 	})
 	})
 }
 }
@@ -305,7 +305,7 @@ func TestAuthWithQuotedSpaceInEnv(t *testing.T) {
 func TestAuthWithQuotedCommaInEnv(t *testing.T) {
 func TestAuthWithQuotedCommaInEnv(t *testing.T) {
 	pub, pubSerialized := getTestKey()
 	pub, pubSerialized := getTestKey()
 	authWithQuotedCommaInEnv := []byte(`env="HOME=/home/root,dir",no-port-forwarding ssh-rsa ` + pubSerialized + `   user@host`)
 	authWithQuotedCommaInEnv := []byte(`env="HOME=/home/root,dir",no-port-forwarding ssh-rsa ` + pubSerialized + `   user@host`)
-	testAuthorizedKeys(t, []byte(authWithQuotedCommaInEnv), []authResult{
+	testAuthorizedKeys(t, []byte(authWithQuotedCommaInEnv), []testAuthResult{
 		{pub, []string{`env="HOME=/home/root,dir"`, "no-port-forwarding"}, "user@host", "", true},
 		{pub, []string{`env="HOME=/home/root,dir"`, "no-port-forwarding"}, "user@host", "", true},
 	})
 	})
 }
 }
@@ -314,11 +314,11 @@ func TestAuthWithQuotedQuoteInEnv(t *testing.T) {
 	pub, pubSerialized := getTestKey()
 	pub, pubSerialized := getTestKey()
 	authWithQuotedQuoteInEnv := []byte(`env="HOME=/home/\"root dir",no-port-forwarding` + "\t" + `ssh-rsa` + "\t" + pubSerialized + `   user@host`)
 	authWithQuotedQuoteInEnv := []byte(`env="HOME=/home/\"root dir",no-port-forwarding` + "\t" + `ssh-rsa` + "\t" + pubSerialized + `   user@host`)
 	authWithDoubleQuotedQuote := []byte(`no-port-forwarding,env="HOME=/home/ \"root dir\"" ssh-rsa ` + pubSerialized + "\t" + `user@host`)
 	authWithDoubleQuotedQuote := []byte(`no-port-forwarding,env="HOME=/home/ \"root dir\"" ssh-rsa ` + pubSerialized + "\t" + `user@host`)
-	testAuthorizedKeys(t, []byte(authWithQuotedQuoteInEnv), []authResult{
+	testAuthorizedKeys(t, []byte(authWithQuotedQuoteInEnv), []testAuthResult{
 		{pub, []string{`env="HOME=/home/\"root dir"`, "no-port-forwarding"}, "user@host", "", true},
 		{pub, []string{`env="HOME=/home/\"root dir"`, "no-port-forwarding"}, "user@host", "", true},
 	})
 	})
 
 
-	testAuthorizedKeys(t, []byte(authWithDoubleQuotedQuote), []authResult{
+	testAuthorizedKeys(t, []byte(authWithDoubleQuotedQuote), []testAuthResult{
 		{pub, []string{"no-port-forwarding", `env="HOME=/home/ \"root dir\""`}, "user@host", "", true},
 		{pub, []string{"no-port-forwarding", `env="HOME=/home/ \"root dir\""`}, "user@host", "", true},
 	})
 	})
 }
 }
@@ -327,7 +327,7 @@ func TestAuthWithInvalidSpace(t *testing.T) {
 	_, pubSerialized := getTestKey()
 	_, pubSerialized := getTestKey()
 	authWithInvalidSpace := []byte(`env="HOME=/home/root dir", no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host
 	authWithInvalidSpace := []byte(`env="HOME=/home/root dir", no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host
 #more to follow but still no valid keys`)
 #more to follow but still no valid keys`)
-	testAuthorizedKeys(t, []byte(authWithInvalidSpace), []authResult{
+	testAuthorizedKeys(t, []byte(authWithInvalidSpace), []testAuthResult{
 		{nil, nil, "", "", false},
 		{nil, nil, "", "", false},
 	})
 	})
 }
 }
@@ -337,7 +337,7 @@ func TestAuthWithMissingQuote(t *testing.T) {
 	authWithMissingQuote := []byte(`env="HOME=/home/root,no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host
 	authWithMissingQuote := []byte(`env="HOME=/home/root,no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host
 env="HOME=/home/root",shared-control ssh-rsa ` + pubSerialized + ` user@host`)
 env="HOME=/home/root",shared-control ssh-rsa ` + pubSerialized + ` user@host`)
 
 
-	testAuthorizedKeys(t, []byte(authWithMissingQuote), []authResult{
+	testAuthorizedKeys(t, []byte(authWithMissingQuote), []testAuthResult{
 		{pub, []string{`env="HOME=/home/root"`, `shared-control`}, "user@host", "", true},
 		{pub, []string{`env="HOME=/home/root"`, `shared-control`}, "user@host", "", true},
 	})
 	})
 }
 }

+ 142 - 0
ssh/test/multi_auth_test.go

@@ -0,0 +1,142 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests for ssh client multi-auth
+//
+// These tests run a simple go ssh client against OpenSSH server
+// over unix domain sockets. The tests use multiple combinations
+// of password, keyboard-interactive and publickey authentication
+// methods.
+//
+// A wrapper library for making sshd PAM authentication use test
+// passwords is required in ./sshd_test_pw.so. If the library does
+// not exist these tests will be skipped. See compile instructions
+// (for linux) in file ./sshd_test_pw.c.
+
+package test
+
+import (
+	"fmt"
+	"strings"
+	"testing"
+
+	"golang.org/x/crypto/ssh"
+)
+
+// test cases
+type multiAuthTestCase struct {
+	authMethods         []string
+	expectedPasswordCbs int
+	expectedKbdIntCbs   int
+}
+
+// test context
+type multiAuthTestCtx struct {
+	password       string
+	numPasswordCbs int
+	numKbdIntCbs   int
+}
+
+// create test context
+func newMultiAuthTestCtx(t *testing.T) *multiAuthTestCtx {
+	password, err := randomPassword()
+	if err != nil {
+		t.Fatalf("Failed to generate random test password: %s", err.Error())
+	}
+
+	return &multiAuthTestCtx{
+		password: password,
+	}
+}
+
+// password callback
+func (ctx *multiAuthTestCtx) passwordCb() (secret string, err error) {
+	ctx.numPasswordCbs++
+	return ctx.password, nil
+}
+
+// keyboard-interactive callback
+func (ctx *multiAuthTestCtx) kbdIntCb(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
+	if len(questions) == 0 {
+		return nil, nil
+	}
+
+	ctx.numKbdIntCbs++
+	if len(questions) == 1 {
+		return []string{ctx.password}, nil
+	}
+
+	return nil, fmt.Errorf("unsupported keyboard-interactive flow")
+}
+
+// TestMultiAuth runs several subtests for different combinations of password, keyboard-interactive and publickey authentication methods
+func TestMultiAuth(t *testing.T) {
+	testCases := []multiAuthTestCase{
+		// Test password,publickey authentication, assert that password callback is called 1 time
+		multiAuthTestCase{
+			authMethods:         []string{"password", "publickey"},
+			expectedPasswordCbs: 1,
+		},
+		// Test keyboard-interactive,publickey authentication, assert that keyboard-interactive callback is called 1 time
+		multiAuthTestCase{
+			authMethods:       []string{"keyboard-interactive", "publickey"},
+			expectedKbdIntCbs: 1,
+		},
+		// Test publickey,password authentication, assert that password callback is called 1 time
+		multiAuthTestCase{
+			authMethods:         []string{"publickey", "password"},
+			expectedPasswordCbs: 1,
+		},
+		// Test publickey,keyboard-interactive authentication, assert that keyboard-interactive callback is called 1 time
+		multiAuthTestCase{
+			authMethods:       []string{"publickey", "keyboard-interactive"},
+			expectedKbdIntCbs: 1,
+		},
+		// Test password,password authentication, assert that password callback is called 2 times
+		multiAuthTestCase{
+			authMethods:         []string{"password", "password"},
+			expectedPasswordCbs: 2,
+		},
+	}
+
+	for _, testCase := range testCases {
+		t.Run(strings.Join(testCase.authMethods, ","), func(t *testing.T) {
+			ctx := newMultiAuthTestCtx(t)
+
+			server := newServerForConfig(t, "MultiAuth", map[string]string{"AuthMethods": strings.Join(testCase.authMethods, ",")})
+			defer server.Shutdown()
+
+			clientConfig := clientConfig()
+			server.setTestPassword(clientConfig.User, ctx.password)
+
+			publicKeyAuthMethod := clientConfig.Auth[0]
+			clientConfig.Auth = nil
+			for _, authMethod := range testCase.authMethods {
+				switch authMethod {
+				case "publickey":
+					clientConfig.Auth = append(clientConfig.Auth, publicKeyAuthMethod)
+				case "password":
+					clientConfig.Auth = append(clientConfig.Auth,
+						ssh.RetryableAuthMethod(ssh.PasswordCallback(ctx.passwordCb), 5))
+				case "keyboard-interactive":
+					clientConfig.Auth = append(clientConfig.Auth,
+						ssh.RetryableAuthMethod(ssh.KeyboardInteractive(ctx.kbdIntCb), 5))
+				default:
+					t.Fatalf("Unknown authentication method %s", authMethod)
+				}
+			}
+
+			conn := server.Dial(clientConfig)
+			defer conn.Close()
+
+			if ctx.numPasswordCbs != testCase.expectedPasswordCbs {
+				t.Fatalf("passwordCallback was called %d times, expected %d times", ctx.numPasswordCbs, testCase.expectedPasswordCbs)
+			}
+
+			if ctx.numKbdIntCbs != testCase.expectedKbdIntCbs {
+				t.Fatalf("keyboardInteractiveCallback was called %d times, expected %d times", ctx.numKbdIntCbs, testCase.expectedKbdIntCbs)
+			}
+		})
+	}
+}

+ 173 - 0
ssh/test/sshd_test_pw.c

@@ -0,0 +1,173 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// sshd_test_pw.c
+// Wrapper to inject test password data for sshd PAM authentication
+//
+// This wrapper implements custom versions of getpwnam, getpwnam_r,
+// getspnam and getspnam_r. These functions first call their real
+// libc versions, then check if the requested user matches test user
+// specified in env variable TEST_USER and if so replace the password
+// with crypted() value of TEST_PASSWD env variable.
+//
+// Compile:
+// gcc -Wall -shared -o sshd_test_pw.so -fPIC sshd_test_pw.c
+//
+// Compile with debug:
+// gcc -DVERBOSE -Wall -shared -o sshd_test_pw.so -fPIC sshd_test_pw.c
+//
+// Run sshd:
+// LD_PRELOAD="sshd_test_pw.so" TEST_USER="..." TEST_PASSWD="..." sshd ...
+
+// +build ignore
+
+#define _GNU_SOURCE
+#include <string.h>
+#include <pwd.h>
+#include <shadow.h>
+#include <dlfcn.h>
+#include <stdlib.h>
+#include <unistd.h>
+#include <stdio.h>
+
+#ifdef VERBOSE
+#define DEBUG(X...) fprintf(stderr, X)
+#else
+#define DEBUG(X...) while (0) { }
+#endif
+
+/* crypt() password */
+static char *
+pwhash(char *passwd) {
+  return strdup(crypt(passwd, "$6$"));
+}
+
+/* Pointers to real functions in libc */
+static struct passwd * (*real_getpwnam)(const char *) = NULL;
+static int (*real_getpwnam_r)(const char *, struct passwd *, char *, size_t, struct passwd **) = NULL;
+static struct spwd * (*real_getspnam)(const char *) = NULL;
+static int (*real_getspnam_r)(const char *, struct spwd *, char *, size_t, struct spwd **) = NULL;
+
+/* Cached test user and test password */
+static char *test_user = NULL;
+static char *test_passwd_hash = NULL;
+
+static void
+init(void) {
+  /* Fetch real libc function pointers */
+  real_getpwnam = dlsym(RTLD_NEXT, "getpwnam");
+  real_getpwnam_r = dlsym(RTLD_NEXT, "getpwnam_r");
+  real_getspnam = dlsym(RTLD_NEXT, "getspnam");
+  real_getspnam_r = dlsym(RTLD_NEXT, "getspnam_r");
+  
+  /* abort if env variables are not defined */
+  if (getenv("TEST_USER") == NULL || getenv("TEST_PASSWD") == NULL) {
+    fprintf(stderr, "env variables TEST_USER and TEST_PASSWD are missing\n");
+    abort();
+  }
+
+  /* Fetch test user and test password from env */
+  test_user = strdup(getenv("TEST_USER"));
+  test_passwd_hash = pwhash(getenv("TEST_PASSWD"));
+
+  DEBUG("sshd_test_pw init():\n");
+  DEBUG("\treal_getpwnam: %p\n", real_getpwnam);
+  DEBUG("\treal_getpwnam_r: %p\n", real_getpwnam_r);
+  DEBUG("\treal_getspnam: %p\n", real_getspnam);
+  DEBUG("\treal_getspnam_r: %p\n", real_getspnam_r);
+  DEBUG("\tTEST_USER: '%s'\n", test_user);
+  DEBUG("\tTEST_PASSWD: '%s'\n", getenv("TEST_PASSWD"));
+  DEBUG("\tTEST_PASSWD_HASH: '%s'\n", test_passwd_hash);
+}
+
+static int
+is_test_user(const char *name) {
+  if (test_user != NULL && strcmp(test_user, name) == 0)
+    return 1;
+  return 0;
+}
+
+/* getpwnam */
+
+struct passwd *
+getpwnam(const char *name) {
+  struct passwd *pw;
+
+  DEBUG("sshd_test_pw getpwnam(%s)\n", name);
+  
+  if (real_getpwnam == NULL)
+    init();
+  if ((pw = real_getpwnam(name)) == NULL)
+    return NULL;
+
+  if (is_test_user(name))
+    pw->pw_passwd = strdup(test_passwd_hash);
+      
+  return pw;
+}
+
+/* getpwnam_r */
+
+int
+getpwnam_r(const char *name,
+	   struct passwd *pwd,
+	   char *buf,
+	   size_t buflen,
+	   struct passwd **result) {
+  int r;
+
+  DEBUG("sshd_test_pw getpwnam_r(%s)\n", name);
+  
+  if (real_getpwnam_r == NULL)
+    init();
+  if ((r = real_getpwnam_r(name, pwd, buf, buflen, result)) != 0 || *result == NULL)
+    return r;
+
+  if (is_test_user(name))
+    pwd->pw_passwd = strdup(test_passwd_hash);
+  
+  return 0;
+}
+
+/* getspnam */
+
+struct spwd *
+getspnam(const char *name) {
+  struct spwd *sp;
+
+  DEBUG("sshd_test_pw getspnam(%s)\n", name);
+  
+  if (real_getspnam == NULL)
+    init();
+  if ((sp = real_getspnam(name)) == NULL)
+    return NULL;
+
+  if (is_test_user(name))
+    sp->sp_pwdp = strdup(test_passwd_hash);
+  
+  return sp;
+}
+
+/* getspnam_r */
+
+int
+getspnam_r(const char *name,
+	   struct spwd *spbuf,
+	   char *buf,
+	   size_t buflen,
+	   struct spwd **spbufp) {
+  int r;
+
+  DEBUG("sshd_test_pw getspnam_r(%s)\n", name);
+  
+  if (real_getspnam_r == NULL)
+    init();
+  if ((r = real_getspnam_r(name, spbuf, buf, buflen, spbufp)) != 0)
+    return r;
+
+  if (is_test_user(name))
+    spbuf->sp_pwdp = strdup(test_passwd_hash);
+  
+  return r;
+}

+ 68 - 5
ssh/test/test_unix_test.go

@@ -10,6 +10,8 @@ package test
 
 
 import (
 import (
 	"bytes"
 	"bytes"
+	"crypto/rand"
+	"encoding/base64"
 	"fmt"
 	"fmt"
 	"io/ioutil"
 	"io/ioutil"
 	"log"
 	"log"
@@ -25,7 +27,8 @@ import (
 	"golang.org/x/crypto/ssh/testdata"
 	"golang.org/x/crypto/ssh/testdata"
 )
 )
 
 
-const sshdConfig = `
+const (
+	defaultSshdConfig = `
 Protocol 2
 Protocol 2
 Banner {{.Dir}}/banner
 Banner {{.Dir}}/banner
 HostKey {{.Dir}}/id_rsa
 HostKey {{.Dir}}/id_rsa
@@ -50,8 +53,17 @@ RhostsRSAAuthentication no
 HostbasedAuthentication no
 HostbasedAuthentication no
 PubkeyAcceptedKeyTypes=*
 PubkeyAcceptedKeyTypes=*
 `
 `
+	multiAuthSshdConfigTail = `
+UsePAM yes
+PasswordAuthentication yes
+ChallengeResponseAuthentication yes
+AuthenticationMethods {{.AuthMethods}}
+`
+)
 
 
-var configTmpl = template.Must(template.New("").Parse(sshdConfig))
+var configTmpl = map[string]*template.Template{
+	"default":   template.Must(template.New("").Parse(defaultSshdConfig)),
+	"MultiAuth": template.Must(template.New("").Parse(defaultSshdConfig + multiAuthSshdConfigTail))}
 
 
 type server struct {
 type server struct {
 	t          *testing.T
 	t          *testing.T
@@ -60,6 +72,10 @@ type server struct {
 	cmd        *exec.Cmd
 	cmd        *exec.Cmd
 	output     bytes.Buffer // holds stderr from sshd process
 	output     bytes.Buffer // holds stderr from sshd process
 
 
+	testUser     string // test username for sshd
+	testPasswd   string // test password for sshd
+	sshdTestPwSo string // dynamic library to inject a custom password into sshd
+
 	// Client half of the network connection.
 	// Client half of the network connection.
 	clientConn net.Conn
 	clientConn net.Conn
 }
 }
@@ -186,6 +202,20 @@ func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Cl
 	s.cmd.Stdin = f
 	s.cmd.Stdin = f
 	s.cmd.Stdout = f
 	s.cmd.Stdout = f
 	s.cmd.Stderr = &s.output
 	s.cmd.Stderr = &s.output
+
+	if s.sshdTestPwSo != "" {
+		if s.testUser == "" {
+			s.t.Fatal("user missing from sshd_test_pw.so config")
+		}
+		if s.testPasswd == "" {
+			s.t.Fatal("password missing from sshd_test_pw.so config")
+		}
+		s.cmd.Env = append(os.Environ(),
+			fmt.Sprintf("LD_PRELOAD=%s", s.sshdTestPwSo),
+			fmt.Sprintf("TEST_USER=%s", s.testUser),
+			fmt.Sprintf("TEST_PASSWD=%s", s.testPasswd))
+	}
+
 	if err := s.cmd.Start(); err != nil {
 	if err := s.cmd.Start(); err != nil {
 		s.t.Fail()
 		s.t.Fail()
 		s.Shutdown()
 		s.Shutdown()
@@ -236,8 +266,39 @@ func writeFile(path string, contents []byte) {
 	}
 	}
 }
 }
 
 
+// generate random password
+func randomPassword() (string, error) {
+	b := make([]byte, 12)
+	_, err := rand.Read(b)
+	if err != nil {
+		return "", err
+	}
+	return base64.RawURLEncoding.EncodeToString(b), nil
+}
+
+// setTestPassword is used for setting user and password data for sshd_test_pw.so
+// This function also checks that ./sshd_test_pw.so exists and if not calls s.t.Skip()
+func (s *server) setTestPassword(user, passwd string) error {
+	wd, _ := os.Getwd()
+	wrapper := filepath.Join(wd, "sshd_test_pw.so")
+	if _, err := os.Stat(wrapper); err != nil {
+		s.t.Skip(fmt.Errorf("sshd_test_pw.so is not available"))
+		return err
+	}
+
+	s.sshdTestPwSo = wrapper
+	s.testUser = user
+	s.testPasswd = passwd
+	return nil
+}
+
 // newServer returns a new mock ssh server.
 // newServer returns a new mock ssh server.
 func newServer(t *testing.T) *server {
 func newServer(t *testing.T) *server {
+	return newServerForConfig(t, "default", map[string]string{})
+}
+
+// newServerForConfig returns a new mock ssh server.
+func newServerForConfig(t *testing.T, config string, configVars map[string]string) *server {
 	if testing.Short() {
 	if testing.Short() {
 		t.Skip("skipping test due to -short")
 		t.Skip("skipping test due to -short")
 	}
 	}
@@ -249,9 +310,11 @@ func newServer(t *testing.T) *server {
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
-	err = configTmpl.Execute(f, map[string]string{
-		"Dir": dir,
-	})
+	if _, ok := configTmpl[config]; ok == false {
+		t.Fatal(fmt.Errorf("Invalid server config '%s'", config))
+	}
+	configVars["Dir"] = dir
+	err = configTmpl[config].Execute(f, configVars)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}