Jelajahi Sumber

serious extends bug fixed & correct logger file path

xormplus 9 tahun lalu
induk
melakukan
0911afc633
7 mengubah file dengan 194 tambahan dan 84 penghapusan
  1. 21 35
      column.go
  2. 131 19
      db.go
  3. 4 4
      db_test.go
  4. 14 5
      dialect.go
  5. 1 3
      error.go
  6. 17 14
      ilogger.go
  7. 6 4
      type.go

+ 21 - 35
column.go

@@ -1,10 +1,10 @@
 package core
 
 import (
-	"errors"
 	"fmt"
 	"reflect"
 	"strings"
+	"time"
 )
 
 const (
@@ -35,6 +35,8 @@ type Column struct {
 	DefaultIsEmpty  bool
 	EnumOptions     map[string]int
 	SetOptions      map[string]int
+	DisableTimeZone bool
+	TimeZone        *time.Location // column specified time zone
 }
 
 func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int, nullable bool) *Column {
@@ -122,50 +124,34 @@ func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) {
 	}
 
 	if dataStruct.Type().Kind() == reflect.Map {
-		var keyValue reflect.Value
-
-		if len(col.fieldPath) == 1 {
-			keyValue = reflect.ValueOf(col.FieldName)
-		} else if len(col.fieldPath) == 2 {
-			keyValue = reflect.ValueOf(col.fieldPath[1])
-		} else {
-			return nil, fmt.Errorf("Unsupported mutliderive %v", col.FieldName)
-		}
-
+		keyValue := reflect.ValueOf(col.fieldPath[len(col.fieldPath)-1])
 		fieldValue = dataStruct.MapIndex(keyValue)
 		return &fieldValue, nil
+	} else if dataStruct.Type().Kind() == reflect.Interface {
+		structValue := reflect.ValueOf(dataStruct.Interface())
+		dataStruct = &structValue
 	}
 
