Browse Source

Merge pull request #96 from go-sql-driver/clean-up

Clean up
Julien Schmidt 12 years ago
parent
commit
0a174803c1
10 changed files with 145 additions and 177 deletions
  1. 0 1
      .gitattributes
  2. 1 1
      README.md
  3. 11 13
      connection.go
  4. 2 7
      driver.go
  5. 110 138
      driver_test.go
  6. 0 1
      errors.go
  7. 1 1
      infile.go
  8. 3 6
      packets.go
  9. 8 0
      utils.go
  10. 9 9
      utils_test.go

+ 0 - 1
.gitattributes

@@ -1 +0,0 @@
-README.md merge=ours

+ 1 - 1
README.md

@@ -40,7 +40,7 @@ A MySQL-Driver for Go's [database/sql](http://golang.org/pkg/database/sql) packa
   * Optional `time.Time` parsing
 
 ## Requirements
-  * Go 1.0.3 or higher
+  * Go 1.1 or higher (use [v1.0](https://github.com/go-sql-driver/mysql/tags) for Go 1.0.x)
   * MySQL (Version 4.1 or higher), MariaDB or Percona Server
 
 ---------------------------------------

+ 11 - 13
connection.go

@@ -35,15 +35,17 @@ type mysqlConn struct {
 }
 
 type config struct {
-	user    string
-	passwd  string
-	net     string
-	addr    string
-	dbname  string
-	params  map[string]string
-	loc     *time.Location
-	timeout time.Duration
-	tls     *tls.Config
+	user            string
+	passwd          string
+	net             string
+	addr            string
+	dbname          string
+	params          map[string]string
+	loc             *time.Location
+	timeout         time.Duration
+	tls             *tls.Config
+	allowAllFiles   bool
+	clientFoundRows bool
 }
 
 // Handles parameters set in DSN
@@ -64,10 +66,6 @@ func (mc *mysqlConn) handleParams() (err error) {
 				return
 			}
 
-		// handled elsewhere
-		case "allowAllFiles", "clientFoundRows":
-			continue
-
 		// time.Time parsing
 		case "parseTime":
 			mc.parseTime = readBool(val)

+ 2 - 7
driver.go

@@ -33,13 +33,8 @@ func (d *mysqlDriver) Open(dsn string) (driver.Conn, error) {
 	}
 
 	// Connect to Server
-	if mc.cfg.timeout > 0 { // with timeout
-		if err == nil {
-			mc.netConn, err = net.DialTimeout(mc.cfg.net, mc.cfg.addr, mc.cfg.timeout)
-		}
-	} else { // no timeout
-		mc.netConn, err = net.Dial(mc.cfg.net, mc.cfg.addr)
-	}
+	nd := net.Dialer{Timeout: mc.cfg.timeout}
+	mc.netConn, err = nd.Dial(mc.cfg.net, mc.cfg.addr)
 	if err != nil {
 		return nil, err
 	}

+ 110 - 138
driver_test.go

@@ -13,7 +13,6 @@ import (
 )
 
 var (
-	charset   string
 	dsn       string
 	netAddr   string
 	available bool
@@ -42,7 +41,6 @@ func init() {
 	prot := env("MYSQL_TEST_PROT", "tcp")
 	addr := env("MYSQL_TEST_ADDR", "localhost:3306")
 	dbname := env("MYSQL_TEST_DBNAME", "gotest")
-	charset = "charset=utf8"
 	netAddr = fmt.Sprintf("%s(%s)", prot, addr)
 	dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true", user, pass, netAddr, dbname)
 	c, err := net.Dial(prot, addr)
@@ -57,15 +55,14 @@ type DBTest struct {
 	db *sql.DB
 }
 
-func runTests(t *testing.T, name, dsn string, tests ...func(dbt *DBTest)) {
+func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
 	if !available {
-		t.Logf("MySQL-Server not running on %s. Skipping %s", netAddr, name)
-		return
+		t.Skipf("MySQL-Server not running on %s", netAddr)
 	}
 
 	db, err := sql.Open("mysql", dsn)
 	if err != nil {
-		t.Fatalf("Error connecting: %v", err)
+		t.Fatalf("Error connecting: %s", err.Error())
 	}
 	defer db.Close()
 
@@ -82,7 +79,7 @@ func (dbt *DBTest) fail(method, query string, err error) {
 	if len(query) > 300 {
 		query = "[query too large to print]"
 	}
-	dbt.Fatalf("Error on %s %s: %v", method, query, err)
+	dbt.Fatalf("Error on %s %s: %s", method, query, err.Error())
 }
 
 func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) {
@@ -102,32 +99,26 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows)
 }
 
 func TestCharset(t *testing.T) {
-	mustSetCharset := func(charsetParam, expected string) {
-		db, err := sql.Open("mysql", dsn+"&"+charsetParam)
-		if err != nil {
-			t.Fatalf("Error on Open: %v", err)
-		}
-		defer db.Close()
-
-		dbt := &DBTest{t, db}
-		rows := dbt.mustQuery("SELECT @@character_set_connection")
-		defer rows.Close()
+	if !available {
+		t.Skipf("MySQL-Server not running on %s", netAddr)
+	}
 
-		if !rows.Next() {
-			dbt.Fatalf("Error getting connection charset: %v", err)
-		}
+	mustSetCharset := func(charsetParam, expected string) {
+		runTests(t, dsn+"&"+charsetParam, func(dbt *DBTest) {
+			rows := dbt.mustQuery("SELECT @@character_set_connection")
+			defer rows.Close()
 
-		var got string
-		rows.Scan(&got)
+			if !rows.Next() {
+				dbt.Fatalf("Error getting connection charset: %s", rows.Err())
+			}
 
-		if got != expected {
-			dbt.Fatalf("Expected connection charset %s but got %s", expected, got)
-		}
-	}
+			var got string
+			rows.Scan(&got)
 
-	if !available {
-		t.Logf("MySQL-Server not running on %s. Skipping TestCharset", netAddr)
-		return
+			if got != expected {
+				dbt.Fatalf("Expected connection charset %s but got %s", expected, got)
+			}
+		})
 	}
 
 	// non utf8 test
@@ -142,26 +133,18 @@ func TestCharset(t *testing.T) {
 }
 
 func TestFailingCharset(t *testing.T) {
-	if !available {
-		t.Logf("MySQL-Server not running on %s. Skipping TestFailingCharset", netAddr)
-		return
-	}
-	db, err := sql.Open("mysql", dsn+"&charset=none")
-	if err != nil {
-		t.Fatalf("Error on Open: %v", err)
-	}
-	defer db.Close()
-
-	// run query to really establish connection...
-	_, err = db.Exec("SELECT 1")
-	if err == nil {
-		db.Close()
-		t.Fatalf("Connection must not succeed without a valid charset")
-	}
+	runTests(t, dsn+"&charset=none", func(dbt *DBTest) {
+		// run query to really establish connection...
+		_, err := dbt.db.Exec("SELECT 1")
+		if err == nil {
+			dbt.db.Close()
+			t.Fatalf("Connection must not succeed without a valid charset")
+		}
+	})
 }
 
 func TestRawBytesResultExceedsBuffer(t *testing.T) {
-	runTests(t, "TestRawBytesResultExceedsBuffer", dsn, func(dbt *DBTest) {
+	runTests(t, dsn, func(dbt *DBTest) {
 		// defaultBufSize from buffer.go
 		expected := strings.Repeat("abc", defaultBufSize)
 		rows := dbt.mustQuery("SELECT '" + expected + "'")
@@ -178,7 +161,7 @@ func TestRawBytesResultExceedsBuffer(t *testing.T) {
 }
 
 func TestCRUD(t *testing.T) {
-	runTests(t, "TestCRUD", dsn, func(dbt *DBTest) {
+	runTests(t, dsn, func(dbt *DBTest) {
 		// Create Table
 		dbt.mustExec("CREATE TABLE test (value BOOL)")
 
@@ -193,7 +176,7 @@ func TestCRUD(t *testing.T) {
 		res := dbt.mustExec("INSERT INTO test VALUES (1)")
 		count, err := res.RowsAffected()
 		if err != nil {
-			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
+			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
 		}
 		if count != 1 {
 			dbt.Fatalf("Expected 1 affected row, got %d", count)
@@ -201,7 +184,7 @@ func TestCRUD(t *testing.T) {
 
 		id, err := res.LastInsertId()
 		if err != nil {
-			dbt.Fatalf("res.LastInsertId() returned error: %v", err)
+			dbt.Fatalf("res.LastInsertId() returned error: %s", err.Error())
 		}
 		if id != 0 {
 			dbt.Fatalf("Expected InsertID 0, got %d", id)
@@ -226,7 +209,7 @@ func TestCRUD(t *testing.T) {
 		res = dbt.mustExec("UPDATE test SET value = ? WHERE value = ?", false, true)
 		count, err = res.RowsAffected()
 		if err != nil {
-			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
+			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
 		}
 		if count != 1 {
 			dbt.Fatalf("Expected 1 affected row, got %d", count)
@@ -251,7 +234,7 @@ func TestCRUD(t *testing.T) {
 		res = dbt.mustExec("DELETE FROM test WHERE value = ?", false)
 		count, err = res.RowsAffected()
 		if err != nil {
-			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
+			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
 		}
 		if count != 1 {
 			dbt.Fatalf("Expected 1 affected row, got %d", count)
@@ -261,7 +244,7 @@ func TestCRUD(t *testing.T) {
 		res = dbt.mustExec("DELETE FROM test")
 		count, err = res.RowsAffected()
 		if err != nil {
-			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
+			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
 		}
 		if count != 0 {
 			dbt.Fatalf("Expected 0 affected row, got %d", count)
@@ -270,7 +253,7 @@ func TestCRUD(t *testing.T) {
 }
 
 func TestInt(t *testing.T) {
-	runTests(t, "TestInt", dsn, func(dbt *DBTest) {
+	runTests(t, dsn, func(dbt *DBTest) {
 		types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"}
 		in := int64(42)
 		var out int64
@@ -317,7 +300,7 @@ func TestInt(t *testing.T) {
 }
 
 func TestFloat(t *testing.T) {
-	runTests(t, "TestFloat", dsn, func(dbt *DBTest) {
+	runTests(t, dsn, func(dbt *DBTest) {
 		types := [2]string{"FLOAT", "DOUBLE"}
 		in := float32(42.23)
 		var out float32
@@ -340,7 +323,7 @@ func TestFloat(t *testing.T) {
 }
 
 func TestString(t *testing.T) {
-	runTests(t, "TestString", dsn, func(dbt *DBTest) {
+	runTests(t, dsn, func(dbt *DBTest) {
 		types := [6]string{"CHAR(255)", "VARCHAR(255)", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT"}
 		in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах  น่าฟังเอย"
 		var out string
@@ -380,7 +363,7 @@ func TestString(t *testing.T) {
 
 		err := dbt.db.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&out)
 		if err != nil {
-			dbt.Fatalf("Error on BLOB-Query: %v", err)
+			dbt.Fatalf("Error on BLOB-Query: %s", err.Error())
 		} else if out != in {
 			dbt.Errorf("BLOB: %s != %s", in, out)
 		}
@@ -429,7 +412,7 @@ func TestDateTime(t *testing.T) {
 				dbt *DBTest, rows *sql.Rows, test *timetest, sqltype, resulttype, mode string) {
 				var sOut string
 				if err := rows.Scan(&sOut); err != nil {
-					dbt.Errorf("%s (%s %s): %v", sqltype, resulttype, mode, err)
+					dbt.Errorf("%s (%s %s): %s", sqltype, resulttype, mode, err.Error())
 				} else if test.sOut != sOut {
 					dbt.Errorf("%s (%s %s): %s != %s", sqltype, resulttype, mode, test.sOut, sOut)
 				}
@@ -438,7 +421,7 @@ func TestDateTime(t *testing.T) {
 				dbt *DBTest, rows *sql.Rows, test *timetest, sqltype, resulttype, mode string) {
 				var tOut time.Time
 				if err := rows.Scan(&tOut); err != nil {
-					dbt.Errorf("%s (%s %s): %v", sqltype, resulttype, mode, err)
+					dbt.Errorf("%s (%s %s): %s", sqltype, resulttype, mode, err.Error())
 				} else if test.tOut != tOut || test.tIsZero != tOut.IsZero() {
 					dbt.Errorf("%s (%s %s): %s [%t] != %s [%t]", sqltype, resulttype, mode, test.tOut, test.tIsZero, tOut, tOut.IsZero())
 				}
@@ -460,8 +443,8 @@ func TestDateTime(t *testing.T) {
 						s.test(dbt, rows, test, sqltype, s.vartype, mode)
 					} else {
 						if err := rows.Err(); err != nil {
-							dbt.Errorf("%s (%s %s): %v",
-								sqltype, s.vartype, mode, err)
+							dbt.Errorf("%s (%s %s): %s",
+								sqltype, s.vartype, mode, err.Error())
 						} else {
 							dbt.Errorf("%s (%s %s): no data",
 								sqltype, s.vartype, mode)
@@ -476,12 +459,12 @@ func TestDateTime(t *testing.T) {
 	timeDsn := dsn + "&sql_mode=ALLOW_INVALID_DATES"
 	for _, v := range setups {
 		s = v
-		runTests(t, "TestDateTime", timeDsn+s.dsnSuffix, testTime)
+		runTests(t, timeDsn+s.dsnSuffix, testTime)
 	}
 }
 
 func TestNULL(t *testing.T) {
-	runTests(t, "TestNULL", dsn, func(dbt *DBTest) {
+	runTests(t, dsn, func(dbt *DBTest) {
 		nullStmt, err := dbt.db.Prepare("SELECT NULL")
 		if err != nil {
 			dbt.Fatal(err)
@@ -597,7 +580,7 @@ func TestNULL(t *testing.T) {
 }
 
 func TestLongData(t *testing.T) {
-	runTests(t, "TestLongData", dsn, func(dbt *DBTest) {
+	runTests(t, dsn, func(dbt *DBTest) {
 		var maxAllowedPacketSize int
 		err := dbt.db.QueryRow("select @@max_allowed_packet").Scan(&maxAllowedPacketSize)
 		if err != nil {
@@ -658,7 +641,7 @@ func TestLongData(t *testing.T) {
 }
 
 func TestLoadData(t *testing.T) {
-	runTests(t, "TestLoadData", dsn, func(dbt *DBTest) {
+	runTests(t, dsn, func(dbt *DBTest) {
 		verifyLoadDataResult := func() {
 			rows, err := dbt.db.Query("SELECT * FROM test")
 			if err != nil {
@@ -744,10 +727,55 @@ func TestLoadData(t *testing.T) {
 	})
 }
 
+func TestFoundRows(t *testing.T) {
+	runTests(t, dsn, func(dbt *DBTest) {
+		dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
+		dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
+
+		res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
+		count, err := res.RowsAffected()
+		if err != nil {
+			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
+		}
+		if count != 2 {
+			dbt.Fatalf("Expected 2 affected rows, got %d", count)
+		}
+		res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
+		count, err = res.RowsAffected()
+		if err != nil {
+			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
+		}
+		if count != 2 {
+			dbt.Fatalf("Expected 2 affected rows, got %d", count)
+		}
+	})
+	runTests(t, dsn+"&clientFoundRows=true", func(dbt *DBTest) {
+		dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
+		dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
+
+		res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
+		count, err := res.RowsAffected()
+		if err != nil {
+			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
+		}
+		if count != 2 {
+			dbt.Fatalf("Expected 2 matched rows, got %d", count)
+		}
+		res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
+		count, err = res.RowsAffected()
+		if err != nil {
+			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
+		}
+		if count != 3 {
+			dbt.Fatalf("Expected 3 matched rows, got %d", count)
+		}
+	})
+}
+
 func TestStrict(t *testing.T) {
 	// ALLOW_INVALID_DATES to get rid of stricter modes - we want to test for warnings, not errors
 	relaxedDsn := dsn + "&sql_mode=ALLOW_INVALID_DATES"
-	runTests(t, "TestStrict", relaxedDsn, func(dbt *DBTest) {
+	runTests(t, relaxedDsn, func(dbt *DBTest) {
 		dbt.mustExec("CREATE TABLE test (a TINYINT NOT NULL, b CHAR(4))")
 
 		var queries = [...]struct {
@@ -797,7 +825,7 @@ func TestStrict(t *testing.T) {
 		for i := range queries {
 			stmt, err = dbt.db.Prepare(queries[i].in)
 			if err != nil {
-				dbt.Error("Error on preparing query %: ", queries[i].in, err.Error())
+				dbt.Errorf("Error on preparing query %s: %s", queries[i].in, err.Error())
 			}
 
 			_, err = stmt.Exec()
@@ -805,28 +833,17 @@ func TestStrict(t *testing.T) {
 
 			err = stmt.Close()
 			if err != nil {
-				dbt.Error("Error on closing stmt for query %: ", queries[i].in, err.Error())
+				dbt.Errorf("Error on closing stmt for query %s: %s", queries[i].in, err.Error())
 			}
 		}
 	})
 }
 
 func TestTLS(t *testing.T) {
-	runTests(t, "TestTLS", dsn+"&tls=skip-verify", func(dbt *DBTest) {
-		/* TODO: GO 1.1 API */
-		/*if err := dbt.db.Ping(); err != nil {
-		    if err == errNoTLS {
-		        dbt.Skip("Server does not support TLS. Skipping TestTLS")
-		    } else {
-		        dbt.Fatalf("Error on Ping: %s", err.Error())
-		    }
-		}*/
-
-		/* GO 1.0 API */
-		if _, err := dbt.db.Exec("DO 1"); err != nil {
+	runTests(t, dsn+"&tls=skip-verify", func(dbt *DBTest) {
+		if err := dbt.db.Ping(); err != nil {
 			if err == errNoTLS {
-				dbt.Log("Server does not support TLS. Skipping TestTLS")
-				return
+				dbt.Skip("Server does not support TLS")
 			} else {
 				dbt.Fatalf("Error on Ping: %s", err.Error())
 			}
@@ -850,7 +867,7 @@ func TestTLS(t *testing.T) {
 // Special cases
 
 func TestRowsClose(t *testing.T) {
-	runTests(t, "TestRowsClose", dsn, func(dbt *DBTest) {
+	runTests(t, dsn, func(dbt *DBTest) {
 		rows, err := dbt.db.Query("SELECT 1")
 		if err != nil {
 			dbt.Fatal(err)
@@ -875,7 +892,7 @@ func TestRowsClose(t *testing.T) {
 // dangling statements
 // http://code.google.com/p/go/issues/detail?id=3865
 func TestCloseStmtBeforeRows(t *testing.T) {
-	runTests(t, "TestCloseStmtBeforeRows", dsn, func(dbt *DBTest) {
+	runTests(t, dsn, func(dbt *DBTest) {
 		stmt, err := dbt.db.Prepare("SELECT 1")
 		if err != nil {
 			dbt.Fatal(err)
@@ -904,7 +921,7 @@ func TestCloseStmtBeforeRows(t *testing.T) {
 			var out bool
 			err = rows.Scan(&out)
 			if err != nil {
-				dbt.Fatalf("Error on rows.Scan(): %v", err)
+				dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
 			}
 			if out != true {
 				dbt.Errorf("true != %t", out)
@@ -916,7 +933,7 @@ func TestCloseStmtBeforeRows(t *testing.T) {
 // It is valid to have multiple Rows for the same Stmt
 // http://code.google.com/p/go/issues/detail?id=3734
 func TestStmtMultiRows(t *testing.T) {
-	runTests(t, "TestStmtMultiRows", dsn, func(dbt *DBTest) {
+	runTests(t, dsn, func(dbt *DBTest) {
 		stmt, err := dbt.db.Prepare("SELECT 1 UNION SELECT 0")
 		if err != nil {
 			dbt.Fatal(err)
@@ -949,7 +966,7 @@ func TestStmtMultiRows(t *testing.T) {
 
 			err = rows1.Scan(&out)
 			if err != nil {
-				dbt.Fatalf("Error on rows.Scan(): %v", err)
+				dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
 			}
 			if out != true {
 				dbt.Errorf("true != %t", out)
@@ -966,7 +983,7 @@ func TestStmtMultiRows(t *testing.T) {
 
 			err = rows2.Scan(&out)
 			if err != nil {
-				dbt.Fatalf("Error on rows.Scan(): %v", err)
+				dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
 			}
 			if out != true {
 				dbt.Errorf("true != %t", out)
@@ -984,7 +1001,7 @@ func TestStmtMultiRows(t *testing.T) {
 
 			err = rows1.Scan(&out)
 			if err != nil {
-				dbt.Fatalf("Error on rows.Scan(): %v", err)
+				dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
 			}
 			if out != false {
 				dbt.Errorf("false != %t", out)
@@ -1009,7 +1026,7 @@ func TestStmtMultiRows(t *testing.T) {
 
 			err = rows2.Scan(&out)
 			if err != nil {
-				dbt.Fatalf("Error on rows.Scan(): %v", err)
+				dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
 			}
 			if out != false {
 				dbt.Errorf("false != %t", out)
@@ -1028,14 +1045,14 @@ func TestStmtMultiRows(t *testing.T) {
 
 func TestConcurrent(t *testing.T) {
 	if readBool(os.Getenv("MYSQL_TEST_CONCURRENT")) != true {
-		t.Log("CONCURRENT env var not set. Skipping TestConcurrent")
-		return
+		t.Skip("CONCURRENT env var not set")
 	}
-	runTests(t, "TestConcurrent", dsn, func(dbt *DBTest) {
+
+	runTests(t, dsn, func(dbt *DBTest) {
 		var max int
 		err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max)
 		if err != nil {
-			dbt.Fatalf("%v", err)
+			dbt.Fatalf("%s", err.Error())
 		}
 		dbt.Logf("Testing up to %d concurrent connections \r\n", max)
 		canStop := false
@@ -1076,51 +1093,6 @@ func TestConcurrent(t *testing.T) {
 	})
 }
 
-func TestFoundRows(t *testing.T) {
-	runTests(t, "TestFoundRows1", dsn, func(dbt *DBTest) {
-		dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
-		dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
-
-		res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
-		count, err := res.RowsAffected()
-		if err != nil {
-			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
-		}
-		if count != 2 {
-			dbt.Fatalf("Expected 2 affected rows, got %d", count)
-		}
-		res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
-		count, err = res.RowsAffected()
-		if err != nil {
-			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
-		}
-		if count != 2 {
-			dbt.Fatalf("Expected 2 affected rows, got %d", count)
-		}
-	})
-	runTests(t, "TestFoundRows2", dsn+"&clientFoundRows=true", func(dbt *DBTest) {
-		dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
-		dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
-
-		res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
-		count, err := res.RowsAffected()
-		if err != nil {
-			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
-		}
-		if count != 2 {
-			dbt.Fatalf("Expected 2 matched rows, got %d", count)
-		}
-		res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
-		count, err = res.RowsAffected()
-		if err != nil {
-			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
-		}
-		if count != 3 {
-			dbt.Fatalf("Expected 3 matched rows, got %d", count)
-		}
-	})
-}
-
 // BENCHMARKS
 var sample []byte
 

+ 0 - 1
errors.go

@@ -102,5 +102,4 @@ func (mc *mysqlConn) getWarnings() (err error) {
 			return
 		}
 	}
-	return
 }

+ 1 - 1
infile.go

@@ -74,7 +74,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
 		}
 	} else { // File
 		name = strings.Trim(name, `"`)
-		if fileRegister[name] || mc.cfg.params[`allowAllFiles`] == `true` {
+		if mc.cfg.allowAllFiles || fileRegister[name] {
 			rdr, err = os.Open(name)
 		} else {
 			err = fmt.Errorf("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files", name)

+ 3 - 6
packets.go

@@ -215,7 +215,7 @@ func (mc *mysqlConn) writeAuthPacket() error {
 		clientLocalFiles |
 		mc.flags&clientLongFlag
 
-	if _, ok := mc.cfg.params["clientFoundRows"]; ok {
+	if mc.cfg.clientFoundRows {
 		clientFlags |= clientFoundRows
 	}
 
@@ -231,9 +231,9 @@ func (mc *mysqlConn) writeAuthPacket() error {
 	pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff)
 
 	// To specify a db name
-	if len(mc.cfg.dbname) > 0 {
+	if n := len(mc.cfg.dbname); n > 0 {
 		clientFlags |= clientConnectWithDB
-		pktLen += len(mc.cfg.dbname) + 1
+		pktLen += n + 1
 	}
 
 	// Calculate packet length and make buffer with that size
@@ -569,8 +569,6 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
 
 		i++
 	}
-
-	return
 }
 
 // Read Packets as Field Packets until EOF-Packet or an Error appears
@@ -636,7 +634,6 @@ func (mc *mysqlConn) readUntilEOF() (err error) {
 		}
 		return // Err or EOF
 	}
-	return
 }
 
 /******************************************************************************

+ 8 - 0
utils.go

@@ -124,6 +124,14 @@ func parseDSN(dsn string) (cfg *config, err error) {
 				// cfg params
 				switch value := param[1]; param[0] {
 
+				// Disable INFILE whitelist / enable all files
+				case "allowAllFiles":
+					cfg.allowAllFiles = readBool(value)
+
+				// Switch "rowsAffected" mode
+				case "clientFoundRows":
+					cfg.clientFoundRows = readBool(value)
+
 				// Time Location
 				case "loc":
 					cfg.loc, err = time.LoadLocation(value)

+ 9 - 9
utils_test.go

@@ -21,15 +21,15 @@ func TestDSNParser(t *testing.T) {
 		out string
 		loc *time.Location
 	}{
-		{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil>}", time.UTC},
-		{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil>}", time.UTC},
-		{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil>}", time.UTC},
-		{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil>}", time.UTC},
-		{"user:password@/dbname?loc=UTC&timeout=30s", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil>}", time.UTC},
-		{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil>}", time.Local},
-		{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil>}", time.UTC},
-		{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil>}", time.UTC},
-		{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil>}", time.UTC},
+		{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
+		{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
+		{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
+		{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
+		{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil> allowAllFiles:true clientFoundRows:true}", time.UTC},
+		{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.Local},
+		{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
+		{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
+		{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
 	}
 
 	var cfg *config