|
@@ -128,24 +128,52 @@ func newTokenProvider(accessToken string, err error) *TokenProvider {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func TestReceiveSASLOAuthBearerClientResponse(t *testing.T) {
|
|
|
+func TestSASLOAuthBearer(t *testing.T) {
|
|
|
|
|
|
testTable := []struct {
|
|
|
- name string
|
|
|
- err error
|
|
|
- tokProvider *TokenProvider
|
|
|
+ name string
|
|
|
+ mockAuthErr KError
|
|
|
+ mockHandshakeErr KError
|
|
|
+ expectClientErr bool
|
|
|
+ tokProvider *TokenProvider
|
|
|
+ extensions map[string]string
|
|
|
}{
|
|
|
- {"OK server response",
|
|
|
- nil,
|
|
|
+ {"SASL/OAUTHBEARER OK server response",
|
|
|
+ ErrNoError,
|
|
|
+ ErrNoError,
|
|
|
+ false,
|
|
|
+ newTokenProvider("access-token-123", nil),
|
|
|
+ map[string]string{},
|
|
|
+ },
|
|
|
+ {"SASL/OAUTHBEARER authentication failure response",
|
|
|
+ ErrSASLAuthenticationFailed,
|
|
|
+ ErrNoError,
|
|
|
+ false,
|
|
|
newTokenProvider("access-token-123", nil),
|
|
|
+ map[string]string{},
|
|
|
},
|
|
|
- {"SASL authentication failure response",
|
|
|
+ {"SASL/OAUTHBEARER handshake failure response",
|
|
|
+ ErrNoError,
|
|
|
ErrSASLAuthenticationFailed,
|
|
|
+ false,
|
|
|
newTokenProvider("access-token-123", nil),
|
|
|
+ map[string]string{},
|
|
|
},
|
|
|
- {"Token generation error",
|
|
|
- ErrTokenFailure,
|
|
|
+ {"SASL/OAUTHBEARER token generation error",
|
|
|
+ ErrNoError,
|
|
|
+ ErrNoError,
|
|
|
+ true,
|
|
|
newTokenProvider("access-token-123", ErrTokenFailure),
|
|
|
+ map[string]string{},
|
|
|
+ },
|
|
|
+ {"SASL/OAUTHBEARER invalid extension",
|
|
|
+ ErrNoError,
|
|
|
+ ErrNoError,
|
|
|
+ true,
|
|
|
+ newTokenProvider("access-token-123", nil),
|
|
|
+ map[string]string{
|
|
|
+ "auth": "auth-value",
|
|
|
+ },
|
|
|
},
|
|
|
}
|
|
|
|
|
@@ -156,14 +184,20 @@ func TestReceiveSASLOAuthBearerClientResponse(t *testing.T) {
|
|
|
|
|
|
mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t)
|
|
|
|
|
|
- if e, ok := test.err.(KError); ok {
|
|
|
- mockSASLAuthResponse = mockSASLAuthResponse.SetError(e)
|
|
|
+ if test.mockAuthErr != ErrNoError {
|
|
|
+ mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockAuthErr)
|
|
|
+ }
|
|
|
+
|
|
|
+ mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).
|
|
|
+ SetEnabledMechanisms([]string{SASLTypeOAuth})
|
|
|
+
|
|
|
+ if test.mockHandshakeErr != ErrNoError {
|
|
|
+ mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr)
|
|
|
}
|
|
|
|
|
|
mockBroker.SetHandlerByMap(map[string]MockResponse{
|
|
|
"SaslAuthenticateRequest": mockSASLAuthResponse,
|
|
|
- "SaslHandshakeRequest": NewMockSaslHandshakeResponse(t).
|
|
|
- SetEnabledMechanisms([]string{SASLTypeOAuth}),
|
|
|
+ "SaslHandshakeRequest": mockSASLHandshakeResponse,
|
|
|
})
|
|
|
|
|
|
|
|
@@ -179,6 +213,7 @@ func TestReceiveSASLOAuthBearerClientResponse(t *testing.T) {
|
|
|
conf := NewConfig()
|
|
|
conf.Net.SASL.Mechanism = SASLTypeOAuth
|
|
|
conf.Net.SASL.TokenProvider = test.tokProvider
|
|
|
+ conf.Net.SASL.Extensions = test.extensions
|
|
|
|
|
|
broker.conf = conf
|
|
|
|
|
@@ -198,8 +233,18 @@ func TestReceiveSASLOAuthBearerClientResponse(t *testing.T) {
|
|
|
|
|
|
err = broker.authenticateViaSASL()
|
|
|
|
|
|
- if test.err != err {
|
|
|
- t.Errorf("[%d]:[%s] Expected %s error, got %s\n", i, test.name, test.err, err)
|
|
|
+ if test.mockAuthErr != ErrNoError {
|
|
|
+ if test.mockAuthErr != err {
|
|
|
+ t.Errorf("[%d]:[%s] Expected %s auth error, got %s\n", i, test.name, test.mockAuthErr, err)
|
|
|
+ }
|
|
|
+ } else if test.mockHandshakeErr != ErrNoError {
|
|
|
+ if test.mockHandshakeErr != err {
|
|
|
+ t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.mockHandshakeErr, err)
|
|
|
+ }
|
|
|
+ } else if test.expectClientErr && err == nil {
|
|
|
+ t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name)
|
|
|
+ } else if !test.expectClientErr && err != nil {
|
|
|
+ t.Errorf("[%d]:[%s] Unexpected error, got %s\n", i, test.name, err)
|
|
|
}
|
|
|
|
|
|
mockBroker.Close()
|