Browse Source

support microseconds with MySQL 5.7+

passes all tests except TIME(1) -> string in binary protocol.
TIMESTAMP support with microsecond resolution is still incomplete.
Arne Hormann 11 years ago
parent
commit
1009a046eb
6 changed files with 194 additions and 173 deletions
  1. 0 1
      const.go
  2. 134 126
      driver_test.go
  3. 27 12
      packets.go
  4. 4 2
      rows.go
  5. 20 23
      utils.go
  6. 9 9
      utils_test.go

+ 0 - 1
const.go

@@ -11,7 +11,6 @@ package mysql
 const (
 	minProtocolVersion byte = 10
 	maxPacketSize           = 1<<24 - 1
-	timeFormat              = "2006-01-02 15:04:05"
 )
 
 // MySQL constants documentation:

+ 134 - 126
driver_test.go

@@ -327,96 +327,139 @@ func TestString(t *testing.T) {
 	})
 }
 
-func TestDateTime(t *testing.T) {
-	type testmode struct {
-		selectSuffix string
-		args         []interface{}
+type timeTests struct {
+	dbtype  string
+	tlayout string
+	tests   []timeTest
+}
+
+type timeTest struct {
+	s string
+	t time.Time
+}
+
+func (t timeTest) run(dbt *DBTest, dbtype, tlayout string, binaryProtocol bool) {
+	const queryBin = `SELECT CAST(? AS %[2]s)`
+	const queryTxt = `SELECT CAST("%[1]s" AS %[2]s)`
+	var rows *sql.Rows
+	var protocol string
+	if binaryProtocol {
+		protocol = "binary"
+		rows = dbt.mustQuery(fmt.Sprintf(queryBin, t.s, dbtype), t.t)
+	} else {
+		protocol = "text"
+		rows = dbt.mustQuery(fmt.Sprintf(queryTxt, t.s, dbtype))
 	}
-	type timetest struct {
-		in      interface{}
-		sOut    string
-		tOut    time.Time
-		tIsZero bool
+	defer rows.Close()
+	var err error
+	if !rows.Next() {
+		err = rows.Err()
+		if err == nil {
+			err = fmt.Errorf("no data")
+		}
+		dbt.Errorf("%s [%s]: %s",
+			dbtype, protocol, err,
+		)
+		return
 	}
-	type tester func(dbt *DBTest, rows *sql.Rows,
-		test *timetest, sqltype, resulttype, mode string)
-	type setup struct {
-		vartype   string
-		dsnSuffix string
-		test      tester
+	var dst interface{}
+	err = rows.Scan(&dst)
+	if err != nil {
+		dbt.Errorf("%s [%s]: %s",
+			dbtype, protocol, err,
+		)
+		return
 	}
-	var (
-		modes = map[string]*testmode{
-			"text":   &testmode{},
-			"binary": &testmode{" WHERE 1 = ?", []interface{}{1}},
-		}
-		timetests = map[string][]*timetest{
-			"DATE": {
-				{sDate, sDate, tDate, false},
-				{sDate0, sDate0, tDate0, true},
-				{tDate, sDate, tDate, false},
-				{tDate0, sDate0, tDate0, true},
-			},
-			"DATETIME": {
-				{sDateTime, sDateTime, tDateTime, false},
-				{sDateTime0, sDateTime0, tDate0, true},
-				{tDateTime, sDateTime, tDateTime, false},
-				{tDate0, sDateTime0, tDate0, true},
-			},
-		}
-		setups = []*setup{
-			{"string", "&parseTime=false", func(
-				dbt *DBTest, rows *sql.Rows, test *timetest, sqltype, resulttype, mode string) {
-				var sOut string
-				if err := rows.Scan(&sOut); err != nil {
-					dbt.Errorf("%s (%s %s): %s", sqltype, resulttype, mode, err.Error())
-				} else if test.sOut != sOut {
-					dbt.Errorf("%s (%s %s): %s != %s", sqltype, resulttype, mode, test.sOut, sOut)
-				}
-			}},
-			{"time.Time", "&parseTime=true", func(
-				dbt *DBTest, rows *sql.Rows, test *timetest, sqltype, resulttype, mode string) {
-				var tOut time.Time
-				if err := rows.Scan(&tOut); err != nil {
-					dbt.Errorf("%s (%s %s): %s", sqltype, resulttype, mode, err.Error())
-				} else if test.tOut != tOut || test.tIsZero != tOut.IsZero() {
-					dbt.Errorf("%s (%s %s): %s [%t] != %s [%t]", sqltype, resulttype, mode, test.tOut, test.tIsZero, tOut, tOut.IsZero())
-				}
-			}},
-		}
-	)
+	switch val := dst.(type) {
+	case []uint8:
+		str := string(val)
+		if str == t.s {
+			return
+		}
+		dbt.Errorf("%s to string [%s]: expected '%s', got '%s'",
+			dbtype, protocol,
+			t.s, str,
+		)
+	case time.Time:
+		if val == t.t {
+			return
+		}
+		dbt.Errorf("%s to string [%s]: expected '%s', got '%s'",
+			dbtype, protocol,
+			t.s, val.Format(tlayout),
+		)
+	default:
+		dbt.Errorf("%s [%s]: unhandled type %T (is '%s')",
+			dbtype, protocol, val, val,
+		)
+	}
+}
 
-	var s *setup
-	testTime := func(dbt *DBTest) {
-		var rows *sql.Rows
-		for sqltype, tests := range timetests {
-			dbt.mustExec("CREATE TABLE test (value " + sqltype + ")")
-			for _, test := range tests {
-				for mode, q := range modes {
-					dbt.mustExec("TRUNCATE test")
-					dbt.mustExec("INSERT INTO test VALUES (?)", test.in)
-					rows = dbt.mustQuery("SELECT value FROM test"+q.selectSuffix, q.args...)
-					if rows.Next() {
-						s.test(dbt, rows, test, sqltype, s.vartype, mode)
-					} else {
-						if err := rows.Err(); err != nil {
-							dbt.Errorf("%s (%s %s): %s",
-								sqltype, s.vartype, mode, err.Error())
-						} else {
-							dbt.Errorf("%s (%s %s): no data",
-								sqltype, s.vartype, mode)
-						}
+func TestDateTime(t *testing.T) {
+	afterTime0 := func(d string) time.Time {
+		dur, err := time.ParseDuration(d)
+		if err != nil {
+			panic(err)
+		}
+		return time.Time{}.Add(dur)
+	}
+	// NOTE: MySQL rounds DATETIME(x) up - but that's not included in the tests
+	format := "2006-01-02 15:04:05.999999"
+	t0 := time.Time{}
+	tstr0 := "0000-00-00 00:00:00.000000"
+	testcases := []timeTests{
+		{"DATE", format[:10], []timeTest{
+			{t: time.Date(2011, 11, 20, 0, 0, 0, 0, time.UTC)},
+			{t: t0, s: tstr0[:10]},
+		}},
+		{"DATETIME", format[:19], []timeTest{
+			{t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)},
+			{t: t0, s: tstr0[:19]},
+		}},
+		{"DATETIME(1)", format[:21], []timeTest{
+			{t: time.Date(2011, 11, 20, 21, 27, 37, 100000000, time.UTC)},
+			{t: t0, s: tstr0[:21]},
+		}},
+		{"DATETIME(6)", format, []timeTest{
+			{t: time.Date(2011, 11, 20, 21, 27, 37, 123456000, time.UTC)},
+			{t: t0, s: tstr0},
+		}},
+		{"TIME", format[11:19], []timeTest{
+			{t: afterTime0("12345s")},
+			{t: afterTime0("-12345s")},
+			{t: t0, s: tstr0[11:19]},
+		}},
+		{"TIME(1)", format[11:21], []timeTest{
+			{t: afterTime0("12345600ms")},
+			{t: afterTime0("-12345600ms")},
+			{t: t0, s: tstr0[11:21]},
+		}},
+		{"TIME(6)", format[11:], []timeTest{
+			{t: afterTime0("1234567890123000ns")},
+			{t: afterTime0("-1234567890123000ns")},
+			{t: t0, s: tstr0[11:]},
+		}},
+	}
+	dsns := map[string]bool{
+		dsn + "&parseTime=true":                               true,
+		dsn + "&sql_mode=ALLOW_INVALID_DATES&parseTime=true":  true,
+		dsn + "&parseTime=false":                              false,
+		dsn + "&sql_mode=ALLOW_INVALID_DATES&parseTime=false": false,
+	}
+	for testdsn, parseTime := range dsns {
+		var _ = parseTime
+		runTests(t, testdsn, func(dbt *DBTest) {
+			for _, setups := range testcases {
+				for _, setup := range setups.tests {
+					if setup.s == "" {
+						// fill time string where Go can reliable produce it
+						setup.s = setup.t.Format(setups.tlayout)
 					}
+					setup.run(dbt, setups.dbtype, setups.tlayout, true)
+					setup.run(dbt, setups.dbtype, setups.tlayout, false)
 				}
 			}
-			dbt.mustExec("DROP TABLE IF EXISTS test")
-		}
-	}
-
-	timeDsn := dsn + "&sql_mode=ALLOW_INVALID_DATES"
-	for _, v := range setups {
-		s = v
-		runTests(t, timeDsn+s.dsnSuffix, testTime)
+		})
 	}
 }
 
@@ -1010,9 +1053,10 @@ func TestTimezoneConversion(t *testing.T) {
 		dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)")
 
 		// Insert local time into database (should be converted)
+		utc, _ := time.LoadLocation("UTC")
 		usCentral, _ := time.LoadLocation("US/Central")
-		now := time.Now().In(usCentral)
-		dbt.mustExec("INSERT INTO test VALUE (?)", now)
+		reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, utc).In(usCentral)
+		dbt.mustExec("INSERT INTO test VALUE (?)", reftime)
 
 		// Retrieve time from DB
 		rows := dbt.mustQuery("SELECT ts FROM test")
@@ -1020,17 +1064,17 @@ func TestTimezoneConversion(t *testing.T) {
 			dbt.Fatal("Didn't get any rows out")
 		}
 
-		var nowDB time.Time
-		err := rows.Scan(&nowDB)
+		var dbTime time.Time
+		err := rows.Scan(&dbTime)
 		if err != nil {
 			dbt.Fatal("Err", err)
 		}
 
 		// Check that dates match
-		if now.Unix() != nowDB.Unix() {
+		if reftime.Unix() != dbTime.Unix() {
 			dbt.Errorf("Times don't match.\n")
-			dbt.Errorf(" Now(%v)=%v\n", usCentral, now)
-			dbt.Errorf(" Now(UTC)=%v\n", nowDB)
+			dbt.Errorf(" Now(%v)=%v\n", usCentral, reftime)
+			dbt.Errorf(" Now(UTC)=%v\n", dbTime)
 		}
 	}
 
@@ -1039,42 +1083,6 @@ func TestTimezoneConversion(t *testing.T) {
 	}
 }
 
-// This tests for https://github.com/go-sql-driver/mysql/pull/139
-//
-// An extra (invisible) nil byte was being added to the beginning of positive
-// time strings.
-func TestTimeSign(t *testing.T) {
-	runTests(t, dsn, func(dbt *DBTest) {
-		var sTimes = []struct {
-			value     string
-			fieldType string
-		}{
-			{"12:34:56", "TIME"},
-			{"-12:34:56", "TIME"},
-			// As described in http://dev.mysql.com/doc/refman/5.6/en/fractional-seconds.html
-			// they *should* work, but only in 5.6+.
-			// { "12:34:56.789", "TIME(3)" },
-			// { "-12:34:56.789", "TIME(3)" },
-		}
-
-		for _, sTime := range sTimes {
-			dbt.db.Exec("DROP TABLE IF EXISTS test")
-			dbt.mustExec("CREATE TABLE test (id INT, time_field " + sTime.fieldType + ")")
-			dbt.mustExec("INSERT INTO test (id, time_field) VALUES(1, '" + sTime.value + "')")
-			rows := dbt.mustQuery("SELECT time_field FROM test WHERE id = ?", 1)
-			if rows.Next() {
-				var oTime string
-				rows.Scan(&oTime)
-				if oTime != sTime.value {
-					dbt.Errorf(`time values differ: got %q, expected %q.`, oTime, sTime.value)
-				}
-			} else {
-				dbt.Error("expecting at least one row.")
-			}
-		}
-	})
-}
-
 // Special cases
 
 func TestRowsClose(t *testing.T) {

+ 27 - 12
packets.go

@@ -16,6 +16,7 @@ import (
 	"fmt"
 	"io"
 	"math"
+	"strconv"
 	"time"
 )
 
@@ -557,20 +558,24 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
 			return nil, err
 		}
 
-		// Filler [1 byte]
-		// Charset [16 bit uint]
-		// Length [32 bit uint]
-		pos += n + 1 + 2 + 4
+		// Filler [uint8]
+		// Charset [charset, collation uint8]
+		pos += n + 1 + 2
 
-		// Field type [byte]
+		// Length [uint32]
+		columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4])
+		pos += 4
+
+		// Field type [uint8]
 		columns[i].fieldType = data[pos]
 		pos++
 
-		// Flags [16 bit uint]
+		// Flags [uint16]
 		columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
-		//pos += 2
+		pos += 2
 
-		// Decimals [8 bit uint]
+		// Decimals [uint8]
+		columns[i].decimals = data[pos]
 		//pos++
 
 		// Default value [len coded binary]
@@ -950,6 +955,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 
 // http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
 func (rows *binaryRows) readRow(dest []driver.Value) error {
+	timestr := "00:00:00.000000"
 	data, err := rows.mc.readPacket()
 	if err != nil {
 		return err
@@ -1068,7 +1074,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
 			if rows.mc.parseTime {
 				dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.loc)
 			} else {
-				dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], false)
+				dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], 10)
 			}
 
 			if err == nil {
@@ -1088,7 +1094,11 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
 					dest[i] = nil
 					continue
 				} else {
-					dest[i] = []byte("00:00:00")
+					length := uint8(8)
+					if rows.columns[i].decimals > 0 {
+						length += 1 + uint8(rows.columns[i].decimals)
+					}
+					dest[i] = []byte(timestr[:length])
 					continue
 				}
 			}
@@ -1109,8 +1119,9 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
 				pos += 8
 				continue
 			case 12:
+				decimals := strconv.FormatInt(int64(rows.columns[i].decimals), 10)
 				dest[i] = []byte(fmt.Sprintf(
-					sign+"%02d:%02d:%02d.%06d",
+					sign+"%02d:%02d:%02d.%0"+decimals+"d",
 					uint16(data[pos+1])*24+uint16(data[pos+5]),
 					data[pos+6],
 					data[pos+7],
@@ -1136,7 +1147,11 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
 			if rows.mc.parseTime {
 				dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.loc)
 			} else {
-				dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], true)
+				length := uint8(19)
+				if rows.columns[i].decimals > 0 {
+					length += 1 + uint8(rows.columns[i].decimals)
+				}
+				dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], length)
 			}
 
 			if err == nil {

+ 4 - 2
rows.go

@@ -14,9 +14,11 @@ import (
 )
 
 type mysqlField struct {
-	fieldType byte
-	flags     fieldFlag
 	name      string
+	length    uint32 // length as string: DATETIME(4) => 24
+	flags     fieldFlag
+	fieldType byte
+	decimals  byte // numeric precision: DATETIME(4) => 4, also for DECIMAL etc.
 }
 
 type mysqlRows struct {

+ 20 - 23
utils.go

@@ -29,6 +29,9 @@ var (
 	errInvalidDSNNoSlash   = errors.New("Invalid DSN: Missing the slash separating the database name")
 )
 
+// timeFormat must not be changed
+var timeFormat = "2006-01-02 15:04:05.999999"
+
 func init() {
 	tlsConfigRegister = make(map[string]*tls.Config)
 }
@@ -451,17 +454,13 @@ func (nt NullTime) Value() (driver.Value, error) {
 }
 
 func parseDateTime(str string, loc *time.Location) (t time.Time, err error) {
+	base := "0000-00-00 00:00:00.0000000"
 	switch len(str) {
-	case 10: // YYYY-MM-DD
-		if str == "0000-00-00" {
-			return
-		}
-		t, err = time.Parse(timeFormat[:10], str)
-	case 19: // YYYY-MM-DD HH:MM:SS
-		if str == "0000-00-00 00:00:00" {
+	case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM"
+		if str == base[:len(str)] {
 			return
 		}
-		t, err = time.Parse(timeFormat, str)
+		t, err = time.Parse(timeFormat[:len(str)], str)
 	default:
 		err = fmt.Errorf("Invalid Time-String: %s", str)
 		return
@@ -519,24 +518,22 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va
 // if the DATE or DATETIME has the zero value.
 // It must never be changed.
 // The current behavior depends on database/sql copying the result.
-var zeroDateTime = []byte("0000-00-00 00:00:00")
+var zeroDateTime = []byte("0000-00-00 00:00:00.000000")
 
-func formatBinaryDateTime(src []byte, withTime bool) (driver.Value, error) {
+func formatBinaryDateTime(src []byte, length uint8) (driver.Value, error) {
 	if len(src) == 0 {
-		if withTime {
-			return zeroDateTime, nil
-		}
-		return zeroDateTime[:10], nil
+		return zeroDateTime[:length], nil
 	}
 	var dst []byte
-	if withTime {
-		if len(src) == 11 {
-			dst = []byte("0000-00-00 00:00:00.000000")
-		} else {
-			dst = []byte("0000-00-00 00:00:00")
-		}
-	} else {
+	switch length {
+	case 10:
 		dst = []byte("0000-00-00")
+	case 19:
+		dst = []byte("0000-00-00 00:00:00")
+	case 21, 22, 23, 24, 25, 26:
+		dst = []byte("0000-00-00 00:00:00.000000")
+	default:
+		return nil, fmt.Errorf("illegal datetime length %d", length)
 	}
 	switch len(src) {
 	case 11:
@@ -584,10 +581,10 @@ func formatBinaryDateTime(src []byte, withTime bool) (driver.Value, error) {
 		tmp16, year = tmp16/10, tmp16
 		dst[1] += byte(year - 10*tmp16)
 		dst[0] += byte(tmp16)
-		return dst, nil
+		return dst[:length], nil
 	}
 	var t string
-	if withTime {
+	if length >= 19 {
 		t = "DATETIME"
 	} else {
 		t = "DATE"

+ 9 - 9
utils_test.go

@@ -191,22 +191,22 @@ func TestFormatBinaryDateTime(t *testing.T) {
 	rawDate[5] = 46                                    // minutes
 	rawDate[6] = 23                                    // seconds
 	binary.LittleEndian.PutUint32(rawDate[7:], 987654) // microseconds
-	expect := func(expected string, length int, withTime bool) {
-		actual, _ := formatBinaryDateTime(rawDate[:length], withTime)
+	expect := func(expected string, inlen, outlen uint8) {
+		actual, _ := formatBinaryDateTime(rawDate[:inlen], outlen)
 		bytes, ok := actual.([]byte)
 		if !ok {
 			t.Errorf("formatBinaryDateTime must return []byte, was %T", actual)
 		}
 		if string(bytes) != expected {
 			t.Errorf(
-				"expected %q, got %q for length %d, withTime %v",
-				bytes, actual, length, withTime,
+				"expected %q, got %q for length in %d, out %d",
+				bytes, actual, inlen, outlen,
 			)
 		}
 	}
-	expect("0000-00-00", 0, false)
-	expect("0000-00-00 00:00:00", 0, true)
-	expect("1978-12-30", 4, false)
-	expect("1978-12-30 15:46:23", 7, true)
-	expect("1978-12-30 15:46:23.987654", 11, true)
+	expect("0000-00-00", 0, 10)
+	expect("0000-00-00 00:00:00", 0, 19)
+	expect("1978-12-30", 4, 10)
+	expect("1978-12-30 15:46:23", 7, 19)
+	expect("1978-12-30 15:46:23.987654", 11, 26)
 }