Browse Source

Add NullTime struct

Julien Schmidt 12 năm trước cách đây
mục cha
commit
f56d52bb2b
4 tập tin đã thay đổi với 137 bổ sung42 xóa
  1. 3 1
      README.md
  2. 19 16
      driver_test.go
  3. 59 9
      utils.go
  4. 56 16
      utils_test.go

+ 3 - 1
README.md

@@ -159,6 +159,8 @@ However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` v
 
 **Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes).
 
+Alternatively you can use the [`NullTime`](http://godoc.org/github.com/go-sql-driver/mysql#NullTime) type as the scan destination, which works with both time.Time and string / []byte.
+
 
 
 ## Testing / Development
@@ -167,7 +169,7 @@ To run the driver tests you may need to adjust the configuration. See the [Testi
 Go-MySQL-Driver is not feature-complete yet. Your help is very appreciated.
 If you want to contribute, you can work on an [open issue](https://github.com/go-sql-driver/mysql/issues?state=open) or review a [pull request](https://github.com/go-sql-driver/mysql/pulls).
 
-Code changes must be proposed via a Pull Request and must be reviewed. Only *LGTM*-ed (" *Looks good to me* ") code may be committed to the master branch. 
+Code changes must be proposed via a Pull Request and must be reviewed. Only *LGTM*-ed (" *Looks good to me* ") code may be committed to the master branch.
 
 ---------------------------------------
 

+ 19 - 16
driver_test.go

@@ -19,6 +19,16 @@ var (
 	available bool
 )
 
+var (
+	tDate      = time.Date(2012, 6, 14, 0, 0, 0, 0, time.UTC)
+	sDate      = "2012-06-14"
+	tDateTime  = time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)
+	sDateTime  = "2011-11-20 21:27:37"
+	tDate0     = time.Time{}
+	sDate0     = "0000-00-00"
+	sDateTime0 = "0000-00-00 00:00:00"
+)
+
 // See https://github.com/go-sql-driver/mysql/wiki/Testing
 func init() {
 	env := func(key, defaultValue string) string {
@@ -396,29 +406,22 @@ func TestDateTime(t *testing.T) {
 		test      tester
 	}
 	var (
-		tdate      = time.Date(2012, 6, 14, 0, 0, 0, 0, time.UTC)
-		sdate      = "2012-06-14"
-		tdatetime  = time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)
-		sdatetime  = "2011-11-20 21:27:37"
-		tdate0     = time.Time{}
-		sdate0     = "0000-00-00"
-		sdatetime0 = "0000-00-00 00:00:00"
-		modes      = map[string]*testmode{
+		modes = map[string]*testmode{
 			"text":   &testmode{},
 			"binary": &testmode{" WHERE 1 = ?", []interface{}{1}},
 		}
 		timetests = map[string][]*timetest{
 			"DATE": {
-				{sdate, sdate, tdate, false},
-				{sdate0, sdate0, tdate0, true},
-				{tdate, sdate, tdate, false},
-				{tdate0, sdate0, tdate0, true},
+				{sDate, sDate, tDate, false},
+				{sDate0, sDate0, tDate0, true},
+				{tDate, sDate, tDate, false},
+				{tDate0, sDate0, tDate0, true},
 			},
 			"DATETIME": {
-				{sdatetime, sdatetime, tdatetime, false},
-				{sdatetime0, sdatetime0, tdate0, true},
-				{tdatetime, sdatetime, tdatetime, false},
-				{tdate0, sdatetime0, tdate0, true},
+				{sDateTime, sDateTime, tDateTime, false},
+				{sDateTime0, sDateTime0, tDate0, true},
+				{tDateTime, sDateTime, tDateTime, false},
+				{tDate0, sDateTime0, tDate0, true},
 			},
 		}
 		setups = []*setup{

+ 59 - 9
utils.go

@@ -22,6 +22,58 @@ import (
 	"time"
 )
 
+// NullTime represents a time.Time that may be NULL.
+// NullTime implements the Scanner interface so
+// it can be used as a scan destination:
+//
+//  var nt NullTime
+//  err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
+//  ...
+//  if nt.Valid {
+//     // use nt.Time
+//  } else {
+//     // NULL value
+//  }
+//
+// This NullTime implementation is not driver-specific
+type NullTime struct {
+	Time  time.Time
+	Valid bool // Valid is true if Time is not NULL
+}
+
+// Scan implements the Scanner interface.
+// The value type must be time.Time or string / []byte (formatted time-string),
+// otherwise Scan fails.
+func (nt *NullTime) Scan(value interface{}) (err error) {
+	if value == nil {
+		nt.Time, nt.Valid = time.Time{}, false
+	} else {
+		switch v := value.(type) {
+		case time.Time:
+			nt.Time, nt.Valid = v, true
+		case []byte:
+			nt.Time, err = parseDateTime(string(v), time.UTC)
+			nt.Valid = (err == nil)
+		case string:
+			nt.Time, err = parseDateTime(v, time.UTC)
+			nt.Valid = (err == nil)
+		default:
+			nt.Valid = false
+			err = fmt.Errorf("Can't convert %T to time.Time", v)
+		}
+	}
+
+	return
+}
+
+// Value implements the driver Valuer interface.
+func (nt NullTime) Value() (driver.Value, error) {
+	if !nt.Valid {
+		return nil, nil
+	}
+	return nt.Time, nil
+}
+
 // Logger
 var (
 	errLog *log.Logger
@@ -116,33 +168,31 @@ func scramblePassword(scramble, password []byte) []byte {
 	return scramble
 }
 
-func parseDateTime(str string, loc *time.Location) (driver.Value, error) {
-	var t time.Time
-	var err error
-
+func parseDateTime(str string, loc *time.Location) (t time.Time, err error) {
 	switch len(str) {
 	case 10: // YYYY-MM-DD
 		if str == "0000-00-00" {
-			return time.Time{}, nil
+			return
 		}
 		t, err = time.Parse(timeFormat[:10], str)
 	case 19: // YYYY-MM-DD HH:MM:SS
 		if str == "0000-00-00 00:00:00" {
-			return time.Time{}, nil
+			return
 		}
 		t, err = time.Parse(timeFormat, str)
 	default:
-		return nil, fmt.Errorf("Invalid Time-String: %s", str)
+		err = fmt.Errorf("Invalid Time-String: %s", str)
+		return
 	}
 
 	// Adjust location
 	if err == nil && loc != time.UTC {
 		y, mo, d := t.Date()
 		h, mi, s := t.Clock()
-		return time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil
+		t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil
 	}
 
-	return t, err
+	return
 }
 
 func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) {

+ 56 - 16
utils_test.go

@@ -15,23 +15,23 @@ import (
 	"time"
 )
 
-var testDSNs = []struct {
-	in  string
-	out string
-	loc *time.Location
-}{
-	{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p}", time.UTC},
-	{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p}", time.UTC},
-	{"user:password@tcp(localhost:5555)/dbname?charset=utf8", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p}", time.UTC},
-	{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p}", time.UTC},
-	{"user:password@/dbname?loc=UTC", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[loc:UTC] loc:%p}", time.UTC},
-	{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[loc:Local] loc:%p}", time.Local},
-	{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p}", time.UTC},
-	{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p}", time.UTC},
-	{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p}", time.UTC},
-}
-
 func TestDSNParser(t *testing.T) {
+	var testDSNs = []struct {
+		in  string
+		out string
+		loc *time.Location
+	}{
+		{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p}", time.UTC},
+		{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p}", time.UTC},
+		{"user:password@tcp(localhost:5555)/dbname?charset=utf8", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p}", time.UTC},
+		{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p}", time.UTC},
+		{"user:password@/dbname?loc=UTC", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[loc:UTC] loc:%p}", time.UTC},
+		{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[loc:Local] loc:%p}", time.Local},
+		{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p}", time.UTC},
+		{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p}", time.UTC},
+		{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p}", time.UTC},
+	}
+
 	var cfg *config
 	var err error
 	var res string
@@ -48,3 +48,43 @@ func TestDSNParser(t *testing.T) {
 		}
 	}
 }
+
+func TestScanNullTime(t *testing.T) {
+	var scanTests = []struct {
+		in    interface{}
+		error bool
+		valid bool
+		time  time.Time
+	}{
+		{tDate, false, true, tDate},
+		{sDate, false, true, tDate},
+		{[]byte(sDate), false, true, tDate},
+		{tDateTime, false, true, tDateTime},
+		{sDateTime, false, true, tDateTime},
+		{[]byte(sDateTime), false, true, tDateTime},
+		{tDate0, false, true, tDate0},
+		{sDate0, false, true, tDate0},
+		{[]byte(sDate0), false, true, tDate0},
+		{sDateTime0, false, true, tDate0},
+		{[]byte(sDateTime0), false, true, tDate0},
+		{"", true, false, tDate0},
+		{"1234", true, false, tDate0},
+		{0, true, false, tDate0},
+	}
+
+	var nt = NullTime{}
+	var err error
+
+	for _, tst := range scanTests {
+		err = nt.Scan(tst.in)
+		if (err != nil) != tst.error {
+			t.Errorf("%v: expected error status %b, got %b", tst.in, tst.error, (err != nil))
+		}
+		if nt.Valid != tst.valid {
+			t.Errorf("%v: expected valid status %b, got %b", tst.in, tst.valid, nt.Valid)
+		}
+		if nt.Time != tst.time {
+			t.Errorf("%v: expected time %v, got %v", tst.in, tst.time, nt.Time)
+		}
+	}
+}