Selaa lähdekoodia

Add SASL SCRAM-SHA-512 and SCRAM-SHA-256 mechanismes

Iyed Bennour 6 vuotta sitten
vanhempi
commit
55578535cf
5 muutettua tiedostoa jossa 303 lisäystä ja 22 poistoa
  1. 118 5
      broker.go
  2. 135 6
      broker_test.go
  3. 27 9
      config.go
  4. 20 2
      config_test.go
  5. 3 0
      response_header.go

+ 118 - 5
broker.go

@@ -56,6 +56,10 @@ const (
 	SASLTypeOAuth = "OAUTHBEARER"
 	// SASLTypePlaintext represents the SASL/PLAIN mechanism
 	SASLTypePlaintext = "PLAIN"
+	// SASLTypeSCRAMSHA256 represents the SCRAM-SHA-256 mechanism.
+	SASLTypeSCRAMSHA256 = "SCRAM-SHA-256"
+	// SASLTypeSCRAMSHA512 represents the SCRAM-SHA-512 mechanism.
+	SASLTypeSCRAMSHA512 = "SCRAM-SHA-512"
 	// SASLHandshakeV0 is v0 of the Kafka SASL handshake protocol. Client and
 	// server negotiate SASL auth using opaque packets.
 	SASLHandshakeV0 = int16(0)
@@ -92,6 +96,20 @@ type AccessTokenProvider interface {
 	Token() (*AccessToken, error)
 }
 
+// SCRAMClient is a an interface to a SCRAM
+// client implementation.
+type SCRAMClient interface {
+	// Begin prepares the client for the SCRAM exchange
+	// with the server with a user name and a password
+	Begin(userName, password, authzID string) error
+	// Step steps client through the SCRAM exchange. It is
+	// called repeatedly until it errors or `Done` returns true.
+	Step(challenge string) (response string, err error)
+	// Done should return true when the SCRAM conversation
+	// is over.
+	Done() bool
+}
+
 type responsePromise struct {
 	requestTime   time.Time
 	correlationID int32
@@ -793,14 +811,19 @@ func (b *Broker) responseReceiver() {
 }
 
 func (b *Broker) authenticateViaSASL() error {
-	if b.conf.Net.SASL.Mechanism == SASLTypeOAuth {
+	switch b.conf.Net.SASL.Mechanism {
+	case SASLTypeOAuth:
 		return b.sendAndReceiveSASLOAuth(b.conf.Net.SASL.TokenProvider)
+	case SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512:
+		return b.sendAndReceiveSASLSCRAMv1()
+	default:
+		return b.sendAndReceiveSASLPlainAuth()
 	}
-	return b.sendAndReceiveSASLPlainAuth()
+
 }
 
-func (b *Broker) sendAndReceiveSASLHandshake(saslType string, version int16) error {
-	rb := &SaslHandshakeRequest{Mechanism: saslType, Version: version}
+func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int16) error {
+	rb := &SaslHandshakeRequest{Mechanism: string(saslType), Version: version}
 
 	req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
 	buf, err := encode(req, b.conf.MetricRegistry)
@@ -846,7 +869,7 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType string, version int16) err
 		Logger.Printf("Invalid SASL Mechanism : %s\n", res.Err.Error())
 		return res.Err
 	}
-	Logger.Print("Successful SASL handshake")
+	Logger.Print("Successful SASL handshake. Available mechanisms: ", res.EnabledMechanisms)
 	return nil
 }
 
@@ -949,6 +972,96 @@ func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error {
 	return nil
 }
 
+func (b *Broker) sendAndReceiveSASLSCRAMv1() error {
+	if err := b.sendAndReceiveSASLHandshake(b.conf.Net.SASL.Mechanism, SASLHandshakeV1); err != nil {
+		return err
+	}
+
+	scramClient := b.conf.Net.SASL.SCRAMClient
+	if err := scramClient.Begin(b.conf.Net.SASL.User, b.conf.Net.SASL.Password, b.conf.Net.SASL.SCRAMAuthzID); err != nil {
+		return fmt.Errorf("failed to start SCRAM exchange with the server: %s", err.Error())
+	}
+
+	msg, err := scramClient.Step("")
+	if err != nil {
+		return fmt.Errorf("failed to advance the SCRAM exchange: %s", err.Error())
+
+	}
+
+	for !scramClient.Done() {
+		requestTime := time.Now()
+		correlationID := b.correlationID
+		bytesWritten, err := b.sendSaslAuthenticateRequest(correlationID, []byte(msg))
+		if err != nil {
+			Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error())
+			return err
+		}
+
+		b.updateOutgoingCommunicationMetrics(bytesWritten)
+		b.correlationID++
+		challenge, err := b.receiveSaslAuthenticateResponse(correlationID)
+		if err != nil {
+			Logger.Printf("Failed to read response while authenticating with SASL to broker %s: %s\n", b.addr, err.Error())
+			return err
+		}
+
+		b.updateIncomingCommunicationMetrics(len(challenge), time.Since(requestTime))
+		msg, err = scramClient.Step(string(challenge))
+		if err != nil {
+			Logger.Println("SASL authentication failed", err)
+			return err
+		}
+	}
+	Logger.Println("SASL authentication succeeded")
+	return nil
+}
+
+func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (int, error) {
+	rb := &SaslAuthenticateRequest{msg}
+	req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
+	buf, err := encode(req, b.conf.MetricRegistry)
+	if err != nil {
+		return 0, err
+	}
+	if err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)); err != nil {
+		return 0, err
+	}
+	return b.conn.Write(buf)
+}
+
+func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, error) {
+	buf := make([]byte, responseLengthSize+correlationIDSize)
+	bytesRead, err := io.ReadFull(b.conn, buf)
+	if err != nil {
+		return nil, err
+	}
+	header := responseHeader{}
+	err = decode(buf, &header)
+	if err != nil {
+		return nil, err
+	}
+	if header.correlationID != correlationID {
+		return nil, fmt.Errorf("correlation ID didn't match, wanted %d, got %d", b.correlationID, header.correlationID)
+	}
+	buf = make([]byte, header.length-correlationIDSize)
+	c, err := io.ReadFull(b.conn, buf)
+	bytesRead += c
+	if err != nil {
+		return nil, err
+	}
+	res := &SaslAuthenticateResponse{}
+	if err := versionedDecode(buf, res, 0); err != nil {
+		return nil, err
+	}
+	if err != nil {
+		return nil, err
+	}
+	if res.Err != ErrNoError {
+		return nil, res.Err
+	}
+	return res.SaslAuthBytes, nil
+}
+
 // Build SASL/OAUTHBEARER initial client response as described by RFC-7628
 // https://tools.ietf.org/html/rfc7628
 func buildClientInitialResponse(token *AccessToken) ([]byte, error) {

+ 135 - 6
broker_test.go

@@ -179,16 +179,12 @@ func TestSASLOAuthBearer(t *testing.T) {
 		// mockBroker mocks underlying network logic and broker responses
 		mockBroker := NewMockBroker(t, 0)
 
-		mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).
-			SetAuthBytes([]byte(`response_payload`))
-
+		mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).SetAuthBytes([]byte("response_payload"))
 		if test.mockAuthErr != ErrNoError {
 			mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockAuthErr)
 		}
 
