Ver Fonte

Implement SASL/OAUTHBEARER support

Implement unit tests
Mike Kaminski há 7 anos atrás
pai
commit
980468064d
3 ficheiros alterados com 154 adições e 19 exclusões
  1. 31 19
      broker.go
  2. 102 0
      broker_test.go
  3. 21 0
      config_test.go

+ 31 - 19
broker.go

@@ -11,6 +11,7 @@ import (
 	"sync/atomic"
 	"time"
 
+	"github.com/pkg/errors"
 	"github.com/rcrowley/go-metrics"
 )
 
@@ -906,15 +907,21 @@ func (b *Broker) sendAndReceiveSASLOAuth(tokenProvider AccessTokenProvider) erro
 
 	requestTime := time.Now()
 
-	bytesWritten, err := b.sendSASLOAuthBearerClientResponse(token)
+	correlationID := b.correlationID
+
+	bytesWritten, err := b.sendSASLOAuthBearerClientResponse(token, correlationID)
 
 	if err != nil {
 		return err
 	}
 
+	Logger.Printf("Correlation ID %d ", b.correlationID)
+
 	b.updateOutgoingCommunicationMetrics(bytesWritten)
 
-	bytesRead, err := b.receiveSASLOAuthBearerServerResponse()
+	b.correlationID++
+
+	bytesRead, err := b.receiveSASLOAuthBearerServerResponse(correlationID)
 
 	if err != nil {
 		return err
@@ -926,7 +933,7 @@ func (b *Broker) sendAndReceiveSASLOAuth(tokenProvider AccessTokenProvider) erro
 	return nil
 }
 
