Parcourir la source

fix: Allow TLS to work over socks proxy. (#1666)

Kevin Cross il y a 4 ans
Parent
commit
c6eb1d42e7
4 fichiers modifiés avec 327 ajouts et 300 suppressions
  1. 8 13
      broker.go
  2. 286 274
      broker_test.go
  3. 20 13
      client_tls_test.go
  4. 13 0
      config.go

+ 8 - 13
broker.go

@@ -154,25 +154,20 @@ func (b *Broker) Open(conf *Config) error {
 	go withRecover(func() {
 		defer b.lock.Unlock()
 
-		dialer := net.Dialer{
-			Timeout:   conf.Net.DialTimeout,
-			KeepAlive: conf.Net.KeepAlive,
-			LocalAddr: conf.Net.LocalAddr,
-		}
-
-		if conf.Net.TLS.Enable {
-			b.conn, b.connErr = tls.DialWithDialer(&dialer, "tcp", b.addr, conf.Net.TLS.Config)
-		} else if conf.Net.Proxy.Enable {
-			b.conn, b.connErr = conf.Net.Proxy.Dialer.Dial("tcp", b.addr)
-		} else {
-			b.conn, b.connErr = dialer.Dial("tcp", b.addr)
-		}
+		dialer := conf.getDialer()
+		b.conn, b.connErr = dialer.Dial("tcp", b.addr)
 		if b.connErr != nil {
 			Logger.Printf("Failed to connect to broker %s: %s\n", b.addr, b.connErr)
 			b.conn = nil
 			atomic.StoreInt32(&b.opened, 0)
 			return
 		}
+
+		if conf.Net.TLS.Enable {
+			Logger.Printf("Using tls")
+			b.conn = tls.Client(b.conn, conf.Net.TLS.Config)
+		}
+
 		b.conn = newBufConn(b.conn)
 
 		b.conf = conf

+ 286 - 274
broker_test.go

@@ -80,38 +80,40 @@ func TestBrokerAccessors(t *testing.T) {
 
 func TestSimpleBrokerCommunication(t *testing.T) {
 	for _, tt := range brokerTestTable {
-		Logger.Printf("Testing broker communication for %s", tt.name)
-		mb := NewMockBroker(t, 0)
-		mb.Returns(&mockEncoder{tt.response})
-		pendingNotify := make(chan brokerMetrics)
-		// Register a callback to be notified about successful requests
-		mb.SetNotifier(func(bytesRead, bytesWritten int) {
-			pendingNotify <- brokerMetrics{bytesRead, bytesWritten}
+		t.Run(tt.name, func(t *testing.T) {
+			Logger.Printf("Testing broker communication for %s", tt.name)
+			mb := NewMockBroker(t, 0)
+			mb.Returns(&mockEncoder{tt.response})
+			pendingNotify := make(chan brokerMetrics)
+			// Register a callback to be notified about successful requests
+			mb.SetNotifier(func(bytesRead, bytesWritten int) {
+				pendingNotify <- brokerMetrics{bytesRead, bytesWritten}
+			})
+			broker := NewBroker(mb.Addr())
+			// Set the broker id in order to validate local broker metrics
+			broker.id = 0
+			conf := NewConfig()
+			conf.Version = tt.version
+			err := broker.Open(conf)
+			if err != nil {
+				t.Fatal(err)
+			}
+			tt.runner(t, broker)
+			// Wait up to 500 ms for the remote broker to process the request and
+			// notify us about the metrics
+			timeout := 500 * time.Millisecond
+			select {
+			case mockBrokerMetrics := <-pendingNotify:
+				validateBrokerMetrics(t, broker, mockBrokerMetrics)
+			case <-time.After(timeout):
+				t.Errorf("No request received for: %s after waiting for %v", tt.name, timeout)
+			}
+			mb.Close()
+			err = broker.Close()
+			if err != nil {
+				t.Error(err)
+			}
 		})
-		broker := NewBroker(mb.Addr())
-		// Set the broker id in order to validate local broker metrics
-		broker.id = 0
-		conf := NewConfig()
-		conf.Version = tt.version
-		err := broker.Open(conf)
-		if err != nil {
-			t.Fatal(err)
-		}
-		tt.runner(t, broker)
-		// Wait up to 500 ms for the remote broker to process the request and
-		// notify us about the metrics
-		timeout := 500 * time.Millisecond
-		select {
-		case mockBrokerMetrics := <-pendingNotify:
-			validateBrokerMetrics(t, broker, mockBrokerMetrics)
-		case <-time.After(timeout):
-			t.Errorf("No request received for: %s after waiting for %v", tt.name, timeout)
-		}
-		mb.Close()
-		err = broker.Close()
-		if err != nil {
-			t.Error(err)
-		}
 	}
 }
 
@@ -204,58 +206,60 @@ func TestSASLOAuthBearer(t *testing.T) {
 	}
 
 	for i, test := range testTable {
-		// mockBroker mocks underlying network logic and broker responses
-		mockBroker := NewMockBroker(t, 0)
+		t.Run(test.name, func(t *testing.T) {
+			// mockBroker mocks underlying network logic and broker responses
+			mockBroker := NewMockBroker(t, 0)
 
-		mockBroker.SetHandlerByMap(map[string]MockResponse{
-			"SaslAuthenticateRequest": test.mockSASLAuthResponse,
-			"SaslHandshakeRequest":    test.mockSASLHandshakeResponse,
-		})
+			mockBroker.SetHandlerByMap(map[string]MockResponse{
+				"SaslAuthenticateRequest": test.mockSASLAuthResponse,
+				"SaslHandshakeRequest":    test.mockSASLHandshakeResponse,
+			})
 
-		// broker executes SASL requests against mockBroker
-		broker := NewBroker(mockBroker.Addr())
-		broker.requestRate = metrics.NilMeter{}
-		broker.outgoingByteRate = metrics.NilMeter{}
-		broker.incomingByteRate = metrics.NilMeter{}
-		broker.requestSize = metrics.NilHistogram{}
-		broker.responseSize = metrics.NilHistogram{}
-		broker.responseRate = metrics.NilMeter{}
-		broker.requestLatency = metrics.NilHistogram{}
-		broker.requestsInFlight = metrics.NilCounter{}
+			// broker executes SASL requests against mockBroker
+			broker := NewBroker(mockBroker.Addr())
+			broker.requestRate = metrics.NilMeter{}
+			broker.outgoingByteRate = metrics.NilMeter{}
+			broker.incomingByteRate = metrics.NilMeter{}
+			broker.requestSize = metrics.NilHistogram{}
+			broker.responseSize = metrics.NilHistogram{}
+			broker.responseRate = metrics.NilMeter{}
+			broker.requestLatency = metrics.NilHistogram{}
+			broker.requestsInFlight = metrics.NilCounter{}
 
-		conf := NewConfig()
-		conf.Net.SASL.Mechanism = SASLTypeOAuth
-		conf.Net.SASL.TokenProvider = test.tokProvider
+			conf := NewConfig()
+			conf.Net.SASL.Mechanism = SASLTypeOAuth
+			conf.Net.SASL.TokenProvider = test.tokProvider
 
-		broker.conf = conf
+			broker.conf = conf
 
-		dialer := net.Dialer{
-			Timeout:   conf.Net.DialTimeout,
-			KeepAlive: conf.Net.KeepAlive,
-			LocalAddr: conf.Net.LocalAddr,
-		}
+			dialer := net.Dialer{
+				Timeout:   conf.Net.DialTimeout,
+				KeepAlive: conf.Net.KeepAlive,
+				LocalAddr: conf.Net.LocalAddr,
+			}
 
-		conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String())
+			conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String())
 
-		if err != nil {
-			t.Fatal(err)
-		}
+			if err != nil {
+				t.Fatal(err)
+			}
 
-		broker.conn = conn
+			broker.conn = conn
 
-		err = broker.authenticateViaSASL()
+			err = broker.authenticateViaSASL()
 
-		if test.expectedBrokerError != ErrNoError {
-			if test.expectedBrokerError != err {
-				t.Errorf("[%d]:[%s] Expected %s auth error, got %s\n", i, test.name, test.expectedBrokerError, err)
+			if test.expectedBrokerError != ErrNoError {
+				if test.expectedBrokerError != err {
+					t.Errorf("[%d]:[%s] Expected %s auth error, got %s\n", i, test.name, test.expectedBrokerError, err)
+				}
+			} else if test.expectClientErr && err == nil {
+				t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name)
+			} else if !test.expectClientErr && err != nil {
+				t.Errorf("[%d]:[%s] Unexpected error, got %s\n", i, test.name, err)
 			}
-		} else if test.expectClientErr && err == nil {
-			t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name)
-		} else if !test.expectClientErr && err != nil {
-			t.Errorf("[%d]:[%s] Unexpected error, got %s\n", i, test.name, err)
-		}
 
-		mockBroker.Close()
+			mockBroker.Close()
+		})
 	}
 }
 
@@ -264,7 +268,7 @@ type MockSCRAMClient struct {
 	done bool
 }
 
-func (m *MockSCRAMClient) Begin(userName, password, authzID string) (err error) {
+func (m *MockSCRAMClient) Begin(_, _, _ string) (err error) {
 	return nil
 }
 
@@ -325,70 +329,72 @@ func TestSASLSCRAMSHAXXX(t *testing.T) {
 	}
 
 	for i, test := range testTable {
-		// mockBroker mocks underlying network logic and broker responses
-		mockBroker := NewMockBroker(t, 0)
-		broker := NewBroker(mockBroker.Addr())
-		// broker executes SASL requests against mockBroker
-		broker.requestRate = metrics.NilMeter{}
-		broker.outgoingByteRate = metrics.NilMeter{}
-		broker.incomingByteRate = metrics.NilMeter{}
-		broker.requestSize = metrics.NilHistogram{}
-		broker.responseSize = metrics.NilHistogram{}
-		broker.responseRate = metrics.NilMeter{}
-		broker.requestLatency = metrics.NilHistogram{}
-		broker.requestsInFlight = metrics.NilCounter{}
+		t.Run(test.name, func(t *testing.T) {
+			// mockBroker mocks underlying network logic and broker responses
+			mockBroker := NewMockBroker(t, 0)
+			broker := NewBroker(mockBroker.Addr())
+			// broker executes SASL requests against mockBroker
+			broker.requestRate = metrics.NilMeter{}
+			broker.outgoingByteRate = metrics.NilMeter{}
+			broker.incomingByteRate = metrics.NilMeter{}
+			broker.requestSize = metrics.NilHistogram{}
+			broker.responseSize = metrics.NilHistogram{}
+			broker.responseRate = metrics.NilMeter{}
+			broker.requestLatency = metrics.NilHistogram{}
+			broker.requestsInFlight = metrics.NilCounter{}
 
-		mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).SetAuthBytes([]byte(test.scramChallengeResp))
-		mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).SetEnabledMechanisms([]string{SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512})
+			mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).SetAuthBytes([]byte(test.scramChallengeResp))
+			mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).SetEnabledMechanisms([]string{SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512})
 
-		if test.mockSASLAuthErr != ErrNoError {
-			mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockSASLAuthErr)
-		}
-		if test.mockHandshakeErr != ErrNoError {
-			mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr)
-		}
+			if test.mockSASLAuthErr != ErrNoError {
+				mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockSASLAuthErr)
+			}
+			if test.mockHandshakeErr != ErrNoError {
+				mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr)
+			}
 
