ソースを参照

PR done, all tests succeed and problematic ones are auto-skipped

Arne Hormann 11 年 前
コミット
f3e6a605b4
1 ファイル変更54 行追加16 行削除
  1. 54 16
      driver_test.go

+ 54 - 16
driver_test.go

@@ -338,17 +338,28 @@ type timeTest struct {
 	t time.Time
 }
 
+func (t timeTest) genQuery(dbtype string, binaryProtocol bool) string {
+	var inner string
+	if binaryProtocol {
+		inner = "?"
+	} else {
+		inner = `"%s"`
+	}
+	if len(dbtype) >= 9 && dbtype[:9] == "TIMESTAMP" {
+		return `SELECT TIMESTAMPADD(SECOND,0,CAST(` + inner + ` AS DATETIME` + dbtype[9:] + `))`
+	}
+	return `SELECT CAST(` + inner + ` AS ` + dbtype + `)`
+}
+
 func (t timeTest) run(dbt *DBTest, dbtype, tlayout string, binaryProtocol bool) {
-	const queryBin = `SELECT CAST(? AS %[2]s)`
-	const queryTxt = `SELECT CAST("%[1]s" AS %[2]s)`
 	var rows *sql.Rows
 	var protocol string
-	if binaryProtocol {
+	if query := t.genQuery(dbtype, binaryProtocol); binaryProtocol {
 		protocol = "binary"
-		rows = dbt.mustQuery(fmt.Sprintf(queryBin, t.s, dbtype), t.t)
+		rows = dbt.mustQuery(query, t.t)
 	} else {
 		protocol = "text"
-		rows = dbt.mustQuery(fmt.Sprintf(queryTxt, t.s, dbtype))
+		rows = dbt.mustQuery(fmt.Sprintf(query, t.s))
 	}
 	defer rows.Close()
 	var err error
@@ -396,16 +407,17 @@ func (t timeTest) run(dbt *DBTest, dbtype, tlayout string, binaryProtocol bool)
 }
 
 func TestDateTime(t *testing.T) {
-	afterTime0 := func(d string) time.Time {
+	afterTime := func(t time.Time, d string) time.Time {
 		dur, err := time.ParseDuration(d)
 		if err != nil {
 			panic(err)
 		}
-		return time.Time{}.Add(dur)
+		return t.Add(dur)
 	}
 	// NOTE: MySQL rounds DATETIME(x) up - but that's not included in the tests
 	format := "2006-01-02 15:04:05.999999"
 	t0 := time.Time{}
+	ts0 := time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)
 	tstr0 := "0000-00-00 00:00:00.000000"
 	testcases := []timeTests{
 		{"DATE", format[:10], []timeTest{
@@ -425,20 +437,32 @@ func TestDateTime(t *testing.T) {
 			{t: t0, s: tstr0},
 		}},
 		{"TIME", format[11:19], []timeTest{
-			{t: afterTime0("12345s")},
-			{t: afterTime0("-12345s")},
+			{t: afterTime(t0, "12345s")},
+			{t: afterTime(t0, "-12345s")},
 			{t: t0, s: tstr0[11:19]},
 		}},
 		{"TIME(1)", format[11:21], []timeTest{
-			{t: afterTime0("12345600ms")},
-			{t: afterTime0("-12345600ms")},
+			{t: afterTime(t0, "12345600ms")},
+			{t: afterTime(t0, "-12345600ms")},
 			{t: t0, s: tstr0[11:21]},
 		}},
 		{"TIME(6)", format[11:], []timeTest{
-			{t: afterTime0("1234567890123000ns")},
-			{t: afterTime0("-1234567890123000ns")},
+			{t: afterTime(t0, "1234567890123000ns")},
+			{t: afterTime(t0, "-1234567890123000ns")},
 			{t: t0, s: tstr0[11:]},
 		}},
+		{"TIMESTAMP", format[:19], []timeTest{
+			{t: afterTime(ts0, "12345s")},
+			{t: ts0, s: "1970-01-01 00:00:00"},
+		}},
+		{"TIMESTAMP(1)", format[:21], []timeTest{
+			{t: afterTime(ts0, "12345600ms")},
+			{t: ts0, s: "1970-01-01 00:00:00.0"},
+		}},
+		{"TIMESTAMP(6)", format, []timeTest{
+			{t: afterTime(ts0, "1234567890123000ns")},
+			{t: ts0, s: "1970-01-01 00:00:00.000000"},
+		}},
 	}
 	dsns := map[string]bool{
 		dsn + "&parseTime=true":                               true,
@@ -446,13 +470,28 @@ func TestDateTime(t *testing.T) {
 		dsn + "&parseTime=false":                              false,
 		dsn + "&sql_mode=ALLOW_INVALID_DATES&parseTime=false": false,
 	}
+	var withFrac bool
+	if db, err := sql.Open("mysql", dsn); err != nil {
+		t.Fatal(err)
+	} else {
+		rows, err := db.Query(`SELECT CAST("00:00:00.123" AS TIME(3)) = "00:00:00.123"`)
+		if err == nil {
+			withFrac = true
+			rows.Close()
+		}
+		db.Close()
+	}
 	for testdsn, parseTime := range dsns {
 		var _ = parseTime
 		runTests(t, testdsn, func(dbt *DBTest) {
 			for _, setups := range testcases {
+				if t := setups.dbtype; !withFrac && t[len(t)-1:] == ")" {
+					// skip fractional tests if unsupported by DB
+					continue
+				}
 				for _, setup := range setups.tests {
 					if setup.s == "" {
-						// fill time string where Go can reliable produce it
+						// fill time string whereever Go can reliable produce it
 						setup.s = setup.t.Format(setups.tlayout)
 					}
 					setup.run(dbt, setups.dbtype, setups.tlayout, true)
@@ -1053,9 +1092,8 @@ func TestTimezoneConversion(t *testing.T) {
 		dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)")
 
 		// Insert local time into database (should be converted)
-		utc, _ := time.LoadLocation("UTC")
 		usCentral, _ := time.LoadLocation("US/Central")
-		reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, utc).In(usCentral)
+		reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(usCentral)
 		dbt.mustExec("INSERT INTO test VALUE (?)", reftime)
 
 		// Retrieve time from DB