Browse Source

DSN: check for separating slash

Fixes #186
Julien Schmidt 12 years ago
parent
commit
54917e3c77
2 changed files with 17 additions and 10 deletions
  1. 7 0
      utils.go
  2. 10 10
      utils_test.go

+ 7 - 0
utils.go

@@ -26,6 +26,7 @@ var (
 
 	errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?")
 	errInvalidDSNAddr      = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)")
+	errInvalidDSNNoSlash   = errors.New("Invalid DSN: Missing the slash separating the database name")
 )
 
 func init() {
@@ -77,8 +78,10 @@ func parseDSN(dsn string) (cfg *config, err error) {
 
 	// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
 	// Find the last '/' (since the password or the net addr might contain a '/')
+	foundSlash := false
 	for i := len(dsn) - 1; i >= 0; i-- {
 		if dsn[i] == '/' {
+			foundSlash = true
 			var j, k int
 
 			// left part is empty if i <= 0
@@ -135,6 +138,10 @@ func parseDSN(dsn string) (cfg *config, err error) {
 		}
 	}
 
+	if !foundSlash && len(dsn) > 0 {
+		return nil, errInvalidDSNNoSlash
+	}
+
 	// Set default network if empty
 	if cfg.net == "" {
 		cfg.net = "tcp"

+ 10 - 10
utils_test.go

@@ -9,9 +9,9 @@
 package mysql
 
 import (
+	"bytes"
 	"fmt"
 	"testing"
-	"bytes"
 	"time"
 )
 
@@ -57,11 +57,12 @@ func TestDSNParser(t *testing.T) {
 
 func TestDSNParserInvalid(t *testing.T) {
 	var invalidDSNs = []string{
-		"@net(addr/",  // no closing brace
-		"@tcp(/",      // no closing brace
-		"tcp(/",       // no closing brace
-		"(/",          // no closing brace
-		"net(addr)//", // unescaped
+		"@net(addr/",                  // no closing brace
+		"@tcp(/",                      // no closing brace
+		"tcp(/",                       // no closing brace
+		"(/",                          // no closing brace
+		"net(addr)//",                 // unescaped
+		"user:pass@tcp(1.2.3.4:3306)", // no trailing slash
 		//"/dbname?arg=/some/unescaped/path",
 	}
 
@@ -126,8 +127,8 @@ func TestScanNullTime(t *testing.T) {
 
 func TestLengthEncodedInteger(t *testing.T) {
 	var integerTests = []struct {
-	        num     uint64
-	        encoded []byte
+		num     uint64
+		encoded []byte
 	}{
 		{0x0000000000000000, []byte{0x00}},
 		{0x0000000000000012, []byte{0x12}},
@@ -155,10 +156,9 @@ func TestLengthEncodedInteger(t *testing.T) {
 			t.Errorf("%x: expected size %d, got %d", tst.encoded, len(tst.encoded), numLen)
 		}
 		encoded := appendLengthEncodedInteger(nil, num)
-		if (!bytes.Equal(encoded, tst.encoded)) {
+		if !bytes.Equal(encoded, tst.encoded) {
 			t.Errorf("%v: expected %x, got %x", num, tst.encoded, encoded)
 		}
 	}
 
-
 }