Browse Source

Merge pull request #1354 from dnwe/sasl

feat: support SaslHandshakeRequest v1
Vlad Gorodetsky 6 years ago
parent
commit
6685e50575
3 changed files with 187 additions and 8 deletions
  1. 75 8
      broker.go
  2. 100 0
      broker_test.go
  3. 12 0
      client.go

+ 75 - 8
broker.go

@@ -13,7 +13,7 @@ import (
 	"sync/atomic"
 	"time"
 
-	"github.com/rcrowley/go-metrics"
+	metrics "github.com/rcrowley/go-metrics"
 )
 
 // Broker represents a single Kafka broker connection. All operations on this object are entirely concurrency-safe.
@@ -905,8 +905,10 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int
 	return nil
 }
 
-// Kafka 0.10.0 plans to support SASL Plain and Kerberos as per PR #812 (KIP-43)/(JIRA KAFKA-3149)
-// Some hosted kafka services such as IBM Message Hub already offer SASL/PLAIN auth with Kafka 0.9
+// Kafka 0.10.x supported SASL PLAIN/Kerberos via KAFKA-3149 (KIP-43).
+// Kafka 1.x.x onward added a SaslAuthenticate request/response message which
+// wraps the SASL flow in the Kafka protocol, which allows for returning
+// meaningful errors on authentication failure.
 //
 // In SASL Plain, Kafka expects the auth header to be in the following format
 // Message format (from https://tools.ietf.org/html/rfc4616):
@@ -920,18 +922,37 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int
 //   SAFE      = UTF1 / UTF2 / UTF3 / UTF4
 //                  ;; any UTF-8 encoded Unicode character except NUL
 //
+// With SASL v0 handshake and auth then:
 // When credentials are valid, Kafka returns a 4 byte array of null characters.
