Browse Source

Fix protocol parsing

Julien Schmidt 12 years ago
parent
commit
6d1a06dd1f
3 changed files with 24 additions and 27 deletions
  1. 1 1
      README.md
  2. 17 23
      utils.go
  3. 6 3
      utils_test.go

+ 1 - 1
README.md

@@ -78,7 +78,7 @@ A DSN in its fullest form:
 username:password@protocol(address)/dbname?param=value
 ```
 
-Except of the databasename, all values are optional. So the minimal DSN is:
+Except for the databasename, all values are optional. So the minimal DSN is:
 ```
 /dbname
 ```

+ 17 - 23
utils.go

@@ -80,10 +80,10 @@ func parseDSN(dsn string) (cfg *config, err error) {
 	// TODO: use strings.IndexByte when we can depend on Go 1.2
 
 	// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
-	// Find the last '/' (since the password might contain a '/')
+	// Find the last '/' (since the password or the net addr might contain a '/')
 	for i := len(dsn) - 1; i >= 0; i-- {
 		if dsn[i] == '/' {
-			var j int
+			var j, k int
 
 			// left part is empty if i <= 0
 			if i > 0 {
@@ -93,7 +93,6 @@ func parseDSN(dsn string) (cfg *config, err error) {
 					if dsn[j] == '@' {
 						// username[:password]
 						// Find the first ':' in dsn[:j]
-						var k int
 						for k = 0; k < j; k++ {
 							if dsn[k] == ':' {
 								cfg.passwd = dsn[k+1 : j]
@@ -102,31 +101,26 @@ func parseDSN(dsn string) (cfg *config, err error) {
 						}
 						cfg.user = dsn[:k]
 
-						// [protocol[(address)]]
-						// Find the first '(' in dsn[j+1:i]
-						for k = j + 1; k < i; k++ {
-							if dsn[k] == '(' {
-								// dsn[i-1] must be == ')' if an adress is specified
-								if dsn[i-1] != ')' {
-									if strings.ContainsRune(dsn[k+1:i], ')') {
-										return nil, errInvalidDSNUnescaped
-									}
-									return nil, errInvalidDSNAddr
-								}
-								cfg.addr = dsn[k+1 : i-1]
-								break
-							}
-						}
-						cfg.net = dsn[j+1 : k]
-
 						break
 					}
 				}
 
-				// non-empty left part must contain an '@'
-				if j < 0 {
-					return nil, errInvalidDSNUnescaped
+				// [protocol[(address)]]
+				// Find the first '(' in dsn[j+1:i]
+				for k = j + 1; k < i; k++ {
+					if dsn[k] == '(' {
+						// dsn[i-1] must be == ')' if an adress is specified
+						if dsn[i-1] != ')' {
+							if strings.ContainsRune(dsn[k+1:i], ')') {
+								return nil, errInvalidDSNUnescaped
+							}
+							return nil, errInvalidDSNAddr
+						}
+						cfg.addr = dsn[k+1 : i-1]
+						break
+					}
 				}
+				cfg.net = dsn[j+1 : k]
 			}
 
 			// dbname[?param1=value1&...&paramN=valueN]

+ 6 - 3
utils_test.go

@@ -30,7 +30,7 @@ var testDSNs = []struct {
 	{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
 	{"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
 	{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
-	{"@unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+	{"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
 }
 
 func TestDSNParser(t *testing.T) {
@@ -56,8 +56,11 @@ func TestDSNParser(t *testing.T) {
 
 func TestDSNParserInvalid(t *testing.T) {
 	var invalidDSNs = []string{
-		"asdf/dbname",
-		"@net(addr/",
+		"@net(addr/",  // no closing brace
+		"@tcp(/",      // no closing brace
+		"tcp(/",       // no closing brace
+		"(/",          // no closing brace
+		"net(addr)//", // unescaped
 		//"/dbname?arg=/some/unescaped/path",
 	}