Browse Source

fix Scan NullTime type failed at sqlite3

xormplus 8 years ago
parent
commit
c5b6cd8dd2
3 changed files with 69 additions and 16 deletions
  1. 12 13
      db_test.go
  2. 2 3
      pk_test.go
  3. 55 0
      scan.go

+ 12 - 13
db_test.go

@@ -2,7 +2,7 @@ package core
 
 import (
 	"errors"
-	"fmt"
+	"flag"
 	"os"
 	"testing"
 	"time"
@@ -12,8 +12,7 @@ import (
 )
 
 var (
-	//dbtype         string = "sqlite3"
-	dbtype         string = "mysql"
+	dbtype         = flag.String("dbtype", "mysql", "database type")
 	createTableSql string
 )
 
@@ -28,7 +27,8 @@ type User struct {
 }
 
 func init() {
-	switch dbtype {
+	flag.Parse()
+	switch *dbtype {
 	case "sqlite3":
 		createTableSql = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NULL, " +
 			"`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);"
@@ -41,7 +41,7 @@ func init() {
 }
 
 func testOpen() (*DB, error) {
-	switch dbtype {
+	switch *dbtype {
 	case "sqlite3":
 		os.Remove("./test.db")
 		return Open("sqlite3", "./test.db")
@@ -133,7 +133,7 @@ func BenchmarkStructQuery(b *testing.B) {
 				b.Error(err)
 			}
 			if user.Name != "xlw" {
-				fmt.Println(user)
+				b.Log(user)
 				b.Error(errors.New("name should be xlw"))
 			}
 		}
@@ -179,7 +179,7 @@ func BenchmarkStruct2Query(b *testing.B) {
 				b.Error(err)
 			}
 			if user.Name != "xlw" {
-				fmt.Println(user)
+				b.Log(user)
 				b.Error(errors.New("name should be xlw"))
 			}
 		}
@@ -228,9 +228,8 @@ func BenchmarkSliceInterfaceQuery(b *testing.B) {
 			if err != nil {
 				b.Error(err)
 			}
-			fmt.Println(slice)
+			b.Log(slice)
 			if *slice[1].(*string) != "xlw" {
-				fmt.Println(slice)
 				b.Error(errors.New("name should be xlw"))
 			}
 		}
@@ -332,7 +331,7 @@ func BenchmarkSliceStringQuery(b *testing.B) {
 				b.Error(err)
 			}
 			if (*slice[1]) != "xlw" {
-				fmt.Println(slice)
+				b.Log(slice)
 				b.Error(errors.New("name should be xlw"))
 			}
 		}
@@ -378,7 +377,7 @@ func BenchmarkMapInterfaceQuery(b *testing.B) {
 				b.Error(err)
 			}
 			if m["name"].(string) != "xlw" {
-				fmt.Println(m)
+				b.Log(m)
 				b.Error(errors.New("name should be xlw"))
 			}
 		}
@@ -579,7 +578,7 @@ func TestExecMap(t *testing.T) {
 		if err != nil {
 			t.Error(err)
 		}
-		fmt.Println("--", user)
+		t.Log("--", user)
 	}
 }
 
@@ -621,7 +620,7 @@ func TestExecStruct(t *testing.T) {
 		if err != nil {
 			t.Error(err)
 		}
-		fmt.Println("1--", user)
+		t.Log("1--", user)
 	}
 }
 

+ 2 - 3
pk_test.go

@@ -1,7 +1,6 @@
 package core
 
 import (
-	"fmt"
 	"reflect"
 	"testing"
 )
@@ -12,14 +11,14 @@ func TestPK(t *testing.T) {
 	if err != nil {
 		t.Error(err)
 	}
-	fmt.Println(str)
+	t.Log(str)
 
 	s := &PK{}
 	err = s.FromString(str)
 	if err != nil {
 		t.Error(err)
 	}
-	fmt.Println(s)
+	t.Log(s)
 
 	if len(*p) != len(*s) {
 		t.Fatal("p", *p, "should be equal", *s)

+ 55 - 0
scan.go

@@ -0,0 +1,55 @@
+package core
+
+import (
+	"database/sql/driver"
+	"fmt"
+	"time"
+)
+
+type NullTime time.Time
+
+var (
+	_ driver.Valuer = NullTime{}
+)
+
+func (ns *NullTime) Scan(value interface{}) error {
+	if value == nil {
+		return nil
+	}
+	return convertTime(ns, value)
+}
+
+// Value implements the driver Valuer interface.
+func (ns NullTime) Value() (driver.Value, error) {
+	if (time.Time)(ns).IsZero() {
+		return nil, nil
+	}
+	return (time.Time)(ns).Format("2006-01-02 15:04:05"), nil
+}
+
+func convertTime(dest *NullTime, src interface{}) error {
+	// Common cases, without reflect.
+	switch s := src.(type) {
+	case string:
+		t, err := time.Parse("2006-01-02 15:04:05", s)
+		if err != nil {
+			return err
+		}
+		*dest = NullTime(t)
+		return nil
+	case []uint8:
+		t, err := time.Parse("2006-01-02 15:04:05", string(s))
+		if err != nil {
+			return err
+		}
+		*dest = NullTime(t)
+		return nil
+	case time.Time:
+		*dest = NullTime(s)
+		return nil
+	case nil:
+	default:
+		return fmt.Errorf("unsupported driver -> Scan pair: %T -> %T", src, dest)
+	}
+	return nil
+}