ソースを参照

Merge pull request #130 from go-sql-driver/dsn_parser

New DSN parser
Julien Schmidt 12 年 前
コミット
58ac805a5b
5 ファイル変更207 行追加111 行削除
  1. 1 0
      CHANGELOG.md
  2. 5 3
      README.md
  3. 2 1
      driver_test.go
  4. 151 91
      utils.go
  5. 48 16
      utils_test.go

+ 1 - 0
CHANGELOG.md

@@ -13,6 +13,7 @@ Changes:
   - Refactored the driver tests
   - Added more benchmarks and moved all to a separate file
   - Other small refactoring
+  - DSN parameter values must now be url.QueryEscape'ed. This allows text values to contain special characters, such as '&'.
 
 New Features:
 

+ 5 - 3
README.md

@@ -78,7 +78,7 @@ A DSN in its fullest form:
 username:password@protocol(address)/dbname?param=value
 ```
 
-Except of the databasename, all values are optional. So the minimal DSN is:
+Except for the databasename, all values are optional. So the minimal DSN is:
 ```
 /dbname
 ```
@@ -110,7 +110,7 @@ Possible Parameters are:
   * `allowOldPasswords`: `allowAllFiles=true` allows the usage of the insecure old password method. This should be avoided, but is necessary in some cases. See also [the old_passwords wiki page](https://github.com/go-sql-driver/mysql/wiki/old_passwords).
   * `charset`: Sets the charset used for client-server interaction ("SET NAMES `value`"). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`).
   * `clientFoundRows`: `clientFoundRows=true` causes an UPDATE to return the number of matching rows instead of the number of rows changed.
-  * `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details.
+  * `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details. Please keep in mind, that param values must be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `US%2FPacific`.
   * `parseTime`: `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string`
   * `strict`: Enable strict mode. MySQL warnings are treated as errors.
   * `timeout`: **Driver** side connection timeout. The value must be a string of decimal numbers, each with optional fraction and a unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*. To set a server side timeout, use the parameter [`wait_timeout`](http://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html#sysvar_wait_timeout).
@@ -122,6 +122,8 @@ All other parameters are interpreted as system variables:
   * `tx_isolation`: *"SET [tx_isolation](https://dev.mysql.com/doc/refman/5.5/en/server-system-variables.html#sysvar_tx_isolation)=`value`"*
   * `param`: *"SET `param`=`value`"*
 
+***The values must be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed!***
+
 #### Examples
 ```
 user@unix(/path/to/socket)/dbname
@@ -132,7 +134,7 @@ user:password@tcp(localhost:5555)/dbname?autocommit=true
 ```
 
 ```
-user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?tls=skip-verify&charset=utf8mb4,utf8
+user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?tls=skip-verify&charset=utf8mb4,utf8&sys_var=withSlash%2FandAt%40
 ```
 
 ```

+ 2 - 1
driver_test.go

@@ -15,6 +15,7 @@ import (
 	"io"
 	"io/ioutil"
 	"net"
+	"net/url"
 	"os"
 	"strings"
 	"testing"
@@ -206,7 +207,7 @@ func TestTimezoneConversion(t *testing.T) {
 	}
 
 	for _, tz := range zones {
-		runTests(t, dsn+"&parseTime=true&loc="+tz, tzTest)
+		runTests(t, dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest)
 	}
 }
 

+ 151 - 91
utils.go

@@ -13,30 +13,26 @@ import (
 	"crypto/tls"
 	"database/sql/driver"
 	"encoding/binary"
+	"errors"
 	"fmt"
 	"io"
 	"log"
+	"net/url"
 	"os"
-	"regexp"
 	"strings"
 	"time"
 )
 
 var (
 	errLog            *log.Logger            // Error Logger
-	dsnPattern        *regexp.Regexp         // Data Source Name Parser
 	tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs
+
+	errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?")
+	errInvalidDSNAddr      = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)")
 )
 
 func init() {
 	errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)
-
-	dsnPattern = regexp.MustCompile(
-		`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
-			`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
-			`\/(?P<dbname>.*?)` + // /dbname
-			`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
-
 	tlsConfigRegister = make(map[string]*tls.Config)
 }
 
@@ -77,98 +73,69 @@ func DeregisterTLSConfig(key string) {
 	delete(tlsConfigRegister, key)
 }
 
