Ver Fonte

Default TLS ServerName to the host in the DSN.

A TLS configuration must either have a ServerName or specify
InsecureSkipVerify. In most cases, the ServerName value will
match the host part of the address in the DSN. This change
updates the DSN parser to default the ServerName to the host
value provided unless InsecureSkipVerify is specified.
Andrew Metcalf há 11 anos atrás
pai
commit
8dc06d8c2a
3 ficheiros alterados com 50 adições e 0 exclusões
  1. 1 0
      AUTHORS
  2. 8 0
      utils.go
  3. 41 0
      utils_test.go

+ 1 - 0
AUTHORS

@@ -34,3 +34,4 @@ Xiuming Chen <cc at cxm.cc>
 
 Barracuda Networks, Inc.
 Google Inc.
+Stripe Inc.

+ 8 - 0
utils.go

@@ -16,6 +16,7 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"net"
 	"net/url"
 	"strings"
 	"time"
@@ -244,6 +245,13 @@ func parseDSNParams(cfg *config, params string) (err error) {
 				if strings.ToLower(value) == "skip-verify" {
 					cfg.tls = &tls.Config{InsecureSkipVerify: true}
 				} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
+					if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
+						host, _, err := net.SplitHostPort(cfg.addr)
+						if err == nil {
+							tlsConfig.ServerName = host
+						}
+					}
+
 					cfg.tls = tlsConfig
 				} else {
 					return fmt.Errorf("Invalid value / unknown config name: %s", value)

+ 41 - 0
utils_test.go

@@ -10,6 +10,7 @@ package mysql
 
 import (
 	"bytes"
+	"crypto/tls"
 	"encoding/binary"
 	"fmt"
 	"testing"
@@ -74,6 +75,46 @@ func TestDSNParserInvalid(t *testing.T) {
 	}
 }
 
+func TestDSNWithCustomTLS(t *testing.T) {
+	baseDSN := "user:password@tcp(localhost:5555)/dbname?tls="
+	tlsCfg := tls.Config{}
+
+	RegisterTLSConfig("utils_test", &tlsCfg)
+
+	// Custom TLS is missing
+	tst := baseDSN + "invalid_tls"
+	cfg, err := parseDSN(tst)
+	if err == nil {
+		t.Errorf("Invalid custom TLS in DSN (%s) but did not error.  Got config: %#v", tst, cfg)
+	}
+
+	tst = baseDSN + "utils_test"
+
+	// Custom TLS with a server name
+	name := "foohost"
+	tlsCfg.ServerName = name
+	cfg, err = parseDSN(tst)
+
+	if err != nil {
+		t.Error(err.Error())
+	} else if cfg.tls.ServerName != name {
+		t.Errorf("Did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst)
+	}
+
+	// Custom TLS without a server name
+	name = "localhost"
+	tlsCfg.ServerName = ""
+	cfg, err = parseDSN(tst)
+
+	if err != nil {
+		t.Error(err.Error())
+	} else if cfg.tls.ServerName != name {
+		t.Errorf("Did not get the correct ServerName (%s) parsing DSN (%s).", name, tst)
+	}
+
+	DeregisterTLSConfig("utils_test")
+}
+
 func BenchmarkParseDSN(b *testing.B) {
 	b.ReportAllocs()