Browse Source

Refactor NullTime as go1.13's sql.NullTime (#995)

Olivier Mengué 6 years ago
parent
commit
d9aa6d3abe
6 changed files with 177 additions and 95 deletions
  1. 50 0
      nulltime.go
  2. 31 0
      nulltime_go113.go
  3. 34 0
      nulltime_legacy.go
  4. 62 0
      nulltime_test.go
  5. 0 54
      utils.go
  6. 0 41
      utils_test.go

+ 50 - 0
nulltime.go

@@ -0,0 +1,50 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+	"database/sql/driver"
+	"fmt"
+	"time"
+)
+
+// Scan implements the Scanner interface.
+// The value type must be time.Time or string / []byte (formatted time-string),
+// otherwise Scan fails.
+func (nt *NullTime) Scan(value interface{}) (err error) {
+	if value == nil {
+		nt.Time, nt.Valid = time.Time{}, false
+		return
+	}
+
+	switch v := value.(type) {
+	case time.Time:
+		nt.Time, nt.Valid = v, true
+		return
+	case []byte:
+		nt.Time, err = parseDateTime(string(v), time.UTC)
+		nt.Valid = (err == nil)
+		return
+	case string:
+		nt.Time, err = parseDateTime(v, time.UTC)
+		nt.Valid = (err == nil)
+		return
+	}
+
+	nt.Valid = false
+	return fmt.Errorf("Can't convert %T to time.Time", value)
+}
+
+// Value implements the driver Valuer interface.
+func (nt NullTime) Value() (driver.Value, error) {
+	if !nt.Valid {
+		return nil, nil
+	}
+	return nt.Time, nil
+}

+ 31 - 0
nulltime_go113.go

@@ -0,0 +1,31 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+// +build go1.13
+
+package mysql
+
+import (
+	"database/sql"
+)
+
+// NullTime represents a time.Time that may be NULL.
+// NullTime implements the Scanner interface so
+// it can be used as a scan destination:
+//
+//  var nt NullTime
+//  err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
+//  ...
+//  if nt.Valid {
+//     // use nt.Time
+//  } else {
+//     // NULL value
+//  }
+//
+// This NullTime implementation is not driver-specific
+type NullTime sql.NullTime

+ 34 - 0
nulltime_legacy.go

@@ -0,0 +1,34 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+// +build !go1.13
+
+package mysql
+
+import (
+	"time"
+)
+
+// NullTime represents a time.Time that may be NULL.
+// NullTime implements the Scanner interface so
+// it can be used as a scan destination:
+//
+//  var nt NullTime
+//  err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
+//  ...
+//  if nt.Valid {
+//     // use nt.Time
+//  } else {
+//     // NULL value
+//  }
+//
+// This NullTime implementation is not driver-specific
+type NullTime struct {
+	Time  time.Time
+	Valid bool // Valid is true if Time is not NULL
+}

+ 62 - 0
nulltime_test.go

@@ -0,0 +1,62 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+	"database/sql"
+	"database/sql/driver"
+	"testing"
+	"time"
+)
+
+var (
+	// Check implementation of interfaces
+	_ driver.Valuer = NullTime{}
+	_ sql.Scanner   = (*NullTime)(nil)
+)
+
+func TestScanNullTime(t *testing.T) {
+	var scanTests = []struct {
+		in    interface{}
+		error bool
+		valid bool
+		time  time.Time
+	}{
+		{tDate, false, true, tDate},
+		{sDate, false, true, tDate},
+		{[]byte(sDate), false, true, tDate},
+		{tDateTime, false, true, tDateTime},
+		{sDateTime, false, true, tDateTime},
+		{[]byte(sDateTime), false, true, tDateTime},
+		{tDate0, false, true, tDate0},
+		{sDate0, false, true, tDate0},
+		{[]byte(sDate0), false, true, tDate0},
+		{sDateTime0, false, true, tDate0},
+		{[]byte(sDateTime0), false, true, tDate0},
+		{"", true, false, tDate0},
+		{"1234", true, false, tDate0},
+		{0, true, false, tDate0},
+	}
+
+	var nt = NullTime{}
+	var err error
+
+	for _, tst := range scanTests {
+		err = nt.Scan(tst.in)
+		if (err != nil) != tst.error {
+			t.Errorf("%v: expected error status %t, got %t", tst.in, tst.error, (err != nil))
+		}
+		if nt.Valid != tst.valid {
+			t.Errorf("%v: expected valid status %t, got %t", tst.in, tst.valid, nt.Valid)
+		}
+		if nt.Time != tst.time {
+			t.Errorf("%v: expected time %v, got %v", tst.in, tst.time, nt.Time)
+		}
+	}
+}

+ 0 - 54
utils.go

@@ -106,60 +106,6 @@ func readBool(input string) (value bool, valid bool) {
 *                           Time related utils                                *
 ******************************************************************************/
 
-// NullTime represents a time.Time that may be NULL.
-// NullTime implements the Scanner interface so
-// it can be used as a scan destination:
-//
-//  var nt NullTime
-//  err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt)
-//  ...
-//  if nt.Valid {
-//     // use nt.Time
-//  } else {
-//     // NULL value
-//  }
-//
-// This NullTime implementation is not driver-specific
-type NullTime struct {
-	Time  time.Time
-	Valid bool // Valid is true if Time is not NULL
-}
-
-// Scan implements the Scanner interface.
-// The value type must be time.Time or string / []byte (formatted time-string),
-// otherwise Scan fails.
-func (nt *NullTime) Scan(value interface{}) (err error) {
-	if value == nil {
-		nt.Time, nt.Valid = time.Time{}, false
-		return
-	}
-
-	switch v := value.(type) {
-	case time.Time:
-		nt.Time, nt.Valid = v, true
-		return
-	case []byte:
-		nt.Time, err = parseDateTime(string(v), time.UTC)
-		nt.Valid = (err == nil)
-		return
-	case string:
-		nt.Time, err = parseDateTime(v, time.UTC)
-		nt.Valid = (err == nil)
-		return
-	}
-
-	nt.Valid = false
-	return fmt.Errorf("Can't convert %T to time.Time", value)
-}
-
-// Value implements the driver Valuer interface.
-func (nt NullTime) Value() (driver.Value, error) {
-	if !nt.Valid {
-		return nil, nil
-	}
-	return nt.Time, nil
-}
-
 func parseDateTime(str string, loc *time.Location) (t time.Time, err error) {
 	base := "0000-00-00 00:00:00.0000000"
 	switch len(str) {

+ 0 - 41
utils_test.go

@@ -14,49 +14,8 @@ import (
 	"database/sql/driver"
 	"encoding/binary"
 	"testing"
-	"time"
 )
 
-func TestScanNullTime(t *testing.T) {
-	var scanTests = []struct {
-		in    interface{}
-		error bool
-		valid bool
-		time  time.Time
-	}{
-		{tDate, false, true, tDate},
-		{sDate, false, true, tDate},
-		{[]byte(sDate), false, true, tDate},
-		{tDateTime, false, true, tDateTime},
-		{sDateTime, false, true, tDateTime},
-		{[]byte(sDateTime), false, true, tDateTime},
-		{tDate0, false, true, tDate0},
-		{sDate0, false, true, tDate0},
-		{[]byte(sDate0), false, true, tDate0},
-		{sDateTime0, false, true, tDate0},
-		{[]byte(sDateTime0), false, true, tDate0},
-		{"", true, false, tDate0},
-		{"1234", true, false, tDate0},
-		{0, true, false, tDate0},
-	}
-
-	var nt = NullTime{}
-	var err error
-
-	for _, tst := range scanTests {
-		err = nt.Scan(tst.in)
-		if (err != nil) != tst.error {
-			t.Errorf("%v: expected error status %t, got %t", tst.in, tst.error, (err != nil))
-		}
-		if nt.Valid != tst.valid {
-			t.Errorf("%v: expected valid status %t, got %t", tst.in, tst.valid, nt.Valid)
-		}
-		if nt.Time != tst.time {
-			t.Errorf("%v: expected time %v, got %v", tst.in, tst.time, nt.Time)
-		}
-	}
-}
-
 func TestLengthEncodedInteger(t *testing.T) {
 	var integerTests = []struct {
 		num     uint64