Ver código fonte

use net.SplitHostPort and don't separate conf var

Diego Alvarez 5 anos atrás
pai
commit
1628dc1dbf
2 arquivos alterados com 12 adições e 14 exclusões
  1. 6 12
      broker.go
  2. 6 2
      client_tls_test.go

+ 6 - 12
broker.go

@@ -1423,23 +1423,17 @@ func (b *Broker) registerCounter(name string) metrics.Counter {
 	return metrics.GetOrRegisterCounter(nameForBroker, b.conf.MetricRegistry)
 }
 
-func validServerNameTLS(addr string, conf *tls.Config) *tls.Config {
-	cfg := conf
+func validServerNameTLS(addr string, cfg *tls.Config) *tls.Config {
 	if cfg == nil {
 		cfg = &tls.Config{}
 	}
-	// If no ServerName is set, infer the ServerName
-	// from the hostname we're connecting to.
-	// Gets the hostname as tls.DialWithDialer does it.
 	if cfg.ServerName == "" {
-		colonPos := strings.LastIndex(addr, ":")
-		if colonPos == -1 {
-			colonPos = len(addr)
-		}
-		hostname := addr[:colonPos]
-		// Make a copy to avoid polluting argument or default.
 		c := cfg.Clone()
-		c.ServerName = hostname
+		sn, _, err := net.SplitHostPort(addr)
+		if err != nil {
+			Logger.Println(fmt.Errorf("failed to get ServerName from addr %w", err))
+		}
+		c.ServerName = sn
 		cfg = c
 	}
 	return cfg

+ 6 - 2
client_tls_test.go

@@ -212,11 +212,11 @@ func doListenerTLSTest(t *testing.T, expectSuccess bool, serverConfig, clientCon
 }
 
 func TestSetServerName(t *testing.T) {
-	if validServerNameTLS("kafka-server.domain.com", nil).ServerName != "kafka-server.domain.com" {
+	if validServerNameTLS("kafka-server.domain.com:9093", nil).ServerName != "kafka-server.domain.com" {
 		t.Fatal("Expected kafka-server.domain.com as tls.ServerName when tls config is nil")
 	}
 
-	if validServerNameTLS("kafka-server.domain.com", &tls.Config{}).ServerName != "kafka-server.domain.com" {
+	if validServerNameTLS("kafka-server.domain.com:9093", &tls.Config{}).ServerName != "kafka-server.domain.com" {
 		t.Fatal("Expected kafka-server.domain.com as tls.ServerName when tls config ServerName is not provided")
 	}
 
@@ -224,4 +224,8 @@ func TestSetServerName(t *testing.T) {
 	if validServerNameTLS("", c).ServerName != "kafka-server-other.domain.com" {
 		t.Fatal("Expected kafka-server-other.domain.com as tls.ServerName when tls config ServerName is provided")
 	}
+
+	if validServerNameTLS("host-no-port", nil).ServerName != "" {
+		t.Fatal("Expected empty ServerName as the broker addr is missing the port")
+	}
 }