|
@@ -13,7 +13,7 @@ import (
|
|
|
"sync/atomic"
|
|
|
"time"
|
|
|
|
|
|
- "github.com/rcrowley/go-metrics"
|
|
|
+ metrics "github.com/rcrowley/go-metrics"
|
|
|
)
|
|
|
|
|
|
|
|
@@ -905,8 +905,10 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-
|
|
|
-
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
|
|
|
|
|
|
|
|
@@ -920,18 +922,37 @@ func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int
|
|
|
|
|
|
|
|
|
|
|
|
+
|
|
|
|
|
|
-
|
|
|
-
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
func (b *Broker) sendAndReceiveSASLPlainAuth() error {
|
|
|
+
|
|
|
+
|
|
|
+ saslHandshake := SASLHandshakeV0
|
|
|
if b.conf.Net.SASL.Handshake {
|
|
|
- handshakeErr := b.sendAndReceiveSASLHandshake(SASLTypePlaintext, SASLHandshakeV0)
|
|
|
+ if b.conf.Version.IsAtLeast(V1_0_0_0) {
|
|
|
+ saslHandshake = SASLHandshakeV1
|
|
|
+ }
|
|
|
+ handshakeErr := b.sendAndReceiveSASLHandshake(SASLTypePlaintext, saslHandshake)
|
|
|
if handshakeErr != nil {
|
|
|
Logger.Printf("Error while performing SASL handshake %s\n", b.addr)
|
|
|
return handshakeErr
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ if saslHandshake == SASLHandshakeV1 {
|
|
|
+ return b.sendAndReceiveV1SASLPlainAuth()
|
|
|
+ }
|
|
|
+ return b.sendAndReceiveV0SASLPlainAuth()
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+func (b *Broker) sendAndReceiveV0SASLPlainAuth() error {
|
|
|
+
|
|
|
length := 1 + len(b.conf.Net.SASL.User) + 1 + len(b.conf.Net.SASL.Password)
|
|
|
authBytes := make([]byte, length+4)
|
|
|
binary.BigEndian.PutUint32(authBytes, uint32(length))
|
|
@@ -965,6 +986,35 @@ func (b *Broker) sendAndReceiveSASLPlainAuth() error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
+
|
|
|
+func (b *Broker) sendAndReceiveV1SASLPlainAuth() error {
|
|
|
+ correlationID := b.correlationID
|
|
|
+
|
|
|
+ requestTime := time.Now()
|
|
|
+
|
|
|
+ bytesWritten, err := b.sendSASLPlainAuthClientResponse(correlationID)
|
|
|
+
|
|
|
+ b.updateOutgoingCommunicationMetrics(bytesWritten)
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error())
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ b.correlationID++
|
|
|
+
|
|
|
+ bytesRead, err := b.receiveSASLServerResponse(correlationID)
|
|
|
+ b.updateIncomingCommunicationMetrics(bytesRead, time.Since(requestTime))
|
|
|
+
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ Logger.Printf("Error returned from broker during SASL flow %s: %s\n", b.addr, err.Error())
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
|
|
|
|
|
|
func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error {
|
|
@@ -988,7 +1038,7 @@ func (b *Broker) sendAndReceiveSASLOAuth(provider AccessTokenProvider) error {
|
|
|
b.updateOutgoingCommunicationMetrics(bytesWritten)
|
|
|
b.correlationID++
|
|
|
|
|
|
- bytesRead, err := b.receiveSASLOAuthBearerServerResponse(correlationID)
|
|
|
+ bytesRead, err := b.receiveSASLServerResponse(correlationID)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
@@ -1123,6 +1173,23 @@ func mapToString(extensions map[string]string, keyValSep string, elemSep string)
|
|
|
return strings.Join(buf, elemSep)
|
|
|
}
|
|
|
|
|
|
+func (b *Broker) sendSASLPlainAuthClientResponse(correlationID int32) (int, error) {
|
|
|
+ authBytes := []byte("\x00" + b.conf.Net.SASL.User + "\x00" + b.conf.Net.SASL.Password)
|
|
|
+ rb := &SaslAuthenticateRequest{authBytes}
|
|
|
+ req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
|
|
|
+ buf, err := encode(req, b.conf.MetricRegistry)
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+
|
|
|
+ err = b.conn.SetWriteDeadline(time.Now().Add(b.conf.Net.WriteTimeout))
|
|
|
+ if err != nil {
|
|
|
+ Logger.Printf("Failed to set write deadline when doing SASL auth with broker %s: %s\n", b.addr, err.Error())
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+ return b.conn.Write(buf)
|
|
|
+}
|
|
|
+
|
|
|
func (b *Broker) sendSASLOAuthBearerClientResponse(token *AccessToken, correlationID int32) (int, error) {
|
|
|
initialResp, err := buildClientInitialResponse(token)
|
|
|
if err != nil {
|
|
@@ -1145,7 +1212,7 @@ func (b *Broker) sendSASLOAuthBearerClientResponse(token *AccessToken, correlati
|
|
|
return b.conn.Write(buf)
|
|
|
}
|
|
|
|
|
|
-func (b *Broker) receiveSASLOAuthBearerServerResponse(correlationID int32) (int, error) {
|
|
|
+func (b *Broker) receiveSASLServerResponse(correlationID int32) (int, error) {
|
|
|
|
|
|
buf := make([]byte, responseLengthSize+correlationIDSize)
|
|
|
|