xormplus 8 年之前
父节点
当前提交
7070f93aa5
共有 1 个文件被更改,包括 39 次插入28 次删除
  1. 39 28
      session_update.go

+ 39 - 28
session_update.go

@@ -253,48 +253,59 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
 	var condSQL string
 	cond := session.Statement.cond.And(autoCond)
 
-	doIncVer := false
+	var doIncVer = (table != nil && table.Version != "" && session.Statement.checkVersion)
 	var verValue *reflect.Value
-	if table != nil && table.Version != "" && session.Statement.checkVersion {
+	if doIncVer {
 		verValue, err = table.VersionColumn().ValueOf(bean)
 		if err != nil {
 			return 0, err
 		}
 
 		cond = cond.And(builder.Eq{session.Engine.Quote(table.Version): verValue.Interface()})
-		condSQL, condArgs, _ = builder.ToSQL(cond)
-
-		if len(condSQL) > 0 {
-			condSQL = "WHERE " + condSQL
-		}
-
-		if st.LimitN > 0 {
-			condSQL = condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
-		}
+		colNames = append(colNames, session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1")
+	}
 
-		sqlStr = fmt.Sprintf("UPDATE %v SET %v, %v %v",
-			session.Engine.Quote(session.Statement.TableName()),
-			strings.Join(colNames, ", "),
-			session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1",
-			condSQL)
+	condSQL, condArgs, _ = builder.ToSQL(cond)
+	if len(condSQL) > 0 {
+		condSQL = "WHERE " + condSQL
+	}
 
-		doIncVer = true
-	} else {
-		condSQL, condArgs, _ = builder.ToSQL(cond)
-		if len(condSQL) > 0 {
-			condSQL = "WHERE " + condSQL
-		}
+	if st.OrderStr != "" {
+		condSQL = condSQL + fmt.Sprintf(" ORDER BY %v", st.OrderStr)
+	}
 
-		if st.LimitN > 0 {
+	// TODO: Oracle support needed
+	var top string
+	if st.LimitN > 0 {
+		if st.Engine.dialect.DBType() == core.MYSQL {
 			condSQL = condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
+		} else if st.Engine.dialect.DBType() == core.SQLITE {
+			tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
+			cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
+				session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...))
+			condSQL, condArgs, _ = builder.ToSQL(cond)
+			if len(condSQL) > 0 {
+				condSQL = "WHERE " + condSQL
+			}
+		} else if st.Engine.dialect.DBType() == core.POSTGRES {
+			tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
+			cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
+				session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...))
+			condSQL, condArgs, _ = builder.ToSQL(cond)
+			if len(condSQL) > 0 {
+				condSQL = "WHERE " + condSQL
+			}
+		} else if st.Engine.dialect.DBType() == core.MSSQL {
+			top = fmt.Sprintf("top (%d) ", st.LimitN)
 		}
-
-		sqlStr = fmt.Sprintf("UPDATE %v SET %v %v",
-			session.Engine.Quote(session.Statement.TableName()),
-			strings.Join(colNames, ", "),
-			condSQL)
 	}
 
+	sqlStr = fmt.Sprintf("UPDATE %v%v SET %v %v",
+		top,
+		session.Engine.Quote(session.Statement.TableName()),
+		strings.Join(colNames, ", "),
+		condSQL)
+
 	res, err := session.exec(sqlStr, append(args, condArgs...)...)
 	if err != nil {
 		return 0, err