Ver Fonte

Merge pull request #283 from metcalf/am-default-servername

Default TLS ServerName to the host in the DSN.
Julien Schmidt há 11 anos atrás
pai
commit
cebeb3599d
3 ficheiros alterados com 50 adições e 0 exclusões
  1. 1 0
      AUTHORS
  2. 8 0
      utils.go
  3. 41 0
      utils_test.go

+ 1 - 0
AUTHORS

@@ -34,3 +34,4 @@ Xiuming Chen <cc at cxm.cc>
 
 Barracuda Networks, Inc.
 Google Inc.
+Stripe Inc.

+ 8 - 0
utils.go

@@ -16,6 +16,7 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"net"
 	"net/url"
 	"strings"
 	"time"
@@ -244,6 +245,13 @@ func parseDSNParams(cfg *config, params string) (err error) {
 				if strings.ToLower(value) == "skip-verify" {
 					cfg.tls = &tls.Config{InsecureSkipVerify: true}
 				} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
+					if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
+						host, _, err := net.SplitHostPort(cfg.addr)
+						if err == nil {
+							tlsConfig.ServerName = host
+						}
+					}
+
 					cfg.tls = tlsConfig
 				} else {
 					return fmt.Errorf("Invalid value / unknown config name: %s", value)

+ 41 - 0
utils_test.go

@@ -10,6 +10,7 @@ package mysql
 
 import (
 	"bytes"
+	"crypto/tls"
 	"encoding/binary"
 	"fmt"
 	"testing"
@@ -74,6 +75,46 @@ func TestDSNParserInvalid(t *testing.T) {
 	}
 }
 
+func TestDSNWithCustomTLS(t *testing.T) {
+	baseDSN := "user:password@tcp(localhost:5555)/dbname?tls="
+	tlsCfg := tls.Config{}
+
+	RegisterTLSConfig("utils_test", &tlsCfg)
+
+	// Custom TLS is missing
+	tst := baseDSN + "invalid_tls"
+	cfg, err := parseDSN(tst)
+	if err == nil {
+		t.Errorf("Invalid custom TLS in DSN (%s) but did not error.  Got config: %#v", tst, cfg)
+	}
+
+	tst = baseDSN + "utils_test"
+
+	// Custom TLS with a server name
+	name := "foohost"
+	tlsCfg.ServerName = name
+	cfg, err = parseDSN(tst)
+
+	if err != nil {
+		t.Error(err.Error())
+	} else if cfg.tls.ServerName != name {
+		t.Errorf("Did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst)
+	}
+
+	// Custom TLS without a server name
+	name = "localhost"
+	tlsCfg.ServerName = ""
+	cfg, err = parseDSN(tst)
+
+	if err != nil {
+		t.Error(err.Error())
+	} else if cfg.tls.ServerName != name {
+		t.Errorf("Did not get the correct ServerName (%s) parsing DSN (%s).", name, tst)
+	}
+
+	DeregisterTLSConfig("utils_test")
+}
+
 func BenchmarkParseDSN(b *testing.B) {
 	b.ReportAllocs()