-func (b *Broker) sendSASLOAuthBearerClientResponse(bearerToken string) (int, error) {
+func (b *Broker) sendSASLOAuthBearerClientResponse(bearerToken string, correlationID int32) (int, error) {
 
 	// Initial client response as described by RFC-7628
 	// https://tools.ietf.org/html/rfc7628
@@ -934,7 +941,7 @@ func (b *Broker) sendSASLOAuthBearerClientResponse(bearerToken string) (int, err
 
 	rb := &SaslAuthenticateRequest{oauthRequest}
 
-	req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
+	req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
 
 	buf, err := encode(req, b.conf.MetricRegistry)
 
@@ -952,39 +959,44 @@ func (b *Broker) sendSASLOAuthBearerClientResponse(bearerToken string) (int, err
 		return 0, err
 	}
 
-	b.correlationID++
-
 	return bytesWritten, nil
 }
 
-func (b *Broker) receiveSASLOAuthBearerServerResponse() (int, error) {
+func (b *Broker) receiveSASLOAuthBearerServerResponse(correlationID int32) (int, error) {
 
-	var totalBytesRead int
+	buf := make([]byte, 8)
 
-	header := make([]byte, 8)
+	bytesRead, err := io.ReadFull(b.conn, buf)
+
+	if err != nil {
+		return bytesRead, err
+	}
+
+	header := responseHeader{}
 
-	bytesRead, err := io.ReadFull(b.conn, header)
+	err = decode(buf, &header)
 
 	if err != nil {
-		return 0, err
+		return bytesRead, err
 	}
 
-	totalBytesRead += bytesRead
+	if header.correlationID != correlationID {
+		return bytesRead, errors.Errorf("correlation ID didn't match, wanted %d, got %d", b.correlationID, header.correlationID)
+	}
 
-	length := binary.BigEndian.Uint32(header[:4])
-	payload := make([]byte, length-4)
+	buf = make([]byte, header.length-4)
 
-	bytesRead, err = io.ReadFull(b.conn, payload)
+	c, err := io.ReadFull(b.conn, buf)
+
+	bytesRead += c
 
 	if err != nil {
 		return bytesRead, err
 	}
 
-	totalBytesRead += bytesRead
-
 	res := &SaslAuthenticateResponse{}
 
-	if err := versionedDecode(payload, res, 0); err != nil {
+	if err := versionedDecode(buf, res, 0); err != nil {
 		return bytesRead, err
 	}
 
@@ -992,7 +1004,7 @@ func (b *Broker) receiveSASLOAuthBearerServerResponse() (int, error) {
 		return bytesRead, res.Err
 	}
 
-	return totalBytesRead, nil
+	return bytesRead, nil
 }
 
 func (b *Broker) updateIncomingCommunicationMetrics(bytes int, requestLatency time.Duration) {

+ 102 - 0
broker_test.go

@@ -2,6 +2,8 @@ package sarama
 
 import (
 	"fmt"
+	"io"
+	"net"
 	"testing"
 	"time"
 )
@@ -105,6 +107,106 @@ func TestSimpleBrokerCommunication(t *testing.T) {
 
 }
 
+type Conn struct {
+	times int
+}
+
+func (c *Conn) Read(b []byte) (n int, err error) {
+	return 20, nil
+}
+
+func (c *Conn) Write(b []byte) (n int, err error) {
+	return 10, nil
+}
+
+func (c *Conn) Close() error {
+	return nil
+}
+
+func (c *Conn) LocalAddr() net.Addr {
+	return nil
+}
+
+func (c *Conn) RemoteAddr() net.Addr {
+	return nil
+}
+
+func (c *Conn) SetDeadline(t time.Time) error {
+	return nil
+}
+
+func (c *Conn) SetReadDeadline(t time.Time) error {
+	return nil
+}
+
+func (c *Conn) SetWriteDeadline(t time.Time) error {
+	return nil
+}
+
+func TestReceiveSASLOAuthBearerServerResponse(t *testing.T) {
+
+	testTable := []struct {
+		name string
+		buf  []byte
+		err  error
+	}{
+		{"OK server response",
+			[]byte{
+				0, 0, 0, 14,
+				0, 0, 0, 0,
+				0, 0,
+				255, 255, // no error message
+				0, 0, 0, 2, 'o', 'k',
+			},
+			nil},
+		{"SASL authentication failure response",
+			[]byte{
+				0, 0, 0, 19,
+				0, 0, 0, 0,
+				0, 58,
+				0, 3, 'e', 'r', 'r',
+				0, 0, 0, 4, 'f', 'a', 'i', 'l',
+			},
+			ErrSASLAuthenticationFailed},
+		{"Truncated header",
+			[]byte{
+				0, 0, 0, 12,
+			},
+			io.ErrUnexpectedEOF},
+		{"Truncated response message",
+			[]byte{
+				0, 0, 0, 12,
+				0, 0, 0, 0,
+				0, 0,
+			},
+			io.ErrUnexpectedEOF},
+	}
+
+	for _, test := range testTable {
+
+		in, out := net.Pipe()
+
+		b := &Broker{conn: out}
+
+		go func() {
+			in.Write(test.buf)
+			in.Close()
+		}()
+
+		bytesRead, err := b.receiveSASLOAuthBearerServerResponse(0)
+
+		out.Close()
+
+		if len(test.buf) != bytesRead {
+			t.Errorf("[%s] Expected %d bytes read, got %d", test.name, len(test.buf), bytesRead)
+		}
+
+		if test.err != err {
+			t.Errorf("[%s] Expected error %s, got %s", test.name, test.err, err)
+		}
+	}
+}
+
 // We're not testing encoding/decoding here, so most of the requests/responses will be empty for simplicity's sake
 var brokerTestTable = []struct {
 	version  KafkaVersion

+ 21 - 0
config_test.go

@@ -33,6 +33,13 @@ func TestEmptyClientIDConfigValidates(t *testing.T) {
 	}
 }
 
+type DummyTokenProvider struct {
+}
+
+func (t *DummyTokenProvider) Token() (string, error) {
+	return "access-token-string", nil
+}
+
 func TestNetConfigValidates(t *testing.T) {
 	tests := []struct {
 		name string
@@ -78,6 +85,20 @@ func TestNetConfigValidates(t *testing.T) {
 				cfg.Net.SASL.Password = ""
 			},
 			"Net.SASL.Password must not be empty when SASL is enabled"},
+		{"SASL.Mechanism - Invalid mechanism type",
+			func(cfg *Config) {
+				cfg.Net.SASL.Enable = true
+				cfg.Net.SASL.Mechanism = "AnIncorrectSASLMechanism"
+				cfg.Net.SASL.TokenProvider = &DummyTokenProvider{}
+			},
+			"The SASL mechanism configuration is invalid. Possible values are `OAUTHBEARER` and `PLAIN`"},
+		{"SASL.Mechanism.OAUTHBEARER - Missing token provider",
+			func(cfg *Config) {
+				cfg.Net.SASL.Enable = true
+				cfg.Net.SASL.Mechanism = SASLTypeOAuth
+				cfg.Net.SASL.TokenProvider = nil
+			},
+			"An AccessTokenProvider instance must be provided to Net.SASL.User.TokenProvider"},
 	}
 
 	for i, test := range tests {