فهرست منبع

improve processors

xormplus 8 سال پیش
والد
کامیت
0e84b39c04
6فایلهای تغییر یافته به همراه170 افزوده شده و 29 حذف شده
  1. 33 7
      processors.go
  2. 65 0
      processors_test.go
  3. 7 3
      rows.go
  4. 52 16
      session.go
  5. 6 1
      session_find.go
  6. 7 2
      session_get.go

+ 33 - 7
processors.go

@@ -29,13 +29,6 @@ type AfterSetProcessor interface {
 	AfterSet(string, Cell)
 }
 
-// !nashtsai! TODO enable BeforeValidateProcessor when xorm start to support validations
-//// Executed before an object is validated
-//type BeforeValidateProcessor interface {
-//    BeforeValidate()
-//}
-// --
-
 // AfterInsertProcessor executed after an object is persisted to the database
 type AfterInsertProcessor interface {
 	AfterInsert()
@@ -50,3 +43,36 @@ type AfterUpdateProcessor interface {
 type AfterDeleteProcessor interface {
 	AfterDelete()
 }
+
+// AfterLoadProcessor executed after an ojbect has been loaded from database
+type AfterLoadProcessor interface {
+	AfterLoad()
+}
+
+// AfterLoadSessionProcessor executed after an ojbect has been loaded from database with session parameter
+type AfterLoadSessionProcessor interface {
+	AfterLoad(*Session)
+}
+
+type executedProcessorFunc func(*Session, interface{}) error
+
+type executedProcessor struct {
+	fun     executedProcessorFunc
+	session *Session
+	bean    interface{}
+}
+
+func (executor *executedProcessor) execute() error {
+	return executor.fun(executor.session, executor.bean)
+}
+
+func (session *Session) executeProcessors() error {
+	processors := session.afterProcessors
+	session.afterProcessors = make([]executedProcessor, 0)
+	for _, processor := range processors {
+		if err := processor.execute(); err != nil {
+			return err
+		}
+	}
+	return nil
+}

+ 65 - 0
processors_test.go

@@ -964,3 +964,68 @@ func TestProcessorsTx(t *testing.T) {
 	session.Close()
 	// --
 }
+
+type AfterLoadStructA struct {
+	Id      int64
+	Content string
+}
+
+type AfterLoadStructB struct {
+	Id      int64
+	Content string
+	AId     int64
+	A       AfterLoadStructA `xorm:"-"`
+	Err     error            `xorm:"-"`
+}
+
+func (s *AfterLoadStructB) AfterLoad(session *Session) {
+	has, err := session.ID(s.AId).NoAutoCondition().Get(&s.A)
+	if err != nil {
+		s.Err = err
+		return
+	}
+	if !has {
+		s.Err = ErrNotExist
+	}
+}
+
+func TestAfterLoadProcessor(t *testing.T) {
+	assert.NoError(t, prepareEngine())
+
+	assertSync(t, new(AfterLoadStructA), new(AfterLoadStructB))
+
+	var a = AfterLoadStructA{
+		Content: "testa",
+	}
+	_, err := testEngine.Insert(&a)
+	assert.NoError(t, err)
+
+	var b = AfterLoadStructB{
+		Content: "testb",
+		AId:     a.Id,
+	}
+	_, err = testEngine.Insert(&b)
+	assert.NoError(t, err)
+
+	var b2 AfterLoadStructB
+	has, err := testEngine.ID(b.Id).Get(&b2)
+	assert.NoError(t, err)
+	assert.True(t, has)
+	assert.EqualValues(t, a.Id, b2.A.Id)
+	assert.EqualValues(t, a.Content, b2.A.Content)
+	assert.NoError(t, b2.Err)
+
+	b.Id = 0
+	_, err = testEngine.Insert(&b)
+	assert.NoError(t, err)
+
+	var bs []AfterLoadStructB
+	err = testEngine.Find(&bs)
+	assert.NoError(t, err)
+	assert.EqualValues(t, 2, len(bs))
+	for i := 0; i < len(bs); i++ {
+		assert.EqualValues(t, a.Id, bs[i].A.Id)
+		assert.EqualValues(t, a.Content, bs[i].A.Content)
+		assert.NoError(t, bs[i].Err)
+	}
+}

+ 7 - 3
rows.go

@@ -99,13 +99,17 @@ func (rows *Rows) Scan(bean interface{}) error {
 		return err
 	}
 
-	scanResults, err := rows.session.row2Slice(rows.rows, rows.fields, len(rows.fields), bean)
+	scanResults, err := rows.session.row2Slice(rows.rows, rows.fields, bean)
 	if err != nil {
 		return err
 	}
 
-	_, err = rows.session.slice2Bean(scanResults, rows.fields, len(rows.fields), bean, &dataStruct, rows.session.statement.RefTable)
-	return err
+	_, err = rows.session.slice2Bean(scanResults, rows.fields, bean, &dataStruct, rows.session.statement.RefTable)
+	if err != nil {
+		return err
+	}
+
+	return rows.session.executeProcessors()
 }
 
 // Close session if session.IsAutoClose is true, and claimed any opened resources

+ 52 - 16
session.go

@@ -43,6 +43,8 @@ type Session struct {
 	beforeClosures []func(interface{})
 	afterClosures  []func(interface{})
 
+	afterProcessors []executedProcessor
+
 	prepareStmt bool
 	stmtCache   map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr))
 
