Browse Source

Allow interpolateParams only with ascii, latin1 and utf8 collations

INADA Naoki 11 years ago
parent
commit
e517683745
3 changed files with 78 additions and 12 deletions
  1. 11 9
      driver_test.go
  2. 30 3
      utils.go
  3. 37 0
      utils_test.go

+ 11 - 9
driver_test.go

@@ -87,19 +87,21 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
 
 	db.Exec("DROP TABLE IF EXISTS test")
 
-	dbp, err := sql.Open("mysql", dsn+"&interpolateParams=true")
-	if err != nil {
-		t.Fatalf("Error connecting: %s", err.Error())
+	dsn2 := dsn + "&interpolateParams=true"
+	var db2 *sql.DB
+	if _, err := parseDSN(dsn2); err != errInvalidDSNUnsafeCollation {
+		db2, err = sql.Open("mysql", dsn2)
 	}
-	defer dbp.Close()
 
 	dbt := &DBTest{t, db}
-	dbtp := &DBTest{t, dbp}
+	dbt2 := &DBTest{t, db2}
 	for _, test := range tests {
 		test(dbt)
 		dbt.db.Exec("DROP TABLE IF EXISTS test")
-		test(dbtp)
-		dbtp.db.Exec("DROP TABLE IF EXISTS test")
+		if db2 != nil {
+			test(dbt2)
+			dbt2.db.Exec("DROP TABLE IF EXISTS test")
+		}
 	}
 }
 
@@ -864,7 +866,7 @@ func TestLoadData(t *testing.T) {
 					dbt.Fatalf("%d != %d", i, id)
 				}
 				if values[i-1] != value {
-					dbt.Fatalf("%s != %s", values[i-1], value)
+					dbt.Fatalf("%q != %q", values[i-1], value)
 				}
 			}
 			err = rows.Err()
@@ -889,7 +891,7 @@ func TestLoadData(t *testing.T) {
 
 		// Local File
 		RegisterLocalFile(file.Name())
-		dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE '%q' INTO TABLE test", file.Name()))
+		dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name()))
 		verifyLoadDataResult()
 		// negative test
 		_, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'doesnotexist' INTO TABLE test")

+ 30 - 3
utils.go

@@ -25,9 +25,10 @@ import (
 var (
 	tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs
 
-	errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?")
-	errInvalidDSNAddr      = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)")
-	errInvalidDSNNoSlash   = errors.New("Invalid DSN: Missing the slash separating the database name")
+	errInvalidDSNUnescaped       = errors.New("Invalid DSN: Did you forget to escape a param value?")
+	errInvalidDSNAddr            = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)")
+	errInvalidDSNNoSlash         = errors.New("Invalid DSN: Missing the slash separating the database name")
+	errInvalidDSNUnsafeCollation = errors.New("Invalid DSN: interpolateParams can be used with ascii, latin1, utf8 and utf8mb4 charset")
 )
 
 func init() {
@@ -147,6 +148,32 @@ func parseDSN(dsn string) (cfg *config, err error) {
 		return nil, errInvalidDSNNoSlash
 	}
 
+	if cfg.interpolateParams && cfg.collation != defaultCollation {
+		// A whitelist of collations which safe to interpolate parameters.
+		// ASCII and latin-1 are safe since they are single byte encoding.
+		// utf-8 is safe since it doesn't conatins ASCII characters in trailing bytes.
+		safeCollations := []string{"ascii_", "latin1_", "utf8_", "utf8mb4_"}
+
+		var collationName string
+		for name, collation := range collations {
+			if collation == cfg.collation {
+				collationName = name
+				break
+			}
+		}
+
+		safe := false
+		for _, p := range safeCollations {
+			if strings.HasPrefix(collationName, p) {
+				safe = true
+				break
+			}
+		}
+		if !safe {
+			return nil, errInvalidDSNUnsafeCollation
+		}
+	}
+
 	// Set default network if empty
 	if cfg.net == "" {
 		cfg.net = "tcp"

+ 37 - 0
utils_test.go

@@ -116,6 +116,43 @@ func TestDSNWithCustomTLS(t *testing.T) {
 	DeregisterTLSConfig("utils_test")
 }
 
+func TestDSNUnsafeCollation(t *testing.T) {
+	_, err := parseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=true")
+	if err != errInvalidDSNUnsafeCollation {
+		t.Error("Expected %v, Got %v", errInvalidDSNUnsafeCollation, err)
+	}
+
+	_, err = parseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=false")
+	if err != nil {
+		t.Error("Expected %v, Got %v", nil, err)
+	}
+
+	_, err = parseDSN("/dbname?collation=gbk_chinese_ci")
+	if err != nil {
+		t.Error("Expected %v, Got %v", nil, err)
+	}
+
+	_, err = parseDSN("/dbname?collation=ascii_bin&interpolateParams=true")
+	if err != nil {
+		t.Error("Expected %v, Got %v", nil, err)
+	}
+
+	_, err = parseDSN("/dbname?collation=latin1_german1_ci&interpolateParams=true")
+	if err != nil {
+		t.Error("Expected %v, Got %v", nil, err)
+	}
+
+	_, err = parseDSN("/dbname?collation=utf8_general_ci&interpolateParams=true")
+	if err != nil {
+		t.Error("Expected %v, Got %v", nil, err)
+	}
+
+	_, err = parseDSN("/dbname?collation=utf8mb4_general_ci&interpolateParams=true")
+	if err != nil {
+		t.Error("Expected %v, Got %v", nil, err)
+	}
+}
+
 func BenchmarkParseDSN(b *testing.B) {
 	b.ReportAllocs()