Browse Source

Add support for extensions

Mike Kaminski 7 years ago
parent
commit
741ad00c46
3 changed files with 106 additions and 9 deletions
  1. 45 8
      broker.go
  2. 55 1
      broker_test.go
  3. 6 0
      config.go

+ 45 - 8
broker.go

@@ -6,7 +6,9 @@ import (
 	"fmt"
 	"io"
 	"net"
+	"sort"
 	"strconv"
+	"strings"
 	"sync"
 	"sync/atomic"
 	"time"
@@ -61,6 +63,9 @@ const (
 	// SASLHandshakeV1 is v1 of the Kafka SASL handshake protocol. Client and
 	// server negotiate SASL by wrapping tokens with Kafka protocol headers.
 	SASLHandshakeV1 = int16(1)
+	// SASLExtKeyAuth is the reserved extension key name sent as part of the
+	// SASL/OAUTHBEARER intial client response
+	SASLExtKeyAuth = "auth"
 )
 
 // AccessTokenProvider is the interface that encapsulates how implementors
@@ -156,7 +161,7 @@ func (b *Broker) Open(conf *Config) error {
 
 			switch conf.Net.SASL.Mechanism {
 			case SASLTypeOAuth:
-				b.connErr = b.sendAndReceiveSASLOAuth(conf.Net.SASL.TokenProvider)
+				b.connErr = b.sendAndReceiveSASLOAuth(conf.Net.SASL.TokenProvider, conf.Net.SASL.Extensions)
 			case SASLTypePlaintext:
 				b.connErr = b.sendAndReceiveSASLPlainAuth()
 			default:
@@ -893,7 +898,7 @@ func (b *Broker) sendAndReceiveSASLPlainAuth() error {
 
 // sendAndReceiveSASLOAuth performs the authentication flow as described by KIP-255
 // https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=75968876
-func (b *Broker) sendAndReceiveSASLOAuth(tokenProvider AccessTokenProvider) error {
+func (b *Broker) sendAndReceiveSASLOAuth(tokenProvider AccessTokenProvider, extensions map[string]string) error {
 
 	if err := b.sendAndReceiveSASLHandshake(SASLTypeOAuth, SASLHandshakeV1); err != nil {
 		return err
@@ -909,7 +914,7 @@ func (b *Broker) sendAndReceiveSASLOAuth(tokenProvider AccessTokenProvider) erro
 
 	correlationID := b.correlationID
 
-	bytesWritten, err := b.sendSASLOAuthBearerClientResponse(token, correlationID)
+	bytesWritten, err := b.sendSASLOAuthBearerClientResponse(token, extensions, correlationID)
 
 	if err != nil {
 		return err
@@ -931,13 +936,45 @@ func (b *Broker) sendAndReceiveSASLOAuth(tokenProvider AccessTokenProvider) erro
 	return nil
 }
 
-func (b *Broker) sendSASLOAuthBearerClientResponse(bearerToken string, correlationID int32) (int, error) {
+// Build SASL/OAUTHBEARER initial client response as described by RFC-7628
+// https://tools.ietf.org/html/rfc7628
+func buildClientInitialResponse(bearerToken string, extensions map[string]string) ([]byte, error) {
 
-	// Initial client response as described by RFC-7628.
-	// https://tools.ietf.org/html/rfc7628
-	oauthRequest := []byte(fmt.Sprintf("n,,\x01auth=Bearer %s\x01\x01", bearerToken))
+	if _, ok := extensions[SASLExtKeyAuth]; ok {
+		return []byte{}, fmt.Errorf("The extension `%s` is invalid", SASLExtKeyAuth)
+	}
+
+	extensions[SASLExtKeyAuth] = "Bearer " + bearerToken
+
+	resp := []byte(fmt.Sprintf("n,,\x01%s\x01\x01", mapToString(extensions, "=", "\x01")))
+
+	return resp, nil
+}
+
+// mapToString returns a list of key-value pairs ordered by key.
+// keyValSep separates the key from the value. elemSep separates each pair.
+func mapToString(extensions map[string]string, keyValSep string, elemSep string) string {
+
+	buf := make([]string, 0, len(extensions))
+
+	for k, v := range extensions {
+		buf = append(buf, k+keyValSep+v)
+	}
+
+	sort.Strings(buf)
+
+	return strings.Join(buf, elemSep)
+}
+
+func (b *Broker) sendSASLOAuthBearerClientResponse(bearerToken string, extensions map[string]string, correlationID int32) (int, error) {
+
+	initialResp, err := buildClientInitialResponse(bearerToken, extensions)
+
+	if err != nil {
+		return 0, err
+	}
 
-	rb := &SaslAuthenticateRequest{oauthRequest}
+	rb := &SaslAuthenticateRequest{initialResp}
 
 	req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
 

+ 55 - 1
broker_test.go

@@ -4,6 +4,7 @@ import (
 	"errors"
 	"fmt"
 	"net"
+	"reflect"
 	"testing"
 	"time"
 
@@ -192,7 +193,7 @@ func TestReceiveSASLOAuthBearerClientResponse(t *testing.T) {
 
 		broker.conn = conn
 
-		err = broker.sendAndReceiveSASLOAuth(test.tokProvider)
+		err = broker.sendAndReceiveSASLOAuth(test.tokProvider, make(map[string]string))
 
 		if test.err != err {
 			t.Errorf("[%d]:[%s] Expected %s error, got %s\n", i, test.name, test.err, err)
@@ -202,6 +203,59 @@ func TestReceiveSASLOAuthBearerClientResponse(t *testing.T) {
 	}
 }
 
+func TestBuildClientInitialResponse(t *testing.T) {
+
+	testTable := []struct {
+		name        string
+		token       string
+		extensions  map[string]string
+		expected    []byte
+		expectError bool
+	}{
+		{
+			"Build SASL client initial response with two extensions",
+			"the-token",
+			map[string]string{
+				"x": "1",
+				"y": "2",
+			},
+			[]byte("n,,\x01auth=Bearer the-token\x01x=1\x01y=2\x01\x01"),
+			false,
+		},
+		{
+			"Build SASL client initial response with no extensions",
+			"the-token",
+			map[string]string{},
+			[]byte("n,,\x01auth=Bearer the-token\x01\x01"),
+			false,
+		},
+		{
+			"Build SASL client initial response using reserved extension",
+			"the-token",
+			map[string]string{
+				"auth": "auth-value",
+			},
+			[]byte(""),
+			true,
+		},
+	}
+
+	for i, test := range testTable {
+
+		actual, err := buildClientInitialResponse(test.token, test.extensions)
+
+		if !reflect.DeepEqual(test.expected, actual) {
+			t.Errorf("Expected %s, got %s\n", test.expected, actual)
+		}
+		if test.expectError && err == nil {
+			t.Errorf("[%d]:[%s] Expected an error but did not get one", i, test.name)
+		}
+		if !test.expectError && err != nil {
+			t.Errorf("[%d]:[%s] Expected no error but got %s\n", i, test.name, err)
+		}
+	}
+}
+
 // We're not testing encoding/decoding here, so most of the requests/responses will be empty for simplicity's sake
 var brokerTestTable = []struct {
 	version  KafkaVersion

+ 6 - 0
config.go

@@ -69,6 +69,11 @@ type Config struct {
 			// AccessTokenProvider interface docs for proper implementation
 			// guidelines.
 			TokenProvider AccessTokenProvider
+			// Extensions is a optional map of arbitrary key-value pairs that
+			// can be sent with the SASL/OAUTHBEARER initial client response.
+			// These values are ignored by the SASL server if they are
+			// unexpected.
+			Extensions map[string]string
 		}
 
 		// KeepAlive specifies the keep-alive period for an active network connection.
@@ -358,6 +363,7 @@ func NewConfig() *Config {
 	c.Net.ReadTimeout = 30 * time.Second
 	c.Net.WriteTimeout = 30 * time.Second
 	c.Net.SASL.Handshake = true
+	c.Net.SASL.Extensions = make(map[string]string)
 
 	c.Metadata.Retry.Max = 3
 	c.Metadata.Retry.Backoff = 250 * time.Millisecond