|
|
@@ -13,30 +13,26 @@ import (
|
|
|
"crypto/tls"
|
|
|
"database/sql/driver"
|
|
|
"encoding/binary"
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"log"
|
|
|
+ "net/url"
|
|
|
"os"
|
|
|
- "regexp"
|
|
|
"strings"
|
|
|
"time"
|
|
|
)
|
|
|
|
|
|
var (
|
|
|
errLog *log.Logger // Error Logger
|
|
|
- dsnPattern *regexp.Regexp // Data Source Name Parser
|
|
|
tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs
|
|
|
+
|
|
|
+ errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?")
|
|
|
+ errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)")
|
|
|
)
|
|
|
|
|
|
func init() {
|
|
|
errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)
|
|
|
-
|
|
|
- dsnPattern = regexp.MustCompile(
|
|
|
- `^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
|
|
|
- `(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
|
|
|
- `\/(?P<dbname>.*?)` + // /dbname
|
|
|
- `(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1¶mN=valueN]
|
|
|
-
|
|
|
tlsConfigRegister = make(map[string]*tls.Config)
|
|
|
}
|
|
|
|
|
|
@@ -77,98 +73,69 @@ func DeregisterTLSConfig(key string) {
|
|
|
delete(tlsConfigRegister, key)
|
|
|
}
|
|
|
|
|
|
+// parseDSN parses the DSN string to a config
|
|
|
func parseDSN(dsn string) (cfg *config, err error) {
|
|
|
cfg = new(config)
|
|
|
- cfg.params = make(map[string]string)
|
|
|
-
|
|
|
- matches := dsnPattern.FindStringSubmatch(dsn)
|
|
|
- names := dsnPattern.SubexpNames()
|
|
|
-
|
|
|
- for i, match := range matches {
|
|
|
- switch names[i] {
|
|
|
- case "user":
|
|
|
- cfg.user = match
|
|
|
- case "passwd":
|
|
|
- cfg.passwd = match
|
|
|
- case "net":
|
|
|
- cfg.net = match
|
|
|
- case "addr":
|
|
|
- cfg.addr = match
|
|
|
- case "dbname":
|
|
|
- cfg.dbname = match
|
|
|
- case "params":
|
|
|
- for _, v := range strings.Split(match, "&") {
|
|
|
- param := strings.SplitN(v, "=", 2)
|
|
|
- if len(param) != 2 {
|
|
|
- continue
|
|
|
- }
|
|
|
-
|
|
|
- // cfg params
|
|
|
- switch value := param[1]; param[0] {
|
|
|
-
|
|
|
- // Disable INFILE whitelist / enable all files
|
|
|
- case "allowAllFiles":
|
|
|
- var isBool bool
|
|
|
- cfg.allowAllFiles, isBool = readBool(value)
|
|
|
- if !isBool {
|
|
|
- err = fmt.Errorf("Invalid Bool value: %s", value)
|
|
|
- return
|
|
|
- }
|
|
|
|
|
|
- // Switch "rowsAffected" mode
|
|
|
- case "clientFoundRows":
|
|
|
- var isBool bool
|
|
|
- cfg.clientFoundRows, isBool = readBool(value)
|
|
|
- if !isBool {
|
|
|
- err = fmt.Errorf("Invalid Bool value: %s", value)
|
|
|
- return
|
|
|
- }
|
|
|
+ // TODO: use strings.IndexByte when we can depend on Go 1.2
|
|
|
+
|
|
|
+ // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN]
|
|
|
+ // Find the last '/' (since the password or the net addr might contain a '/')
|
|
|
+ for i := len(dsn) - 1; i >= 0; i-- {
|
|
|
+ if dsn[i] == '/' {
|
|
|
+ var j, k int
|
|
|
+
|
|
|
+ // left part is empty if i <= 0
|
|
|
+ if i > 0 {
|
|
|
+ // [username[:password]@][protocol[(address)]]
|
|
|
+ // Find the last '@' in dsn[:i]
|
|
|
+ for j = i; j >= 0; j-- {
|
|
|
+ if dsn[j] == '@' {
|
|
|
+ // username[:password]
|
|
|
+ // Find the first ':' in dsn[:j]
|
|
|
+ for k = 0; k < j; k++ {
|
|
|
+ if dsn[k] == ':' {
|
|
|
+ cfg.passwd = dsn[k+1 : j]
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+ cfg.user = dsn[:k]
|
|
|
|
|
|
- // Use old authentication mode (pre MySQL 4.1)
|
|
|
- case "allowOldPasswords":
|
|
|
- var isBool bool
|
|
|
- cfg.allowOldPasswords, isBool = readBool(value)
|
|
|
- if !isBool {
|
|
|
- err = fmt.Errorf("Invalid Bool value: %s", value)
|
|
|
- return
|
|
|
+ break
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
- // Time Location
|
|
|
- case "loc":
|
|
|
- cfg.loc, err = time.LoadLocation(value)
|
|
|
- if err != nil {
|
|
|
- return
|
|
|
+ // [protocol[(address)]]
|
|
|
+ // Find the first '(' in dsn[j+1:i]
|
|
|
+ for k = j + 1; k < i; k++ {
|
|
|
+ if dsn[k] == '(' {
|
|
|
+ // dsn[i-1] must be == ')' if an adress is specified
|
|
|
+ if dsn[i-1] != ')' {
|
|
|
+ if strings.ContainsRune(dsn[k+1:i], ')') {
|
|
|
+ return nil, errInvalidDSNUnescaped
|
|
|
+ }
|
|
|
+ return nil, errInvalidDSNAddr
|
|
|
+ }
|
|
|
+ cfg.addr = dsn[k+1 : i-1]
|
|
|
+ break
|
|
|
}
|
|
|
+ }
|
|
|
+ cfg.net = dsn[j+1 : k]
|
|
|
+ }
|
|
|
|
|
|
- // Dial Timeout
|
|
|
- case "timeout":
|
|
|
- cfg.timeout, err = time.ParseDuration(value)
|
|
|
- if err != nil {
|
|
|
+ // dbname[?param1=value1&...¶mN=valueN]
|
|
|
+ // Find the first '?' in dsn[i+1:]
|
|
|
+ for j = i + 1; j < len(dsn); j++ {
|
|
|
+ if dsn[j] == '?' {
|
|
|
+ if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
|
|
|
return
|
|
|
}
|
|
|
-
|
|
|
- // TLS-Encryption
|
|
|
- case "tls":
|
|
|
- boolValue, isBool := readBool(value)
|
|
|
- if isBool {
|
|
|
- if boolValue {
|
|
|
- cfg.tls = &tls.Config{}
|
|
|
- }
|
|
|
- } else {
|
|
|
- if strings.ToLower(value) == "skip-verify" {
|
|
|
- cfg.tls = &tls.Config{InsecureSkipVerify: true}
|
|
|
- } else if tlsConfig, ok := tlsConfigRegister[value]; ok {
|
|
|
- cfg.tls = tlsConfig
|
|
|
- } else {
|
|
|
- err = fmt.Errorf("Invalid value / unknown config name: %s", value)
|
|
|
- return
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- default:
|
|
|
- cfg.params[param[0]] = value
|
|
|
+ break
|
|
|
}
|
|
|
}
|
|
|
+ cfg.dbname = dsn[i+1 : j]
|
|
|
+
|
|
|
+ break
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -179,10 +146,18 @@ func parseDSN(dsn string) (cfg *config, err error) {
|
|
|
|
|
|
// Set default adress if empty
|
|
|
if cfg.addr == "" {
|
|
|
- cfg.addr = "127.0.0.1:3306"
|
|
|
+ switch cfg.net {
|
|
|
+ case "tcp":
|
|
|
+ cfg.addr = "127.0.0.1:3306"
|
|
|
+ case "unix":
|
|
|
+ cfg.addr = "/tmp/mysql.sock"
|
|
|
+ default:
|
|
|
+ return nil, errors.New("Default addr for network '" + cfg.net + "' unknown")
|
|
|
+ }
|
|
|
+
|
|
|
}
|
|
|
|
|
|
- // Set default location if not set
|
|
|
+ // Set default location if empty
|
|
|
if cfg.loc == nil {
|
|
|
cfg.loc = time.UTC
|
|
|
}
|
|
|
@@ -190,6 +165,91 @@ func parseDSN(dsn string) (cfg *config, err error) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
+// parseDSNParams parses the DSN "query string"
|
|
|
+// Values must be url.QueryEscape'ed
|
|
|
+func parseDSNParams(cfg *config, params string) (err error) {
|
|
|
+ for _, v := range strings.Split(params, "&") {
|
|
|
+ param := strings.SplitN(v, "=", 2)
|
|
|
+ if len(param) != 2 {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ // cfg params
|
|
|
+ switch value := param[1]; param[0] {
|
|
|
+
|
|
|
+ // Disable INFILE whitelist / enable all files
|
|
|
+ case "allowAllFiles":
|
|
|
+ var isBool bool
|
|
|
+ cfg.allowAllFiles, isBool = readBool(value)
|
|
|
+ if !isBool {
|
|
|
+ return fmt.Errorf("Invalid Bool value: %s", value)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Switch "rowsAffected" mode
|
|
|
+ case "clientFoundRows":
|
|
|
+ var isBool bool
|
|
|
+ cfg.clientFoundRows, isBool = readBool(value)
|
|
|
+ if !isBool {
|
|
|
+ return fmt.Errorf("Invalid Bool value: %s", value)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Use old authentication mode (pre MySQL 4.1)
|
|
|
+ case "allowOldPasswords":
|
|
|
+ var isBool bool
|
|
|
+ cfg.allowOldPasswords, isBool = readBool(value)
|
|
|
+ if !isBool {
|
|
|
+ return fmt.Errorf("Invalid Bool value: %s", value)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Time Location
|
|
|
+ case "loc":
|
|
|
+ if value, err = url.QueryUnescape(value); err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ cfg.loc, err = time.LoadLocation(value)
|
|
|
+ if err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // Dial Timeout
|
|
|
+ case "timeout":
|
|
|
+ cfg.timeout, err = time.ParseDuration(value)
|
|
|
+ if err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ // TLS-Encryption
|
|
|
+ case "tls":
|
|
|
+ boolValue, isBool := readBool(value)
|
|
|
+ if isBool {
|
|
|
+ if boolValue {
|
|
|
+ cfg.tls = &tls.Config{}
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ if strings.ToLower(value) == "skip-verify" {
|
|
|
+ cfg.tls = &tls.Config{InsecureSkipVerify: true}
|
|
|
+ } else if tlsConfig, ok := tlsConfigRegister[value]; ok {
|
|
|
+ cfg.tls = tlsConfig
|
|
|
+ } else {
|
|
|
+ return fmt.Errorf("Invalid value / unknown config name: %s", value)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ default:
|
|
|
+ // lazy init
|
|
|
+ if cfg.params == nil {
|
|
|
+ cfg.params = make(map[string]string)
|
|
|
+ }
|
|
|
+
|
|
|
+ if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return
|
|
|
+}
|
|
|
+
|
|
|
// Returns the bool value of the input.
|
|
|
// The 2nd return value indicates if the input was a valid bool value
|
|
|
func readBool(input string) (value bool, valid bool) {
|