Ver Fonte

refactor driver tests

- Go 1.1 API
- moved TestFoundRows
- Refactored charset tests
- err formatting
- ...
Julien Schmidt há 12 anos atrás
pai
commit
e44f1b6291
1 ficheiros alterados com 108 adições e 136 exclusões
  1. 108 136
      driver_test.go

+ 108 - 136
driver_test.go

@@ -13,7 +13,6 @@ import (
 )
 
 var (
-	charset   string
 	dsn       string
 	netAddr   string
 	available bool
@@ -42,7 +41,6 @@ func init() {
 	prot := env("MYSQL_TEST_PROT", "tcp")
 	addr := env("MYSQL_TEST_ADDR", "localhost:3306")
 	dbname := env("MYSQL_TEST_DBNAME", "gotest")
-	charset = "charset=utf8"
 	netAddr = fmt.Sprintf("%s(%s)", prot, addr)
 	dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true", user, pass, netAddr, dbname)
 	c, err := net.Dial(prot, addr)
@@ -57,15 +55,14 @@ type DBTest struct {
 	db *sql.DB
 }
 
-func runTests(t *testing.T, name, dsn string, tests ...func(dbt *DBTest)) {
+func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
 	if !available {
-		t.Logf("MySQL-Server not running on %s. Skipping %s", netAddr, name)
-		return
+		t.Skipf("MySQL-Server not running on %s", netAddr)
 	}
 
 	db, err := sql.Open("mysql", dsn)
 	if err != nil {
-		t.Fatalf("Error connecting: %v", err)
+		t.Fatalf("Error connecting: %s", err.Error())
 	}
 	defer db.Close()
 
@@ -82,7 +79,7 @@ func (dbt *DBTest) fail(method, query string, err error) {
 	if len(query) > 300 {
 		query = "[query too large to print]"
 	}
-	dbt.Fatalf("Error on %s %s: %v", method, query, err)
+	dbt.Fatalf("Error on %s %s: %s", method, query, err.Error())
 }
 
 func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) {
@@ -102,32 +99,26 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows)
 }
 
 func TestCharset(t *testing.T) {
-	mustSetCharset := func(charsetParam, expected string) {
-		db, err := sql.Open("mysql", dsn+"&"+charsetParam)
-		if err != nil {
-			t.Fatalf("Error on Open: %v", err)
-		}
-		defer db.Close()
-
-		dbt := &DBTest{t, db}
-		rows := dbt.mustQuery("SELECT @@character_set_connection")
-		defer rows.Close()
+	if !available {
+		t.Skipf("MySQL-Server not running on %s", netAddr)
+	}
 
-		if !rows.Next() {
-			dbt.Fatalf("Error getting connection charset: %v", err)
-		}
+	mustSetCharset := func(charsetParam, expected string) {
+		runTests(t, dsn+"&"+charsetParam, func(dbt *DBTest) {
+			rows := dbt.mustQuery("SELECT @@character_set_connection")
+			defer rows.Close()
 
-		var got string
-		rows.Scan(&got)
+			if !rows.Next() {
+				dbt.Fatalf("Error getting connection charset: %s", rows.Err())
+			}
 
-		if got != expected {
-			dbt.Fatalf("Expected connection charset %s but got %s", expected, got)
-		}
-	}
+			var got string
+			rows.Scan(&got)
 
-	if !available {
-		t.Logf("MySQL-Server not running on %s. Skipping TestCharset", netAddr)
-		return
+			if got != expected {
+				dbt.Fatalf("Expected connection charset %s but got %s", expected, got)
+			}
+		})
 	}
 
 	// non utf8 test
@@ -142,26 +133,18 @@ func TestCharset(t *testing.T) {
 }
 
 func TestFailingCharset(t *testing.T) {
-	if !available {
-		t.Logf("MySQL-Server not running on %s. Skipping TestFailingCharset", netAddr)
-		return
-	}
-	db, err := sql.Open("mysql", dsn+"&charset=none")
-	if err != nil {
-		t.Fatalf("Error on Open: %v", err)
-	}
-	defer db.Close()
-
-	// run query to really establish connection...
-	_, err = db.Exec("SELECT 1")
-	if err == nil {
-		db.Close()
-		t.Fatalf("Connection must not succeed without a valid charset")
-	}
+	runTests(t, dsn+"&charset=none", func(dbt *DBTest) {
+		// run query to really establish connection...
+		_, err := dbt.db.Exec("SELECT 1")
+		if err == nil {
+			dbt.db.Close()
+			t.Fatalf("Connection must not succeed without a valid charset")
+		}
+	})
 }
 
 func TestRawBytesResultExceedsBuffer(t *testing.T) {
-	runTests(t, "TestRawBytesResultExceedsBuffer", dsn, func(dbt *DBTest) {
+	runTests(t, dsn, func(dbt *DBTest) {
 		// defaultBufSize from buffer.go
 		expected := strings.Repeat("abc", defaultBufSize)
 		rows := dbt.mustQuery("SELECT '" + expected + "'")
@@ -178,7 +161,7 @@ func TestRawBytesResultExceedsBuffer(t *testing.T) {
 }
 
 func TestCRUD(t *testing.T) {
-	runTests(t, "TestCRUD", dsn, func(dbt *DBTest) {
+	runTests(t, dsn, func(dbt *DBTest) {
 		// Create Table
 		dbt.mustExec("CREATE TABLE test (value BOOL)")
 
@@ -193,7 +176,7 @@ func TestCRUD(t *testing.T) {
 		res := dbt.mustExec("INSERT INTO test VALUES (1)")
 		count, err := res.RowsAffected()
 		if err != nil {
-			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
+			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
 		}
 		if count != 1 {
 			dbt.Fatalf("Expected 1 affected row, got %d", count)
@@ -201,7 +184,7 @@ func TestCRUD(t *testing.T) {
 
 		id, err := res.LastInsertId()
 		if err != nil {
-			dbt.Fatalf("res.LastInsertId() returned error: %v", err)
+			dbt.Fatalf("res.LastInsertId() returned error: %s", err.Error())
 		}
 		if id != 0 {
 			dbt.Fatalf("Expected InsertID 0, got %d", id)
@@ -226,7 +209,7 @@ func TestCRUD(t *testing.T) {
 		res = dbt.mustExec("UPDATE test SET value = ? WHERE value = ?", false, true)
 		count, err = res.RowsAffected()
 		if err != nil {
-			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
+			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
 		}
 		if count != 1 {
 			dbt.Fatalf("Expected 1 affected row, got %d", count)
@@ -251,7 +234,7 @@ func TestCRUD(t *testing.T) {
 		res = dbt.mustExec("DELETE FROM test WHERE value = ?", false)
 		count, err = res.RowsAffected()
 		if err != nil {
-			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
+			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
 		}
 		if count != 1 {
 			dbt.Fatalf("Expected 1 affected row, got %d", count)
@@ -261,7 +244,7 @@ func TestCRUD(t *testing.T) {
 		res = dbt.mustExec("DELETE FROM test")
 		count, err = res.RowsAffected()
 		if err != nil {
-			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
+			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
 		}
 		if count != 0 {
 			dbt.Fatalf("Expected 0 affected row, got %d", count)
@@ -270,7 +253,7 @@ func TestCRUD(t *testing.T) {
 }
 
 func TestInt(t *testing.T) {
-	runTests(t, "TestInt", dsn, func(dbt *DBTest) {
+	runTests(t, dsn, func(dbt *DBTest) {
 		types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"}
 		in := int64(42)
 		var out int64
@@ -317,7 +300,7 @@ func TestInt(t *testing.T) {
 }
 
 func TestFloat(t *testing.T) {
-	runTests(t, "TestFloat", dsn, func(dbt *DBTest) {
+	runTests(t, dsn, func(dbt *DBTest) {
 		types := [2]string{"FLOAT", "DOUBLE"}
 		in := float32(42.23)
 		var out float32
@@ -340,7 +323,7 @@ func TestFloat(t *testing.T) {
 }
 
 func TestString(t *testing.T) {
-	runTests(t, "TestString", dsn, func(dbt *DBTest) {
+	runTests(t, dsn, func(dbt *DBTest) {
 		types := [6]string{"CHAR(255)", "VARCHAR(255)", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT"}
 		in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах  น่าฟังเอย"
 		var out string
@@ -380,7 +363,7 @@ func TestString(t *testing.T) {
 
 		err := dbt.db.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&out)
 		if err != nil {
-			dbt.Fatalf("Error on BLOB-Query: %v", err)
+			dbt.Fatalf("Error on BLOB-Query: %s", err.Error())
 		} else if out != in {
 			dbt.Errorf("BLOB: %s != %s", in, out)
 		}
@@ -429,7 +412,7 @@ func TestDateTime(t *testing.T) {
 				dbt *DBTest, rows *sql.Rows, test *timetest, sqltype, resulttype, mode string) {
 				var sOut string
 				if err := rows.Scan(&sOut); err != nil {
-					dbt.Errorf("%s (%s %s): %v", sqltype, resulttype, mode, err)
+					dbt.Errorf("%s (%s %s): %s", sqltype, resulttype, mode, err.Error())
 				} else if test.sOut != sOut {
 					dbt.Errorf("%s (%s %s): %s != %s", sqltype, resulttype, mode, test.sOut, sOut)
 				}
@@ -438,7 +421,7 @@ func TestDateTime(t *testing.T) {
 				dbt *DBTest, rows *sql.Rows, test *timetest, sqltype, resulttype, mode string) {
 				var tOut time.Time
 				if err := rows.Scan(&tOut); err != nil {
-					dbt.Errorf("%s (%s %s): %v", sqltype, resulttype, mode, err)
+					dbt.Errorf("%s (%s %s): %s", sqltype, resulttype, mode, err.Error())
 				} else if test.tOut != tOut || test.tIsZero != tOut.IsZero() {
 					dbt.Errorf("%s (%s %s): %s [%t] != %s [%t]", sqltype, resulttype, mode, test.tOut, test.tIsZero, tOut, tOut.IsZero())
 				}
@@ -460,8 +443,8 @@ func TestDateTime(t *testing.T) {
 						s.test(dbt, rows, test, sqltype, s.vartype, mode)
 					} else {
 						if err := rows.Err(); err != nil {
-							dbt.Errorf("%s (%s %s): %v",
-								sqltype, s.vartype, mode, err)
+							dbt.Errorf("%s (%s %s): %s",
+								sqltype, s.vartype, mode, err.Error())
 						} else {
 							dbt.Errorf("%s (%s %s): no data",
 								sqltype, s.vartype, mode)
@@ -476,12 +459,12 @@ func TestDateTime(t *testing.T) {
 	timeDsn := dsn + "&sql_mode=ALLOW_INVALID_DATES"
 	for _, v := range setups {
 		s = v
-		runTests(t, "TestDateTime", timeDsn+s.dsnSuffix, testTime)
+		runTests(t, timeDsn+s.dsnSuffix, testTime)
 	}
 }
 
 func TestNULL(t *testing.T) {
-	runTests(t, "TestNULL", dsn, func(dbt *DBTest) {
+	runTests(t, dsn, func(dbt *DBTest) {
 		nullStmt, err := dbt.db.Prepare("SELECT NULL")
 		if err != nil {
 			dbt.Fatal(err)
@@ -597,7 +580,7 @@ func TestNULL(t *testing.T) {
 }
 
 func TestLongData(t *testing.T) {
-	runTests(t, "TestLongData", dsn, func(dbt *DBTest) {
+	runTests(t, dsn, func(dbt *DBTest) {
 		var maxAllowedPacketSize int
 		err := dbt.db.QueryRow("select @@max_allowed_packet").Scan(&maxAllowedPacketSize)
 		if err != nil {
@@ -658,7 +641,7 @@ func TestLongData(t *testing.T) {
 }
 
 func TestLoadData(t *testing.T) {
-	runTests(t, "TestLoadData", dsn, func(dbt *DBTest) {
+	runTests(t, dsn, func(dbt *DBTest) {
 		verifyLoadDataResult := func() {
 			rows, err := dbt.db.Query("SELECT * FROM test")
 			if err != nil {
@@ -744,10 +727,55 @@ func TestLoadData(t *testing.T) {
 	})
 }
 
+func TestFoundRows(t *testing.T) {
+	runTests(t, dsn, func(dbt *DBTest) {
+		dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
+		dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
+
+		res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
+		count, err := res.RowsAffected()
+		if err != nil {
+			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
+		}
+		if count != 2 {
+			dbt.Fatalf("Expected 2 affected rows, got %d", count)
+		}
+		res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
+		count, err = res.RowsAffected()
+		if err != nil {
+			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
+		}
+		if count != 2 {
+			dbt.Fatalf("Expected 2 affected rows, got %d", count)
+		}
+	})
+	runTests(t, dsn+"&clientFoundRows=true", func(dbt *DBTest) {
+		dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
+		dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
+
+		res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
+		count, err := res.RowsAffected()
+		if err != nil {
+			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
+		}
+		if count != 2 {
+			dbt.Fatalf("Expected 2 matched rows, got %d", count)
+		}
+		res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
+		count, err = res.RowsAffected()
+		if err != nil {
+			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
+		}
+		if count != 3 {
+			dbt.Fatalf("Expected 3 matched rows, got %d", count)
+		}
+	})
+}
+
 func TestStrict(t *testing.T) {
 	// ALLOW_INVALID_DATES to get rid of stricter modes - we want to test for warnings, not errors
 	relaxedDsn := dsn + "&sql_mode=ALLOW_INVALID_DATES"
-	runTests(t, "TestStrict", relaxedDsn, func(dbt *DBTest) {
+	runTests(t, relaxedDsn, func(dbt *DBTest) {
 		dbt.mustExec("CREATE TABLE test (a TINYINT NOT NULL, b CHAR(4))")
 
 		var queries = [...]struct {
@@ -812,21 +840,10 @@ func TestStrict(t *testing.T) {
 }
 
 func TestTLS(t *testing.T) {
-	runTests(t, "TestTLS", dsn+"&tls=skip-verify", func(dbt *DBTest) {
-		/* TODO: GO 1.1 API */
-		/*if err := dbt.db.Ping(); err != nil {
-		    if err == errNoTLS {
-		        dbt.Skip("Server does not support TLS. Skipping TestTLS")
-		    } else {
-		        dbt.Fatalf("Error on Ping: %s", err.Error())
-		    }
-		}*/
-
-		/* GO 1.0 API */
-		if _, err := dbt.db.Exec("DO 1"); err != nil {
+	runTests(t, dsn+"&tls=skip-verify", func(dbt *DBTest) {
+		if err := dbt.db.Ping(); err != nil {
 			if err == errNoTLS {
-				dbt.Log("Server does not support TLS. Skipping TestTLS")
-				return
+				dbt.Skip("Server does not support TLS")
 			} else {
 				dbt.Fatalf("Error on Ping: %s", err.Error())
 			}
@@ -850,7 +867,7 @@ func TestTLS(t *testing.T) {
 // Special cases
 
 func TestRowsClose(t *testing.T) {
-	runTests(t, "TestRowsClose", dsn, func(dbt *DBTest) {
+	runTests(t, dsn, func(dbt *DBTest) {
 		rows, err := dbt.db.Query("SELECT 1")
 		if err != nil {
 			dbt.Fatal(err)
@@ -875,7 +892,7 @@ func TestRowsClose(t *testing.T) {
 // dangling statements
 // http://code.google.com/p/go/issues/detail?id=3865
 func TestCloseStmtBeforeRows(t *testing.T) {
-	runTests(t, "TestCloseStmtBeforeRows", dsn, func(dbt *DBTest) {
+	runTests(t, dsn, func(dbt *DBTest) {
 		stmt, err := dbt.db.Prepare("SELECT 1")
 		if err != nil {
 			dbt.Fatal(err)
@@ -904,7 +921,7 @@ func TestCloseStmtBeforeRows(t *testing.T) {
 			var out bool
 			err = rows.Scan(&out)
 			if err != nil {
-				dbt.Fatalf("Error on rows.Scan(): %v", err)
+				dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
 			}
 			if out != true {
 				dbt.Errorf("true != %t", out)
@@ -916,7 +933,7 @@ func TestCloseStmtBeforeRows(t *testing.T) {
 // It is valid to have multiple Rows for the same Stmt
 // http://code.google.com/p/go/issues/detail?id=3734
 func TestStmtMultiRows(t *testing.T) {
-	runTests(t, "TestStmtMultiRows", dsn, func(dbt *DBTest) {
+	runTests(t, dsn, func(dbt *DBTest) {
 		stmt, err := dbt.db.Prepare("SELECT 1 UNION SELECT 0")
 		if err != nil {
 			dbt.Fatal(err)
@@ -949,7 +966,7 @@ func TestStmtMultiRows(t *testing.T) {
 
 			err = rows1.Scan(&out)
 			if err != nil {
-				dbt.Fatalf("Error on rows.Scan(): %v", err)
+				dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
 			}
 			if out != true {
 				dbt.Errorf("true != %t", out)
@@ -966,7 +983,7 @@ func TestStmtMultiRows(t *testing.T) {
 
 			err = rows2.Scan(&out)
 			if err != nil {
-				dbt.Fatalf("Error on rows.Scan(): %v", err)
+				dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
 			}
 			if out != true {
 				dbt.Errorf("true != %t", out)
@@ -984,7 +1001,7 @@ func TestStmtMultiRows(t *testing.T) {
 
 			err = rows1.Scan(&out)
 			if err != nil {
-				dbt.Fatalf("Error on rows.Scan(): %v", err)
+				dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
 			}
 			if out != false {
 				dbt.Errorf("false != %t", out)
@@ -1009,7 +1026,7 @@ func TestStmtMultiRows(t *testing.T) {
 
 			err = rows2.Scan(&out)
 			if err != nil {
-				dbt.Fatalf("Error on rows.Scan(): %v", err)
+				dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
 			}
 			if out != false {
 				dbt.Errorf("false != %t", out)
@@ -1028,14 +1045,14 @@ func TestStmtMultiRows(t *testing.T) {
 
 func TestConcurrent(t *testing.T) {
 	if readBool(os.Getenv("MYSQL_TEST_CONCURRENT")) != true {
-		t.Log("CONCURRENT env var not set. Skipping TestConcurrent")
-		return
+		t.Skip("CONCURRENT env var not set")
 	}
-	runTests(t, "TestConcurrent", dsn, func(dbt *DBTest) {
+
+	runTests(t, dsn, func(dbt *DBTest) {
 		var max int
 		err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max)
 		if err != nil {
-			dbt.Fatalf("%v", err)
+			dbt.Fatalf("%s", err.Error())
 		}
 		dbt.Logf("Testing up to %d concurrent connections \r\n", max)
 		canStop := false
@@ -1076,51 +1093,6 @@ func TestConcurrent(t *testing.T) {
 	})
 }
 
-func TestFoundRows(t *testing.T) {
-	runTests(t, "TestFoundRows1", dsn, func(dbt *DBTest) {
-		dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
-		dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
-
-		res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
-		count, err := res.RowsAffected()
-		if err != nil {
-			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
-		}
-		if count != 2 {
-			dbt.Fatalf("Expected 2 affected rows, got %d", count)
-		}
-		res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
-		count, err = res.RowsAffected()
-		if err != nil {
-			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
-		}
-		if count != 2 {
-			dbt.Fatalf("Expected 2 affected rows, got %d", count)
-		}
-	})
-	runTests(t, "TestFoundRows2", dsn+"&clientFoundRows=true", func(dbt *DBTest) {
-		dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
-		dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
-
-		res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
-		count, err := res.RowsAffected()
-		if err != nil {
-			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
-		}
-		if count != 2 {
-			dbt.Fatalf("Expected 2 matched rows, got %d", count)
-		}
-		res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
-		count, err = res.RowsAffected()
-		if err != nil {
-			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
-		}
-		if count != 3 {
-			dbt.Fatalf("Expected 3 matched rows, got %d", count)
-		}
-	})
-}
-
 // BENCHMARKS
 var sample []byte