浏览代码

added genSelectSql mothed

xormplus 9 年之前
父节点
当前提交
3801d0267e
共有 2 个文件被更改,包括 98 次插入3 次删除
  1. 3 1
      session.go
  2. 95 2
      sessionplus.go

+ 3 - 1
session.go

@@ -1412,7 +1412,9 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
 
 	if sliceValue.Kind() != reflect.Map {
 		if session.IsSqlFuc {
-			sql := session.Statement.RawSQL
+			rownumber := "xorm" + NewShortUUID().String()
+			sql := session.genSelectSql(rownumber)
+
 			params := session.Statement.RawParams
 			i := len(params)
 			if i == 1 {

+ 95 - 2
sessionplus.go

@@ -421,13 +421,77 @@ func (session *Session) Search(rowsSlicePtr interface{}, condiBean ...interface{
 	return r
 }
 
+func (session *Session) genSelectSql(rownumber string) string {
+	var dialect = session.Statement.Engine.Dialect()
+	var sql = session.Statement.RawSQL
+	var orderBys = session.Statement.OrderStr
+
+	if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
+		if session.Statement.Start > 0 {
+			sql = fmt.Sprintf("%v LIMIT %v OFFSET %v", sql, session.Statement.LimitN, session.Statement.Start)
+		} else if session.Statement.LimitN > 0 {
+			sql = fmt.Sprintf("%v LIMIT %v", sql, session.Statement.LimitN)
+		}
+	} else if dialect.DBType() == core.ORACLE {
+		if session.Statement.Start != 0 || session.Statement.LimitN != 0 {
+			sql = fmt.Sprintf("SELECT aat.* FROM (SELECT at.*,ROWNUM %v FROM (%v) at WHERE ROWNUM <= %d) aat WHERE %v > %d",
+				rownumber, sql, session.Statement.Start+session.Statement.LimitN, rownumber, session.Statement.Start)
+		}
+	} else {
+		keepSelect := false
+		var fullQuery string
+		if session.Statement.Start > 0 {
+			fullQuery = fmt.Sprintf("SELECT sq.* FROM (SELECT ROW_NUMBER() OVER (ORDER BY %v) AS %v,", orderBys, rownumber)
+		} else if session.Statement.LimitN > 0 {
+			fullQuery = fmt.Sprintf("SELECT TOP %d", session.Statement.LimitN)
+		} else {
+			keepSelect = true
+		}
+
+		if !keepSelect {
+			expr := `^\s*SELECT\s*`
+			reg, err := regexp.Compile(expr)
+			if err != nil {
+				fmt.Println(err)
+			}
+			sql = strings.ToUpper(sql)
+			if reg.MatchString(sql) {
+				str := reg.FindAllString(sql, -1)
+				fullQuery = fmt.Sprintf("%v %v", fullQuery, sql[len(str[0]):])
+			}
+		}
+
+		if session.Statement.Start > 0 {
+			// T-SQL offset starts with 1, not like MySQL with 0;
+			if session.Statement.LimitN > 0 {
+				fullQuery = fmt.Sprintf("%v) AS sq WHERE %v BETWEEN %d AND %d", fullQuery, rownumber, session.Statement.Start+1, session.Statement.Start+session.Statement.LimitN)
+			} else {
+				fullQuery = fmt.Sprintf("%v) AS sq WHERE %v >= %d", fullQuery, rownumber, session.Statement.Start+1)
+			}
+		} else {
+			fullQuery = fmt.Sprintf("%v ORDER BY %v", fullQuery, orderBys)
+		}
+
+		if keepSelect {
+			if len(orderBys) > 0 {
+				sql = fmt.Sprintf("%v ORDER BY %v", sql, orderBys)
+			}
+		} else {
+			sql = fullQuery
+		}
+	}
+
+	return sql
+}
+
 // Exec a raw sql and return records as ResultMap
 func (session *Session) Query() *ResultMap {
 	defer session.resetStatement()
 	if session.IsAutoClose {
 		defer session.Close()
 	}
-	sql := session.Statement.RawSQL
+	rownumber := "xorm" + NewShortUUID().String()
+	sql := session.genSelectSql(rownumber)
 	params := session.Statement.RawParams
 	i := len(params)
 
@@ -443,6 +507,20 @@ func (session *Session) Query() *ResultMap {
 	} else {
 		result, err = session.queryAll(sql, params...)
 	}
+	var dialect = session.Statement.Engine.Dialect()
+	if dialect.DBType() == core.MSSQL {
+		if session.Statement.Start > 0 {
+			for i, _ := range result {
+				delete(result[i], rownumber)
+			}
+		}
+	} else if dialect.DBType() == core.ORACLE {
+		if session.Statement.Start != 0 || session.Statement.LimitN != 0 {
+			for i, _ := range result {
+				delete(result[i], rownumber)
+			}
+		}
+	}
 	r := &ResultMap{Results: result, Error: err}
 	return r
 }
@@ -453,7 +531,8 @@ func (session *Session) QueryWithDateFormat(dateFormat string) *ResultMap {
 	if session.IsAutoClose {
 		defer session.Close()
 	}
-	sql := session.Statement.RawSQL
+	rownumber := "xorm" + NewShortUUID().String()
+	sql := session.genSelectSql(rownumber)
 	params := session.Statement.RawParams
 	i := len(params)
 	var result []map[string]interface{}
@@ -468,6 +547,20 @@ func (session *Session) QueryWithDateFormat(dateFormat string) *ResultMap {
 	} else {
 		result, err = session.queryAllWithDateFormat(dateFormat, sql, params...)
 	}
+	var dialect = session.Statement.Engine.Dialect()
+	if dialect.DBType() == core.MSSQL {
+		if session.Statement.Start > 0 {
+			for i, _ := range result {
+				delete(result[i], rownumber)
+			}
+		}
+	} else if dialect.DBType() == core.ORACLE {
+		if session.Statement.Start != 0 || session.Statement.LimitN != 0 {
+			for i, _ := range result {
+				delete(result[i], rownumber)
+			}
+		}
+	}
 	r := &ResultMap{Results: result, Error: err}
 	return r
 }