xormplus 7 年 前
コミット
44ee0d00dc
2 ファイル変更48 行追加73 行削除
  1. 45 12
      db.go
  2. 3 61
      rows.go

+ 45 - 12
db.go

@@ -7,6 +7,11 @@ import (
 	"fmt"
 	"fmt"
 	"reflect"
 	"reflect"
 	"regexp"
 	"regexp"
+	"sync"
+)
+
+var (
+	DefaultCacheSize = 200
 )
 )
 
 
 func MapToSlice(query string, mp interface{}) (string, []interface{}, error) {
 func MapToSlice(query string, mp interface{}) (string, []interface{}, error) {
@@ -58,9 +63,16 @@ func StructToSlice(query string, st interface{}) (string, []interface{}, error)
 	return query, args, nil
 	return query, args, nil
 }
 }
 
 
+type cacheStruct struct {
+	value reflect.Value
+	idx   int
+}
+
 type DB struct {
 type DB struct {
 	*sql.DB
 	*sql.DB
-	Mapper IMapper
+	Mapper            IMapper
+	reflectCache      map[reflect.Type]*cacheStruct
+	reflectCacheMutex sync.RWMutex
 }
 }
 
 
 func Open(driverName, dataSourceName string) (*DB, error) {
 func Open(driverName, dataSourceName string) (*DB, error) {
@@ -68,11 +80,32 @@ func Open(driverName, dataSourceName string) (*DB, error) {
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	return &DB{db, NewCacheMapper(&SnakeMapper{})}, nil
+	return &DB{
+		DB:           db,
+		Mapper:       NewCacheMapper(&SnakeMapper{}),
+		reflectCache: make(map[reflect.Type]*cacheStruct),
+	}, nil
 }
 }
 
 
 func FromDB(db *sql.DB) *DB {
 func FromDB(db *sql.DB) *DB {
-	return &DB{db, NewCacheMapper(&SnakeMapper{})}
+	return &DB{
+		DB:           db,
+		Mapper:       NewCacheMapper(&SnakeMapper{}),
+		reflectCache: make(map[reflect.Type]*cacheStruct),
+	}
+}
+
+func (db *DB) reflectNew(typ reflect.Type) reflect.Value {
+	db.reflectCacheMutex.Lock()
+	defer db.reflectCacheMutex.Unlock()
+	cs, ok := db.reflectCache[typ]
+	if !ok || cs.idx+1 > DefaultCacheSize-1 {
+		cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), DefaultCacheSize, DefaultCacheSize), 0}
+		db.reflectCache[typ] = cs
+	} else {
+		cs.idx = cs.idx + 1
+	}
+	return cs.value.Index(cs.idx).Addr()
 }
 }
 
 
 func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
 func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
@@ -83,7 +116,7 @@ func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
 		}
 		}
 		return nil, err
 		return nil, err
 	}
 	}