@@ -80,6 +82,8 @@ func (session *Session) Init() {
 	session.beforeClosures = make([]func(interface{}), 0)
 	session.afterClosures = make([]func(interface{}), 0)
 
+	session.afterProcessors = make([]executedProcessor, 0)
+
 	session.lastSQL = ""
 	session.lastSQLArgs = []interface{}{}
 }
@@ -302,37 +306,40 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *c
 // Cell cell is a result of one column field
 type Cell *interface{}
 
-func (session *Session) rows2Beans(rows *core.Rows, fields []string, fieldsCount int,
+func (session *Session) rows2Beans(rows *core.Rows, fields []string,
 	table *core.Table, newElemFunc func([]string) reflect.Value,
 	sliceValueSetFunc func(*reflect.Value, core.PK) error) error {
 	for rows.Next() {
 		var newValue = newElemFunc(fields)
 		bean := newValue.Interface()
-		dataStruct := rValue(bean)
+		dataStruct := newValue.Elem()
 
 		// handle beforeClosures
-		scanResults, err := session.row2Slice(rows, fields, fieldsCount, bean)
+		scanResults, err := session.row2Slice(rows, fields, bean)
 		if err != nil {
 			return err
 		}
-		pk, err := session.slice2Bean(scanResults, fields, fieldsCount, bean, &dataStruct, table)
-		if err != nil {
-			return err
-		}
-		err = sliceValueSetFunc(&newValue, pk)
+		pk, err := session.slice2Bean(scanResults, fields, bean, &dataStruct, table)
 		if err != nil {
 			return err
 		}
+		session.afterProcessors = append(session.afterProcessors, executedProcessor{
+			fun: func(*Session, interface{}) error {
+				return sliceValueSetFunc(&newValue, pk)
+			},
+			session: session,
+			bean:    bean,
+		})
 	}
 	return nil
 }
 
-func (session *Session) row2Slice(rows *core.Rows, fields []string, fieldsCount int, bean interface{}) ([]interface{}, error) {
+func (session *Session) row2Slice(rows *core.Rows, fields []string, bean interface{}) ([]interface{}, error) {
 	for _, closure := range session.beforeClosures {
 		closure(bean)
 	}
 
-	scanResults := make([]interface{}, fieldsCount)
+	scanResults := make([]interface{}, len(fields))
 	for i := 0; i < len(fields); i++ {
 		var cell interface{}
 		scanResults[i] = &cell
@@ -349,20 +356,49 @@ func (session *Session) row2Slice(rows *core.Rows, fields []string, fieldsCount
 	return scanResults, nil
 }
 
-func (session *Session) slice2Bean(scanResults []interface{}, fields []string, fieldsCount int, bean interface{}, dataStruct *reflect.Value, table *core.Table) (core.PK, error) {
+func (session *Session) slice2Bean(scanResults []interface{}, fields []string, bean interface{}, dataStruct *reflect.Value, table *core.Table) (core.PK, error) {
 	defer func() {
 		if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet {
 			for ii, key := range fields {
 				b.AfterSet(key, Cell(scanResults[ii].(*interface{})))
 			}
 		}
-
-		// handle afterClosures
-		for _, closure := range session.afterClosures {
-			closure(bean)
-		}
 	}()
 
+	// handle afterClosures
+	for _, closure := range session.afterClosures {
+		session.afterProcessors = append(session.afterProcessors, executedProcessor{
+			fun: func(sess *Session, bean interface{}) error {
+				closure(bean)
+				return nil
+			},
+			session: session,
+			bean:    bean,
+		})
+	}
+
+	if a, has := bean.(AfterLoadProcessor); has {
+		session.afterProcessors = append(session.afterProcessors, executedProcessor{
+			fun: func(sess *Session, bean interface{}) error {
+				a.AfterLoad()
+				return nil
+			},
+			session: session,
+			bean:    bean,
+		})
+	}
+
+	if a, has := bean.(AfterLoadSessionProcessor); has {
+		session.afterProcessors = append(session.afterProcessors, executedProcessor{
+			fun: func(sess *Session, bean interface{}) error {
+				a.AfterLoad(sess)
+				return nil
+			},
+			session: session,
+			bean:    bean,
+		})
+	}
+
 	var tempMap = make(map[string]int)
 	var pk core.PK
 	for ii, key := range fields {

+ 6 - 1
session_find.go

@@ -262,7 +262,12 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va
 		if err != nil {
 			return err
 		}
-		return session.rows2Beans(rows, fields, len(fields), tb, newElemFunc, containerValueSetFunc)
+		err = session.rows2Beans(rows, fields, tb, newElemFunc, containerValueSetFunc)
+		rows.Close()
+		if err != nil {
+			return err
+		}
+		return session.executeProcessors()
 	}
 
 	for rows.Next() {

+ 7 - 2
session_get.go

@@ -98,7 +98,7 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bea
 			return true, err
 		}
 
-		scanResults, err := session.row2Slice(rows, fields, len(fields), bean)
+		scanResults, err := session.row2Slice(rows, fields, bean)
 		if err != nil {
 			return false, err
 		}
@@ -106,7 +106,12 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bea
 		rows.Close()
 
 		dataStruct := rValue(bean)
-		_, err = session.slice2Bean(scanResults, fields, len(fields), bean, &dataStruct, table)
+		_, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table)
+		if err != nil {
+			return true, err
+		}
+
+		return true, session.executeProcessors()
 	case reflect.Slice:
 		err = rows.ScanSlice(bean)
 	case reflect.Map: