Browse Source

Write handshake error test case

Mike Kaminski 6 years ago
parent
commit
c0eb84c7d0
2 changed files with 65 additions and 15 deletions
  1. 60 15
      broker_test.go
  2. 5 0
      mockresponses.go

+ 60 - 15
broker_test.go

@@ -128,24 +128,52 @@ func newTokenProvider(accessToken string, err error) *TokenProvider {
 	}
 }
 
-func TestReceiveSASLOAuthBearerClientResponse(t *testing.T) {
+func TestSASLOAuthBearer(t *testing.T) {
 
 	testTable := []struct {
-		name        string
-		err         error
-		tokProvider *TokenProvider
+		name             string
+		mockAuthErr      KError // Mock and expect error returned from SaslAuthenticateRequest
+		mockHandshakeErr KError // Mock and expect error returned from SaslHandshakeRequest
+		expectClientErr  bool   // Expect an internal client-side error
+		tokProvider      *TokenProvider
+		extensions       map[string]string
 	}{
-		{"OK server response",
-			nil,
+		{"SASL/OAUTHBEARER OK server response",
+			ErrNoError,
+			ErrNoError,
+			false,
+			newTokenProvider("access-token-123", nil),
+			map[string]string{},
+		},
+		{"SASL/OAUTHBEARER authentication failure response",
+			ErrSASLAuthenticationFailed,
+			ErrNoError,
+			false,
 			newTokenProvider("access-token-123", nil),
+			map[string]string{},
 		},
-		{"SASL authentication failure response",
+		{"SASL/OAUTHBEARER handshake failure response",
+			ErrNoError,
 			ErrSASLAuthenticationFailed,
+			false,
 			newTokenProvider("access-token-123", nil),
+			map[string]string{},
 		},
-		{"Token generation error",
-			ErrTokenFailure,
+		{"SASL/OAUTHBEARER token generation error",
+			ErrNoError,
+			ErrNoError,
+			true,
 			newTokenProvider("access-token-123", ErrTokenFailure),
+			map[string]string{},
+		},
+		{"SASL/OAUTHBEARER invalid extension",
+			ErrNoError,
+			ErrNoError,
+			true,
+			newTokenProvider("access-token-123", nil),
+			map[string]string{
+				"auth": "auth-value",
+			},
 		},
 	}
 
@@ -156,14 +184,20 @@ func TestReceiveSASLOAuthBearerClientResponse(t *testing.T) {
 
 		mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t)
 
-		if e, ok := test.err.(KError); ok {
-			mockSASLAuthResponse = mockSASLAuthResponse.SetError(e)
+		if test.mockAuthErr != ErrNoError {
+			mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockAuthErr)
+		}
+
+		mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).
+			SetEnabledMechanisms([]string{SASLTypeOAuth})
+
+		if test.mockHandshakeErr != ErrNoError {
+			mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr)
 		}
 
 		mockBroker.SetHandlerByMap(map[string]MockResponse{
 			"SaslAuthenticateRequest": mockSASLAuthResponse,
-			"SaslHandshakeRequest": NewMockSaslHandshakeResponse(t).
-				SetEnabledMechanisms([]string{SASLTypeOAuth}),
+			"SaslHandshakeRequest":    mockSASLHandshakeResponse,
 		})
 
 		// broker executes SASL requests against mockBroker
@@ -179,6 +213,7 @@ func TestReceiveSASLOAuthBearerClientResponse(t *testing.T) {
 		conf := NewConfig()
 		conf.Net.SASL.Mechanism = SASLTypeOAuth
 		conf.Net.SASL.TokenProvider = test.tokProvider
+		conf.Net.SASL.Extensions = test.extensions
 
 		broker.conf = conf
 
@@ -198,8 +233,18 @@ func TestReceiveSASLOAuthBearerClientResponse(t *testing.T) {
 
 		err = broker.authenticateViaSASL()
 
-		if test.err != err {
-			t.Errorf("[%d]:[%s] Expected %s error, got %s\n", i, test.name, test.err, err)
+		if test.mockAuthErr != ErrNoError {
+			if test.mockAuthErr != err {
+				t.Errorf("[%d]:[%s] Expected %s auth error, got %s\n", i, test.name, test.mockAuthErr, err)
+			}
+		} else if test.mockHandshakeErr != ErrNoError {
+			if test.mockHandshakeErr != err {
+				t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.mockHandshakeErr, err)
+			}
+		} else if test.expectClientErr && err == nil {
+			t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name)
+		} else if !test.expectClientErr && err != nil {
+			t.Errorf("[%d]:[%s] Unexpected error, got %s\n", i, test.name, err)
 		}
 
 		mockBroker.Close()

+ 5 - 0
mockresponses.go

@@ -747,6 +747,11 @@ func (mshr *MockSaslHandshakeResponse) For(reqBody versionedDecoder) encoder {
 	return res
 }
 
+func (mshr *MockSaslHandshakeResponse) SetError(kerror KError) *MockSaslHandshakeResponse {
+	mshr.kerror = kerror
+	return mshr
+}
+
 func (mshr *MockSaslHandshakeResponse) SetEnabledMechanisms(enabledMechanisms []string) *MockSaslHandshakeResponse {
 	mshr.enabledMechanisms = enabledMechanisms
 	return mshr