-		mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).
-			SetEnabledMechanisms([]string{SASLTypeOAuth})
-
+		mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).SetEnabledMechanisms([]string{SASLTypeOAuth})
 		if test.mockHandshakeErr != ErrNoError {
 			mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr)
 		}
@@ -248,6 +244,139 @@ func TestSASLOAuthBearer(t *testing.T) {
 	}
 }
 
+// A mock scram client.
+type MockSCRAMClient struct {
+	done bool
+}
+
+func (m *MockSCRAMClient) Begin(userName, password, authzID string) (err error) {
+	return nil
+}
+
+func (m *MockSCRAMClient) Step(challenge string) (response string, err error) {
+	if challenge == "" {
+		return "ping", nil
+	}
+	if challenge == "pong" {
+		m.done = true
+		return "", nil
+	}
+	return "", errors.New("failed to authenticate :(")
+}
+
+func (m *MockSCRAMClient) Done() bool {
+	return m.done
+}
+
+var _ SCRAMClient = &MockSCRAMClient{}
+
+func TestSASLSCRAMSHAXXX(t *testing.T) {
+	testTable := []struct {
+		name               string
+		mockHandshakeErr   KError
+		mockSASLAuthErr    KError
+		expectClientErr    bool
+		scramClient        *MockSCRAMClient
+		scramChallengeResp string
+	}{
+		{
+			name:               "SASL/SCRAMSHAXXX successfull authentication",
+			mockHandshakeErr:   ErrNoError,
+			scramClient:        &MockSCRAMClient{},
+			scramChallengeResp: "pong",
+		},
+		{
+			name:               "SASL/SCRAMSHAXXX SCRAM client step error client",
+			mockHandshakeErr:   ErrNoError,
+			mockSASLAuthErr:    ErrNoError,
+			scramClient:        &MockSCRAMClient{},
+			scramChallengeResp: "gong",
+			expectClientErr:    true,
+		},
+		{
+			name:               "SASL/SCRAMSHAXXX server authentication error",
+			mockHandshakeErr:   ErrNoError,
+			mockSASLAuthErr:    ErrSASLAuthenticationFailed,
+			scramClient:        &MockSCRAMClient{},
+			scramChallengeResp: "pong",
+		},
+		{
+			name:               "SASL/SCRAMSHAXXX unsupported SCRAM mechanism",
+			mockHandshakeErr:   ErrUnsupportedSASLMechanism,
+			mockSASLAuthErr:    ErrNoError,
+			scramClient:        &MockSCRAMClient{},
+			scramChallengeResp: "pong",
+		},
+	}
+
+	for i, test := range testTable {
+
+		// mockBroker mocks underlying network logic and broker responses
+		mockBroker := NewMockBroker(t, 0)
+		broker := NewBroker(mockBroker.Addr())
+		// broker executes SASL requests against mockBroker
+		broker.requestRate = metrics.NilMeter{}
+		broker.outgoingByteRate = metrics.NilMeter{}
+		broker.incomingByteRate = metrics.NilMeter{}
+		broker.requestSize = metrics.NilHistogram{}
+		broker.responseSize = metrics.NilHistogram{}
+		broker.responseRate = metrics.NilMeter{}
+		broker.requestLatency = metrics.NilHistogram{}
+
+		mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).SetAuthBytes([]byte(test.scramChallengeResp))
+		mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).SetEnabledMechanisms([]string{SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512})
+
+		if test.mockSASLAuthErr != ErrNoError {
+			mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockSASLAuthErr)
+		}
+		if test.mockHandshakeErr != ErrNoError {
+			mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr)
+		}
+
+		mockBroker.SetHandlerByMap(map[string]MockResponse{
+			"SaslAuthenticateRequest": mockSASLAuthResponse,
+			"SaslHandshakeRequest":    mockSASLHandshakeResponse,
+		})
+
+		conf := NewConfig()
+		conf.Net.SASL.Mechanism = SASLTypeSCRAMSHA512
+		conf.Net.SASL.SCRAMClient = test.scramClient
+
+		broker.conf = conf
+		dialer := net.Dialer{
+			Timeout:   conf.Net.DialTimeout,
+			KeepAlive: conf.Net.KeepAlive,
+			LocalAddr: conf.Net.LocalAddr,
+		}
+
+		conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String())
+
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		broker.conn = conn
+
+		err = broker.authenticateViaSASL()
+
+		if test.mockSASLAuthErr != ErrNoError {
+			if test.mockSASLAuthErr != err {
+				t.Errorf("[%d]:[%s] Expected %s SASL authentication error, got %s\n", i, test.name, test.mockHandshakeErr, err)
+			}
+		} else if test.mockHandshakeErr != ErrNoError {
+			if test.mockHandshakeErr != err {
+				t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.mockHandshakeErr, err)
+			}
+		} else if test.expectClientErr && err == nil {
+			t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name)
+		} else if !test.expectClientErr && err != nil {
+			t.Errorf("[%d]:[%s] Unexpected error, got %s\n", i, test.name, err)
+		}
+
+		mockBroker.Close()
+	}
+}
+
 func TestBuildClientInitialResponse(t *testing.T) {
 
 	testTable := []struct {

+ 27 - 9
config.go

@@ -61,9 +61,14 @@ type Config struct {
 			// (defaults to true). You should only set this to false if you're using
 			// a non-Kafka SASL proxy.
 			Handshake bool
-			//username and password for SASL/PLAIN authentication
+			//username and password for SASL/PLAIN  or SASL/SCRAM authentication
 			User     string
 			Password string
+			// authz id used for SASL/SCRAM authentication
+			SCRAMAuthzID string
+			// SCRAMClient is a user provided implementation of a SCRAM
+			// client used to perform the SCRAM exchange with the server.
+			SCRAMClient SCRAMClient
 			// TokenProvider is a user-defined callback for generating
 			// access tokens for SASL/OAUTHBEARER auth. See the
 			// AccessTokenProvider interface docs for proper implementation
@@ -475,22 +480,35 @@ func (c *Config) Validate() error {
 	case c.Net.KeepAlive < 0:
 		return ConfigurationError("Net.KeepAlive must be >= 0")
 	case c.Net.SASL.Enable:
-		// For backwards compatibility, empty mechanism value defaults to PLAIN
-		isSASLPlain := len(c.Net.SASL.Mechanism) == 0 || c.Net.SASL.Mechanism == SASLTypePlaintext
-		if isSASLPlain {
+		if c.Net.SASL.Mechanism == "" {
+			c.Net.SASL.Mechanism = SASLTypePlaintext
+		}
+
+		switch c.Net.SASL.Mechanism {
+		case SASLTypePlaintext:
 			if c.Net.SASL.User == "" {
 				return ConfigurationError("Net.SASL.User must not be empty when SASL is enabled")
 			}
 			if c.Net.SASL.Password == "" {
 				return ConfigurationError("Net.SASL.Password must not be empty when SASL is enabled")
 			}
-		} else if c.Net.SASL.Mechanism == SASLTypeOAuth {
+		case SASLTypeOAuth:
 			if c.Net.SASL.TokenProvider == nil {
-				return ConfigurationError("An AccessTokenProvider instance must be provided to Net.SASL.User.TokenProvider")
+				return ConfigurationError("An AccessTokenProvider instance must be provided to Net.SASL.TokenProvider")
+			}
+		case SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512:
+			if c.Net.SASL.User == "" {
+				return ConfigurationError("Net.SASL.User must not be empty when SASL is enabled")
+			}
+			if c.Net.SASL.Password == "" {
+				return ConfigurationError("Net.SASL.Password must not be empty when SASL is enabled")
+			}
+			if c.Net.SASL.SCRAMClient == nil {
+				return ConfigurationError("A SCRAMClient instance must be provided to Net.SASL.SCRAMClient")
 			}
-		} else {
-			msg := fmt.Sprintf("The SASL mechanism configuration is invalid. Possible values are `%s` and `%s`",
-				SASLTypeOAuth, SASLTypePlaintext)
+		default:
+			msg := fmt.Sprintf("The SASL mechanism configuration is invalid. Possible values are `%s`, `%s`, `%s` and `%s`",
+				SASLTypeOAuth, SASLTypePlaintext, SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512)
 			return ConfigurationError(msg)
 		}
 	}

+ 20 - 2
config_test.go

@@ -91,14 +91,32 @@ func TestNetConfigValidates(t *testing.T) {
 				cfg.Net.SASL.Mechanism = "AnIncorrectSASLMechanism"
 				cfg.Net.SASL.TokenProvider = &DummyTokenProvider{}
 			},
-			"The SASL mechanism configuration is invalid. Possible values are `OAUTHBEARER` and `PLAIN`"},
+			"The SASL mechanism configuration is invalid. Possible values are `OAUTHBEARER`, `PLAIN`, `SCRAM-SHA-256` and `SCRAM-SHA-512`"},
 		{"SASL.Mechanism.OAUTHBEARER - Missing token provider",
 			func(cfg *Config) {
 				cfg.Net.SASL.Enable = true
 				cfg.Net.SASL.Mechanism = SASLTypeOAuth
 				cfg.Net.SASL.TokenProvider = nil
 			},
-			"An AccessTokenProvider instance must be provided to Net.SASL.User.TokenProvider"},
+			"An AccessTokenProvider instance must be provided to Net.SASL.TokenProvider"},
+		{"SASL.Mechanism SCRAM-SHA-256 - Missing SCRAM client",
+			func(cfg *Config) {
+				cfg.Net.SASL.Enable = true
+				cfg.Net.SASL.Mechanism = SASLTypeSCRAMSHA256
+				cfg.Net.SASL.SCRAMClient = nil
+				cfg.Net.SASL.User = "user"
+				cfg.Net.SASL.Password = "stong_password"
+			},
+			"A SCRAMClient instance must be provided to Net.SASL.SCRAMClient"},
+		{"SASL.Mechanism SCRAM-SHA-512 - Missing SCRAM client",
+			func(cfg *Config) {
+				cfg.Net.SASL.Enable = true
+				cfg.Net.SASL.Mechanism = SASLTypeSCRAMSHA512
+				cfg.Net.SASL.SCRAMClient = nil
+				cfg.Net.SASL.User = "user"
+				cfg.Net.SASL.Password = "stong_password"
+			},
+			"A SCRAMClient instance must be provided to Net.SASL.SCRAMClient"},
 	}
 
 	for i, test := range tests {

+ 3 - 0
response_header.go

@@ -2,6 +2,9 @@ package sarama
 
 import "fmt"
 
+const responseLengthSize = 4
+const correlationIDSize = 4
+
 type responseHeader struct {
 	length        int32
 	correlationID int32