Selaa lähdekoodia

Increase test coverage

Mike Kaminski 7 vuotta sitten
vanhempi
commit
dc6025a247
3 muutettua tiedostoa jossa 25 lisäystä ja 23 poistoa
  1. 12 8
      broker.go
  2. 4 1
      broker_test.go
  3. 9 14
      sasl_authenticate_response_test.go

+ 12 - 8
broker.go

@@ -159,14 +159,7 @@ func (b *Broker) Open(conf *Config) error {
 
 		if conf.Net.SASL.Enable {
 
-			switch conf.Net.SASL.Mechanism {
-			case SASLTypeOAuth:
-				b.connErr = b.sendAndReceiveSASLOAuth(conf.Net.SASL.TokenProvider, conf.Net.SASL.Extensions)
-			case SASLTypePlaintext:
-				b.connErr = b.sendAndReceiveSASLPlainAuth()
-			default:
-				b.connErr = b.sendAndReceiveSASLPlainAuth()
-			}
+			b.connErr = b.authenticateViaSASL()
 
 			if b.connErr != nil {
 				err = b.conn.Close()
@@ -786,6 +779,17 @@ func (b *Broker) responseReceiver() {
 	close(b.done)
 }
 
+func (b *Broker) authenticateViaSASL() error {
+	switch b.conf.Net.SASL.Mechanism {
+	case SASLTypeOAuth:
+		return b.sendAndReceiveSASLOAuth(b.conf.Net.SASL.TokenProvider, b.conf.Net.SASL.Extensions)
+	case SASLTypePlaintext:
+		return b.sendAndReceiveSASLPlainAuth()
+	default:
+		return b.sendAndReceiveSASLPlainAuth()
+	}
+}
+
 func (b *Broker) sendAndReceiveSASLHandshake(saslType string, version int16) error {
 	rb := &SaslHandshakeRequest{Mechanism: saslType, Version: version}
 

+ 4 - 1
broker_test.go

@@ -177,6 +177,9 @@ func TestReceiveSASLOAuthBearerClientResponse(t *testing.T) {
 		broker.requestLatency = metrics.NilHistogram{}
 
 		conf := NewConfig()
+		conf.Net.SASL.Mechanism = SASLTypeOAuth
+		conf.Net.SASL.TokenProvider = test.tokProvider
+
 		broker.conf = conf
 
 		dialer := net.Dialer{
@@ -193,7 +196,7 @@ func TestReceiveSASLOAuthBearerClientResponse(t *testing.T) {
 
 		broker.conn = conn
 
-		err = broker.sendAndReceiveSASLOAuth(test.tokProvider, make(map[string]string))
+		err = broker.authenticateViaSASL()
 
 		if test.err != err {
 			t.Errorf("[%d]:[%s] Expected %s error, got %s\n", i, test.name, test.err, err)

+ 9 - 14
sasl_authenticate_response_test.go

@@ -3,25 +3,20 @@ package sarama
 import "testing"
 
 var (
-	saslAuthenticatResponse = []byte{
-		0, 0,
+	saslAuthenticatResponseErr = []byte{
+		0, 58,
 		0, 3, 'e', 'r', 'r',
 		0, 0, 0, 3, 'm', 's', 'g',
 	}
 )
 
 func TestSaslAuthenticateResponse(t *testing.T) {
-	var response *SaslAuthenticateResponse
 
-	response = new(SaslAuthenticateResponse)
-	testVersionDecodable(t, "no error", response, saslAuthenticatResponse, 0)
-	if response.Err != ErrNoError {
-		t.Error("Decoding error failed: no error expected but found", response.Err)
-	}
-	if *response.ErrorMessage != "err" {
-		t.Error("Decoding error failed: expected 'err' but found", *response.ErrorMessage)
-	}
-	if string(response.SaslAuthBytes) != "msg" {
-		t.Error("Decoding error failed: expected 'msg' but found", string(response.SaslAuthBytes))
-	}
+	response := new(SaslAuthenticateResponse)
+	response.Err = ErrSASLAuthenticationFailed
+	msg := "err"
+	response.ErrorMessage = &msg
+	response.SaslAuthBytes = []byte(`msg`)
+
+	testResponse(t, "authenticate reponse", response, saslAuthenticatResponseErr)
 }