Browse Source

Default TLS ServerName to the host in the DSN.

A TLS configuration must either have a ServerName or specify
InsecureSkipVerify. In most cases, the ServerName value will
match the host part of the address in the DSN. This change
updates the DSN parser to default the ServerName to the host
value provided unless InsecureSkipVerify is specified.
Andrew Metcalf 11 years ago
parent
commit
8dc06d8c2a
3 changed files with 50 additions and 0 deletions
  1. 1 0
      AUTHORS
  2. 8 0
      utils.go
  3. 41 0
      utils_test.go

+ 1 - 0
AUTHORS

@@ -34,3 +34,4 @@ Xiuming Chen <cc at cxm.cc>
 
 
 Barracuda Networks, Inc.
 Barracuda Networks, Inc.
 Google Inc.
 Google Inc.
+Stripe Inc.

+ 8 - 0
utils.go

@@ -16,6 +16,7 @@ import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
+	"net"
 	"net/url"
 	"net/url"
 	"strings"
 	"strings"
 	"time"
 	"time"
@@ -244,6 +245,13 @@ func parseDSNParams(cfg *config, params string) (err error) {
 				if strings.ToLower(value) == "skip-verify" {
 				if strings.ToLower(value) == "skip-verify" {
 					cfg.tls = &tls.Config{InsecureSkipVerify: true}
 					cfg.tls = &tls.Config{InsecureSkipVerify: true}
 				} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
 				} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
+					if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
+						host, _, err := net.SplitHostPort(cfg.addr)
+						if err == nil {
+							tlsConfig.ServerName = host
+						}
+					}
+
 					cfg.tls = tlsConfig
 					cfg.tls = tlsConfig
 				} else {
 				} else {
 					return fmt.Errorf("Invalid value / unknown config name: %s", value)
 					return fmt.Errorf("Invalid value / unknown config name: %s", value)

+ 41 - 0
utils_test.go

@@ -10,6 +10,7 @@ package mysql
 
 
 import (
 import (
 	"bytes"
 	"bytes"
+	"crypto/tls"
 	"encoding/binary"
 	"encoding/binary"
 	"fmt"
 	"fmt"
 	"testing"
 	"testing"
@@ -74,6 +75,46 @@ func TestDSNParserInvalid(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestDSNWithCustomTLS(t *testing.T) {
+	baseDSN := "user:password@tcp(localhost:5555)/dbname?tls="
+	tlsCfg := tls.Config{}
+
+	RegisterTLSConfig("utils_test", &tlsCfg)
+
+	// Custom TLS is missing
+	tst := baseDSN + "invalid_tls"
+	cfg, err := parseDSN(tst)
+	if err == nil {
+		t.Errorf("Invalid custom TLS in DSN (%s) but did not error.  Got config: %#v", tst, cfg)
+	}
+
+	tst = baseDSN + "utils_test"
+
+	// Custom TLS with a server name
+	name := "foohost"
+	tlsCfg.ServerName = name
+	cfg, err = parseDSN(tst)
+
+	if err != nil {
+		t.Error(err.Error())
+	} else if cfg.tls.ServerName != name {
+		t.Errorf("Did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst)
+	}
+
+	// Custom TLS without a server name
+	name = "localhost"
+	tlsCfg.ServerName = ""
+	cfg, err = parseDSN(tst)
+
+	if err != nil {
+		t.Error(err.Error())
+	} else if cfg.tls.ServerName != name {
+		t.Errorf("Did not get the correct ServerName (%s) parsing DSN (%s).", name, tst)
+	}
+
+	DeregisterTLSConfig("utils_test")
+}
+
 func BenchmarkParseDSN(b *testing.B) {
 func BenchmarkParseDSN(b *testing.B) {
 	b.ReportAllocs()
 	b.ReportAllocs()