-	return &Rows{rows, db.Mapper}, nil
+	return &Rows{rows, db}, nil
 }
 }
 
 
 func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) {
 func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) {
@@ -128,8 +161,8 @@ func (db *DB) QueryRowStruct(query string, st interface{}) *Row {
 
 
 type Stmt struct {
 type Stmt struct {
 	*sql.Stmt
 	*sql.Stmt
-	Mapper IMapper
-	names  map[string]int
+	db    *DB
+	names map[string]int
 }
 }
 
 
 func (db *DB) Prepare(query string) (*Stmt, error) {
 func (db *DB) Prepare(query string) (*Stmt, error) {
@@ -145,7 +178,7 @@ func (db *DB) Prepare(query string) (*Stmt, error) {
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	return &Stmt{stmt, db.Mapper, names}, nil
+	return &Stmt{stmt, db, names}, nil
 }
 }
 
 
 func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) {
 func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) {
@@ -179,7 +212,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	return &Rows{rows, s.Mapper}, nil
+	return &Rows{rows, s.db}, nil
 }
 }
 
 
 func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) {
 func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) {
@@ -274,7 +307,7 @@ func (EmptyScanner) Scan(src interface{}) error {
 
 
 type Tx struct {
 type Tx struct {
 	*sql.Tx
 	*sql.Tx
-	Mapper IMapper
+	db *DB
 }
 }
 
 
 func (db *DB) Begin() (*Tx, error) {
 func (db *DB) Begin() (*Tx, error) {
@@ -282,7 +315,7 @@ func (db *DB) Begin() (*Tx, error) {
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	return &Tx{tx, db.Mapper}, nil
+	return &Tx{tx, db}, nil
 }
 }
 
 
 func (tx *Tx) Prepare(query string) (*Stmt, error) {
 func (tx *Tx) Prepare(query string) (*Stmt, error) {
@@ -298,7 +331,7 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	return &Stmt{stmt, tx.Mapper, names}, nil
+	return &Stmt{stmt, tx.db, names}, nil
 }
 }
 
 
 func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
 func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
@@ -327,7 +360,7 @@ func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	return &Rows{rows, tx.Mapper}, nil
+	return &Rows{rows, tx.db}, nil
 }
 }
 
 
 func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) {
 func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) {

+ 3 - 61
rows.go

@@ -9,7 +9,7 @@ import (
 
 
 type Rows struct {
 type Rows struct {
 	*sql.Rows
 	*sql.Rows
-	Mapper IMapper
+	db *DB
 }
 }
 
 
 func (rs *Rows) ToMapString() ([]map[string]string, error) {
 func (rs *Rows) ToMapString() ([]map[string]string, error) {
@@ -105,7 +105,7 @@ func (rs *Rows) ScanStructByName(dest interface{}) error {
 	newDest := make([]interface{}, len(cols))
 	newDest := make([]interface{}, len(cols))
 	var v EmptyScanner
 	var v EmptyScanner
 	for j, name := range cols {
 	for j, name := range cols {
-		f := fieldByName(vv.Elem(), rs.Mapper.Table2Obj(name))
+		f := fieldByName(vv.Elem(), rs.db.Mapper.Table2Obj(name))
 		if f.IsValid() {
 		if f.IsValid() {
 			newDest[j] = f.Addr().Interface()
 			newDest[j] = f.Addr().Interface()
 		} else {
 		} else {
@@ -116,36 +116,6 @@ func (rs *Rows) ScanStructByName(dest interface{}) error {
 	return rs.Rows.Scan(newDest...)
 	return rs.Rows.Scan(newDest...)
 }
 }
 
 
-type cacheStruct struct {
-	value reflect.Value
-	idx   int
-}
-
-var (
-	reflectCache      = make(map[reflect.Type]*cacheStruct)
-	reflectCacheMutex sync.RWMutex
-)
-
-func ReflectNew(typ reflect.Type) reflect.Value {
-	reflectCacheMutex.RLock()
-	cs, ok := reflectCache[typ]
-	reflectCacheMutex.RUnlock()
-
-	const newSize = 200
-
-	if !ok || cs.idx+1 > newSize-1 {
-		cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), newSize, newSize), 0}
-		reflectCacheMutex.Lock()
-		reflectCache[typ] = cs
-		reflectCacheMutex.Unlock()
-	} else {
-		reflectCacheMutex.Lock()
-		cs.idx = cs.idx + 1
-		reflectCacheMutex.Unlock()
-	}
-	return cs.value.Index(cs.idx).Addr()
-}
-
 // scan data to a slice's pointer, slice's length should equal to columns' number
 // scan data to a slice's pointer, slice's length should equal to columns' number
 func (rs *Rows) ScanSlice(dest interface{}) error {
 func (rs *Rows) ScanSlice(dest interface{}) error {
 	vv := reflect.ValueOf(dest)
 	vv := reflect.ValueOf(dest)
@@ -197,9 +167,7 @@ func (rs *Rows) ScanMap(dest interface{}) error {
 	vvv := vv.Elem()
 	vvv := vv.Elem()
 
 
 	for i, _ := range cols {
 	for i, _ := range cols {
-		newDest[i] = ReflectNew(vvv.Type().Elem()).Interface()
-		//v := reflect.New(vvv.Type().Elem())
-		//newDest[i] = v.Interface()
+		newDest[i] = rs.db.reflectNew(vvv.Type().Elem()).Interface()
 	}
 	}
 
 
 	err = rs.Rows.Scan(newDest...)
 	err = rs.Rows.Scan(newDest...)
@@ -215,32 +183,6 @@ func (rs *Rows) ScanMap(dest interface{}) error {
 	return nil
 	return nil
 }
 }
 
 
-/*func (rs *Rows) ScanMap(dest interface{}) error {
-	vv := reflect.ValueOf(dest)
-	if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
-		return errors.New("dest should be a map's pointer")
-	}
-
-	cols, err := rs.Columns()
-	if err != nil {
-		return err
-	}
-
-	newDest := make([]interface{}, len(cols))
-	err = rs.ScanSlice(newDest)
-	if err != nil {
-		return err
-	}
-
-	vvv := vv.Elem()
-
-	for i, name := range cols {
-		vname := reflect.ValueOf(name)
-		vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem())
-	}
-
-	return nil
-}*/
 type Row struct {
 type Row struct {
 	rows *Rows
 	rows *Rows
 	// One of these two will be non-nil:
 	// One of these two will be non-nil: