Ver Fonte

Add test for []byte(nil)

Also reordered the tests so that the general tests are on top
Julien Schmidt há 12 anos atrás
pai
commit
b95ecaab58
1 ficheiros alterados com 207 adições e 191 exclusões
  1. 207 191
      driver_test.go

+ 207 - 191
driver_test.go

@@ -108,143 +108,6 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows)
 	return rows
 }
 
-func TestReuseClosedConnection(t *testing.T) {
-	// this test does not use sql.database, it uses the driver directly
-	if !available {
-		t.Skipf("MySQL-Server not running on %s", netAddr)
-	}
-	driver := &MySQLDriver{}
-	conn, err := driver.Open(dsn)
-	if err != nil {
-		t.Fatalf("Error connecting: %s", err.Error())
-	}
-	stmt, err := conn.Prepare("DO 1")
-	if err != nil {
-		t.Fatalf("Error preparing statement: %s", err.Error())
-	}
-	_, err = stmt.Exec(nil)
-	if err != nil {
-		t.Fatalf("Error executing statement: %s", err.Error())
-	}
-	err = conn.Close()
-	if err != nil {
-		t.Fatalf("Error closing connection: %s", err.Error())
-	}
-	defer func() {
-		if err := recover(); err != nil {
-			t.Errorf("Panic after reusing a closed connection: %v", err)
-		}
-	}()
-	_, err = stmt.Exec(nil)
-	if err != nil && err != errInvalidConn {
-		t.Errorf("Unexpected error '%s', expected '%s'",
-			err.Error(), errInvalidConn.Error())
-	}
-}
-
-func TestCharset(t *testing.T) {
-	if !available {
-		t.Skipf("MySQL-Server not running on %s", netAddr)
-	}
-
-	mustSetCharset := func(charsetParam, expected string) {
-		runTests(t, dsn+"&"+charsetParam, func(dbt *DBTest) {
-			rows := dbt.mustQuery("SELECT @@character_set_connection")
-			defer rows.Close()
-
-			if !rows.Next() {
-				dbt.Fatalf("Error getting connection charset: %s", rows.Err())
-			}
-
-			var got string
-			rows.Scan(&got)
-
-			if got != expected {
-				dbt.Fatalf("Expected connection charset %s but got %s", expected, got)
-			}
-		})
-	}
-
-	// non utf8 test
-	mustSetCharset("charset=ascii", "ascii")
-
-	// when the first charset is invalid, use the second
-	mustSetCharset("charset=none,utf8", "utf8")
-
-	// when the first charset is valid, use it
-	mustSetCharset("charset=ascii,utf8", "ascii")
-	mustSetCharset("charset=utf8,ascii", "utf8")
-}
-
-func TestFailingCharset(t *testing.T) {
-	runTests(t, dsn+"&charset=none", func(dbt *DBTest) {
-		// run query to really establish connection...
-		_, err := dbt.db.Exec("SELECT 1")
-		if err == nil {
-			dbt.db.Close()
-			t.Fatalf("Connection must not succeed without a valid charset")
-		}
-	})
-}
-
-func TestRawBytesResultExceedsBuffer(t *testing.T) {
-	runTests(t, dsn, func(dbt *DBTest) {
-		// defaultBufSize from buffer.go
-		expected := strings.Repeat("abc", defaultBufSize)
-		rows := dbt.mustQuery("SELECT '" + expected + "'")
-		defer rows.Close()
-		if !rows.Next() {
-			dbt.Error("expected result, got none")
-		}
-		var result sql.RawBytes
-		rows.Scan(&result)
-		if expected != string(result) {
-			dbt.Error("result did not match expected value")
-		}
-	})
-}
-
-func TestTimezoneConversion(t *testing.T) {
-
-	zones := []string{"UTC", "US/Central", "US/Pacific", "Local"}
-
-	// Regression test for timezone handling
-	tzTest := func(dbt *DBTest) {
-
-		// Create table
-		dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)")
-
-		// Insert local time into database (should be converted)
-		usCentral, _ := time.LoadLocation("US/Central")
-		now := time.Now().In(usCentral)
-		dbt.mustExec("INSERT INTO test VALUE (?)", now)
-
-		// Retrieve time from DB
-		rows := dbt.mustQuery("SELECT ts FROM test")
-		if !rows.Next() {
-			dbt.Fatal("Didn't get any rows out")
-		}
-
-		var nowDB time.Time
-		err := rows.Scan(&nowDB)
-		if err != nil {
-			dbt.Fatal("Err", err)
-		}
-
-		// Check that dates match
-		if now.Unix() != nowDB.Unix() {
-			dbt.Errorf("Times don't match.\n")
-			dbt.Errorf(" Now(%v)=%v\n", usCentral, now)
-			dbt.Errorf(" Now(UTC)=%v\n", nowDB)
-		}
-
-	}
-
-	for _, tz := range zones {
-		runTests(t, dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest)
-	}
-}
-
 func TestCRUD(t *testing.T) {
 	runTests(t, dsn, func(dbt *DBTest) {
 		// Create Table
@@ -548,44 +411,6 @@ func TestDateTime(t *testing.T) {
 	}
 }
 
-// This tests for https://github.com/go-sql-driver/mysql/pull/139
-//
-// An extra (invisible) nil byte was being added to the beginning of positive
-// time strings.
-func TestTimeSign(t *testing.T) {
-	runTests(t, dsn, func(dbt *DBTest) {
-		var sTimes = []struct {
-			value     string
-			fieldType string
-		}{
-			{"12:34:56", "TIME"},
-			{"-12:34:56", "TIME"},
-			// As described in http://dev.mysql.com/doc/refman/5.6/en/fractional-seconds.html
-			// they *should* work, but only in 5.6+.
-			// { "12:34:56.789", "TIME(3)" },
-			// { "-12:34:56.789", "TIME(3)" },
-		}
-
-		for _, sTime := range sTimes {
-			dbt.db.Exec("DROP TABLE IF EXISTS test")
-			dbt.mustExec("CREATE TABLE test (id INT, time_field " + sTime.fieldType + ")")
-			dbt.mustExec("INSERT INTO test (id, time_field) VALUES(1, '" + sTime.value + "')")
-			rows := dbt.mustQuery("SELECT time_field FROM test WHERE id = ?", 1)
-			if rows.Next() {
-				var oTime string
-				rows.Scan(&oTime)
-				if oTime != sTime.value {
-					dbt.Errorf(`time values differ: got %q, expected %q.`, oTime, sTime.value)
-				}
-			} else {
-				dbt.Error("expecting at least one row.")
-			}
-		}
-
-	})
-
-}
-
 func TestNULL(t *testing.T) {
 	runTests(t, dsn, func(dbt *DBTest) {
 		nullStmt, err := dbt.db.Prepare("SELECT NULL")
@@ -603,16 +428,14 @@ func TestNULL(t *testing.T) {
 		// NullBool
 		var nb sql.NullBool
 		// Invalid
-		err = nullStmt.QueryRow().Scan(&nb)
-		if err != nil {
+		if err = nullStmt.QueryRow().Scan(&nb); err != nil {
 			dbt.Fatal(err)
 		}
 		if nb.Valid {
 			dbt.Error("Valid NullBool which should be invalid")
 		}
 		// Valid
-		err = nonNullStmt.QueryRow().Scan(&nb)
-		if err != nil {
+		if err = nonNullStmt.QueryRow().Scan(&nb); err != nil {
 			dbt.Fatal(err)
 		}
 		if !nb.Valid {
@@ -624,16 +447,14 @@ func TestNULL(t *testing.T) {
 		// NullFloat64
 		var nf sql.NullFloat64
 		// Invalid
-		err = nullStmt.QueryRow().Scan(&nf)
-		if err != nil {
+		if err = nullStmt.QueryRow().Scan(&nf); err != nil {
 			dbt.Fatal(err)
 		}
 		if nf.Valid {
 			dbt.Error("Valid NullFloat64 which should be invalid")
 		}
 		// Valid
-		err = nonNullStmt.QueryRow().Scan(&nf)
-		if err != nil {
+		if err = nonNullStmt.QueryRow().Scan(&nf); err != nil {
 			dbt.Fatal(err)
 		}
 		if !nf.Valid {
@@ -645,16 +466,14 @@ func TestNULL(t *testing.T) {
 		// NullInt64
 		var ni sql.NullInt64
 		// Invalid
-		err = nullStmt.QueryRow().Scan(&ni)
-		if err != nil {
+		if err = nullStmt.QueryRow().Scan(&ni); err != nil {
 			dbt.Fatal(err)
 		}
 		if ni.Valid {
 			dbt.Error("Valid NullInt64 which should be invalid")
 		}
 		// Valid
-		err = nonNullStmt.QueryRow().Scan(&ni)
-		if err != nil {
+		if err = nonNullStmt.QueryRow().Scan(&ni); err != nil {
 			dbt.Fatal(err)
 		}
 		if !ni.Valid {
@@ -666,16 +485,14 @@ func TestNULL(t *testing.T) {
 		// NullString
 		var ns sql.NullString
 		// Invalid
-		err = nullStmt.QueryRow().Scan(&ns)
-		if err != nil {
+		if err = nullStmt.QueryRow().Scan(&ns); err != nil {
 			dbt.Fatal(err)
 		}
 		if ns.Valid {
 			dbt.Error("Valid NullString which should be invalid")
 		}
 		// Valid
-		err = nonNullStmt.QueryRow().Scan(&ns)
-		if err != nil {
+		if err = nonNullStmt.QueryRow().Scan(&ns); err != nil {
 			dbt.Fatal(err)
 		}
 		if !ns.Valid {
@@ -684,6 +501,31 @@ func TestNULL(t *testing.T) {
 			dbt.Error("Unexpected NullString value:" + ns.String + " (should be `1`)")
 		}
 
+		// nil-bytes
+		var b []byte
+		// Read nil
+		if err = nullStmt.QueryRow().Scan(&b); err != nil {
+			dbt.Fatal(err)
+		}
+		if b != nil {
+			dbt.Error("Non-nil []byte wich should be nil")
+		}
+		// Read non-nil
+		if err = nonNullStmt.QueryRow().Scan(&b); err != nil {
+			dbt.Fatal(err)
+		}
+		if b == nil {
+			dbt.Error("Nil []byte wich should be non-nil")
+		}
+		// Check input==output (==nil)
+		b = nil
+		if err = dbt.db.QueryRow("SELECT ?", nil).Scan(&b); err != nil {
+			dbt.Fatal(err)
+		}
+		if b != nil {
+			dbt.Error("Non-nil echo from nil input")
+		}
+
 		// Insert NULL
 		dbt.mustExec("CREATE TABLE test (dummmy1 int, value int, dummy2 int)")
 
@@ -995,6 +837,180 @@ func TestTLS(t *testing.T) {
 	runTests(t, dsn+"&tls=custom-skip-verify", tlsTest)
 }
 
+func TestReuseClosedConnection(t *testing.T) {
+	// this test does not use sql.database, it uses the driver directly
+	if !available {
+		t.Skipf("MySQL-Server not running on %s", netAddr)
+	}
+
+	driver := &MySQLDriver{}
+	conn, err := driver.Open(dsn)
+	if err != nil {
+		t.Fatalf("Error connecting: %s", err.Error())
+	}
+	stmt, err := conn.Prepare("DO 1")
+	if err != nil {
+		t.Fatalf("Error preparing statement: %s", err.Error())
+	}
+	_, err = stmt.Exec(nil)
+	if err != nil {
+		t.Fatalf("Error executing statement: %s", err.Error())
+	}
+	err = conn.Close()
+	if err != nil {
+		t.Fatalf("Error closing connection: %s", err.Error())
+	}
+
+	defer func() {
+		if err := recover(); err != nil {
+			t.Errorf("Panic after reusing a closed connection: %v", err)
+		}
+	}()
+	_, err = stmt.Exec(nil)
+	if err != nil && err != errInvalidConn {
+		t.Errorf("Unexpected error '%s', expected '%s'",
+			err.Error(), errInvalidConn.Error())
+	}
+}
+
+func TestCharset(t *testing.T) {
+	if !available {
+		t.Skipf("MySQL-Server not running on %s", netAddr)
+	}
+
+	mustSetCharset := func(charsetParam, expected string) {
+		runTests(t, dsn+"&"+charsetParam, func(dbt *DBTest) {
+			rows := dbt.mustQuery("SELECT @@character_set_connection")
+			defer rows.Close()
+
+			if !rows.Next() {
+				dbt.Fatalf("Error getting connection charset: %s", rows.Err())
+			}
+
+			var got string
+			rows.Scan(&got)
+
+			if got != expected {
+				dbt.Fatalf("Expected connection charset %s but got %s", expected, got)
+			}
+		})
+	}
+
+	// non utf8 test
+	mustSetCharset("charset=ascii", "ascii")
+
+	// when the first charset is invalid, use the second
+	mustSetCharset("charset=none,utf8", "utf8")
+
+	// when the first charset is valid, use it
+	mustSetCharset("charset=ascii,utf8", "ascii")
+	mustSetCharset("charset=utf8,ascii", "utf8")
+}
+
+func TestFailingCharset(t *testing.T) {
+	runTests(t, dsn+"&charset=none", func(dbt *DBTest) {
+		// run query to really establish connection...
+		_, err := dbt.db.Exec("SELECT 1")
+		if err == nil {
+			dbt.db.Close()
+			t.Fatalf("Connection must not succeed without a valid charset")
+		}
+	})
+}
+
+func TestRawBytesResultExceedsBuffer(t *testing.T) {
+	runTests(t, dsn, func(dbt *DBTest) {
+		// defaultBufSize from buffer.go
+		expected := strings.Repeat("abc", defaultBufSize)
+
+		rows := dbt.mustQuery("SELECT '" + expected + "'")
+		defer rows.Close()
+		if !rows.Next() {
+			dbt.Error("expected result, got none")
+		}
+		var result sql.RawBytes
+		rows.Scan(&result)
+		if expected != string(result) {
+			dbt.Error("result did not match expected value")
+		}
+	})
+}
+
+func TestTimezoneConversion(t *testing.T) {
+	zones := []string{"UTC", "US/Central", "US/Pacific", "Local"}
+
+	// Regression test for timezone handling
+	tzTest := func(dbt *DBTest) {
+
+		// Create table
+		dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)")
+
+		// Insert local time into database (should be converted)
+		usCentral, _ := time.LoadLocation("US/Central")
+		now := time.Now().In(usCentral)
+		dbt.mustExec("INSERT INTO test VALUE (?)", now)
+
+		// Retrieve time from DB
+		rows := dbt.mustQuery("SELECT ts FROM test")
+		if !rows.Next() {
+			dbt.Fatal("Didn't get any rows out")
+		}
+
+		var nowDB time.Time
+		err := rows.Scan(&nowDB)
+		if err != nil {
+			dbt.Fatal("Err", err)
+		}
+
+		// Check that dates match
+		if now.Unix() != nowDB.Unix() {
+			dbt.Errorf("Times don't match.\n")
+			dbt.Errorf(" Now(%v)=%v\n", usCentral, now)
+			dbt.Errorf(" Now(UTC)=%v\n", nowDB)
+		}
+	}
+
+	for _, tz := range zones {
+		runTests(t, dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest)
+	}
+}
+
+// This tests for https://github.com/go-sql-driver/mysql/pull/139
+//
+// An extra (invisible) nil byte was being added to the beginning of positive
+// time strings.
+func TestTimeSign(t *testing.T) {
+	runTests(t, dsn, func(dbt *DBTest) {
+		var sTimes = []struct {
+			value     string
+			fieldType string
+		}{
+			{"12:34:56", "TIME"},
+			{"-12:34:56", "TIME"},
+			// As described in http://dev.mysql.com/doc/refman/5.6/en/fractional-seconds.html
+			// they *should* work, but only in 5.6+.
+			// { "12:34:56.789", "TIME(3)" },
+			// { "-12:34:56.789", "TIME(3)" },
+		}
+
+		for _, sTime := range sTimes {
+			dbt.db.Exec("DROP TABLE IF EXISTS test")
+			dbt.mustExec("CREATE TABLE test (id INT, time_field " + sTime.fieldType + ")")
+			dbt.mustExec("INSERT INTO test (id, time_field) VALUES(1, '" + sTime.value + "')")
+			rows := dbt.mustQuery("SELECT time_field FROM test WHERE id = ?", 1)
+			if rows.Next() {
+				var oTime string
+				rows.Scan(&oTime)
+				if oTime != sTime.value {
+					dbt.Errorf(`time values differ: got %q, expected %q.`, oTime, sTime.value)
+				}
+			} else {
+				dbt.Error("expecting at least one row.")
+			}
+		}
+	})
+}
+
 // Special cases
 
 func TestRowsClose(t *testing.T) {