-// When credentials are invalid, Kafka closes the connection. This does not seem to be the ideal way
-// of responding to bad credentials but thats how its being done today.
+// When credentials are invalid, Kafka closes the connection.
+//
+// With SASL v1 handshake and auth then:
+// When credentials are invalid, Kafka replies with a SaslAuthenticate response
+// containing an error code and message detailing the authentication failure.
 func (b *Broker) sendAndReceiveSASLPlainAuth() error {
+	// default to V0 to allow for backward compatability when SASL is enabled
+	// but not the handshake
+	saslHandshake := SASLHandshakeV0
 	if b.conf.Net.SASL.Handshake {
-		handshakeErr := b.sendAndReceiveSASLHandshake(SASLTypePlaintext, SASLHandshakeV0)
+		if b.conf.Version.IsAtLeast(V1_0_0_0) {
+			saslHandshake = SASLHandshakeV1
+		}
+		handshakeErr := b.sendAndReceiveSASLHandshake(SASLTypePlaintext, saslHandshake)
 		if handshakeErr != nil {
 			Logger.Printf("Error while performing SASL handshake %s\n", b.addr)
 			return handshakeErr
 		}
 	}
 
+	if saslHandshake == SASLHandshakeV1 {
+		return b.sendAndReceiveV1SASLPlainAuth()
+	}
+	return b.sendAndReceiveV0SASLPlainAuth()
+}
+
+// sendAndReceiveV0SASLPlainAuth flows the v0 sasl auth NOT wrapped in the kafka protocol
+func (b *Broker) sendAndReceiveV0SASLPlainAuth() error {
+
 	length := 1 + len(b.conf.Net.SASL.User) + 1 + len(b.conf.Net.SASL.Password)
 	authBytes := make([]byte, length+4) //4 byte length header + auth data
 	binary.BigEndian.PutUint32(authBytes, uint32(length))
@@ -965,6 +986,35 @@ func (b *Broker) sendAndReceiveSASLPlainAuth() error {
 	return nil
 }
 
+// sendAndReceiveV1SASLPlainAuth flows the v1 sasl authentication using the kafka protocol
+func (b *Broker) sendAndReceiveV1SASLPlainAuth() error {
+	correlationID := b.correlationID
+
+	requestTime := time.Now()
+
+	bytesWritten, err := b.sendSASLPlainAuthClientResponse(correlationID)
+
+	b.updateOutgoingCommunicationMetrics(bytesWritten)
+
+	if err != nil {
+		Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error())
+		return err
+	}
+
+	b.correlationID++
+
+	bytesRead, err := b.receiveSASLServerResponse(correlationID)
+	b.updateIncomingCommunicationMetrics(bytesRead, time.Since(requestTime))
+
+	// With v1 sasl we get an error message set in the response we can return
+	if err != nil {
+		Logger.Printf("Error returned from broker during SASL flow %s: %s\n", b.addr, err.Error())
+		return err
+	}
+
+	return nil
+}
+
 // sendAndReceiveSASLOAuth performs the authentication flow as described by KIP-255
 // https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=75968876
 func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error {
@@ -988,7 +1038,7 @@ func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error {
 	b.updateOutgoingCommunicationMetrics(bytesWritten)
 	b.correlationID++
 
-	bytesRead, err := b.receiveSASLOAuthBearerServerResponse(correlationID)
+	bytesRead, err := b.receiveSASLServerResponse(correlationID)
 	if err != nil {
 		return err
 	}
@@ -1123,6 +1173,23 @@ func mapToString(extensions map[string]string, keyValSep string, elemSep string)
 	return strings.Join(buf, elemSep)
 }
 
+func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, error) {
+	authBytes := []byte("\x00" + b.conf.Net.SASL.User + "\x00" + b.conf.Net.SASL.Password)
+	rb := &SaslAuthenticateRequest{authBytes}
+	req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
+	buf, err := encode(req, b.conf.MetricRegistry)
+	if err != nil {
+		return 0, err
+	}
+
+	err = b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout))
+	if err != nil {
+		Logger.Printf("Failed to set write deadline when doing SASL auth with broker %s: %s\n", b.addr, err.Error())
+		return 0, err
+	}
+	return b.conn.Write(buf)
+}
+
 func (b *Broker) sendSASLOAuthBearerClientResponse(token *AccessToken, correlationID int32) (int, error) {
 	initialResp, err := buildClientInitialResponse(token)
 	if err != nil {
@@ -1145,7 +1212,7 @@ func (b *Broker) sendSASLOAuthBearerClientResponse(token *AccessToken, correlati
 	return b.conn.Write(buf)
 }
 
-func (b *Broker) receiveSASLOAuthBearerServerResponse(correlationID int32) (int, error) {
+func (b *Broker) receiveSASLServerResponse(correlationID int32) (int, error) {
 
 	buf := make([]byte, responseLengthSize+correlationIDSize)
 

+ 100 - 0
broker_test.go

@@ -377,6 +377,106 @@ func TestSASLSCRAMSHAXXX(t *testing.T) {
 	}
 }
 
+func TestSASLPlainAuth(t *testing.T) {
+
+	testTable := []struct {
+		name             string
+		mockAuthErr      KError // Mock and expect error returned from SaslAuthenticateRequest
+		mockHandshakeErr KError // Mock and expect error returned from SaslHandshakeRequest
+		expectClientErr  bool   // Expect an internal client-side error
+	}{
+		{
+			name:             "SASL Plain OK server response",
+			mockAuthErr:      ErrNoError,
+			mockHandshakeErr: ErrNoError,
+		},
+		{
+			name:             "SASL Plain authentication failure response",
+			mockAuthErr:      ErrSASLAuthenticationFailed,
+			mockHandshakeErr: ErrNoError,
+		},
+		{
+			name:             "SASL Plain handshake failure response",
+			mockAuthErr:      ErrNoError,
+			mockHandshakeErr: ErrSASLAuthenticationFailed,
+		},
+	}
+
+	for i, test := range testTable {
+
+		// mockBroker mocks underlying network logic and broker responses
+		mockBroker := NewMockBroker(t, 0)
+
+		mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).
+			SetAuthBytes([]byte(`response_payload`))
+
+		if test.mockAuthErr != ErrNoError {
+			mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockAuthErr)
+		}
+
+		mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).
+			SetEnabledMechanisms([]string{SASLTypePlaintext})
+
+		if test.mockHandshakeErr != ErrNoError {
+			mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr)
+		}
+
+		mockBroker.SetHandlerByMap(map[string]MockResponse{
+			"SaslAuthenticateRequest": mockSASLAuthResponse,
+			"SaslHandshakeRequest":    mockSASLHandshakeResponse,
+		})
+
+		// broker executes SASL requests against mockBroker
+		broker := NewBroker(mockBroker.Addr())
+		broker.requestRate = metrics.NilMeter{}
+		broker.outgoingByteRate = metrics.NilMeter{}
+		broker.incomingByteRate = metrics.NilMeter{}
+		broker.requestSize = metrics.NilHistogram{}
+		broker.responseSize = metrics.NilHistogram{}
+		broker.responseRate = metrics.NilMeter{}
+		broker.requestLatency = metrics.NilHistogram{}
+
+		conf := NewConfig()
+		conf.Net.SASL.Mechanism = SASLTypePlaintext
+		conf.Net.SASL.User = "token"
+		conf.Net.SASL.Password = "password"
+
+		broker.conf = conf
+		broker.conf.Version = V1_0_0_0
+		dialer := net.Dialer{
+			Timeout:   conf.Net.DialTimeout,
+			KeepAlive: conf.Net.KeepAlive,
+			LocalAddr: conf.Net.LocalAddr,
+		}
+
+		conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String())
+
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		broker.conn = conn
+
+		err = broker.authenticateViaSASL()
+
+		if test.mockAuthErr != ErrNoError {
+			if test.mockAuthErr != err {
+				t.Errorf("[%d]:[%s] Expected %s auth error, got %s\n", i, test.name, test.mockAuthErr, err)
+			}
+		} else if test.mockHandshakeErr != ErrNoError {
+			if test.mockHandshakeErr != err {
+				t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.mockHandshakeErr, err)
+			}
+		} else if test.expectClientErr && err == nil {
+			t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name)
+		} else if !test.expectClientErr && err != nil {
+			t.Errorf("[%d]:[%s] Unexpected error, got %s\n", i, test.name, err)
+		}
+
+		mockBroker.Close()
+	}
+}
+
 func TestBuildClientInitialResponse(t *testing.T) {
 
 	testTable := []struct {

+ 12 - 0
client.go

@@ -780,6 +780,18 @@ func (client *client) tryRefreshMetadata(topics []string, attemptsRemaining int)
 		case PacketEncodingError:
 			// didn't even send, return the error
 			return err
+
+		case KError:
+			// if SASL auth error return as this _should_ be a non retryable err for all brokers
+			if err.(KError) == ErrSASLAuthenticationFailed {
+				Logger.Println("client/metadata failed SASL authentication")
+				return err
+			}
+			// else remove that broker and try again
+			Logger.Printf("client/metadata got error from broker %d while fetching metadata: %v\n", broker.ID(), err)
+			_ = broker.Close()
+			client.deregisterBroker(broker)
+
 		default:
 			// some other error, remove that broker and try again
 			Logger.Printf("client/metadata got error from broker %d while fetching metadata: %v\n", broker.ID(), err)