Преглед на файлове

DSN: Add cfg.Format method

Julien Schmidt преди 10 години
родител
ревизия
8b688fbe79
променени са 4 файла, в които са добавени 278 реда и са изтрити 54 реда
  1. 11 11
      collations.go
  2. 209 19
      dsn.go
  3. 45 19
      dsn_test.go
  4. 13 5
      packets.go

+ 11 - 11
collations.go

@@ -8,7 +8,7 @@
 
 package mysql
 
-const defaultCollation byte = 33 // utf8_general_ci
+const defaultCollation = "utf8_general_ci"
 
 // A list of available collations mapped to the internal ID.
 // To update this map use the following MySQL query:
@@ -237,14 +237,14 @@ var collations = map[string]byte{
 
 // A blacklist of collations which is unsafe to interpolate parameters.
 // These multibyte encodings may contains 0x5c (`\`) in their trailing bytes.
-var unsafeCollations = map[byte]bool{
-	1:  true, // big5_chinese_ci
-	13: true, // sjis_japanese_ci
-	28: true, // gbk_chinese_ci
-	84: true, // big5_bin
-	86: true, // gb2312_bin
-	87: true, // gbk_bin
-	88: true, // sjis_bin
-	95: true, // cp932_japanese_ci
-	96: true, // cp932_bin
+var unsafeCollations = map[string]bool{
+	"big5_chinese_ci":   true,
+	"sjis_japanese_ci":  true,
+	"gbk_chinese_ci":    true,
+	"big5_bin":          true,
+	"gb2312_bin":        true,
+	"gbk_bin":           true,
+	"sjis_bin":          true,
+	"cp932_japanese_ci": true,
+	"cp932_bin":         true,
 }

+ 209 - 19
dsn.go

@@ -9,6 +9,7 @@
 package mysql
 
 import (
+	"bytes"
 	"crypto/tls"
 	"errors"
 	"fmt"
@@ -33,12 +34,13 @@ type Config struct {
 	Addr         string            // Network address
 	DBName       string            // Database name
 	Params       map[string]string // Connection parameters
+	Collation    string            // Connection collation
 	Loc          *time.Location    // Location for time.Time values
-	TLS          *tls.Config       // TLS configuration
+	TLSConfig    string            // TLS configuration name
+	tls          *tls.Config       // TLS configuration
 	Timeout      time.Duration     // Dial timeout
 	ReadTimeout  time.Duration     // I/O read timeout
 	WriteTimeout time.Duration     // I/O write timeout
-	Collation    uint8             // Connection collation
 
 	AllowAllFiles           bool // Allow all files to be used with LOAD DATA LOCAL INFILE
 	AllowCleartextPasswords bool // Allows the cleartext client side plugin
@@ -51,6 +53,194 @@ type Config struct {
 	Strict                  bool // Return warnings as errors
 }
 
+// FormatDSN formats the given Config into a DSN string which can be passed to
+// the driver.
+func (cfg *Config) FormatDSN() string {
+	var buf bytes.Buffer
+
+	// [username[:password]@]
+	if len(cfg.User) > 0 {
+		buf.WriteString(cfg.User)
+		if len(cfg.Passwd) > 0 {
+			buf.WriteByte(':')
+			buf.WriteString(cfg.Passwd)
+		}
+		buf.WriteByte('@')
+	}
+
+	// [protocol[(address)]]
+	if len(cfg.Net) > 0 {
+		buf.WriteString(cfg.Net)
+		if len(cfg.Addr) > 0 {
+			buf.WriteByte('(')
+			buf.WriteString(cfg.Addr)
+			buf.WriteByte(')')
+		}
+	}
+
+	// /dbname
+	buf.WriteByte('/')
+	buf.WriteString(cfg.DBName)
+
+	// [?param1=value1&...&paramN=valueN]
+	hasParam := false
+
+	if cfg.AllowAllFiles {
+		hasParam = true
+		buf.WriteString("?allowAllFiles=true")
+	}
+
+	if cfg.AllowCleartextPasswords {
+		if hasParam {
+			buf.WriteString("&allowCleartextPasswords=true")
+		} else {
+			hasParam = true
+			buf.WriteString("?allowCleartextPasswords=true")
+		}
+	}
+
+	if cfg.AllowOldPasswords {
+		if hasParam {
+			buf.WriteString("&allowOldPasswords=true")
+		} else {
+			hasParam = true
+			buf.WriteString("?allowOldPasswords=true")
+		}
+	}
+
+	if cfg.ClientFoundRows {
+		if hasParam {
+			buf.WriteString("&clientFoundRows=true")
+		} else {
+			hasParam = true
+			buf.WriteString("?clientFoundRows=true")
+		}
+	}
+
+	if col := cfg.Collation; col != defaultCollation && len(col) > 0 {
+		if hasParam {
+			buf.WriteString("&collation=")
+		} else {
+			hasParam = true
+			buf.WriteString("?collation=")
+		}
+		buf.WriteString(col)
+	}
+
+	if cfg.ColumnsWithAlias {
+		if hasParam {
+			buf.WriteString("&columnsWithAlias=true")
+		} else {
+			hasParam = true
+			buf.WriteString("?columnsWithAlias=true")
+		}
+	}
+
+	if cfg.InterpolateParams {
+		if hasParam {
+			buf.WriteString("&interpolateParams=true")
+		} else {
+			hasParam = true
+			buf.WriteString("?interpolateParams=true")
+		}
+	}
+
+	if cfg.Loc != time.UTC && cfg.Loc != nil {
+		if hasParam {
+			buf.WriteString("&loc=")
+		} else {
+			hasParam = true
+			buf.WriteString("?loc=")
+		}
+		buf.WriteString(url.QueryEscape(cfg.Loc.String()))
+	}
+
+	if cfg.MultiStatements {
+		if hasParam {
+			buf.WriteString("&multiStatements=true")
+		} else {
+			hasParam = true
+			buf.WriteString("?multiStatements=true")
+		}
+	}
+
+	if cfg.ParseTime {
+		if hasParam {
+			buf.WriteString("&parseTime=true")
+		} else {
+			hasParam = true
+			buf.WriteString("?parseTime=true")
+		}
+	}
+
+	if cfg.ReadTimeout > 0 {
+		if hasParam {
+			buf.WriteString("&readTimeout=")
+		} else {
+			hasParam = true
+			buf.WriteString("?readTimeout=")
+		}
+		buf.WriteString(cfg.ReadTimeout.String())
+	}
+
+	if cfg.Strict {
+		if hasParam {
+			buf.WriteString("&strict=true")
+		} else {
+			hasParam = true
+			buf.WriteString("?strict=true")
+		}
+	}
+
+	if cfg.Timeout > 0 {
+		if hasParam {
+			buf.WriteString("&timeout=")
+		} else {
+			hasParam = true
+			buf.WriteString("?timeout=")
+		}
+		buf.WriteString(cfg.Timeout.String())
+	}
+
+	if len(cfg.TLSConfig) > 0 {
+		if hasParam {
+			buf.WriteString("&tls=")
+		} else {
+			hasParam = true
+			buf.WriteString("?tls=")
+		}
+		buf.WriteString(url.QueryEscape(cfg.TLSConfig))
+	}
+
+	if cfg.WriteTimeout > 0 {
+		if hasParam {
+			buf.WriteString("&writeTimeout=")
+		} else {
+			hasParam = true
+			buf.WriteString("?writeTimeout=")
+		}
+		buf.WriteString(cfg.WriteTimeout.String())
+	}
+
+	// other params
+	if cfg.Params != nil {
+		for param, value := range cfg.Params {
+			if hasParam {
+				buf.WriteByte('&')
+			} else {
+				hasParam = true
+				buf.WriteByte('?')
+			}
+
+			buf.WriteString(param)
+			buf.WriteByte('=')
+			buf.WriteString(url.QueryEscape(value))
+		}
+	}
+
+	return buf.String()
+}
+
 // ParseDSN parses the DSN string to a Config
 func ParseDSN(dsn string) (cfg *Config, err error) {
 	// New config with some default values
@@ -196,15 +386,7 @@ func parseDSNParams(cfg *Config, params string) (err error) {
 
 		// Collation
 		case "collation":
-			collation, ok := collations[value]
-			if !ok {
-				// Note possibility for false negatives:
-				// could be triggered  although the collation is valid if the
-				// collations map does not contain entries the server supports.
-				err = errors.New("unknown collation")
-				return
-			}
-			cfg.Collation = collation
+			cfg.Collation = value
 			break
 
 		case "columnsWithAlias":
@@ -279,14 +461,21 @@ func parseDSNParams(cfg *Config, params string) (err error) {
 			boolValue, isBool := readBool(value)
 			if isBool {
 				if boolValue {
-					cfg.TLS = &tls.Config{}
+					cfg.TLSConfig = "true"
+					cfg.tls = &tls.Config{}
+				} else {
+					cfg.TLSConfig = "false"
 				}
-			} else if value, err := url.QueryUnescape(value); err != nil {
-				return fmt.Errorf("invalid value for TLS config name: %v", err)
+			} else if vl := strings.ToLower(value); vl == "skip-verify" {
+				cfg.TLSConfig = vl
+				cfg.tls = &tls.Config{InsecureSkipVerify: true}
 			} else {
-				if strings.ToLower(value) == "skip-verify" {
-					cfg.TLS = &tls.Config{InsecureSkipVerify: true}
-				} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
+				name, err := url.QueryUnescape(value)
+				if err != nil {
+					return fmt.Errorf("invalid value for TLS config name: %v", err)
+				}
+
+				if tlsConfig, ok := tlsConfigRegister[name]; ok {
 					if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
 						host, _, err := net.SplitHostPort(cfg.Addr)
 						if err == nil {
@@ -294,9 +483,10 @@ func parseDSNParams(cfg *Config, params string) (err error) {
 						}
 					}
 
-					cfg.TLS = tlsConfig
+					cfg.TLSConfig = name
+					cfg.tls = tlsConfig
 				} else {
-					return errors.New("invalid value / unknown config name: " + value)
+					return errors.New("invalid value / unknown config name: " + name)
 				}
 			}
 

+ 45 - 19
dsn_test.go

@@ -19,20 +19,20 @@ var testDSNs = []struct {
 	in  string
 	out string
 }{
-	{"username:password@protocol(address)/dbname?param=value", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
-	{"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:true InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
-	{"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true&multiStatements=true", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:true InterpolateParams:false MultiStatements:true ParseTime:false Strict:false}"},
-	{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{User:user Passwd: Net:unix Addr:/path/to/socket DBName:dbname Params:map[charset:utf8] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
-	{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{User:user Passwd:password Net:tcp Addr:localhost:5555 DBName:dbname Params:map[charset:utf8] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
-	{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{User:user Passwd:password Net:tcp Addr:localhost:5555 DBName:dbname Params:map[charset:utf8mb4,utf8] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
-	{"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{User:user Passwd:password Net:tcp Addr:127.0.0.1:3306 DBName:dbname Params:map[] Loc:UTC TLS:<nil> Timeout:30s ReadTimeout:1s WriteTimeout:1s Collation:224 AllowAllFiles:true AllowCleartextPasswords:false AllowOldPasswords:true ClientFoundRows:true ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
-	{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{User:user Passwd:p@ss(word) Net:tcp Addr:[de:ad:be:ef::ca:fe]:80 DBName:dbname Params:map[] Loc:Local TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
-	{"/dbname", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName:dbname Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
-	{"@/", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
-	{"/", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
-	{"", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
-	{"user:p@/ssword@/", "&{User:user Passwd:p@/ssword Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
-	{"unix/?arg=%2Fsome%2Fpath.ext", "&{User: Passwd: Net:unix Addr:/tmp/mysql.sock DBName: Params:map[arg:/some/path.ext] Loc:UTC TLS:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 Collation:33 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"username:password@protocol(address)/dbname?param=value", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Collation:utf8_general_ci Loc:UTC TLSConfig: tls:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Collation:utf8_general_ci Loc:UTC TLSConfig: tls:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:true InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true&multiStatements=true", "&{User:username Passwd:password Net:protocol Addr:address DBName:dbname Params:map[param:value] Collation:utf8_general_ci Loc:UTC TLSConfig: tls:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:true InterpolateParams:false MultiStatements:true ParseTime:false Strict:false}"},
+	{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{User:user Passwd: Net:unix Addr:/path/to/socket DBName:dbname Params:map[charset:utf8] Collation:utf8_general_ci Loc:UTC TLSConfig: tls:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{User:user Passwd:password Net:tcp Addr:localhost:5555 DBName:dbname Params:map[charset:utf8] Collation:utf8_general_ci Loc:UTC TLSConfig:true tls:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{User:user Passwd:password Net:tcp Addr:localhost:5555 DBName:dbname Params:map[charset:utf8mb4,utf8] Collation:utf8_general_ci Loc:UTC TLSConfig:skip-verify tls:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{User:user Passwd:password Net:tcp Addr:127.0.0.1:3306 DBName:dbname Params:map[] Collation:utf8mb4_unicode_ci Loc:UTC TLSConfig: tls:<nil> Timeout:30s ReadTimeout:1s WriteTimeout:1s AllowAllFiles:true AllowCleartextPasswords:false AllowOldPasswords:true ClientFoundRows:true ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{User:user Passwd:p@ss(word) Net:tcp Addr:[de:ad:be:ef::ca:fe]:80 DBName:dbname Params:map[] Collation:utf8_general_ci Loc:Local TLSConfig: tls:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"/dbname", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName:dbname Params:map[] Collation:utf8_general_ci Loc:UTC TLSConfig: tls:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"@/", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Collation:utf8_general_ci Loc:UTC TLSConfig: tls:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"/", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Collation:utf8_general_ci Loc:UTC TLSConfig: tls:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"", "&{User: Passwd: Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Collation:utf8_general_ci Loc:UTC TLSConfig: tls:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"user:p@/ssword@/", "&{User:user Passwd:p@/ssword Net:tcp Addr:127.0.0.1:3306 DBName: Params:map[] Collation:utf8_general_ci Loc:UTC TLSConfig: tls:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
+	{"unix/?arg=%2Fsome%2Fpath.ext", "&{User: Passwd: Net:unix Addr:/tmp/mysql.sock DBName: Params:map[arg:/some/path.ext] Collation:utf8_general_ci Loc:UTC TLSConfig: tls:<nil> Timeout:0 ReadTimeout:0 WriteTimeout:0 AllowAllFiles:false AllowCleartextPasswords:false AllowOldPasswords:false ClientFoundRows:false ColumnsWithAlias:false InterpolateParams:false MultiStatements:false ParseTime:false Strict:false}"},
 }
 
 func TestDSNParser(t *testing.T) {
@@ -47,7 +47,7 @@ func TestDSNParser(t *testing.T) {
 		}
 
 		// pointer not static
-		cfg.TLS = nil
+		cfg.tls = nil
 
 		res = fmt.Sprintf("%+v", cfg)
 		if res != tst.out {
@@ -74,6 +74,32 @@ func TestDSNParserInvalid(t *testing.T) {
 	}
 }
 
+func TestDSNReformat(t *testing.T) {
+	for i, tst := range testDSNs {
+		dsn1 := tst.in
+		cfg1, err := ParseDSN(dsn1)
+		if err != nil {
+			t.Error(err.Error())
+			continue
+		}
+		cfg1.tls = nil // pointer not static
+		res1 := fmt.Sprintf("%+v", cfg1)
+
+		dsn2 := cfg1.FormatDSN()
+		cfg2, err := ParseDSN(dsn2)
+		if err != nil {
+			t.Error(err.Error())
+			continue
+		}
+		cfg2.tls = nil // pointer not static
+		res2 := fmt.Sprintf("%+v", cfg2)
+
+		if res1 != res2 {
+			t.Errorf("%d. %q does not match %q", i, res2, res1)
+		}
+	}
+}
+
 func TestDSNWithCustomTLS(t *testing.T) {
 	baseDSN := "User:password@tcp(localhost:5555)/dbname?tls="
 	tlsCfg := tls.Config{}
@@ -96,7 +122,7 @@ func TestDSNWithCustomTLS(t *testing.T) {
 
 	if err != nil {
 		t.Error(err.Error())
-	} else if cfg.TLS.ServerName != name {
+	} else if cfg.tls.ServerName != name {
 		t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst)
 	}
 
@@ -107,14 +133,14 @@ func TestDSNWithCustomTLS(t *testing.T) {
 
 	if err != nil {
 		t.Error(err.Error())
-	} else if cfg.TLS.ServerName != name {
+	} else if cfg.tls.ServerName != name {
 		t.Errorf("did not get the correct ServerName (%s) parsing DSN (%s).", name, tst)
 	}
 
 	DeregisterTLSConfig("utils_test")
 }
 
-func TestDSNWithCustomTLS_queryEscape(t *testing.T) {
+func TestDSNWithCustomTLSQueryEscape(t *testing.T) {
 	const configKey = "&%!:"
 	dsn := "User:password@tcp(localhost:5555)/dbname?tls=" + url.QueryEscape(configKey)
 	name := "foohost"
@@ -126,7 +152,7 @@ func TestDSNWithCustomTLS_queryEscape(t *testing.T) {
 
 	if err != nil {
 		t.Error(err.Error())
-	} else if cfg.TLS.ServerName != name {
+	} else if cfg.tls.ServerName != name {
 		t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, dsn)
 	}
 }

+ 13 - 5
packets.go

@@ -13,6 +13,7 @@ import (
 	"crypto/tls"
 	"database/sql/driver"
 	"encoding/binary"
+	"errors"
 	"fmt"
 	"io"
 	"math"
@@ -166,7 +167,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
 	if mc.flags&clientProtocol41 == 0 {
 		return nil, ErrOldProtocol
 	}
-	if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil {
+	if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
 		return nil, ErrNoTLS
 	}
 	pos += 2
@@ -232,7 +233,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 	}
 
 	// To enable TLS / SSL
-	if mc.cfg.TLS != nil {
+	if mc.cfg.tls != nil {
 		clientFlags |= clientSSL
 	}
 
@@ -272,18 +273,25 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 	data[11] = 0x00
 
 	// Charset [1 byte]
-	data[12] = mc.cfg.Collation
+	var found bool
+	data[12], found = collations[mc.cfg.Collation]
+	if !found {
+		// Note possibility for false negatives:
+		// could be triggered  although the collation is valid if the
+		// collations map does not contain entries the server supports.
+		return errors.New("unknown collation")
+	}
 
 	// SSL Connection Request Packet
 	// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
-	if mc.cfg.TLS != nil {
+	if mc.cfg.tls != nil {
 		// Send TLS / SSL request packet
 		if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
 			return err
 		}
 
 		// Switch to TLS
-		tlsConn := tls.Client(mc.netConn, mc.cfg.TLS)
+		tlsConn := tls.Client(mc.netConn, mc.cfg.tls)
 		if err := tlsConn.Handshake(); err != nil {
 			return err
 		}