Browse Source

Fix tls=true didn't work with host without port (#718)

Fixes #717
INADA Naoki 8 years ago
parent
commit
9181e3a86a
2 changed files with 37 additions and 11 deletions
  1. 9 11
      dsn.go
  2. 28 0
      dsn_test.go

+ 9 - 11
dsn.go

@@ -94,6 +94,15 @@ func (cfg *Config) normalize() error {
 		cfg.Addr = ensureHavePort(cfg.Addr)
 	}
 
+	if cfg.tls != nil {
+		if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
+			host, _, err := net.SplitHostPort(cfg.Addr)
+			if err == nil {
+				cfg.tls.ServerName = host
+			}
+		}
+	}
+
 	return nil
 }
 
@@ -521,10 +530,6 @@ func parseDSNParams(cfg *Config, params string) (err error) {
 				if boolValue {
 					cfg.TLSConfig = "true"
 					cfg.tls = &tls.Config{}
-					host, _, err := net.SplitHostPort(cfg.Addr)
-					if err == nil {
-						cfg.tls.ServerName = host
-					}
 				} else {
 					cfg.TLSConfig = "false"
 				}
@@ -538,13 +543,6 @@ func parseDSNParams(cfg *Config, params string) (err error) {
 				}
 
 				if tlsConfig := getTLSConfigClone(name); tlsConfig != nil {
-					if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
-						host, _, err := net.SplitHostPort(cfg.Addr)
-						if err == nil {
-							tlsConfig.ServerName = host
-						}
-					}
-
 					cfg.TLSConfig = name
 					cfg.tls = tlsConfig
 				} else {

+ 28 - 0
dsn_test.go

@@ -177,6 +177,34 @@ func TestDSNWithCustomTLS(t *testing.T) {
 	DeregisterTLSConfig("utils_test")
 }
 
+func TestDSNTLSConfig(t *testing.T) {
+	expectedServerName := "example.com"
+	dsn := "tcp(example.com:1234)/?tls=true"
+
+	cfg, err := ParseDSN(dsn)
+	if err != nil {
+		t.Error(err.Error())
+	}
+	if cfg.tls == nil {
+		t.Error("cfg.tls should not be nil")
+	}
+	if cfg.tls.ServerName != expectedServerName {
+		t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName)
+	}
+
+	dsn = "tcp(example.com)/?tls=true"
+	cfg, err = ParseDSN(dsn)
+	if err != nil {
+		t.Error(err.Error())
+	}
+	if cfg.tls == nil {
+		t.Error("cfg.tls should not be nil")
+	}
+	if cfg.tls.ServerName != expectedServerName {
+		t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName)
+	}
+}
+
 func TestDSNWithCustomTLSQueryEscape(t *testing.T) {
 	const configKey = "&%!:"
 	dsn := "User:password@tcp(localhost:5555)/dbname?tls=" + url.QueryEscape(configKey)