Browse Source

move tls and pubkey object creation to Config.normalize() (#958)

This is still less than ideal since we cannot directly pass in
tls.Config into Config and have it be used, but it is sill backwards
compatable.  In the future this should be revisited to be able to use a
custome tls.Config passed directly in without string
parsing/registering.
Brandon Bennett 6 years ago
parent
commit
877a9775f0
4 changed files with 76 additions and 27 deletions
  1. 2 1
      AUTHORS
  2. 27 23
      dsn.go
  3. 46 2
      dsn_test.go
  4. 1 1
      utils.go

+ 2 - 1
AUTHORS

@@ -90,11 +90,12 @@ Zhenye Xie <xiezhenye at gmail.com>
 
 Barracuda Networks, Inc.
 Counting Ltd.
+Facebook Inc.
 GitHub Inc.
 Google Inc.
 InfoSum Ltd.
 Keybase Inc.
+Multiplay Ltd.
 Percona LLC
 Pivotal Inc.
 Stripe Inc.
-Multiplay Ltd.

+ 27 - 23
dsn.go

@@ -113,17 +113,35 @@ func (cfg *Config) normalize() error {
 		default:
 			return errors.New("default addr for network '" + cfg.Net + "' unknown")
 		}
-
 	} else if cfg.Net == "tcp" {
 		cfg.Addr = ensureHavePort(cfg.Addr)
 	}
 
-	if cfg.tls != nil {
-		if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
-			host, _, err := net.SplitHostPort(cfg.Addr)
-			if err == nil {
-				cfg.tls.ServerName = host
-			}
+	switch cfg.TLSConfig {
+	case "false", "":
+		// don't set anything
+	case "true":
+		cfg.tls = &tls.Config{}
+	case "skip-verify", "preferred":
+		cfg.tls = &tls.Config{InsecureSkipVerify: true}
+	default:
+		cfg.tls = getTLSConfigClone(cfg.TLSConfig)
+		if cfg.tls == nil {
+			return errors.New("invalid value / unknown config name: " + cfg.TLSConfig)
+		}
+	}
+
+	if cfg.tls != nil && cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
+		host, _, err := net.SplitHostPort(cfg.Addr)
+		if err == nil {
+			cfg.tls.ServerName = host
+		}
+	}
+
+	if cfg.ServerPubKey != "" {
+		cfg.pubKey = getServerPubKey(cfg.ServerPubKey)
+		if cfg.pubKey == nil {
+			return errors.New("invalid value / unknown server pub key name: " + cfg.ServerPubKey)
 		}
 	}
 
@@ -552,13 +570,7 @@ func parseDSNParams(cfg *Config, params string) (err error) {
 			if err != nil {
 				return fmt.Errorf("invalid value for server pub key name: %v", err)
 			}
-
-			if pubKey := getServerPubKey(name); pubKey != nil {
-				cfg.ServerPubKey = name
-				cfg.pubKey = pubKey
-			} else {
-				return errors.New("invalid value / unknown server pub key name: " + name)
-			}
+			cfg.ServerPubKey = name
 
 		// Strict mode
 		case "strict":
@@ -577,25 +589,17 @@ func parseDSNParams(cfg *Config, params string) (err error) {
 			if isBool {
 				if boolValue {
 					cfg.TLSConfig = "true"
-					cfg.tls = &tls.Config{}
 				} else {
 					cfg.TLSConfig = "false"
 				}
 			} else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" {
 				cfg.TLSConfig = vl
-				cfg.tls = &tls.Config{InsecureSkipVerify: true}
 			} else {
 				name, err := url.QueryUnescape(value)
 				if err != nil {
 					return fmt.Errorf("invalid value for TLS config name: %v", err)
 				}
-
-				if tlsConfig := getTLSConfigClone(name); tlsConfig != nil {
-					cfg.TLSConfig = name
-					cfg.tls = tlsConfig
-				} else {
-					return errors.New("invalid value / unknown config name: " + name)
-				}
+				cfg.TLSConfig = name
 			}
 
 		// I/O write Timeout

+ 46 - 2
dsn_test.go

@@ -39,8 +39,8 @@ var testDSNs = []struct {
 	"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify",
 	&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "skip-verify"},
 }, {
-	"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216",
-	&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true, MaxAllowedPacket: 16777216},
+	"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216&tls=false&allowCleartextPasswords=true&parseTime=true&rejectReadOnly=true",
+	&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true},
 }, {
 	"user:password@/dbname?allowNativePasswords=false&maxAllowedPacket=0",
 	&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false},
@@ -358,6 +358,50 @@ func TestCloneConfig(t *testing.T) {
 	}
 }
 
+func TestNormalizeTLSConfig(t *testing.T) {
+	tt := []struct {
+		tlsConfig string
+		want      *tls.Config
+	}{
+		{"", nil},
+		{"false", nil},
+		{"true", &tls.Config{ServerName: "myserver"}},
+		{"skip-verify", &tls.Config{InsecureSkipVerify: true}},
+		{"preferred", &tls.Config{InsecureSkipVerify: true}},
+		{"test_tls_config", &tls.Config{ServerName: "myServerName"}},
+	}
+
+	RegisterTLSConfig("test_tls_config", &tls.Config{ServerName: "myServerName"})
+	defer func() { DeregisterTLSConfig("test_tls_config") }()
+
+	for _, tc := range tt {
+		t.Run(tc.tlsConfig, func(t *testing.T) {
+			cfg := &Config{
+				Addr:      "myserver:3306",
+				TLSConfig: tc.tlsConfig,
+			}
+
+			cfg.normalize()
+
+			if cfg.tls == nil {
+				if tc.want != nil {
+					t.Fatal("wanted a tls config but got nil instead")
+				}
+				return
+			}
+
+			if cfg.tls.ServerName != tc.want.ServerName {
+				t.Errorf("tls.ServerName doesn't match (want: '%s', got: '%s')",
+					tc.want.ServerName, cfg.tls.ServerName)
+			}
+			if cfg.tls.InsecureSkipVerify != tc.want.InsecureSkipVerify {
+				t.Errorf("tls.InsecureSkipVerify doesn't match (want: %T, got :%T)",
+					tc.want.InsecureSkipVerify, cfg.tls.InsecureSkipVerify)
+			}
+		})
+	}
+}
+
 func BenchmarkParseDSN(b *testing.B) {
 	b.ReportAllocs()
 

+ 1 - 1
utils.go

@@ -56,7 +56,7 @@ var (
 //  db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
 //
 func RegisterTLSConfig(key string, config *tls.Config) error {
-	if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" {
+	if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" {
 		return fmt.Errorf("key '%s' is reserved", key)
 	}