Browse Source

Move extensions from configuration to access token struct

Mike Kaminski 7 years ago
parent
commit
d47ce2db92
4 changed files with 53 additions and 53 deletions
  1. 25 15
      broker.go
  2. 26 30
      broker_test.go
  3. 0 6
      config.go
  4. 2 2
      config_test.go

+ 25 - 15
broker.go

@@ -67,6 +67,17 @@ const (
 	SASLExtKeyAuth = "auth"
 )
 
+// AccessToken contains an access token used to authenticate a
+// SASL/OAUTHBEARER client along with associated metadata.
+type AccessToken struct {
+	// Token is the access token payload.
+	Token string
+	// Extensions is a optional map of arbitrary key-value pairs that can be
+	// sent with the SASL/OAUTHBEARER initial client response. These values are
+	// ignored by the SASL server if they are unexpected.
+	Extensions map[string]string
+}
+
 // AccessTokenProvider is the interface that encapsulates how implementors
 // can generate access tokens for Kafka broker authentication.
 type AccessTokenProvider interface {
@@ -75,7 +86,7 @@ type AccessTokenProvider interface {
 	// not block indefinitely. A timeout error should be returned after a short
 	// period of inactivity so that the broker connection logic can log
 	// debugging information and retry.
-	Token() (string, error)
+	Token() (*AccessToken, error)
 }
 
 type responsePromise struct {
@@ -780,7 +791,7 @@ func (b *Broker) responseReceiver() {
 
 func (b *Broker) authenticateViaSASL() error {
 	if b.conf.Net.SASL.Mechanism == SASLTypeOAuth {
-		return b.sendAndReceiveSASLOAuth(b.conf.Net.SASL.TokenProvider, b.conf.Net.SASL.Extensions)
+		return b.sendAndReceiveSASLOAuth(b.conf.Net.SASL.TokenProvider)
 	}
 	return b.sendAndReceiveSASLPlainAuth()
 }
@@ -897,13 +908,13 @@ func (b *Broker) sendAndReceiveSASLPlainAuth() error {
 
 // sendAndReceiveSASLOAuth performs the authentication flow as described by KIP-255
 // https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=75968876
-func (b *Broker) sendAndReceiveSASLOAuth(tokenProvider AccessTokenProvider, extensions map[string]string) error {
+func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error {
 
 	if err := b.sendAndReceiveSASLHandshake(SASLTypeOAuth, SASLHandshakeV1); err != nil {
 		return err
 	}
 
-	token, err := tokenProvider.Token()
+	token, err := provider.Token()
 
 	if err != nil {
 		return err
@@ -913,7 +924,7 @@ func (b *Broker) sendAndReceiveSASLOAuth(tokenProvider AccessTokenProvider, exte
 
 	correlationID := b.correlationID
 
-	bytesWritten, err := b.sendSASLOAuthBearerClientResponse(token, extensions, correlationID)
+	bytesWritten, err := b.sendSASLOAuthBearerClientResponse(token, correlationID)
 
 	if err != nil {
 		return err
@@ -937,19 +948,18 @@ func (b *Broker) sendAndReceiveSASLOAuth(tokenProvider AccessTokenProvider, exte
 
 // Build SASL/OAUTHBEARER initial client response as described by RFC-7628
 // https://tools.ietf.org/html/rfc7628
-func buildClientInitialResponse(bearerToken string, extensions map[string]string) ([]byte, error) {
-
-	if _, ok := extensions[SASLExtKeyAuth]; ok {
-		return []byte{}, fmt.Errorf("The extension `%s` is invalid", SASLExtKeyAuth)
-	}
+func buildClientInitialResponse(token *AccessToken) ([]byte, error) {
 
 	ext := ""
 
-	if len(extensions) > 0 {
-		ext = "\x01" + mapToString(extensions, "=", "\x01")
+	if token.Extensions != nil && len(token.Extensions) > 0 {
+		if _, ok := token.Extensions[SASLExtKeyAuth]; ok {
+			return []byte{}, fmt.Errorf("The extension `%s` is invalid", SASLExtKeyAuth)
+		}
+		ext = "\x01" + mapToString(token.Extensions, "=", "\x01")
 	}
 
-	resp := []byte(fmt.Sprintf("n,,\x01auth=Bearer %s%s\x01\x01", bearerToken, ext))
+	resp := []byte(fmt.Sprintf("n,,\x01auth=Bearer %s%s\x01\x01", token.Token, ext))
 
 	return resp, nil
 }
@@ -969,9 +979,9 @@ func mapToString(extensions map[string]string, keyValSep string, elemSep string)
 	return strings.Join(buf, elemSep)
 }
 
-func (b *Broker) sendSASLOAuthBearerClientResponse(bearerToken string, extensions map[string]string, correlationID int32) (int, error) {
+func (b *Broker) sendSASLOAuthBearerClientResponse(token *AccessToken, correlationID int32) (int, error) {
 
-	initialResp, err := buildClientInitialResponse(bearerToken, extensions)
+	initialResp, err := buildClientInitialResponse(token)
 
 	if err != nil {
 		return 0, err

+ 26 - 30
broker_test.go

@@ -113,17 +113,17 @@ func TestSimpleBrokerCommunication(t *testing.T) {
 var ErrTokenFailure = errors.New("Failure generating token")
 
 type TokenProvider struct {
-	accessToken string
+	accessToken *AccessToken
 	err         error
 }
 
-func (t *TokenProvider) Token() (string, error) {
+func (t *TokenProvider) Token() (*AccessToken, error) {
 	return t.accessToken, t.err
 }
 
-func newTokenProvider(accessToken string, err error) *TokenProvider {
+func newTokenProvider(token *AccessToken, err error) *TokenProvider {
 	return &TokenProvider{
-		accessToken: accessToken,
+		accessToken: token,
 		err:         err,
 	}
 }
@@ -136,44 +136,39 @@ func TestSASLOAuthBearer(t *testing.T) {
 		mockHandshakeErr KError // Mock and expect error returned from SaslHandshakeRequest
 		expectClientErr  bool   // Expect an internal client-side error
 		tokProvider      *TokenProvider
-		extensions       map[string]string
 	}{
 		{"SASL/OAUTHBEARER OK server response",
 			ErrNoError,
 			ErrNoError,
 			false,
-			newTokenProvider("access-token-123", nil),
-			map[string]string{},
+			newTokenProvider(&AccessToken{Token: "access-token-123"}, nil),
 		},
 		{"SASL/OAUTHBEARER authentication failure response",
 			ErrSASLAuthenticationFailed,
 			ErrNoError,
 			false,
-			newTokenProvider("access-token-123", nil),
-			map[string]string{},
+			newTokenProvider(&AccessToken{Token: "access-token-123"}, nil),
 		},
 		{"SASL/OAUTHBEARER handshake failure response",
 			ErrNoError,
 			ErrSASLAuthenticationFailed,
 			false,
-			newTokenProvider("access-token-123", nil),
-			map[string]string{},
+			newTokenProvider(&AccessToken{Token: "access-token-123"}, nil),
 		},
 		{"SASL/OAUTHBEARER token generation error",
 			ErrNoError,
 			ErrNoError,
 			true,
-			newTokenProvider("access-token-123", ErrTokenFailure),
-			map[string]string{},
+			newTokenProvider(&AccessToken{Token: "access-token-123"}, ErrTokenFailure),
 		},
 		{"SASL/OAUTHBEARER invalid extension",
 			ErrNoError,
 			ErrNoError,
 			true,
-			newTokenProvider("access-token-123", nil),
-			map[string]string{
-				"auth": "auth-value",
-			},
+			newTokenProvider(&AccessToken{
+				Token:      "access-token-123",
+				Extensions: map[string]string{"auth": "auth-value"},
+			}, nil),
 		},
 	}
 
@@ -213,7 +208,6 @@ func TestSASLOAuthBearer(t *testing.T) {
 		conf := NewConfig()
 		conf.Net.SASL.Mechanism = SASLTypeOAuth
 		conf.Net.SASL.TokenProvider = test.tokProvider
-		conf.Net.SASL.Extensions = test.extensions
 
 		broker.conf = conf
 
@@ -255,33 +249,35 @@ func TestBuildClientInitialResponse(t *testing.T) {
 
 	testTable := []struct {
 		name        string
-		token       string
-		extensions  map[string]string
+		token       *AccessToken
 		expected    []byte
 		expectError bool
 	}{
 		{
 			"Build SASL client initial response with two extensions",
-			"the-token",
-			map[string]string{
-				"x": "1",
-				"y": "2",
+			&AccessToken{
+				Token: "the-token",
+				Extensions: map[string]string{
+					"x": "1",
+					"y": "2",
+				},
 			},
 			[]byte("n,,\x01auth=Bearer the-token\x01x=1\x01y=2\x01\x01"),
 			false,
 		},
 		{
 			"Build SASL client initial response with no extensions",
-			"the-token",
-			map[string]string{},
+			&AccessToken{Token: "the-token"},
 			[]byte("n,,\x01auth=Bearer the-token\x01\x01"),
 			false,
 		},
 		{
 			"Build SASL client initial response using reserved extension",
-			"the-token",
-			map[string]string{
-				"auth": "auth-value",
+			&AccessToken{
+				Token: "the-token",
+				Extensions: map[string]string{
+					"auth": "auth-value",
+				},
 			},
 			[]byte(""),
 			true,
@@ -290,7 +286,7 @@ func TestBuildClientInitialResponse(t *testing.T) {
 
 	for i, test := range testTable {
 
-		actual, err := buildClientInitialResponse(test.token, test.extensions)
+		actual, err := buildClientInitialResponse(test.token)
 
 		if !reflect.DeepEqual(test.expected, actual) {
 			t.Errorf("Expected %s, got %s\n", test.expected, actual)

+ 0 - 6
config.go

@@ -69,11 +69,6 @@ type Config struct {
 			// AccessTokenProvider interface docs for proper implementation
 			// guidelines.
 			TokenProvider AccessTokenProvider
-			// Extensions is a optional map of arbitrary key-value pairs that
-			// can be sent with the SASL/OAUTHBEARER initial client response.
-			// These values are ignored by the SASL server if they are
-			// unexpected.
-			Extensions map[string]string
 		}
 
 		// KeepAlive specifies the keep-alive period for an active network connection.
@@ -363,7 +358,6 @@ func NewConfig() *Config {
 	c.Net.ReadTimeout = 30 * time.Second
 	c.Net.WriteTimeout = 30 * time.Second
 	c.Net.SASL.Handshake = true
-	c.Net.SASL.Extensions = make(map[string]string)
 
 	c.Metadata.Retry.Max = 3
 	c.Metadata.Retry.Backoff = 250 * time.Millisecond

+ 2 - 2
config_test.go

@@ -36,8 +36,8 @@ func TestEmptyClientIDConfigValidates(t *testing.T) {
 type DummyTokenProvider struct {
 }
 
-func (t *DummyTokenProvider) Token() (string, error) {
-	return "access-token-string", nil
+func (t *DummyTokenProvider) Token() (*AccessToken, error) {
+	return &AccessToken{Token: "access-token-string"}, nil
 }
 
 func TestNetConfigValidates(t *testing.T) {