瀏覽代碼

fix FindAndCount bug with Limit

xormplus 7 年之前
父節點
當前提交
1216548386
共有 4 個文件被更改,包括 27 次插入19 次删除
  1. 1 1
      session_find.go
  2. 10 4
      session_find_test.go
  3. 1 1
      session_query.go
  4. 15 13
      statement.go

+ 1 - 1
session_find.go

@@ -151,7 +151,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
 		}
 
 		args = append(session.statement.joinArgs, condArgs...)
-		sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL)
+		sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL, true)
 		if err != nil {
 			return err
 		}

+ 10 - 4
session_find_test.go

@@ -523,9 +523,9 @@ func TestFindMark(t *testing.T) {
 
 func TestFindAndCountOneFunc(t *testing.T) {
 	type FindAndCountStruct struct {
-		Id  int64
+		Id      int64
 		Content string
-		Msg bool `xorm:"bit"`
+		Msg     bool `xorm:"bit"`
 	}
 
 	assert.NoError(t, prepareEngine())
@@ -534,11 +534,11 @@ func TestFindAndCountOneFunc(t *testing.T) {
 	cnt, err := testEngine.Insert([]FindAndCountStruct{
 		{
 			Content: "111",
-			Msg: false,
+			Msg:     false,
 		},
 		{
 			Content: "222",
-			Msg: true,
+			Msg:     true,
 		},
 	})
 	assert.NoError(t, err)
@@ -555,4 +555,10 @@ func TestFindAndCountOneFunc(t *testing.T) {
 	assert.NoError(t, err)
 	assert.EqualValues(t, 1, len(results))
 	assert.EqualValues(t, 1, cnt)
+
+	results = make([]FindAndCountStruct, 0, 1)
+	cnt, err = testEngine.Where("msg = ?", true).Limit(1).FindAndCount(&results)
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, len(results))
+	assert.EqualValues(t, 1, cnt)
 }

+ 1 - 1
session_query.go

@@ -90,7 +90,7 @@ func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interfa
 	}
 
 	args := append(session.statement.joinArgs, condArgs...)
-	sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL)
+	sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL, true)
 	if err != nil {
 		return "", nil, err
 	}

+ 15 - 13
statement.go

@@ -988,7 +988,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
 		return "", nil, err
 	}
 
-	sqlStr, err := statement.genSelectSQL(columnStr, condSQL)
+	sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true)
 	if err != nil {
 		return "", nil, err
 	}
@@ -1018,7 +1018,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
 			selectSQL = "count(*)"
 		}
 	}
-	sqlStr, err := statement.genSelectSQL(selectSQL, condSQL)
+	sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false)
 	if err != nil {
 		return "", nil, err
 	}
@@ -1043,7 +1043,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
 		return "", nil, err
 	}
 
-	sqlStr, err := statement.genSelectSQL(sumSelect, condSQL)
+	sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true)
 	if err != nil {
 		return "", nil, err
 	}
@@ -1051,7 +1051,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
 	return sqlStr, append(statement.joinArgs, condArgs...), nil
 }
 
-func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) {
+func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit bool) (a string, err error) {
 	var distinct string
 	if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
 		distinct = "DISTINCT "
@@ -1149,15 +1149,17 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, e
 	if statement.OrderStr != "" {
 		a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
 	}
-	if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
-		if statement.Start > 0 {
-			a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
-		} else if statement.LimitN > 0 {
-			a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
-		}
-	} else if dialect.DBType() == core.ORACLE {
-		if statement.Start != 0 || statement.LimitN != 0 {
-			a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start)
+	if needLimit {
+		if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
+			if statement.Start > 0 {
+				a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
+			} else if statement.LimitN > 0 {
+				a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
+			}
+		} else if dialect.DBType() == core.ORACLE {
+			if statement.Start != 0 || statement.LimitN != 0 {
+				a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start)
+			}
 		}
 	}
 	if statement.IsForUpdate {