Browse Source

Merge pull request #1428 from mkaminski1988/master

Handle SASL/OAUTHBEARER token rejection
Vlad Gorodetsky 5 years ago
parent
commit
d1948414ad
2 changed files with 99 additions and 68 deletions
  1. 38 21
      broker.go
  2. 61 47
      broker_test.go

+ 38 - 21
broker.go

@@ -1013,7 +1013,7 @@ func (b *Broker) sendAndReceiveV1SASLPlainAuth() error {
 
 	b.correlationID++
 
-	bytesRead, err := b.receiveSASLServerResponse(correlationID)
+	bytesRead, err := b.receiveSASLServerResponse(&SaslAuthenticateResponse{}, correlationID)
 	b.updateIncomingCommunicationMetrics(bytesRead, time.Since(requestTime))
 
 	// With v1 sasl we get an error message set in the response we can return
@@ -1037,26 +1037,53 @@ func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error {
 		return err
 	}
 
+	message, err := buildClientFirstMessage(token)
+	if err != nil {
+		return err
+	}
+
+	challenged, err := b.sendClientMessage(message)
+	if err != nil {
+		return err
+	}
+
+	if challenged {
+		// Abort the token exchange. The broker returns the failure code.
+		_, err = b.sendClientMessage([]byte(`\x01`))
+	}
+
+	return err
+}
+
+// sendClientMessage sends a SASL/OAUTHBEARER client message and returns true
+// if the broker responds with a challenge, in which case the token is
+// rejected.
+func (b *Broker) sendClientMessage(message []byte) (bool, error) {
+
 	requestTime := time.Now()
 	correlationID := b.correlationID
 
-	bytesWritten, err := b.sendSASLOAuthBearerClientResponse(token, correlationID)
+	bytesWritten, err := b.sendSASLOAuthBearerClientMessage(message, correlationID)
 	if err != nil {
-		return err
+		return false, err
 	}
 
 	b.updateOutgoingCommunicationMetrics(bytesWritten)
 	b.correlationID++
 
-	bytesRead, err := b.receiveSASLServerResponse(correlationID)
-	if err != nil {
-		return err
-	}
+	res := &SaslAuthenticateResponse{}
+	bytesRead, err := b.receiveSASLServerResponse(res, correlationID)
 
 	requestLatency := time.Since(requestTime)
 	b.updateIncomingCommunicationMetrics(bytesRead, requestLatency)
 
-	return nil
+	isChallenge := len(res.SaslAuthBytes) > 0
+
+	if isChallenge && err != nil {
+		Logger.Printf("Broker rejected authentication token: %s", res.SaslAuthBytes)
+	}
+
+	return isChallenge, err
 }
 
 func (b *Broker) sendAndReceiveSASLSCRAMv1() error {
@@ -1154,7 +1181,7 @@ func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, e
 
 // Build SASL/OAUTHBEARER initial client response as described by RFC-7628
 // https://tools.ietf.org/html/rfc7628
-func buildClientInitialResponse(token *AccessToken) ([]byte, error) {
+func buildClientFirstMessage(token *AccessToken) ([]byte, error) {
 	var ext string
 
 	if token.Extensions != nil && len(token.Extensions) > 0 {
@@ -1200,11 +1227,7 @@ func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, erro
 	return b.conn.Write(buf)
 }
 
-func (b *Broker) sendSASLOAuthBearerClientResponse(token *AccessToken, correlationID int32) (int, error) {
-	initialResp, err := buildClientInitialResponse(token)
-	if err != nil {
-		return 0, err
-	}
+func (b *Broker) sendSASLOAuthBearerClientMessage(initialResp []byte, correlationID int32) (int, error) {
 
 	rb := &SaslAuthenticateRequest{initialResp}
 
@@ -1222,7 +1245,7 @@ func (b *Broker) sendSASLOAuthBearerClientResponse(token *AccessToken, correlati
 	return b.conn.Write(buf)
 }
 
-func (b *Broker) receiveSASLServerResponse(correlationID int32) (int, error) {
+func (b *Broker) receiveSASLServerResponse(res *SaslAuthenticateResponse, correlationID int32) (int, error) {
 
 	buf := make([]byte, responseLengthSize+correlationIDSize)
 
@@ -1250,8 +1273,6 @@ func (b *Broker) receiveSASLServerResponse(correlationID int32) (int, error) {
 		return bytesRead, err
 	}
 
-	res := &SaslAuthenticateResponse{}
-
 	if err := versionedDecode(buf, res, 0); err != nil {
 		return bytesRead, err
 	}
@@ -1260,10 +1281,6 @@ func (b *Broker) receiveSASLServerResponse(correlationID int32) (int, error) {
 		return bytesRead, res.Err
 	}
 
-	if len(res.SaslAuthBytes) > 0 {
-		Logger.Printf("Received SASL auth response: %s", res.SaslAuthBytes)
-	}
-
 	return bytesRead, nil
 }
 

+ 61 - 47
broker_test.go

@@ -3,13 +3,13 @@ package sarama
 import (
 	"errors"
 	"fmt"
-	"gopkg.in/jcmturner/gokrb5.v7/krberror"
 	"net"
 	"reflect"
 	"testing"
 	"time"
 
 	"github.com/rcrowley/go-metrics"
+	"gopkg.in/jcmturner/gokrb5.v7/krberror"
 )
 
 func ExampleBroker() {
@@ -132,42 +132,66 @@ func newTokenProvider(token *AccessToken, err error) *TokenProvider {
 func TestSASLOAuthBearer(t *testing.T) {
 
 	testTable := []struct {
-		name             string
-		mockAuthErr      KError // Mock and expect error returned from SaslAuthenticateRequest
-		mockHandshakeErr KError // Mock and expect error returned from SaslHandshakeRequest
-		expectClientErr  bool   // Expect an internal client-side error
-		tokProvider      *TokenProvider
+		name                      string
+		mockSASLHandshakeResponse MockResponse // Mock SaslHandshakeRequest response from broker
+		mockSASLAuthResponse      MockResponse // Mock SaslAuthenticateRequest response from broker
+		expectClientErr           bool         // Expect an internal client-side error
+		expectedBrokerError       KError       // Expected Kafka error returned by client
+		tokProvider               *TokenProvider
 	}{
 		{
-			name:             "SASL/OAUTHBEARER OK server response",
-			mockAuthErr:      ErrNoError,
-			mockHandshakeErr: ErrNoError,
-			tokProvider:      newTokenProvider(&AccessToken{Token: "access-token-123"}, nil),
+			name: "SASL/OAUTHBEARER OK server response",
+			mockSASLHandshakeResponse: NewMockSaslHandshakeResponse(t).
+				SetEnabledMechanisms([]string{SASLTypeOAuth}),
+			mockSASLAuthResponse: NewMockSaslAuthenticateResponse(t),
+			expectClientErr:      false,
+			expectedBrokerError:  ErrNoError,
+			tokProvider:          newTokenProvider(&AccessToken{Token: "access-token-123"}, nil),
 		},
 		{
-			name:             "SASL/OAUTHBEARER authentication failure response",
-			mockAuthErr:      ErrSASLAuthenticationFailed,
-			mockHandshakeErr: ErrNoError,
-			tokProvider:      newTokenProvider(&AccessToken{Token: "access-token-123"}, nil),
+			name: "SASL/OAUTHBEARER authentication failure response",
+			mockSASLHandshakeResponse: NewMockSaslHandshakeResponse(t).
+				SetEnabledMechanisms([]string{SASLTypeOAuth}),
+			mockSASLAuthResponse: NewMockSequence(
+				// First, the broker response with a challenge
+				NewMockSaslAuthenticateResponse(t).
+					SetAuthBytes([]byte(`{"status":"invalid_request1"}`)),
+				// Next, the client terminates the token exchange. Finally, the
+				// broker responds with an error message.
+				NewMockSaslAuthenticateResponse(t).
+					SetAuthBytes([]byte(`{"status":"invalid_request2"}`)).
+					SetError(ErrSASLAuthenticationFailed),
+			),
+			expectClientErr:     true,
+			expectedBrokerError: ErrSASLAuthenticationFailed,
+			tokProvider:         newTokenProvider(&AccessToken{Token: "access-token-123"}, nil),
 		},
 		{
-			name:             "SASL/OAUTHBEARER handshake failure response",
-			mockAuthErr:      ErrNoError,
-			mockHandshakeErr: ErrSASLAuthenticationFailed,
-			tokProvider:      newTokenProvider(&AccessToken{Token: "access-token-123"}, nil),
+			name: "SASL/OAUTHBEARER handshake failure response",
+			mockSASLHandshakeResponse: NewMockSaslHandshakeResponse(t).
+				SetEnabledMechanisms([]string{SASLTypeOAuth}).
+				SetError(ErrSASLAuthenticationFailed),
+			mockSASLAuthResponse: NewMockSaslAuthenticateResponse(t),
+			expectClientErr:      true,
+			expectedBrokerError:  ErrSASLAuthenticationFailed,
+			tokProvider:          newTokenProvider(&AccessToken{Token: "access-token-123"}, nil),
 		},
 		{
-			name:             "SASL/OAUTHBEARER token generation error",
-			mockAuthErr:      ErrNoError,
-			mockHandshakeErr: ErrNoError,
-			expectClientErr:  true,
-			tokProvider:      newTokenProvider(&AccessToken{Token: "access-token-123"}, ErrTokenFailure),
+			name: "SASL/OAUTHBEARER token generation error",
+			mockSASLHandshakeResponse: NewMockSaslHandshakeResponse(t).
+				SetEnabledMechanisms([]string{SASLTypeOAuth}),
+			mockSASLAuthResponse: NewMockSaslAuthenticateResponse(t),
+			expectClientErr:      true,
+			expectedBrokerError:  ErrNoError,
+			tokProvider:          newTokenProvider(&AccessToken{Token: "access-token-123"}, ErrTokenFailure),
 		},
 		{
-			name:             "SASL/OAUTHBEARER invalid extension",
-			mockAuthErr:      ErrNoError,
-			mockHandshakeErr: ErrNoError,
-			expectClientErr:  true,
+			name: "SASL/OAUTHBEARER invalid extension",
+			mockSASLHandshakeResponse: NewMockSaslHandshakeResponse(t).
+				SetEnabledMechanisms([]string{SASLTypeOAuth}),
+			mockSASLAuthResponse: NewMockSaslAuthenticateResponse(t),
+			expectClientErr:      true,
+			expectedBrokerError:  ErrNoError,
 			tokProvider: newTokenProvider(&AccessToken{
 				Token:      "access-token-123",
 				Extensions: map[string]string{"auth": "auth-value"},
@@ -180,19 +204,9 @@ func TestSASLOAuthBearer(t *testing.T) {
 		// mockBroker mocks underlying network logic and broker responses
 		mockBroker := NewMockBroker(t, 0)
 
-		mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).SetAuthBytes([]byte("response_payload"))
-		if test.mockAuthErr != ErrNoError {
-			mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockAuthErr)
-		}
-
-		mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).SetEnabledMechanisms([]string{SASLTypeOAuth})
-		if test.mockHandshakeErr != ErrNoError {
-			mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr)
-		}
-
 		mockBroker.SetHandlerByMap(map[string]MockResponse{
-			"SaslAuthenticateRequest": mockSASLAuthResponse,
-			"SaslHandshakeRequest":    mockSASLHandshakeResponse,
+			"SaslAuthenticateRequest": test.mockSASLAuthResponse,
+			"SaslHandshakeRequest":    test.mockSASLHandshakeResponse,
 		})
 
 		// broker executes SASL requests against mockBroker
@@ -227,13 +241,13 @@ func TestSASLOAuthBearer(t *testing.T) {
 
 		err = broker.authenticateViaSASL()
 
-		if test.mockAuthErr != ErrNoError {
-			if test.mockAuthErr != err {
-				t.Errorf("[%d]:[%s] Expected %s auth error, got %s\n", i, test.name, test.mockAuthErr, err)
+		if test.expectedBrokerError != ErrNoError {
+			if test.expectedBrokerError != err {
+				t.Errorf("[%d]:[%s] Expected %s auth error, got %s\n", i, test.name, test.expectedBrokerError, err)
 			}
-		} else if test.mockHandshakeErr != ErrNoError {
-			if test.mockHandshakeErr != err {
-				t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.mockHandshakeErr, err)
+		} else if test.expectedBrokerError != ErrNoError {
+			if test.expectedBrokerError != err {
+				t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.expectedBrokerError, err)
 			}
 		} else if test.expectClientErr && err == nil {
 			t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name)
@@ -599,7 +613,7 @@ func TestGSSAPIKerberosAuth_Authorize(t *testing.T) {
 
 }
 
-func TestBuildClientInitialResponse(t *testing.T) {
+func TestBuildClientFirstMessage(t *testing.T) {
 
 	testTable := []struct {
 		name        string
@@ -638,7 +652,7 @@ func TestBuildClientInitialResponse(t *testing.T) {
 
 	for i, test := range testTable {
 
-		actual, err := buildClientInitialResponse(test.token)
+		actual, err := buildClientFirstMessage(test.token)
 
 		if !reflect.DeepEqual(test.expected, actual) {
 			t.Errorf("Expected %s, got %s\n", test.expected, actual)