浏览代码

add deleted condition for Update operation

xormplus 6 年之前
父节点
当前提交
1d70ccbd77
共有 4 个文件被更改,包括 78 次插入17 次删除
  1. 1 1
      README.md
  2. 4 0
      session_exist.go
  3. 31 15
      session_update.go
  4. 42 1
      session_update_test.go

+ 1 - 1
README.md

@@ -480,7 +480,7 @@ id := engine.SqlTemplateClient(key, &paramMap).Query().Results[0]["id"] //返回
 
 * 事物的简写方法
  ```Go
-res, err := engine.Transaction(func(sess *xorm.Session) (interface{}, error) {
+res, err := engine.Transaction(func(session *xorm.Session) (interface{}, error) {
     user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
     if _, err := session.Insert(&user1); err != nil {
         return nil, err

+ 4 - 0
session_exist.go

@@ -42,6 +42,8 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) {
 
 				if session.engine.dialect.DBType() == core.MSSQL {
 					sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s WHERE %s", tableName, condSQL)
+				} else if session.engine.dialect.DBType() == core.ORACLE {
+					sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) AND ROWNUM=1", tableName, condSQL)
 				} else {
 					sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE %s LIMIT 1", tableName, condSQL)
 				}
@@ -49,6 +51,8 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) {
 			} else {
 				if session.engine.dialect.DBType() == core.MSSQL {
 					sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s", tableName)
+				} else if session.engine.dialect.DBType() == core.ORACLE {
+					sqlStr = fmt.Sprintf("SELECT * FROM  %s WHERE ROWNUM=1", tableName)
 				} else {
 					sqlStr = fmt.Sprintf("SELECT * FROM %s LIMIT 1", tableName)
 				}

+ 31 - 15
session_update.go

@@ -244,23 +244,39 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
 	}
 
 	var autoCond builder.Cond
-	if !session.statement.noAutoCondition && len(condiBean) > 0 {
-		if c, ok := condiBean[0].(map[string]interface{}); ok {
-			autoCond = builder.Eq(c)
-		} else {
-			ct := reflect.TypeOf(condiBean[0])
-			k := ct.Kind()
-			if k == reflect.Ptr {
-				k = ct.Elem().Kind()
+	if !session.statement.noAutoCondition {
+		condBeanIsStruct := false
+		if len(condiBean) > 0 {
+			if c, ok := condiBean[0].(map[string]interface{}); ok {
+				autoCond = builder.Eq(c)
+			} else {
+				ct := reflect.TypeOf(condiBean[0])
+				k := ct.Kind()
+				if k == reflect.Ptr {
+					k = ct.Elem().Kind()
+				}
+				if k == reflect.Struct {
+					var err error
+					autoCond, err = session.statement.buildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false)
+					if err != nil {
+						return 0, err
+					}
+					condBeanIsStruct = true
+				} else {
+					return 0, ErrConditionType
+				}
 			}
-			if k == reflect.Struct {
-				var err error
-				autoCond, err = session.statement.buildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false)
-				if err != nil {
-					return 0, err
+		}
+
+		if !condBeanIsStruct && table != nil {
+			if col := table.DeletedColumn(); col != nil && !session.statement.unscoped { // tag "deleted" is enabled
+				autoCond1 := session.engine.CondDeleted(session.engine.Quote(col.Name))
+
+				if autoCond == nil {
+					autoCond = autoCond1
+				} else {
+					autoCond = autoCond.And(autoCond1)
 				}
-			} else {
-				return 0, ErrConditionType
 			}
 		}
 	}

+ 42 - 1
session_update_test.go

@@ -110,7 +110,7 @@ func setupForUpdate(engine EngineInterface) error {
 }
 
 func TestForUpdate(t *testing.T) {
-	if testEngine.Dialect().DriverName() != "mysql" && testEngine.Dialect().DriverName() != "mymysql" {
+	if *ignoreSelectUpdate {
 		return
 	}
 
@@ -1349,3 +1349,44 @@ func TestWhereCondErrorWhenUpdate(t *testing.T) {
 	assert.Error(t, err)
 	assert.EqualValues(t, ErrConditionType, err)
 }
+
+func TestUpdateDeleted(t *testing.T) {
+	assert.NoError(t, prepareEngine())
+
+	type UpdateDeletedStruct struct {
+		Id        int64
+		Name      string
+		DeletedAt time.Time `xorm:"deleted"`
+	}
+
+	assertSync(t, new(UpdateDeletedStruct))
+
+	var s = UpdateDeletedStruct{
+		Name: "test",
+	}
+	cnt, err := testEngine.Insert(&s)
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, cnt)
+
+	cnt, err = testEngine.ID(s.Id).Delete(&UpdateDeletedStruct{})
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, cnt)
+
+	cnt, err = testEngine.ID(s.Id).Update(&UpdateDeletedStruct{
+		Name: "test1",
+	})
+	assert.NoError(t, err)
+	assert.EqualValues(t, 0, cnt)
+
+	cnt, err = testEngine.Table(&UpdateDeletedStruct{}).ID(s.Id).Update(map[string]interface{}{
+		"name": "test1",
+	})
+	assert.NoError(t, err)
+	assert.EqualValues(t, 0, cnt)
+
+	cnt, err = testEngine.ID(s.Id).Unscoped().Update(&UpdateDeletedStruct{
+		Name: "test1",
+	})
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, cnt)
+}