Kaynağa Gözat

Implement SASL/OAUTHBEARER support

This commit implements SASL/OAUTHBEARER client authentication as
described in KIP-255.
Mike Kaminski 7 yıl önce
ebeveyn
işleme
f5b3e0ed01

+ 150 - 4
broker.go

@@ -46,6 +46,24 @@ type Broker struct {
 	brokerResponseSize     metrics.Histogram
 }
 
+// SaslMechanism specifies the SASL mechanism the client uses to authenticate with the broker
+type SaslMechanism string
+
+const (
+	// SaslTypeOAuth represents the SASL/OAUTHBEARER mechanism (Kafka 2.0.0+)
+	SaslTypeOAuth = "OAUTHBEARER"
+	// SaslTypePlaintext represents the SASL/PLAIN mechanism
+	SaslTypePlaintext = "PLAIN"
+)
+
+// OAuthBearerTokenProvider is an interface that encapsualtes bearer token creation in an
+// unopinionated way. Users are free to decide how to construct bearer tokens and how often
+// they are be refreshed.
+type OAuthBearerTokenProvider interface {
+	// Return a valid bearer token.
+	Token() []byte
+}
+
 type responsePromise struct {
 	requestTime   time.Time
 	correlationID int32
@@ -125,7 +143,16 @@ func (b *Broker) Open(conf *Config) error {
 		}
 
 		if conf.Net.SASL.Enable {
-			b.connErr = b.sendAndReceiveSASLPlainAuth()
+
+			switch conf.Net.SASL.Mechanism {
+			case SaslTypeOAuth:
+				b.connErr = b.sendAndReceiveSASLOAuth(conf.Net.SASL.TokenProvider)
+			case SaslTypePlaintext:
+				b.connErr = b.sendAndReceiveSASLPlainAuth()
+			default:
+				b.connErr = b.sendAndReceiveSASLPlainAuth()
+			}
+
 			if b.connErr != nil {
 				err = b.conn.Close()
 				if err == nil {
@@ -744,8 +771,9 @@ func (b *Broker) responseReceiver() {
 	close(b.done)
 }
 
-func (b *Broker) sendAndReceiveSASLPlainHandshake() error {
-	rb := &SaslHandshakeRequest{"PLAIN"}
+func (b *Broker) sendAndReceiveSASLHandshake(saslType string, version int16) error {
+	rb := &SaslHandshakeRequest{saslType, version}
+
 	req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
 	buf, err := encode(req, b.conf.MetricRegistry)
 	if err != nil {
@@ -814,7 +842,7 @@ func (b *Broker) sendAndReceiveSASLPlainHandshake() error {
 // of responding to bad credentials but thats how its being done today.
 func (b *Broker) sendAndReceiveSASLPlainAuth() error {
 	if b.conf.Net.SASL.Handshake {
-		handshakeErr := b.sendAndReceiveSASLPlainHandshake()
+		handshakeErr := b.sendAndReceiveSASLHandshake(SaslTypePlaintext, 0)
 		if handshakeErr != nil {
 			Logger.Printf("Error while performing SASL handshake %s\n", b.addr)
 			return handshakeErr
@@ -853,6 +881,124 @@ func (b *Broker) sendAndReceiveSASLPlainAuth() error {
 	return nil
 }
 
+// sendAndReceiveSASLOAuth performs the authentication flow as described by KIP-255
+// https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=75968876
+func (b *Broker) sendAndReceiveSASLOAuth(tokenProvider OAuthBearerTokenProvider) error {
+
+	// This version allows us to wrap tokens in the Kafka protocol, as opposed
+	// to sending opaque packets
+	handshakeVersion := int16(1)
+
+	if err := b.sendAndReceiveSASLHandshake(SaslTypeOAuth, handshakeVersion); err != nil {
+		Logger.Printf("Error while performing SASL handshake %s\n", b.addr)
+		return err
+	}
+
+	requestTime := time.Now()
+
+	var bytesWritten int
+	var err error
+
+	if bytesWritten, err = b.sendSASLOAuthBearerClientResponse(tokenProvider.Token()); err != nil {
+		return err
+	}
+
+	b.updateOutgoingCommunicationMetrics(bytesWritten)
+
+	var bytesRead int
+
+	if bytesRead, err = b.receiveSASLOAuthBearerServerResponse(); err != nil {
+		return err
+	}
+
+	requestLatency := time.Since(requestTime)
+	b.updateIncomingCommunicationMetrics(bytesRead, requestLatency)
+
+	return nil
+}
+
+func (b *Broker) sendSASLOAuthBearerClientResponse(bearerToken []byte) (int, error) {
+
+	// Initial client response as described by RFC-7628
+	// https://tools.ietf.org/html/rfc7628
+	oauthRequest := []byte(`n,,`)
+	oauthRequest = append(oauthRequest, '\x01')
+	oauthRequest = append(oauthRequest, []byte(`auth=Bearer `)...)
+	oauthRequest = append(oauthRequest, bearerToken...)
+	oauthRequest = append(oauthRequest, '\x01')
+	oauthRequest = append(oauthRequest, '\x01')
+
+	rb := &SaslAuthenticateRequest{oauthRequest}
+
+	req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
+
+	var buf []byte
+	var err error
+
+	buf, err = encode(req, b.conf.MetricRegistry)
+
+	if err != nil {
+		return 0, err
+	}
+
+	if err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)); err != nil {
+		return 0, err
+	}
+
+	var bytesWritten int
+
+	if bytesWritten, err = b.conn.Write(buf); err != nil {
+		Logger.Printf("Failed to send SASL/OAUTHBEARER initial request %s: %s\n", b.addr, err.Error())
+		return 0, err
+	}
+
+	b.correlationID++
+
+	return bytesWritten, nil
+}
+
+func (b *Broker) receiveSASLOAuthBearerServerResponse() (int, error) {
+
+	var bytesRead int
+	var err error
+	totalBytesRead := 0
+
+	header := make([]byte, 8)
+
+	if bytesRead, err = io.ReadFull(b.conn, header); err != nil {
+		Logger.Printf("Failed to read SASL/OAUTHBEARER intitial response header : %s\n", err.Error())
+		return 0, err
+	}
+
+	totalBytesRead += bytesRead
+
+	length := binary.BigEndian.Uint32(header[:4])
+	payload := make([]byte, length-4)
+
+	if bytesRead, err = io.ReadFull(b.conn, payload); err != nil {
+		Logger.Printf("Failed to read SASL/OAUTHBEARER intitial response payload : %s\n", err.Error())
+		return bytesRead, err
+	}
+
+	totalBytesRead += bytesRead
+
+	res := &SaslAuthenticateResponse{}
+
+	if err := versionedDecode(payload, res, 0); err != nil {
+		Logger.Printf("Failed to parse SASL/OAUTHBEARER intitial response : %s\n", err.Error())
+		return bytesRead, err
+	}
+
+	if res.Err != ErrNoError {
+		Logger.Printf("Invalid SASL/OAUTHBEARER request : %s\n", res.Err.Error())
+		return bytesRead, res.Err
+	}
+
+	Logger.Print("Successfully authenticated via SASL/OAUTHBEARER")
+
+	return totalBytesRead, nil
+}
+
 func (b *Broker) updateIncomingCommunicationMetrics(bytes int, requestLatency time.Duration) {
 	b.updateRequestLatencyMetrics(requestLatency)
 	b.responseRate.Mark(1)

+ 28 - 4
config.go

@@ -54,6 +54,8 @@ type Config struct {
 			// Whether or not to use SASL authentication when connecting to the broker
 			// (defaults to false).
 			Enable bool
+			// The type of SASL mechanism to enable. Possible values: OAUTHBEARER, PLAIN (defaults to PLAIN)
+			Mechanism SaslMechanism
 			// Whether or not to send the Kafka SASL handshake first if enabled
 			// (defaults to true). You should only set this to false if you're using
 			// a non-Kafka SASL proxy.
@@ -61,6 +63,10 @@ type Config struct {
 			//username and password for SASL/PLAIN authentication
 			User     string
 			Password string
+			// TokenProvider is a bearer token generator for the OAUTHBEARER flow. You can define an instance of
+			// OAuthBearerTokenProvider that generates authentication tokens according to your Kafka cluster's
+			// configuration.
+			TokenProvider OAuthBearerTokenProvider
 		}
 
 		// KeepAlive specifies the keep-alive period for an active network connection.
@@ -454,10 +460,28 @@ func (c *Config) Validate() error {
 		return ConfigurationError("Net.WriteTimeout must be > 0")
 	case c.Net.KeepAlive < 0:
 		return ConfigurationError("Net.KeepAlive must be >= 0")
-	case c.Net.SASL.Enable == true && c.Net.SASL.User == "":
-		return ConfigurationError("Net.SASL.User must not be empty when SASL is enabled")
-	case c.Net.SASL.Enable == true && c.Net.SASL.Password == "":
-		return ConfigurationError("Net.SASL.Password must not be empty when SASL is enabled")
+	case c.Net.SASL.Enable == true:
+		// For backwards compatibility, empty mechanism value defaults to PLAIN
+		isSaslPlain := len(c.Net.SASL.Mechanism) == 0 || c.Net.SASL.Mechanism == SaslTypePlaintext
+		if isSaslPlain {
+			if c.Net.SASL.User == "" {
+				return ConfigurationError("Net.SASL.User must not be empty when SASL is enabled")
+			}
+			if c.Net.SASL.Password == "" {
+				return ConfigurationError("Net.SASL.Password must not be empty when SASL is enabled")
+			}
+		} else if c.Net.SASL.Mechanism == SaslTypeOAuth {
+			if c.Net.SASL.TokenProvider == nil {
+				return ConfigurationError("A OAuthBearerTokenProvider instance must be provided to Net.SASL.User.TokenProvider")
+			}
+			if !c.Net.SASL.Handshake {
+				Logger.Println("A SASL hanshake is required for SASL/OAUTHBEARER, ignoring disabled handshake config")
+			}
+		} else {
+			msg := fmt.Sprintf("The SASL mechanism configuration is invalid. Possible values are `%s` and `%s`",
+				SaslTypeOAuth, SaslTypePlaintext)
+			return ConfigurationError(msg)
+		}
 	}
 
 	// validate the Admin values

+ 2 - 0
request.go

@@ -140,6 +140,8 @@ func allocateBody(key, version int16) protocolBody {
 		return &DescribeConfigsRequest{}
 	case 33:
 		return &AlterConfigsRequest{}
+	case 36:
+		return &SaslAuthenticateRequest{}
 	case 37:
 		return &CreatePartitionsRequest{}
 	case 42:

+ 33 - 0
sasl_authenticate_request.go

@@ -0,0 +1,33 @@
+package sarama
+
+type SaslAuthenticateRequest struct {
+	SaslAuthBytes []byte
+}
+
+func (r *SaslAuthenticateRequest) encode(pe packetEncoder) error {
+	if err := pe.putBytes(r.SaslAuthBytes); err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (r *SaslAuthenticateRequest) decode(pd packetDecoder, version int16) (err error) {
+	if r.SaslAuthBytes, err = pd.getBytes(); err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (r *SaslAuthenticateRequest) key() int16 {
+	return 36
+}
+
+func (r *SaslAuthenticateRequest) version() int16 {
+	return 0
+}
+
+func (r *SaslAuthenticateRequest) requiredVersion() KafkaVersion {
+	return V1_0_0_0
+}

+ 17 - 0
sasl_authenticate_request_test.go

@@ -0,0 +1,17 @@
+package sarama
+
+import "testing"
+
+var (
+	saslAuthenticateRequest = []byte{
+		0, 3, 'f', 'o', 'o',
+	}
+)
+
+func TestSaslAuthenticateRequest(t *testing.T) {
+	var request *SaslHandshakeRequest
+
+	request = new(SaslHandshakeRequest)
+	request.Mechanism = "foo"
+	testRequest(t, "basic", request, saslAuthenticateRequest)
+}

+ 46 - 0
sasl_authenticate_response.go

@@ -0,0 +1,46 @@
+package sarama
+
+type SaslAuthenticateResponse struct {
+	Err           KError
+	ErrorMessage  *string
+	SaslAuthBytes []byte
+}
+
+func (r *SaslAuthenticateResponse) encode(pe packetEncoder) error {
+	pe.putInt16(int16(r.Err))
+	if err := pe.putNullableString(r.ErrorMessage); err != nil {
+		return err
+	}
+	return pe.putBytes(r.SaslAuthBytes)
+}
+
+func (r *SaslAuthenticateResponse) decode(pd packetDecoder, version int16) error {
+	kerr, err := pd.getInt16()
+	if err != nil {
+		return err
+	}
+
+	r.Err = KError(kerr)
+
+	if r.ErrorMessage, err = pd.getNullableString(); err != nil {
+		return err
+	}
+
+	if r.SaslAuthBytes, err = pd.getBytes(); err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (r *SaslAuthenticateResponse) key() int16 {
+	return 36
+}
+
+func (r *SaslAuthenticateResponse) version() int16 {
+	return 0
+}
+
+func (r *SaslAuthenticateResponse) requiredVersion() KafkaVersion {
+	return V1_0_0_0
+}

+ 27 - 0
sasl_authenticate_response_test.go

@@ -0,0 +1,27 @@
+package sarama
+
+import "testing"
+
+var (
+	saslAuthenticatResponse = []byte{
+		0, 0,
+		0, 3, 'e', 'r', 'r',
+		0, 0, 0, 3, 'm', 's', 'g',
+	}
+)
+
+func TestSaslAuthenticateResponse(t *testing.T) {
+	var response *SaslAuthenticateResponse
+
+	response = new(SaslAuthenticateResponse)
+	testVersionDecodable(t, "no error", response, saslAuthenticatResponse, 0)
+	if response.Err != ErrNoError {
+		t.Error("Decoding error failed: no error expected but found", response.Err)
+	}
+	if *response.ErrorMessage != "err" {
+		t.Error("Decoding error failed: expected 'err' but found", *response.ErrorMessage)
+	}
+	if string(response.SaslAuthBytes) != "msg" {
+		t.Error("Decoding error failed: expected 'msg' but found", string(response.SaslAuthBytes))
+	}
+}

+ 2 - 1
sasl_handshake_request.go

@@ -2,6 +2,7 @@ package sarama
 
 type SaslHandshakeRequest struct {
 	Mechanism string
+	Version   int16
 }
 
 func (r *SaslHandshakeRequest) encode(pe packetEncoder) error {
@@ -25,7 +26,7 @@ func (r *SaslHandshakeRequest) key() int16 {
 }
 
 func (r *SaslHandshakeRequest) version() int16 {
-	return 0
+	return r.Version
 }
 
 func (r *SaslHandshakeRequest) requiredVersion() KafkaVersion {