瀏覽代碼

fix update map with table name

* fix update map with table name

* fix bug update map when cache enabled

* refactor cacheInsert

* fix cache test
xormplus 7 年之前
父節點
當前提交
8b624bbda4
共有 11 個文件被更改,包括 97 次插入103 次删除
  1. 37 52
      engine.go
  2. 2 0
      interface.go
  3. 2 2
      session_delete.go
  4. 7 3
      session_find.go
  5. 3 2
      session_get.go
  6. 14 23
      session_insert.go
  7. 6 7
      session_update.go
  8. 17 2
      session_update_test.go
  9. 5 5
      statement.go
  10. 2 6
      tag_cache_test.go
  11. 2 1
      xorm.go

+ 37 - 52
engine.go

@@ -52,6 +52,35 @@ type Engine struct {
 	tagHandlers map[string]tagHandler
 
 	engineGroup *EngineGroup
+
+	cachers    map[string]core.Cacher
+	cacherLock sync.RWMutex
+}
+
+func (engine *Engine) setCacher(tableName string, cacher core.Cacher) {
+	engine.cacherLock.Lock()
+	engine.cachers[tableName] = cacher
+	engine.cacherLock.Unlock()
+}
+
+func (engine *Engine) SetCacher(tableName string, cacher core.Cacher) {
+	engine.setCacher(tableName, cacher)
+}
+
+func (engine *Engine) getCacher(tableName string) core.Cacher {
+	var cacher core.Cacher
+	var ok bool
+	engine.cacherLock.RLock()
+	cacher, ok = engine.cachers[tableName]
+	engine.cacherLock.RUnlock()
+	if !ok && !engine.disableGlobalCache {
+		cacher = engine.Cacher
+	}
+	return cacher
+}
+
+func (engine *Engine) GetCacher(tableName string) core.Cacher {
+	return engine.getCacher(tableName)
 }
 
 // BufferSize sets buffer size for iterate
@@ -248,13 +277,7 @@ func (engine *Engine) NoCascade() *Session {
 
 // MapCacher Set a table use a special cacher
 func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) error {
-	v := rValue(bean)
-	tb, err := engine.autoMapType(v)
-	if err != nil {
-		return err
-	}
-
-	tb.Cacher = cacher
+	engine.setCacher(engine.TableName(bean, true), cacher)
 	return nil
 }
 
@@ -843,15 +866,6 @@ func addIndex(indexName string, table *core.Table, col *core.Column, indexType i
 	}
 }
 
