Browse Source

Set server name only for the current broker

Fixes https://github.com/Shopify/sarama/issues/1700
Diego Alvarez 4 years ago
parent
commit
d871f336e6
2 changed files with 38 additions and 19 deletions
  1. 23 19
      broker.go
  2. 15 0
      client_tls_test.go

+ 23 - 19
broker.go

@@ -162,29 +162,11 @@ func (b *Broker) Open(conf *Config) error {
 			atomic.StoreInt32(&b.opened, 0)
 			return
 		}
-
 		if conf.Net.TLS.Enable {
-			Logger.Printf("Using tls")
-			cfg := conf.Net.TLS.Config
-			if cfg == nil {
-				cfg = &tls.Config{}
-			}
-			// If no ServerName is set, infer the ServerName
-			// from the hostname we're connecting to.
-			// Gets the hostname as tls.DialWithDialer does it.
-			if cfg.ServerName == "" {
-				colonPos := strings.LastIndex(b.addr, ":")
-				if colonPos == -1 {
-					colonPos = len(b.addr)
-				}
-				hostname := b.addr[:colonPos]
-				cfg.ServerName = hostname
-			}
-			b.conn = tls.Client(b.conn, cfg)
+			b.conn = tls.Client(b.conn, validServerNameTLS(b.addr, conf.Net.TLS.Config))
 		}
 
 		b.conn = newBufConn(b.conn)
-
 		b.conf = conf
 
 		// Create or reuse the global metrics shared between brokers
@@ -1440,3 +1422,25 @@ func (b *Broker) registerCounter(name string) metrics.Counter {
 	b.registeredMetrics = append(b.registeredMetrics, nameForBroker)
 	return metrics.GetOrRegisterCounter(nameForBroker, b.conf.MetricRegistry)
 }
+
+func validServerNameTLS(addr string, conf *tls.Config) *tls.Config {
+	cfg := conf
+	if cfg == nil {
+		cfg = &tls.Config{}
+	}
+	// If no ServerName is set, infer the ServerName
+	// from the hostname we're connecting to.
+	// Gets the hostname as tls.DialWithDialer does it.
+	if cfg.ServerName == "" {
+		colonPos := strings.LastIndex(addr, ":")
+		if colonPos == -1 {
+			colonPos = len(addr)
+		}
+		hostname := addr[:colonPos]
+		// Make a copy to avoid polluting argument or default.
+		c := cfg.Clone()
+		c.ServerName = hostname
+		cfg = c
+	}
+	return cfg
+}

+ 15 - 0
client_tls_test.go

@@ -210,3 +210,18 @@ func doListenerTLSTest(t *testing.T, expectSuccess bool, serverConfig, clientCon
 		}
 	}
 }
+
+func TestSetServerName(t *testing.T) {
+	if validServerNameTLS("kafka-server.domain.com", nil).ServerName != "kafka-server.domain.com" {
+		t.Fatal("Expected kafka-server.domain.com as tls.ServerName when tls config is nil")
+	}
+
+	if validServerNameTLS("kafka-server.domain.com", &tls.Config{}).ServerName != "kafka-server.domain.com" {
+		t.Fatal("Expected kafka-server.domain.com as tls.ServerName when tls config ServerName is not provided")
+	}
+
+	c := &tls.Config{ServerName: "kafka-server-other.domain.com"}
+	if validServerNameTLS("", c).ServerName != "kafka-server-other.domain.com" {
+		t.Fatal("Expected kafka-server-other.domain.com as tls.ServerName when tls config ServerName is provided")
+	}
+}