-	if len(col.fieldPath) == 1 {
-		fieldValue = dataStruct.FieldByName(col.FieldName)
-	} else if len(col.fieldPath) == 2 {
-		parentField := dataStruct.FieldByName(col.fieldPath[0])
-		if parentField.IsValid() {
-			if parentField.Kind() == reflect.Struct {
-				fieldValue = parentField.FieldByName(col.fieldPath[1])
-			} else if parentField.Kind() == reflect.Ptr {
-				if parentField.IsNil() {
-					parentField.Set(reflect.New(parentField.Type().Elem()))
-					fieldValue = parentField.Elem().FieldByName(col.fieldPath[1])
-				} else {
-					parentField = parentField.Elem()
-					if parentField.IsValid() {
-						fieldValue = parentField.FieldByName(col.fieldPath[1])
-					} else {
-						return nil, fmt.Errorf("field  %v is not valid", col.FieldName)
-					}
-				}
+	level := len(col.fieldPath)
+	fieldValue = dataStruct.FieldByName(col.fieldPath[0])
+	for i := 0; i < level-1; i++ {
+		if !fieldValue.IsValid() {
+			break
+		}
+		if fieldValue.Kind() == reflect.Struct {
+			fieldValue = fieldValue.FieldByName(col.fieldPath[i+1])
+		} else if fieldValue.Kind() == reflect.Ptr {
+			if fieldValue.IsNil() {
+				fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
 			}
+			fieldValue = fieldValue.Elem().FieldByName(col.fieldPath[i+1])
 		} else {
-			// so we can use a different struct as conditions
-			fieldValue = dataStruct.FieldByName(col.fieldPath[1])
+			return nil, fmt.Errorf("field  %v is not valid", col.FieldName)
 		}
-	} else {
-		return nil, fmt.Errorf("Unsupported mutliderive %v", col.FieldName)
 	}
 
 	if !fieldValue.IsValid() {
-		return nil, errors.New("no find field matched")
+		return nil, fmt.Errorf("field  %v is not valid", col.FieldName)
 	}
 
 	return &fieldValue, nil

+ 131 - 19
db.go

@@ -2,6 +2,7 @@ package core
 
 import (
 	"database/sql"
+	"database/sql/driver"
 	"errors"
 	"reflect"
 	"regexp"
@@ -29,10 +30,24 @@ func StructToSlice(query string, st interface{}) (string, []interface{}, error)
 	}
 
 	args := make([]interface{}, 0)
+	var err error
 	query = re.ReplaceAllStringFunc(query, func(src string) string {
-		args = append(args, vv.Elem().FieldByName(src[1:]).Interface())
+		fv := vv.Elem().FieldByName(src[1:]).Interface()
+		if v, ok := fv.(driver.Valuer); ok {
+			var value driver.Value
+			value, err = v.Value()
+			if err != nil {
+				return "?"
+			}
+			args = append(args, value)
+		} else {
+			args = append(args, fv)
+		}
 		return "?"
 	})
+	if err != nil {
+		return "", []interface{}{}, err
+	}
 	return query, args, nil
 }
 
@@ -43,7 +58,10 @@ type DB struct {
 
 func Open(driverName, dataSourceName string) (*DB, error) {
 	db, err := sql.Open(driverName, dataSourceName)
-	return &DB{db, NewCacheMapper(&SnakeMapper{})}, err
+	if err != nil {
+		return nil, err
+	}
+	return &DB{db, NewCacheMapper(&SnakeMapper{})}, nil
 }
 
 func FromDB(db *sql.DB) *DB {
@@ -52,7 +70,13 @@ func FromDB(db *sql.DB) *DB {
 
 func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
 	rows, err := db.DB.Query(query, args...)
-	return &Rows{rows, db.Mapper}, err
+	if err != nil {
+		if rows != nil {
+			rows.Close()
+		}
+		return nil, err
+	}
+	return &Rows{rows, db.Mapper}, nil
 }
 
 func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) {
@@ -72,28 +96,114 @@ func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) {
 }
 
 type Row struct {
-	*sql.Row
+	rows *Rows
 	// One of these two will be non-nil:
-	err    error // deferred error for easy chaining
-	Mapper IMapper
+	err error // deferred error for easy chaining
+}
+
+func (row *Row) Columns() ([]string, error) {
+	if row.err != nil {
+		return nil, row.err
+	}
+	return row.rows.Columns()
 }
 
 func (row *Row) Scan(dest ...interface{}) error {
 	if row.err != nil {
 		return row.err
 	}
-	return row.Row.Scan(dest...)
+	defer row.rows.Close()
+
+	for _, dp := range dest {
+		if _, ok := dp.(*sql.RawBytes); ok {
+			return errors.New("sql: RawBytes isn't allowed on Row.Scan")
+		}
+	}
+
+	if !row.rows.Next() {
+		if err := row.rows.Err(); err != nil {
+			return err
+		}
+		return sql.ErrNoRows
+	}
+	err := row.rows.Scan(dest...)
+	if err != nil {
+		return err
+	}
+	// Make sure the query can be processed to completion with no errors.
+	if err := row.rows.Close(); err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (row *Row) ScanStructByName(dest interface{}) error {
+	if row.err != nil {
+		return row.err
+	}
+	if !row.rows.Next() {
+		if err := row.rows.Err(); err != nil {
+			return err
+		}
+		return sql.ErrNoRows
+	}
+	return row.rows.ScanStructByName(dest)
+}
+
+func (row *Row) ScanStructByIndex(dest interface{}) error {
+	if row.err != nil {
+		return row.err
+	}
+	if !row.rows.Next() {
+		if err := row.rows.Err(); err != nil {
+			return err
+		}
+		return sql.ErrNoRows
+	}
+	return row.rows.ScanStructByIndex(dest)
+}
+
+// scan data to a slice's pointer, slice's length should equal to columns' number
+func (row *Row) ScanSlice(dest interface{}) error {
+	if row.err != nil {
+		return row.err
+	}
+	if !row.rows.Next() {
+		if err := row.rows.Err(); err != nil {
+			return err
+		}
+		return sql.ErrNoRows
+	}
+	return row.rows.ScanSlice(dest)
+}
+
+// scan data to a map's pointer
+func (row *Row) ScanMap(dest interface{}) error {
+	if row.err != nil {
+		return row.err
+	}
+	if !row.rows.Next() {
+		if err := row.rows.Err(); err != nil {
+			return err
+		}
+		return sql.ErrNoRows
+	}
+	return row.rows.ScanMap(dest)
 }
 
 func (db *DB) QueryRow(query string, args ...interface{}) *Row {
-	row := db.DB.QueryRow(query, args...)
-	return &Row{row, nil, db.Mapper}
+	rows, err := db.Query(query, args...)
+	if err != nil {
+		return &Row{nil, err}
+	}
+	return &Row{rows, nil}
 }
 
 func (db *DB) QueryRowMap(query string, mp interface{}) *Row {
 	query, args, err := MapToSlice(query, mp)
 	if err != nil {
-		return &Row{nil, err, db.Mapper}
+		return &Row{nil, err}
 	}
 	return db.QueryRow(query, args...)
 }
@@ -101,7 +211,7 @@ func (db *DB) QueryRowMap(query string, mp interface{}) *Row {
 func (db *DB) QueryRowStruct(query string, st interface{}) *Row {
 	query, args, err := StructToSlice(query, st)
 	if err != nil {
-		return &Row{nil, err, db.Mapper}
+		return &Row{nil, err}
 	}
 	return db.QueryRow(query, args...)
 }
@@ -191,14 +301,14 @@ func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) {
 }
 
 func (s *Stmt) QueryRow(args ...interface{}) *Row {
-	row := s.Stmt.QueryRow(args...)
-	return &Row{row, nil, s.Mapper}
+	rows, err := s.Query(args...)
+	return &Row{rows, err}
 }
 
 func (s *Stmt) QueryRowMap(mp interface{}) *Row {
 	vv := reflect.ValueOf(mp)
 	if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
-		return &Row{nil, errors.New("mp should be a map's pointer"), s.Mapper}
+		return &Row{nil, errors.New("mp should be a map's pointer")}
 	}
 
 	args := make([]interface{}, len(s.names))
@@ -212,7 +322,7 @@ func (s *Stmt) QueryRowMap(mp interface{}) *Row {
 func (s *Stmt) QueryRowStruct(st interface{}) *Row {
 	vv := reflect.ValueOf(st)
 	if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
-		return &Row{nil, errors.New("st should be a struct's pointer"), s.Mapper}
+		return &Row{nil, errors.New("st should be a struct's pointer")}
 	}
 
 	args := make([]interface{}, len(s.names))
@@ -425,6 +535,8 @@ func (rs *Rows) ScanMap(dest interface{}) error {
 
 	for i, _ := range cols {
 		newDest[i] = ReflectNew(vvv.Type().Elem()).Interface()
+		//v := reflect.New(vvv.Type().Elem())
+		//newDest[i] = v.Interface()
 	}
 
 	err = rs.Rows.Scan(newDest...)
@@ -542,14 +654,14 @@ func (tx *Tx) QueryStruct(query string, st interface{}) (*Rows, error) {
 }
 
 func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
-	row := tx.Tx.QueryRow(query, args...)
-	return &Row{row, nil, tx.Mapper}
+	rows, err := tx.Query(query, args...)
+	return &Row{rows, err}
 }
 
 func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row {
 	query, args, err := MapToSlice(query, mp)
 	if err != nil {
-		return &Row{nil, err, tx.Mapper}
+		return &Row{nil, err}
 	}
 	return tx.QueryRow(query, args...)
 }
@@ -557,7 +669,7 @@ func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row {
 func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row {
 	query, args, err := StructToSlice(query, st)
 	if err != nil {
-		return &Row{nil, err, tx.Mapper}
+		return &Row{nil, err}
 	}
 	return tx.QueryRow(query, args...)
 }

+ 4 - 4
db_test.go

@@ -24,7 +24,7 @@ type User struct {
 	Age      float32
 	Alias    string
 	NickName string
-	Created  time.Time
+	Created  NullTime
 }
 
 func init() {
@@ -85,7 +85,7 @@ func BenchmarkOriQuery(b *testing.B) {
 			var Id int64
 			var Name, Title, Alias, NickName string
 			var Age float32
-			var Created time.Time
+			var Created NullTime
 			err = rows.Scan(&Id, &Name, &Title, &Age, &Alias, &NickName, &Created)
 			if err != nil {
 				b.Error(err)
@@ -600,7 +600,7 @@ func TestExecStruct(t *testing.T) {
 		Age:      1.2,
 		Alias:    "lunny",
 		NickName: "lunny xiao",
-		Created:  time.Now(),
+		Created:  NullTime(time.Now()),
 	}
 
 	_, err = db.ExecStruct("insert into user (`name`, title, age, alias, nick_name,created) "+
@@ -645,7 +645,7 @@ func BenchmarkExecStruct(b *testing.B) {
 		Age:      1.2,
 		Alias:    "lunny",
 		NickName: "lunny xiao",
-		Created:  time.Now(),
+		Created:  NullTime(time.Now()),
 	}
 
 	for i := 0; i < b.N; i++ {

+ 14 - 5
dialect.go

@@ -20,6 +20,7 @@ type Uri struct {
 	Laddr   string
 	Raddr   string
 	Timeout time.Duration
+	Schema  string
 }
 
 // a dialect is a driver's wrapper
@@ -84,7 +85,7 @@ type Base struct {
 	dialect        Dialect
 	driverName     string
 	dataSourceName string
-	Logger         ILogger
+	logger         ILogger
 	*Uri
 }
 
@@ -93,7 +94,7 @@ func (b *Base) DB() *DB {
 }
 
 func (b *Base) SetLogger(logger ILogger) {
-	b.Logger = logger
+	b.logger = logger
 }
 
 func (b *Base) Init(db *DB, dialect Dialect, uri *Uri, drivername, dataSourceName string) error {
@@ -151,10 +152,8 @@ func (db *Base) DropTableSql(tableName string) string {
 }
 
 func (db *Base) HasRecords(query string, args ...interface{}) (bool, error) {
+	db.LogSQL(query, args)
 	rows, err := db.DB().Query(query, args...)
-	if db.Logger != nil {
-		db.Logger.Info("[sql]", query, args)
-	}
 	if err != nil {
 		return false, err
 	}
@@ -277,6 +276,16 @@ func (b *Base) ForUpdateSql(query string) string {
 	return query + " FOR UPDATE"
 }
 
+func (b *Base) LogSQL(sql string, args []interface{}) {
+	if b.logger != nil && b.logger.IsShowSQL() {
+		if len(args) > 0 {
+			b.logger.Info("[sql]", sql, args)
+		} else {
+			b.logger.Info("[sql]", sql)
+		}
+	}
+}
+
 var (
 	dialects = map[DbType]func() Dialect{}
 )

+ 1 - 3
error.go

@@ -4,7 +4,5 @@ import "errors"
 
 var (
 	ErrNoMapPointer    = errors.New("mp should be a map's pointer")
-	ErrNoStructPointer = errors.New("mp should be a map's pointer")
-	//ErrNotExist        = errors.New("Not exist")
-	//ErrIgnore = errors.New("Ignore")
+	ErrNoStructPointer = errors.New("mp should be a struct's pointer")
 )

+ 17 - 14
ilogger.go

@@ -4,25 +4,28 @@ type LogLevel int
 
 const (
 	// !nashtsai! following level also match syslog.Priority value
-	LOG_UNKNOWN LogLevel = iota - 2
-	LOG_OFF     LogLevel = iota - 1
-	LOG_ERR     LogLevel = iota + 3
+	LOG_DEBUG LogLevel = iota
+	LOG_INFO
 	LOG_WARNING
-	LOG_INFO LogLevel = iota + 6
-	LOG_DEBUG
+	LOG_ERR
+	LOG_OFF
+	LOG_UNKNOWN
 )
 
 // logger interface
 type ILogger interface {
-	Debug(v ...interface{}) (err error)
-	Debugf(format string, v ...interface{}) (err error)
-	Err(v ...interface{}) (err error)
-	Errf(format string, v ...interface{}) (err error)
-	Info(v ...interface{}) (err error)
-	Infof(format string, v ...interface{}) (err error)
-	Warning(v ...interface{}) (err error)
-	Warningf(format string, v ...interface{}) (err error)
+	Debug(v ...interface{})
+	Debugf(format string, v ...interface{})
+	Error(v ...interface{})
+	Errorf(format string, v ...interface{})
+	Info(v ...interface{})
+	Infof(format string, v ...interface{})
+	Warn(v ...interface{})
+	Warnf(format string, v ...interface{})
 
 	Level() LogLevel
-	SetLevel(l LogLevel) (err error)
+	SetLevel(l LogLevel)
+
+	ShowSQL(show ...bool)
+	IsShowSQL() bool
 }

+ 6 - 4
type.go

@@ -105,7 +105,8 @@ var (
 	Serial    = "SERIAL"
 	BigSerial = "BIGSERIAL"
 
-	Json = "JSON"
+	Json  = "JSON"
+	Jsonb = "JSONB"
 
 	SqlTypes = map[string]int{
 		Bit:       NUMERIC_TYPE,
@@ -116,9 +117,10 @@ var (
 		Integer:   NUMERIC_TYPE,
 		BigInt:    NUMERIC_TYPE,
 
-		Enum: TEXT_TYPE,
-		Set:  TEXT_TYPE,
-		Json: TEXT_TYPE,
+		Enum:  TEXT_TYPE,
+		Set:   TEXT_TYPE,
+		Json:  TEXT_TYPE,
+		Jsonb: TEXT_TYPE,
 
 		Char:       TEXT_TYPE,
 		Varchar:    TEXT_TYPE,