Sfoglia il codice sorgente

Default TLS ServerName to the host in the DSN.

A TLS configuration must either have a ServerName or specify
InsecureSkipVerify. In most cases, the ServerName value will
match the host part of the address in the DSN. This change
updates the DSN parser to default the ServerName to the host
value provided unless InsecureSkipVerify is specified.
Andrew Metcalf 11 anni fa
parent
commit
8dc06d8c2a
3 ha cambiato i file con 50 aggiunte e 0 eliminazioni
  1. 1 0
      AUTHORS
  2. 8 0
      utils.go
  3. 41 0
      utils_test.go

+ 1 - 0
AUTHORS

@@ -34,3 +34,4 @@ Xiuming Chen <cc at cxm.cc>
 
 Barracuda Networks, Inc.
 Google Inc.
+Stripe Inc.

+ 8 - 0
utils.go

@@ -16,6 +16,7 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"net"
 	"net/url"
 	"strings"
 	"time"
@@ -244,6 +245,13 @@ func parseDSNParams(cfg *config, params string) (err error) {
 				if strings.ToLower(value) == "skip-verify" {
 					cfg.tls = &tls.Config{InsecureSkipVerify: true}
 				} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
+					if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
+						host, _, err := net.SplitHostPort(cfg.addr)
+						if err == nil {
+							tlsConfig.ServerName = host
+						}
+					}
+
 					cfg.tls = tlsConfig
 				} else {
 					return fmt.Errorf("Invalid value / unknown config name: %s", value)

+ 41 - 0
utils_test.go

@@ -10,6 +10,7 @@ package mysql
 
 import (
 	"bytes"
+	"crypto/tls"
 	"encoding/binary"
 	"fmt"
 	"testing"
@@ -74,6 +75,46 @@ func TestDSNParserInvalid(t *testing.T) {
 	}
 }
 
+func TestDSNWithCustomTLS(t *testing.T) {
+	baseDSN := "user:password@tcp(localhost:5555)/dbname?tls="
+	tlsCfg := tls.Config{}
+
+	RegisterTLSConfig("utils_test", &tlsCfg)
+
+	// Custom TLS is missing
+	tst := baseDSN + "invalid_tls"
+	cfg, err := parseDSN(tst)
+	if err == nil {
+		t.Errorf("Invalid custom TLS in DSN (%s) but did not error.  Got config: %#v", tst, cfg)
+	}
+
+	tst = baseDSN + "utils_test"
+
+	// Custom TLS with a server name
+	name := "foohost"
+	tlsCfg.ServerName = name
+	cfg, err = parseDSN(tst)
+
+	if err != nil {
+		t.Error(err.Error())
+	} else if cfg.tls.ServerName != name {
+		t.Errorf("Did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst)
+	}
+
+	// Custom TLS without a server name
+	name = "localhost"
+	tlsCfg.ServerName = ""
+	cfg, err = parseDSN(tst)
+
+	if err != nil {
+		t.Error(err.Error())
+	} else if cfg.tls.ServerName != name {
+		t.Errorf("Did not get the correct ServerName (%s) parsing DSN (%s).", name, tst)
+	}
+
+	DeregisterTLSConfig("utils_test")
+}
+
 func BenchmarkParseDSN(b *testing.B) {
 	b.ReportAllocs()