|
@@ -56,6 +56,10 @@ const (
|
|
|
SASLTypeOAuth = "OAUTHBEARER"
|
|
SASLTypeOAuth = "OAUTHBEARER"
|
|
|
// SASLTypePlaintext represents the SASL/PLAIN mechanism
|
|
// SASLTypePlaintext represents the SASL/PLAIN mechanism
|
|
|
SASLTypePlaintext = "PLAIN"
|
|
SASLTypePlaintext = "PLAIN"
|
|
|
|
|
+ // SASLTypeSCRAMSHA256 represents the SCRAM-SHA-256 mechanism.
|
|
|
|
|
+ SASLTypeSCRAMSHA256 = "SCRAM-SHA-256"
|
|
|
|
|
+ // SASLTypeSCRAMSHA512 represents the SCRAM-SHA-512 mechanism.
|
|
|
|
|
+ SASLTypeSCRAMSHA512 = "SCRAM-SHA-512"
|
|
|
// SASLHandshakeV0 is v0 of the Kafka SASL handshake protocol. Client and
|
|
// SASLHandshakeV0 is v0 of the Kafka SASL handshake protocol. Client and
|
|
|
// server negotiate SASL auth using opaque packets.
|
|
// server negotiate SASL auth using opaque packets.
|
|
|
SASLHandshakeV0 = int16(0)
|
|
SASLHandshakeV0 = int16(0)
|
|
@@ -92,6 +96,20 @@ type AccessTokenProvider interface {
|
|
|
Token() (*AccessToken, error)
|
|
Token() (*AccessToken, error)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+// SCRAMClient is a an interface to a SCRAM
|
|
|
|
|
+// client implementation.
|
|
|
|
|
+type SCRAMClient interface {
|
|
|
|
|
+ // Begin prepares the client for the SCRAM exchange
|
|
|
|
|
+ // with the server with a user name and a password
|
|
|
|
|
+ Begin(userName, password, authzID string) error
|
|
|
|
|
+ // Step steps client through the SCRAM exchange. It is
|
|
|
|
|
+ // called repeatedly until it errors or `Done` returns true.
|
|
|
|
|
+ Step(challenge string) (response string, err error)
|
|
|
|
|
+ // Done should return true when the SCRAM conversation
|
|
|
|
|
+ // is over.
|
|
|
|
|
+ Done() bool
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
type responsePromise struct {
|
|
type responsePromise struct {
|
|
|
requestTime time.Time
|
|
requestTime time.Time
|
|
|
correlationID int32
|
|
correlationID int32
|
|
@@ -793,14 +811,19 @@ func (b *Broker) responseReceiver() {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (b *Broker) authenticateViaSASL() error {
|
|
func (b *Broker) authenticateViaSASL() error {
|
|
|
- if b.conf.Net.SASL.Mechanism == SASLTypeOAuth {
|
|
|
|
|
|
|
+ switch b.conf.Net.SASL.Mechanism {
|
|
|
|
|
+ case SASLTypeOAuth:
|
|
|
return b.sendAndReceiveSASLOAuth(b.conf.Net.SASL.TokenProvider)
|
|
return b.sendAndReceiveSASLOAuth(b.conf.Net.SASL.TokenProvider)
|
|
|
|
|
+ case SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512:
|
|
|
|
|
+ return b.sendAndReceiveSASLSCRAMv1()
|
|
|
|
|
+ default:
|
|
|
|
|
+ return b.sendAndReceiveSASLPlainAuth()
|
|
|
}
|
|
}
|
|
|
- return b.sendAndReceiveSASLPlainAuth()
|
|
|
|
|
|
|
+
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func (b *Broker) sendAndReceiveSASLHandshake(saslType string, version int16) error {
|
|
|
|
|
- rb := &SaslHandshakeRequest{Mechanism: saslType, Version: version}
|
|
|
|
|
|
|
+func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int16) error {
|
|
|
|
|
+ rb := &SaslHandshakeRequest{Mechanism: string(saslType), Version: version}
|
|
|
|
|
|
|
|
req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
|
|
req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
|
|
|
buf, err := encode(req, b.conf.MetricRegistry)
|
|
buf, err := encode(req, b.conf.MetricRegistry)
|
|
@@ -846,7 +869,7 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType string, version int16) err
|
|
|
Logger.Printf("Invalid SASL Mechanism : %s\n", res.Err.Error())
|
|
Logger.Printf("Invalid SASL Mechanism : %s\n", res.Err.Error())
|
|
|
return res.Err
|
|
return res.Err
|
|
|
}
|
|
}
|
|
|
- Logger.Print("Successful SASL handshake")
|
|
|
|
|
|
|
+ Logger.Print("Successful SASL handshake. Available mechanisms: ", res.EnabledMechanisms)
|
|
|
return nil
|
|
return nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -949,6 +972,96 @@ func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error {
|
|
|
return nil
|
|
return nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+func (b *Broker) sendAndReceiveSASLSCRAMv1() error {
|
|
|
|
|
+ if err := b.sendAndReceiveSASLHandshake(b.conf.Net.SASL.Mechanism, SASLHandshakeV1); err != nil {
|
|
|
|
|
+ return err
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ scramClient := b.conf.Net.SASL.SCRAMClient
|
|
|
|
|
+ if err := scramClient.Begin(b.conf.Net.SASL.User, b.conf.Net.SASL.Password, b.conf.Net.SASL.SCRAMAuthzID); err != nil {
|
|
|
|
|
+ return fmt.Errorf("failed to start SCRAM exchange with the server: %s", err.Error())
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ msg, err := scramClient.Step("")
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return fmt.Errorf("failed to advance the SCRAM exchange: %s", err.Error())
|
|
|
|
|
+
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ for !scramClient.Done() {
|
|
|
|
|
+ requestTime := time.Now()
|
|
|
|
|
+ correlationID := b.correlationID
|
|
|
|
|
+ bytesWritten, err := b.sendSaslAuthenticateRequest(correlationID, []byte(msg))
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error())
|
|
|
|
|
+ return err
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ b.updateOutgoingCommunicationMetrics(bytesWritten)
|
|
|
|
|
+ b.correlationID++
|
|
|
|
|
+ challenge, err := b.receiveSaslAuthenticateResponse(correlationID)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ Logger.Printf("Failed to read response while authenticating with SASL to broker %s: %s\n", b.addr, err.Error())
|
|
|
|
|
+ return err
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ b.updateIncomingCommunicationMetrics(len(challenge), time.Since(requestTime))
|
|
|
|
|
+ msg, err = scramClient.Step(string(challenge))
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ Logger.Println("SASL authentication failed", err)
|
|
|
|
|
+ return err
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ Logger.Println("SASL authentication succeeded")
|
|
|
|
|
+ return nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (b *Broker) sendSaslAuthenticateRequest(correlationID int32, msg []byte) (int, error) {
|
|
|
|
|
+ rb := &SaslAuthenticateRequest{msg}
|
|
|
|
|
+ req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
|
|
|
|
|
+ buf, err := encode(req, b.conf.MetricRegistry)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return 0, err
|
|
|
|
|
+ }
|
|
|
|
|
+ if err := b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout)); err != nil {
|
|
|
|
|
+ return 0, err
|
|
|
|
|
+ }
|
|
|
|
|
+ return b.conn.Write(buf)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, error) {
|
|
|
|
|
+ buf := make([]byte, responseLengthSize+correlationIDSize)
|
|
|
|
|
+ bytesRead, err := io.ReadFull(b.conn, buf)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+ header := responseHeader{}
|
|
|
|
|
+ err = decode(buf, &header)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+ if header.correlationID != correlationID {
|
|
|
|
|
+ return nil, fmt.Errorf("correlation ID didn't match, wanted %d, got %d", b.correlationID, header.correlationID)
|
|
|
|
|
+ }
|
|
|
|
|
+ buf = make([]byte, header.length-correlationIDSize)
|
|
|
|
|
+ c, err := io.ReadFull(b.conn, buf)
|
|
|
|
|
+ bytesRead += c
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+ res := &SaslAuthenticateResponse{}
|
|
|
|
|
+ if err := versionedDecode(buf, res, 0); err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+ if res.Err != ErrNoError {
|
|
|
|
|
+ return nil, res.Err
|
|
|
|
|
+ }
|
|
|
|
|
+ return res.SaslAuthBytes, nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
// Build SASL/OAUTHBEARER initial client response as described by RFC-7628
|
|
// Build SASL/OAUTHBEARER initial client response as described by RFC-7628
|
|
|
// https://tools.ietf.org/html/rfc7628
|
|
// https://tools.ietf.org/html/rfc7628
|
|
|
func buildClientInitialResponse(token *AccessToken) ([]byte, error) {
|
|
func buildClientInitialResponse(token *AccessToken) ([]byte, error) {
|