|
@@ -11,6 +11,7 @@ import (
|
|
|
"sync/atomic"
|
|
"sync/atomic"
|
|
|
"time"
|
|
"time"
|
|
|
|
|
|
|
|
|
|
+ "github.com/pkg/errors"
|
|
|
"github.com/rcrowley/go-metrics"
|
|
"github.com/rcrowley/go-metrics"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
@@ -906,15 +907,21 @@ func (b *Broker) sendAndReceiveSASLOAuth(tokenProvider AccessTokenProvider) erro
|
|
|
|
|
|
|
|
requestTime := time.Now()
|
|
requestTime := time.Now()
|
|
|
|
|
|
|
|
- bytesWritten, err := b.sendSASLOAuthBearerClientResponse(token)
|
|
|
|
|
|
|
+ correlationID := b.correlationID
|
|
|
|
|
+
|
|
|
|
|
+ bytesWritten, err := b.sendSASLOAuthBearerClientResponse(token, correlationID)
|
|
|
|
|
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return err
|
|
return err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ Logger.Printf("Correlation ID %d ", b.correlationID)
|
|
|
|
|
+
|
|
|
b.updateOutgoingCommunicationMetrics(bytesWritten)
|
|
b.updateOutgoingCommunicationMetrics(bytesWritten)
|
|
|
|
|
|
|
|
- bytesRead, err := b.receiveSASLOAuthBearerServerResponse()
|
|
|
|
|
|
|
+ b.correlationID++
|
|
|
|
|
+
|
|
|
|
|
+ bytesRead, err := b.receiveSASLOAuthBearerServerResponse(correlationID)
|
|
|
|
|
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return err
|
|
return err
|
|
@@ -926,7 +933,7 @@ func (b *Broker) sendAndReceiveSASLOAuth(tokenProvider AccessTokenProvider) erro
|
|
|
return nil
|
|
return nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func (b *Broker) sendSASLOAuthBearerClientResponse(bearerToken string) (int, error) {
|
|
|
|
|
|
|
+func (b *Broker) sendSASLOAuthBearerClientResponse(bearerToken string, correlationID int32) (int, error) {
|
|
|
|
|
|
|
|
// Initial client response as described by RFC-7628
|
|
// Initial client response as described by RFC-7628
|
|
|
// https://tools.ietf.org/html/rfc7628
|
|
// https://tools.ietf.org/html/rfc7628
|
|
@@ -934,7 +941,7 @@ func (b *Broker) sendSASLOAuthBearerClientResponse(bearerToken string) (int, err
|
|
|
|
|
|
|
|
rb := &SaslAuthenticateRequest{oauthRequest}
|
|
rb := &SaslAuthenticateRequest{oauthRequest}
|
|
|
|
|
|
|
|
- req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
|
|
|
|
|
|
|
+ req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
|
|
|
|
|
|
|
|
buf, err := encode(req, b.conf.MetricRegistry)
|
|
buf, err := encode(req, b.conf.MetricRegistry)
|
|
|
|
|
|
|
@@ -952,39 +959,44 @@ func (b *Broker) sendSASLOAuthBearerClientResponse(bearerToken string) (int, err
|
|
|
return 0, err
|
|
return 0, err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- b.correlationID++
|
|
|
|
|
-
|
|
|
|
|
return bytesWritten, nil
|
|
return bytesWritten, nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func (b *Broker) receiveSASLOAuthBearerServerResponse() (int, error) {
|
|
|
|
|
|
|
+func (b *Broker) receiveSASLOAuthBearerServerResponse(correlationID int32) (int, error) {
|
|
|
|
|
|
|
|
- var totalBytesRead int
|
|
|
|
|
|
|
+ buf := make([]byte, 8)
|
|
|
|
|
|
|
|
- header := make([]byte, 8)
|
|
|
|
|
|
|
+ bytesRead, err := io.ReadFull(b.conn, buf)
|
|
|
|
|
+
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return bytesRead, err
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ header := responseHeader{}
|
|
|
|
|
|
|
|
- bytesRead, err := io.ReadFull(b.conn, header)
|
|
|
|
|
|
|
+ err = decode(buf, &header)
|
|
|
|
|
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
- return 0, err
|
|
|
|
|
|
|
+ return bytesRead, err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- totalBytesRead += bytesRead
|
|
|
|
|
|
|
+ if header.correlationID != correlationID {
|
|
|
|
|
+ return bytesRead, errors.Errorf("correlation ID didn't match, wanted %d, got %d", b.correlationID, header.correlationID)
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- length := binary.BigEndian.Uint32(header[:4])
|
|
|
|
|
- payload := make([]byte, length-4)
|
|
|
|
|
|
|
+ buf = make([]byte, header.length-4)
|
|
|
|
|
|
|
|
- bytesRead, err = io.ReadFull(b.conn, payload)
|
|
|
|
|
|
|
+ c, err := io.ReadFull(b.conn, buf)
|
|
|
|
|
+
|
|
|
|
|
+ bytesRead += c
|
|
|
|
|
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return bytesRead, err
|
|
return bytesRead, err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- totalBytesRead += bytesRead
|
|
|
|
|
-
|
|
|
|
|
res := &SaslAuthenticateResponse{}
|
|
res := &SaslAuthenticateResponse{}
|
|
|
|
|
|
|
|
- if err := versionedDecode(payload, res, 0); err != nil {
|
|
|
|
|
|
|
+ if err := versionedDecode(buf, res, 0); err != nil {
|
|
|
return bytesRead, err
|
|
return bytesRead, err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -992,7 +1004,7 @@ func (b *Broker) receiveSASLOAuthBearerServerResponse() (int, error) {
|
|
|
return bytesRead, res.Err
|
|
return bytesRead, res.Err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- return totalBytesRead, nil
|
|
|
|
|
|
|
+ return bytesRead, nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (b *Broker) updateIncomingCommunicationMetrics(bytes int, requestLatency time.Duration) {
|
|
func (b *Broker) updateIncomingCommunicationMetrics(bytes int, requestLatency time.Duration) {
|