+// parseDSN parses the DSN string to a config
 func parseDSN(dsn string) (cfg *config, err error) {
 	cfg = new(config)
-	cfg.params = make(map[string]string)
-
-	matches := dsnPattern.FindStringSubmatch(dsn)
-	names := dsnPattern.SubexpNames()
-
-	for i, match := range matches {
-		switch names[i] {
-		case "user":
-			cfg.user = match
-		case "passwd":
-			cfg.passwd = match
-		case "net":
-			cfg.net = match
-		case "addr":
-			cfg.addr = match
-		case "dbname":
-			cfg.dbname = match
-		case "params":
-			for _, v := range strings.Split(match, "&") {
-				param := strings.SplitN(v, "=", 2)
-				if len(param) != 2 {
-					continue
-				}
-
-				// cfg params
-				switch value := param[1]; param[0] {
-
-				// Disable INFILE whitelist / enable all files
-				case "allowAllFiles":
-					var isBool bool
-					cfg.allowAllFiles, isBool = readBool(value)
-					if !isBool {
-						err = fmt.Errorf("Invalid Bool value: %s", value)
-						return
-					}
 
-				// Switch "rowsAffected" mode
-				case "clientFoundRows":
-					var isBool bool
-					cfg.clientFoundRows, isBool = readBool(value)
-					if !isBool {
-						err = fmt.Errorf("Invalid Bool value: %s", value)
-						return
-					}
+	// TODO: use strings.IndexByte when we can depend on Go 1.2
+
+	// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
+	// Find the last '/' (since the password or the net addr might contain a '/')
+	for i := len(dsn) - 1; i >= 0; i-- {
+		if dsn[i] == '/' {
+			var j, k int
+
+			// left part is empty if i <= 0
+			if i > 0 {
+				// [username[:password]@][protocol[(address)]]
+				// Find the last '@' in dsn[:i]
+				for j = i; j >= 0; j-- {
+					if dsn[j] == '@' {
+						// username[:password]
+						// Find the first ':' in dsn[:j]
+						for k = 0; k < j; k++ {
+							if dsn[k] == ':' {
+								cfg.passwd = dsn[k+1 : j]
+								break
+							}
+						}
+						cfg.user = dsn[:k]
 
-				// Use old authentication mode (pre MySQL 4.1)
-				case "allowOldPasswords":
-					var isBool bool
-					cfg.allowOldPasswords, isBool = readBool(value)
-					if !isBool {
-						err = fmt.Errorf("Invalid Bool value: %s", value)
-						return
+						break
 					}
+				}
 
-				// Time Location
-				case "loc":
-					cfg.loc, err = time.LoadLocation(value)
-					if err != nil {
-						return
+				// [protocol[(address)]]
+				// Find the first '(' in dsn[j+1:i]
+				for k = j + 1; k < i; k++ {
+					if dsn[k] == '(' {
+						// dsn[i-1] must be == ')' if an adress is specified
+						if dsn[i-1] != ')' {
+							if strings.ContainsRune(dsn[k+1:i], ')') {
+								return nil, errInvalidDSNUnescaped
+							}
+							return nil, errInvalidDSNAddr
+						}
+						cfg.addr = dsn[k+1 : i-1]
+						break
 					}
+				}
+				cfg.net = dsn[j+1 : k]
+			}
 
-				// Dial Timeout
-				case "timeout":
-					cfg.timeout, err = time.ParseDuration(value)
-					if err != nil {
+			// dbname[?param1=value1&...&paramN=valueN]
+			// Find the first '?' in dsn[i+1:]
+			for j = i + 1; j < len(dsn); j++ {
+				if dsn[j] == '?' {
+					if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
 						return
 					}
-
-				// TLS-Encryption
-				case "tls":
-					boolValue, isBool := readBool(value)
-					if isBool {
-						if boolValue {
-							cfg.tls = &tls.Config{}
-						}
-					} else {
-						if strings.ToLower(value) == "skip-verify" {
-							cfg.tls = &tls.Config{InsecureSkipVerify: true}
-						} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
-							cfg.tls = tlsConfig
-						} else {
-							err = fmt.Errorf("Invalid value / unknown config name: %s", value)
-							return
-						}
-					}
-
-				default:
-					cfg.params[param[0]] = value
+					break
 				}
 			}
+			cfg.dbname = dsn[i+1 : j]
+
+			break
 		}
 	}
 
@@ -179,10 +146,18 @@ func parseDSN(dsn string) (cfg *config, err error) {
 
 	// Set default adress if empty
 	if cfg.addr == "" {
-		cfg.addr = "127.0.0.1:3306"
+		switch cfg.net {
+		case "tcp":
+			cfg.addr = "127.0.0.1:3306"
+		case "unix":
+			cfg.addr = "/tmp/mysql.sock"
+		default:
+			return nil, errors.New("Default addr for network '" + cfg.net + "' unknown")
+		}
+
 	}
 
-	// Set default location if not set
+	// Set default location if empty
 	if cfg.loc == nil {
 		cfg.loc = time.UTC
 	}
@@ -190,6 +165,91 @@ func parseDSN(dsn string) (cfg *config, err error) {
 	return
 }
 
+// parseDSNParams parses the DSN "query string"
+// Values must be url.QueryEscape'ed
+func parseDSNParams(cfg *config, params string) (err error) {
+	for _, v := range strings.Split(params, "&") {
+		param := strings.SplitN(v, "=", 2)
+		if len(param) != 2 {
+			continue
+		}
+
+		// cfg params
+		switch value := param[1]; param[0] {
+
+		// Disable INFILE whitelist / enable all files
+		case "allowAllFiles":
+			var isBool bool
+			cfg.allowAllFiles, isBool = readBool(value)
+			if !isBool {
+				return fmt.Errorf("Invalid Bool value: %s", value)
+			}
+
+		// Switch "rowsAffected" mode
+		case "clientFoundRows":
+			var isBool bool
+			cfg.clientFoundRows, isBool = readBool(value)
+			if !isBool {
+				return fmt.Errorf("Invalid Bool value: %s", value)
+			}
+
+		// Use old authentication mode (pre MySQL 4.1)
+		case "allowOldPasswords":
+			var isBool bool
+			cfg.allowOldPasswords, isBool = readBool(value)
+			if !isBool {
+				return fmt.Errorf("Invalid Bool value: %s", value)
+			}
+
+		// Time Location
+		case "loc":
+			if value, err = url.QueryUnescape(value); err != nil {
+				return
+			}
+			cfg.loc, err = time.LoadLocation(value)
+			if err != nil {
+				return
+			}
+
+		// Dial Timeout
+		case "timeout":
+			cfg.timeout, err = time.ParseDuration(value)
+			if err != nil {
+				return
+			}
+
+		// TLS-Encryption
+		case "tls":
+			boolValue, isBool := readBool(value)
+			if isBool {
+				if boolValue {
+					cfg.tls = &tls.Config{}
+				}
+			} else {
+				if strings.ToLower(value) == "skip-verify" {
+					cfg.tls = &tls.Config{InsecureSkipVerify: true}
+				} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
+					cfg.tls = tlsConfig
+				} else {
+					return fmt.Errorf("Invalid value / unknown config name: %s", value)
+				}
+			}
+
+		default:
+			// lazy init
+			if cfg.params == nil {
+				cfg.params = make(map[string]string)
+			}
+
+			if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil {
+				return
+			}
+		}
+	}
+
+	return
+}
+
 // Returns the bool value of the input.
 // The 2nd return value indicates if the input was a valid bool value
 func readBool(input string) (value bool, valid bool) {

+ 48 - 16
utils_test.go

@@ -14,23 +14,26 @@ import (
 	"time"
 )
 
-func TestDSNParser(t *testing.T) {
-	var testDSNs = []struct {
-		in  string
-		out string
-		loc *time.Location
-	}{
-		{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
-		{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
-		{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
-		{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
-		{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil> allowAllFiles:true allowOldPasswords:true clientFoundRows:true}", time.UTC},
-		{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.Local},
-		{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
-		{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
-		{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
-	}
+var testDSNs = []struct {
+	in  string
+	out string
+	loc *time.Location
+}{
+	{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+	{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+	{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+	{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+	{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil> allowAllFiles:true allowOldPasswords:true clientFoundRows:true}", time.UTC},
+	{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.Local},
+	{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+	{"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+	{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+	{"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+	{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+	{"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+}
 
+func TestDSNParser(t *testing.T) {
 	var cfg *config
 	var err error
 	var res string
@@ -51,6 +54,35 @@ func TestDSNParser(t *testing.T) {
 	}
 }
 
+func TestDSNParserInvalid(t *testing.T) {
+	var invalidDSNs = []string{
+		"@net(addr/",  // no closing brace
+		"@tcp(/",      // no closing brace
+		"tcp(/",       // no closing brace
+		"(/",          // no closing brace
+		"net(addr)//", // unescaped
+		//"/dbname?arg=/some/unescaped/path",
+	}
+
+	for i, tst := range invalidDSNs {
+		if _, err := parseDSN(tst); err == nil {
+			t.Errorf("invalid DSN #%d. (%s) didn't error!", i, tst)
+		}
+	}
+}
+
+func BenchmarkParseDSN(b *testing.B) {
+	b.ReportAllocs()
+
+	for i := 0; i < b.N; i++ {
+		for _, tst := range testDSNs {
+			if _, err := parseDSN(tst.in); err != nil {
+				b.Error(err.Error())
+			}
+		}
+	}
+}
+
 func TestScanNullTime(t *testing.T) {
 	var scanTests = []struct {
 		in    interface{}