-func (engine *Engine) newTable() *core.Table {
-	table := core.NewEmptyTable()
-
-	if !engine.disableGlobalCache {
-		table.Cacher = engine.Cacher
-	}
-	return table
-}
-
 // TableName table name interface to define customerize table name
 type TableName interface {
 	TableName() string
@@ -863,7 +877,7 @@ var (
 
 func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
 	t := v.Type()
-	table := engine.newTable()
+	table := core.NewEmptyTable()
 	table.Type = t
 	table.Name = engine.tbNameForMap(v)
 
@@ -1019,15 +1033,15 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
 	if hasCacheTag {
 		if engine.Cacher != nil { // !nash! use engine's cacher if provided
 			engine.logger.Info("enable cache on table:", table.Name)
-			table.Cacher = engine.Cacher
+			engine.setCacher(table.Name, engine.Cacher)
 		} else {
 			engine.logger.Info("enable LRU cache on table:", table.Name)
-			table.Cacher = NewLRUCacher2(NewMemoryStore(), time.Hour, 10000) // !nashtsai! HACK use LRU cacher for now
+			engine.setCacher(table.Name, NewLRUCacher2(NewMemoryStore(), time.Hour, 10000))
 		}
 	}
 	if hasNoCacheTag {
-		engine.logger.Info("no cache on table:", table.Name)
-		table.Cacher = nil
+		engine.logger.Info("disable cache on table:", table.Name)
+		engine.setCacher(table.Name, nil)
 	}
 
 	return table, nil
@@ -1132,26 +1146,10 @@ func (engine *Engine) CreateUniques(bean interface{}) error {
 	return session.CreateUniques(bean)
 }
 
-func (engine *Engine) getCacher2(table *core.Table) core.Cacher {
-	return table.Cacher
-}
-
 // ClearCacheBean if enabled cache, clear the cache bean
 func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
-	v := rValue(bean)
-	t := v.Type()
-	if t.Kind() != reflect.Struct {
-		return errors.New("error params")
-	}
 	tableName := engine.TableName(bean)
-	table, err := engine.autoMapType(v)
-	if err != nil {
-		return err
-	}
-	cacher := table.Cacher
-	if cacher == nil {
-		cacher = engine.Cacher
-	}
+	cacher := engine.getCacher(tableName)
 	if cacher != nil {
 		cacher.ClearIds(tableName)
 		cacher.DelBean(tableName, id)
@@ -1162,21 +1160,8 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
 // ClearCache if enabled cache, clear some tables' cache
 func (engine *Engine) ClearCache(beans ...interface{}) error {
 	for _, bean := range beans {
-		v := rValue(bean)
-		t := v.Type()
-		if t.Kind() != reflect.Struct {
-			return errors.New("error params")
-		}
 		tableName := engine.TableName(bean)
-		table, err := engine.autoMapType(v)
-		if err != nil {
-			return err
-		}
-
-		cacher := table.Cacher
-		if cacher == nil {
-			cacher = engine.Cacher
-		}
+		cacher := engine.getCacher(tableName)
 		if cacher != nil {
 			cacher.ClearIds(tableName)
 			cacher.ClearBeans(tableName)

+ 2 - 0
interface.go

@@ -77,6 +77,7 @@ type EngineInterface interface {
 	Dialect() core.Dialect
 	DropTables(...interface{}) error
 	DumpAllToFile(fp string, tp ...core.DbType) error
+	GetCacher(string) core.Cacher
 	GetColumnMapper() core.IMapper
 	GetDefaultCacher() core.Cacher
 	GetTableMapper() core.IMapper
@@ -85,6 +86,7 @@ type EngineInterface interface {
 	NewSession() *Session
 	NoAutoTime() *Session
 	Quote(string) string
+	SetCacher(string, core.Cacher)
 	SetDefaultCacher(core.Cacher)
 	SetLogLevel(core.LogLevel)
 	SetMapper(core.IMapper)

+ 2 - 2
session_delete.go

@@ -27,7 +27,7 @@ func (session *Session) cacheDelete(table *core.Table, tableName, sqlStr string,
 		return ErrCacheFailed
 	}
 
-	cacher := session.engine.getCacher2(table)
+	cacher := session.engine.getCacher(tableName)
 	pkColumns := table.PKColumns()
 	ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
 	if err != nil {
@@ -199,7 +199,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
 		})
 	}
 
-	if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
+	if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache {
 		session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...)
 	}
 

+ 7 - 3
session_find.go

@@ -197,7 +197,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
 	}
 
 	if session.canCache() {
-		if cacher := session.engine.getCacher2(table); cacher != nil &&
+		if cacher := session.engine.getCacher(table.Name); cacher != nil &&
 			!session.statement.IsDistinct &&
 			!session.statement.unscoped {
 			err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...)
@@ -369,6 +369,12 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
 		return ErrCacheFailed
 	}
 
+	tableName := session.statement.TableName()
+	cacher := session.engine.getCacher(tableName)
+	if cacher == nil {
+		return nil
+	}
+
 	for _, filter := range session.engine.dialect.Filters() {
 		sqlStr = filter.Do(sqlStr, session.engine.dialect, session.statement.RefTable)
 	}
@@ -378,9 +384,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
 		return ErrCacheFailed
 	}
 
-	tableName := session.statement.TableName()
 	table := session.statement.RefTable
-	cacher := session.engine.getCacher2(table)
 	ids, err := core.GetCacheSql(cacher, tableName, newsql, args)
 	if err != nil {
 		rows, err := session.queryRows(newsql, args...)

+ 3 - 2
session_get.go

@@ -68,7 +68,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
 	table := session.statement.RefTable
 
 	if session.canCache() && beanValue.Elem().Kind() == reflect.Struct {
-		if cacher := session.engine.getCacher2(table); cacher != nil &&
+		if cacher := session.engine.getCacher(table.Name); cacher != nil &&
 			!session.statement.unscoped {
 			has, err := session.cacheGet(bean, sqlStr, args...)
 			if err != ErrCacheFailed {
@@ -145,8 +145,9 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
 		return false, ErrCacheFailed
 	}
 
-	cacher := session.engine.getCacher2(session.statement.RefTable)
 	tableName := session.statement.TableName()
+	cacher := session.engine.getCacher(tableName)
+
 	session.engine.logger.Debug("[cacheGet] find sql:", newsql, args)
 	table := session.statement.RefTable
 	ids, err := core.GetCacheSql(cacher, tableName, newsql, args)

+ 14 - 23
session_insert.go

@@ -70,7 +70,8 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
 		return 0, err
 	}
 
-	if len(session.statement.TableName()) <= 0 {
+	tableName := session.statement.TableName()
+	if len(tableName) <= 0 {
 		return 0, ErrTableNotFound
 	}
 
@@ -205,7 +206,6 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
 
 	var sql = "INSERT INTO %s (%v%v%v) VALUES (%v)"
 	var statement string
-	var tableName = session.statement.TableName()
 	if session.engine.dialect.DBType() == core.ORACLE {
 		sql = "INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL"
 		temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (",
@@ -232,9 +232,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
 		return 0, err
 	}
 
-	if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
-		session.cacheInsert(table, tableName)
-	}
+	session.cacheInsert(tableName)
 
 	lenAfterClosures := len(session.afterClosures)
 	for i := 0; i < size; i++ {
@@ -394,9 +392,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
 
 		defer handleAfterInsertProcessorFunc(bean)
 
-		if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
-			session.cacheInsert(table, tableName)
-		}
+		session.cacheInsert(tableName)
 
 		if table.Version != "" && session.statement.checkVersion {
 			verValue, err := table.VersionColumn().ValueOf(bean)
@@ -439,9 +435,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
 		}
 		defer handleAfterInsertProcessorFunc(bean)
 
-		if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
-			session.cacheInsert(table, tableName)
-		}
+		session.cacheInsert(tableName)
 
 		if table.Version != "" && session.statement.checkVersion {
 			verValue, err := table.VersionColumn().ValueOf(bean)
@@ -482,9 +476,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
 
 		defer handleAfterInsertProcessorFunc(bean)
 
-		if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
-			session.cacheInsert(table, tableName)
-		}
+		session.cacheInsert(tableName)
 
 		if table.Version != "" && session.statement.checkVersion {
 			verValue, err := table.VersionColumn().ValueOf(bean)
@@ -531,17 +523,16 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) {
 	return session.innerInsert(bean)
 }
 
-func (session *Session) cacheInsert(table *core.Table, tables ...string) error {
-	if table == nil {
-		return ErrCacheFailed
+func (session *Session) cacheInsert(table string) error {
+	if !session.statement.UseCache {
+		return nil
 	}
-
-	cacher := session.engine.getCacher2(table)
-	for _, t := range tables {
-		session.engine.logger.Debug("[cache] clear sql:", t)
-		cacher.ClearIds(t)
+	cacher := session.engine.getCacher(table)
+	if cacher == nil {
+		return nil
 	}
-
+	session.engine.logger.Debug("[cache] clear sql:", table)
+	cacher.ClearIds(table)
 	return nil
 }
 

+ 6 - 7
session_update.go

@@ -40,7 +40,7 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
 		}
 	}
 
-	cacher := session.engine.getCacher2(table)
+	cacher := session.engine.getCacher(tableName)
 	session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:])
 	ids, err := core.GetCacheSql(cacher, tableName, newsql, args[nStart:])
 	if err != nil {
@@ -361,12 +361,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
 		}
 	}
 
-	if table != nil {
-		if cacher := session.engine.getCacher2(table); cacher != nil && session.statement.UseCache {
-			//session.cacheUpdate(table, tableName, sqlStr, args...)
-			cacher.ClearIds(tableName)
-			cacher.ClearBeans(tableName)
-		}
+	if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache {
+		//session.cacheUpdate(table, tableName, sqlStr, args...)
+		session.engine.logger.Debug("[cacheUpdate] clear table ", tableName)
+		cacher.ClearIds(tableName)
+		cacher.ClearBeans(tableName)
 	}
 
 	// handle after update processors

+ 17 - 2
session_update_test.go

@@ -892,7 +892,6 @@ func TestUpdateSameMapper(t *testing.T) {
 	}
 
 	{
-
 		col1 := &UpdateIncr{}
 		err = testEngine.Sync(col1)
 		if err != nil {
@@ -1199,7 +1198,7 @@ func TestUpdateMapContent(t *testing.T) {
 	assert.EqualValues(t, 0, c1.Age)
 
 	cnt, err = testEngine.Table(new(UpdateMapContent)).ID(c.Id).Update(map[string]interface{}{
-		"age": 16,
+		"age":    16,
 		"is_man": false,
 		"gender": 2,
 	})
@@ -1213,4 +1212,20 @@ func TestUpdateMapContent(t *testing.T) {
 	assert.EqualValues(t, 16, c2.Age)
 	assert.EqualValues(t, false, c2.IsMan)
 	assert.EqualValues(t, 2, c2.Gender)
+
+	cnt, err = testEngine.Table(testEngine.TableName(new(UpdateMapContent))).ID(c.Id).Update(map[string]interface{}{
+		"age":    15,
+		"is_man": true,
+		"gender": 1,
+	})
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, cnt)
+
+	var c3 UpdateMapContent
+	has, err = testEngine.ID(c.Id).Get(&c3)
+	assert.NoError(t, err)
+	assert.True(t, has)
+	assert.EqualValues(t, 15, c3.Age)
+	assert.EqualValues(t, true, c3.IsMan)
+	assert.EqualValues(t, 1, c3.Gender)
 }

+ 5 - 5
statement.go

@@ -948,14 +948,14 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
 		columnStr = "*"
 	}
 
-	if err := statement.processIDParam(); err != nil {
-		return "", nil, err
-	}
-
 	if isStruct {
 		if err := statement.mergeConds(bean); err != nil {
 			return "", nil, err
 		}
+	} else {
+		if err := statement.processIDParam(); err != nil {
+			return "", nil, err
+		}
 	}
 	condSQL, condArgs, err := builder.ToSQL(statement.cond)
 	if err != nil {
@@ -1141,7 +1141,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
 }
 
 func (statement *Statement) processIDParam() error {
-	if statement.idParam == nil {
+	if statement.idParam == nil || statement.RefTable == nil {
 		return nil
 	}
 

+ 2 - 6
tag_cache_test.go

@@ -19,9 +19,7 @@ func TestCacheTag(t *testing.T) {
 	}
 
 	assert.NoError(t, testEngine.CreateTables(&CacheDomain{}))
-
-	table := testEngine.TableInfo(&CacheDomain{})
-	assert.True(t, table.Cacher != nil)
+	assert.True(t, testEngine.GetCacher(testEngine.TableName(&CacheDomain{})) != nil)
 }
 
 func TestNoCacheTag(t *testing.T) {
@@ -33,7 +31,5 @@ func TestNoCacheTag(t *testing.T) {
 	}
 
 	assert.NoError(t, testEngine.CreateTables(&NoCacheDomain{}))
-
-	table := testEngine.TableInfo(&NoCacheDomain{})
-	assert.True(t, table.Cacher == nil)
+	assert.True(t, testEngine.GetCacher(testEngine.TableName(&NoCacheDomain{})) == nil)
 }

+ 2 - 1
xorm.go

@@ -17,7 +17,7 @@ import (
 
 const (
 	// Version show the xorm's version
-	Version string = "0.6.5.0411"
+	Version string = "0.6.5.0412"
 )
 
 func regDrvsNDialects() bool {
@@ -90,6 +90,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
 		TagIdentifier: "xorm",
 		TZLocation:    time.Local,
 		tagHandlers:   defaultTagHandlers,
+		cachers:       make(map[string]core.Cacher),
 	}
 
 	if uri.DbType == core.SQLITE {