-		mockBroker.SetHandlerByMap(map[string]MockResponse{
-			"SaslAuthenticateRequest": mockSASLAuthResponse,
-			"SaslHandshakeRequest":    mockSASLHandshakeResponse,
-		})
+			mockBroker.SetHandlerByMap(map[string]MockResponse{
+				"SaslAuthenticateRequest": mockSASLAuthResponse,
+				"SaslHandshakeRequest":    mockSASLHandshakeResponse,
+			})
 
-		conf := NewConfig()
-		conf.Net.SASL.Mechanism = SASLTypeSCRAMSHA512
-		conf.Net.SASL.SCRAMClientGeneratorFunc = func() SCRAMClient { return test.scramClient }
+			conf := NewConfig()
+			conf.Net.SASL.Mechanism = SASLTypeSCRAMSHA512
+			conf.Net.SASL.SCRAMClientGeneratorFunc = func() SCRAMClient { return test.scramClient }
 
-		broker.conf = conf
-		dialer := net.Dialer{
-			Timeout:   conf.Net.DialTimeout,
-			KeepAlive: conf.Net.KeepAlive,
-			LocalAddr: conf.Net.LocalAddr,
-		}
+			broker.conf = conf
+			dialer := net.Dialer{
+				Timeout:   conf.Net.DialTimeout,
+				KeepAlive: conf.Net.KeepAlive,
+				LocalAddr: conf.Net.LocalAddr,
+			}
 
-		conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String())
+			conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String())
 
-		if err != nil {
-			t.Fatal(err)
-		}
+			if err != nil {
+				t.Fatal(err)
+			}
 
-		broker.conn = conn
+			broker.conn = conn
 
-		err = broker.authenticateViaSASL()
+			err = broker.authenticateViaSASL()
 
-		if test.mockSASLAuthErr != ErrNoError {
-			if test.mockSASLAuthErr != err {
-				t.Errorf("[%d]:[%s] Expected %s SASL authentication error, got %s\n", i, test.name, test.mockHandshakeErr, err)
-			}
-		} else if test.mockHandshakeErr != ErrNoError {
-			if test.mockHandshakeErr != err {
-				t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.mockHandshakeErr, err)
+			if test.mockSASLAuthErr != ErrNoError {
+				if test.mockSASLAuthErr != err {
+					t.Errorf("[%d]:[%s] Expected %s SASL authentication error, got %s\n", i, test.name, test.mockHandshakeErr, err)
+				}
+			} else if test.mockHandshakeErr != ErrNoError {
+				if test.mockHandshakeErr != err {
+					t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.mockHandshakeErr, err)
+				}
+			} else if test.expectClientErr && err == nil {
+				t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name)
+			} else if !test.expectClientErr && err != nil {
+				t.Errorf("[%d]:[%s] Unexpected error, got %s\n", i, test.name, err)
 			}
-		} else if test.expectClientErr && err == nil {
-			t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name)
-		} else if !test.expectClientErr && err != nil {
-			t.Errorf("[%d]:[%s] Unexpected error, got %s\n", i, test.name, err)
-		}
 
-		mockBroker.Close()
+			mockBroker.Close()
+		})
 	}
 }
 
@@ -424,96 +430,98 @@ func TestSASLPlainAuth(t *testing.T) {
 	}
 
 	for i, test := range testTable {
-		// mockBroker mocks underlying network logic and broker responses
-		mockBroker := NewMockBroker(t, 0)
+		t.Run(test.name, func(t *testing.T) {
+			// mockBroker mocks underlying network logic and broker responses
+			mockBroker := NewMockBroker(t, 0)
 
-		mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).
-			SetAuthBytes([]byte(`response_payload`))
+			mockSASLAuthResponse := NewMockSaslAuthenticateResponse(t).
+				SetAuthBytes([]byte(`response_payload`))
 
-		if test.mockAuthErr != ErrNoError {
-			mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockAuthErr)
-		}
+			if test.mockAuthErr != ErrNoError {
+				mockSASLAuthResponse = mockSASLAuthResponse.SetError(test.mockAuthErr)
+			}
 
-		mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).
-			SetEnabledMechanisms([]string{SASLTypePlaintext})
+			mockSASLHandshakeResponse := NewMockSaslHandshakeResponse(t).
+				SetEnabledMechanisms([]string{SASLTypePlaintext})
 
-		if test.mockHandshakeErr != ErrNoError {
-			mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr)
-		}
+			if test.mockHandshakeErr != ErrNoError {
+				mockSASLHandshakeResponse = mockSASLHandshakeResponse.SetError(test.mockHandshakeErr)
+			}
 
-		mockBroker.SetHandlerByMap(map[string]MockResponse{
-			"SaslAuthenticateRequest": mockSASLAuthResponse,
-			"SaslHandshakeRequest":    mockSASLHandshakeResponse,
-		})
+			mockBroker.SetHandlerByMap(map[string]MockResponse{
+				"SaslAuthenticateRequest": mockSASLAuthResponse,
+				"SaslHandshakeRequest":    mockSASLHandshakeResponse,
+			})
 
-		// broker executes SASL requests against mockBroker
-		broker := NewBroker(mockBroker.Addr())
-		broker.requestRate = metrics.NilMeter{}
-		broker.outgoingByteRate = metrics.NilMeter{}
-		broker.incomingByteRate = metrics.NilMeter{}
-		broker.requestSize = metrics.NilHistogram{}
-		broker.responseSize = metrics.NilHistogram{}
-		broker.responseRate = metrics.NilMeter{}
-		broker.requestLatency = metrics.NilHistogram{}
-		broker.requestsInFlight = metrics.NilCounter{}
+			// broker executes SASL requests against mockBroker
+			broker := NewBroker(mockBroker.Addr())
+			broker.requestRate = metrics.NilMeter{}
+			broker.outgoingByteRate = metrics.NilMeter{}
+			broker.incomingByteRate = metrics.NilMeter{}
+			broker.requestSize = metrics.NilHistogram{}
+			broker.responseSize = metrics.NilHistogram{}
+			broker.responseRate = metrics.NilMeter{}
+			broker.requestLatency = metrics.NilHistogram{}
+			broker.requestsInFlight = metrics.NilCounter{}
 
-		conf := NewConfig()
-		conf.Net.SASL.Mechanism = SASLTypePlaintext
-		conf.Net.SASL.AuthIdentity = test.authidentity
-		conf.Net.SASL.User = "token"
-		conf.Net.SASL.Password = "password"
-		conf.Net.SASL.Version = SASLHandshakeV1
+			conf := NewConfig()
+			conf.Net.SASL.Mechanism = SASLTypePlaintext
+			conf.Net.SASL.AuthIdentity = test.authidentity
+			conf.Net.SASL.User = "token"
+			conf.Net.SASL.Password = "password"
+			conf.Net.SASL.Version = SASLHandshakeV1
 
-		broker.conf = conf
-		broker.conf.Version = V1_0_0_0
-		dialer := net.Dialer{
-			Timeout:   conf.Net.DialTimeout,
-			KeepAlive: conf.Net.KeepAlive,
-			LocalAddr: conf.Net.LocalAddr,
-		}
-
-		conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String())
-
-		if err != nil {
-			t.Fatal(err)
-		}
-
-		broker.conn = conn
-
-		err = broker.authenticateViaSASL()
-		if err == nil {
-			for _, rr := range mockBroker.History() {
-				switch r := rr.Request.(type) {
-				case *SaslAuthenticateRequest:
-					x := bytes.SplitN(r.SaslAuthBytes, []byte("\x00"), 3)
-					if string(x[0]) != conf.Net.SASL.AuthIdentity {
-						t.Errorf("[%d]:[%s] expected %s auth identity, got %s\n", i, test.name, conf.Net.SASL.AuthIdentity, x[0])
-					}
-					if string(x[1]) != conf.Net.SASL.User {
-						t.Errorf("[%d]:[%s] expected %s user, got %s\n", i, test.name, conf.Net.SASL.User, x[1])
-					}
-					if string(x[2]) != conf.Net.SASL.Password {
-						t.Errorf("[%d]:[%s] expected %s password, got %s\n", i, test.name, conf.Net.SASL.Password, x[2])
+			broker.conf = conf
+			broker.conf.Version = V1_0_0_0
+			dialer := net.Dialer{
+				Timeout:   conf.Net.DialTimeout,
+				KeepAlive: conf.Net.KeepAlive,
+				LocalAddr: conf.Net.LocalAddr,
+			}
+
+			conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String())
+
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			broker.conn = conn
+
+			err = broker.authenticateViaSASL()
+			if err == nil {
+				for _, rr := range mockBroker.History() {
+					switch r := rr.Request.(type) {
+					case *SaslAuthenticateRequest:
+						x := bytes.SplitN(r.SaslAuthBytes, []byte("\x00"), 3)
+						if string(x[0]) != conf.Net.SASL.AuthIdentity {
+							t.Errorf("[%d]:[%s] expected %s auth identity, got %s\n", i, test.name, conf.Net.SASL.AuthIdentity, x[0])
+						}
+						if string(x[1]) != conf.Net.SASL.User {
+							t.Errorf("[%d]:[%s] expected %s user, got %s\n", i, test.name, conf.Net.SASL.User, x[1])
+						}
+						if string(x[2]) != conf.Net.SASL.Password {
+							t.Errorf("[%d]:[%s] expected %s password, got %s\n", i, test.name, conf.Net.SASL.Password, x[2])
+						}
 					}
 				}
 			}
-		}
 
-		if test.mockAuthErr != ErrNoError {
-			if test.mockAuthErr != err {
-				t.Errorf("[%d]:[%s] Expected %s auth error, got %s\n", i, test.name, test.mockAuthErr, err)
-			}
-		} else if test.mockHandshakeErr != ErrNoError {
-			if test.mockHandshakeErr != err {
-				t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.mockHandshakeErr, err)
+			if test.mockAuthErr != ErrNoError {
+				if test.mockAuthErr != err {
+					t.Errorf("[%d]:[%s] Expected %s auth error, got %s\n", i, test.name, test.mockAuthErr, err)
+				}
+			} else if test.mockHandshakeErr != ErrNoError {
+				if test.mockHandshakeErr != err {
+					t.Errorf("[%d]:[%s] Expected %s handshake error, got %s\n", i, test.name, test.mockHandshakeErr, err)
+				}
+			} else if test.expectClientErr && err == nil {
+				t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name)
+			} else if !test.expectClientErr && err != nil {
+				t.Errorf("[%d]:[%s] Unexpected error, got %s\n", i, test.name, err)
 			}
-		} else if test.expectClientErr && err == nil {
-			t.Errorf("[%d]:[%s] Expected a client error and got none\n", i, test.name)
-		} else if !test.expectClientErr && err != nil {
-			t.Errorf("[%d]:[%s] Unexpected error, got %s\n", i, test.name, err)
-		}
 
-		mockBroker.Close()
+			mockBroker.Close()
+		})
 	}
 }
 
@@ -616,73 +624,75 @@ func TestGSSAPIKerberosAuth_Authorize(t *testing.T) {
 		},
 	}
 	for i, test := range testTable {
-		mockBroker := NewMockBroker(t, 0)
-		// broker executes SASL requests against mockBroker
+		t.Run(test.name, func(t *testing.T) {
+			mockBroker := NewMockBroker(t, 0)
+			// broker executes SASL requests against mockBroker
+
+			mockBroker.SetGSSAPIHandler(func(bytes []byte) []byte {
+				return nil
+			})
+			broker := NewBroker(mockBroker.Addr())
+			broker.requestRate = metrics.NilMeter{}
+			broker.outgoingByteRate = metrics.NilMeter{}
+			broker.incomingByteRate = metrics.NilMeter{}
+			broker.requestSize = metrics.NilHistogram{}
+			broker.responseSize = metrics.NilHistogram{}
+			broker.responseRate = metrics.NilMeter{}
+			broker.requestLatency = metrics.NilHistogram{}
+			broker.requestsInFlight = metrics.NilCounter{}
+			conf := NewConfig()
+			conf.Net.SASL.Mechanism = SASLTypeGSSAPI
+			conf.Net.SASL.GSSAPI.ServiceName = "kafka"
+			conf.Net.SASL.GSSAPI.KerberosConfigPath = "krb5.conf"
+			conf.Net.SASL.GSSAPI.Realm = "EXAMPLE.COM"
+			conf.Net.SASL.GSSAPI.Username = "kafka"
+			conf.Net.SASL.GSSAPI.Password = "kafka"
+			conf.Net.SASL.GSSAPI.KeyTabPath = "kafka.keytab"
+			conf.Net.SASL.GSSAPI.AuthType = KRB5_USER_AUTH
+			broker.conf = conf
+			broker.conf.Version = V1_0_0_0
+			dialer := net.Dialer{
+				Timeout:   conf.Net.DialTimeout,
+				KeepAlive: conf.Net.KeepAlive,
+				LocalAddr: conf.Net.LocalAddr,
+			}
+
+			conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String())
 
-		mockBroker.SetGSSAPIHandler(func(bytes []byte) []byte {
-			return nil
-		})
-		broker := NewBroker(mockBroker.Addr())
-		broker.requestRate = metrics.NilMeter{}
-		broker.outgoingByteRate = metrics.NilMeter{}
-		broker.incomingByteRate = metrics.NilMeter{}
-		broker.requestSize = metrics.NilHistogram{}
-		broker.responseSize = metrics.NilHistogram{}
-		broker.responseRate = metrics.NilMeter{}
-		broker.requestLatency = metrics.NilHistogram{}
-		broker.requestsInFlight = metrics.NilCounter{}
-		conf := NewConfig()
-		conf.Net.SASL.Mechanism = SASLTypeGSSAPI
-		conf.Net.SASL.GSSAPI.ServiceName = "kafka"
-		conf.Net.SASL.GSSAPI.KerberosConfigPath = "krb5.conf"
-		conf.Net.SASL.GSSAPI.Realm = "EXAMPLE.COM"
-		conf.Net.SASL.GSSAPI.Username = "kafka"
-		conf.Net.SASL.GSSAPI.Password = "kafka"
-		conf.Net.SASL.GSSAPI.KeyTabPath = "kafka.keytab"
-		conf.Net.SASL.GSSAPI.AuthType = KRB5_USER_AUTH
-		broker.conf = conf
-		broker.conf.Version = V1_0_0_0
-		dialer := net.Dialer{
-			Timeout:   conf.Net.DialTimeout,
-			KeepAlive: conf.Net.KeepAlive,
-			LocalAddr: conf.Net.LocalAddr,
-		}
-
-		conn, err := dialer.Dial("tcp", mockBroker.listener.Addr().String())
-
-		if err != nil {
-			t.Fatal(err)
-		}
-
-		gssapiHandler := KafkaGSSAPIHandler{
-			client:         &MockKerberosClient{},
-			badResponse:    test.badResponse,
-			badKeyChecksum: test.badKeyChecksum,
-		}
-		mockBroker.SetGSSAPIHandler(gssapiHandler.MockKafkaGSSAPI)
-		broker.conn = conn
-		if test.mockKerberosClient {
-			broker.kerberosAuthenticator.NewKerberosClientFunc = func(config *GSSAPIConfig) (KerberosClient, error) {
-				return &MockKerberosClient{
-					mockError:  test.error,
-					errorStage: test.errorStage,
-				}, nil
-			}
-		} else {
-			broker.kerberosAuthenticator.NewKerberosClientFunc = nil
-		}
-
-		err = broker.authenticateViaSASL()
-
-		if err != nil && test.error != nil {
-			if test.error.Error() != err.Error() {
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			gssapiHandler := KafkaGSSAPIHandler{
+				client:         &MockKerberosClient{},
+				badResponse:    test.badResponse,
+				badKeyChecksum: test.badKeyChecksum,
+			}
+			mockBroker.SetGSSAPIHandler(gssapiHandler.MockKafkaGSSAPI)
+			broker.conn = conn
+			if test.mockKerberosClient {
+				broker.kerberosAuthenticator.NewKerberosClientFunc = func(config *GSSAPIConfig) (KerberosClient, error) {
+					return &MockKerberosClient{
+						mockError:  test.error,
+						errorStage: test.errorStage,
+					}, nil
+				}
+			} else {
+				broker.kerberosAuthenticator.NewKerberosClientFunc = nil
+			}
+
+			err = broker.authenticateViaSASL()
+
+			if err != nil && test.error != nil {
+				if test.error.Error() != err.Error() {
+					t.Errorf("[%d] Expected error:%s, got:%s.", i, test.error, err)
+				}
+			} else if (err == nil && test.error != nil) || (err != nil && test.error == nil) {
 				t.Errorf("[%d] Expected error:%s, got:%s.", i, test.error, err)
 			}
-		} else if (err == nil && test.error != nil) || (err != nil && test.error == nil) {
-			t.Errorf("[%d] Expected error:%s, got:%s.", i, test.error, err)
-		}
 
-		mockBroker.Close()
+			mockBroker.Close()
+		})
 	}
 }
 
@@ -723,17 +733,19 @@ func TestBuildClientFirstMessage(t *testing.T) {
 	}
 
 	for i, test := range testTable {
-		actual, err := buildClientFirstMessage(test.token)
-
-		if !reflect.DeepEqual(test.expected, actual) {
-			t.Errorf("Expected %s, got %s\n", test.expected, actual)
-		}
-		if test.expectError && err == nil {
-			t.Errorf("[%d]:[%s] Expected an error but did not get one", i, test.name)
-		}
-		if !test.expectError && err != nil {
-			t.Errorf("[%d]:[%s] Expected no error but got %s\n", i, test.name, err)
-		}
+		t.Run(test.name, func(t *testing.T) {
+			actual, err := buildClientFirstMessage(test.token)
+
+			if !reflect.DeepEqual(test.expected, actual) {
+				t.Errorf("Expected %s, got %s\n", test.expected, actual)
+			}
+			if test.expectError && err == nil {
+				t.Errorf("[%d]:[%s] Expected an error but did not get one", i, test.name)
+			}
+			if !test.expectError && err != nil {
+				t.Errorf("[%d]:[%s] Expected no error but got %s\n", i, test.name, err)
+			}
+		})
 	}
 }
 

+ 20 - 13
client_tls_test.go

@@ -1,16 +1,15 @@
 package sarama
 
 import (
-	"math/big"
-	"net"
-	"testing"
-	"time"
-
 	"crypto/rand"
 	"crypto/rsa"
 	"crypto/tls"
 	"crypto/x509"
 	"crypto/x509/pkix"
+	"math/big"
+	"net"
+	"testing"
+	"time"
 )
 
 func TestTLS(t *testing.T) {
@@ -95,10 +94,12 @@ func TestTLS(t *testing.T) {
 	}
 
 	for _, tc := range []struct {
+		name           string
 		Succeed        bool
 		Server, Client *tls.Config
 	}{
-		{ // Verify client fails if wrong CA cert pool is specified
+		{
+			name:    "Verify client fails if wrong CA cert pool is specified",
 			Succeed: false,
 			Server:  serverTLSConfig,
 			Client: &tls.Config{
@@ -109,7 +110,8 @@ func TestTLS(t *testing.T) {
 				}},
 			},
 		},
-		{ // Verify client fails if wrong key is specified
+		{
+			name:    "Verify client fails if wrong key is specified",
 			Succeed: false,
 			Server:  serverTLSConfig,
 			Client: &tls.Config{
@@ -120,7 +122,8 @@ func TestTLS(t *testing.T) {
 				}},
 			},
 		},
-		{ // Verify client fails if wrong cert is specified
+		{
+			name:    "Verify client fails if wrong cert is specified",
 			Succeed: false,
 			Server:  serverTLSConfig,
 			Client: &tls.Config{
@@ -131,7 +134,8 @@ func TestTLS(t *testing.T) {
 				}},
 			},
 		},
-		{ // Verify client fails if no CAs are specified
+		{
+			name:    "Verify client fails if no CAs are specified",
 			Succeed: false,
 			Server:  serverTLSConfig,
 			Client: &tls.Config{
@@ -141,18 +145,21 @@ func TestTLS(t *testing.T) {
 				}},
 			},
 		},
-		{ // Verify client fails if no keys are specified
+		{
+			name:    "Verify client fails if no keys are specified",
 			Succeed: false,
 			Server:  serverTLSConfig,
 			Client: &tls.Config{
 				RootCAs: pool,
 			},
 		},
-		{ // Finally, verify it all works happily with client and server cert in place
+		{
+			name:    "Finally, verify it all works happily with client and server cert in place",
 			Succeed: true,
 			Server:  serverTLSConfig,
 			Client: &tls.Config{
-				RootCAs: pool,
+				RootCAs:    pool,
+				ServerName: "127.0.0.1",
 				Certificates: []tls.Certificate{{
 					Certificate: [][]byte{clientDer},
 					PrivateKey:  clientkey,
@@ -160,7 +167,7 @@ func TestTLS(t *testing.T) {
 			},
 		},
 	} {
-		doListenerTLSTest(t, tc.Succeed, tc.Server, tc.Client)
+		t.Run(tc.name, func(t *testing.T) { doListenerTLSTest(t, tc.Succeed, tc.Server, tc.Client) })
 	}
 }
 

+ 13 - 0
config.go

@@ -734,3 +734,16 @@ func (c *Config) Validate() error {
 
 	return nil
 }
+
+func (c *Config) getDialer() proxy.Dialer {
+	if c.Net.Proxy.Enable {
+		Logger.Printf("using proxy %s", c.Net.Proxy.Dialer)
+		return c.Net.Proxy.Dialer
+	} else {
+		return &net.Dialer{
+			Timeout:   c.Net.DialTimeout,
+			KeepAlive: c.Net.KeepAlive,
+			LocalAddr: c.Net.LocalAddr,
+		}
+	}
+}