Переглянути джерело

Exec support builder

* fix group by error
xormplus 7 роки тому
батько
коміт
6a1a8be9ff
7 змінених файлів з 52 додано та 21 видалено
  1. 10 2
      engine.go
  2. 1 1
      interface.go
  3. 2 2
      session_find.go
  4. 9 0
      session_find_test.go
  5. 3 13
      session_query.go
  6. 25 1
      session_raw.go
  7. 2 2
      statement.go

+ 10 - 2
engine.go

@@ -180,6 +180,14 @@ func (engine *Engine) QuoteStr() string {
 	return engine.dialect.QuoteStr()
 }
 
+func (engine *Engine) quoteColumns(columnStr string) string {
+	columns := strings.Split(columnStr, ",")
+	for i := 0; i < len(columns); i++ {
+		columns[i] = engine.Quote(strings.TrimSpace(columns[i]))
+	}
+	return strings.Join(columns, ",")
+}
+
 // Quote Use QuoteStr quote the string sql
 func (engine *Engine) Quote(value string) string {
 	value = strings.TrimSpace(value)
@@ -1342,10 +1350,10 @@ func (engine *Engine) DropIndexes(bean interface{}) error {
 }
 
 // Exec raw sql
-func (engine *Engine) Exec(sql string, args ...interface{}) (sql.Result, error) {
+func (engine *Engine) Exec(sqlorArgs ...interface{}) (sql.Result, error) {
 	session := engine.NewSession()
 	defer session.Close()
-	return session.Exec(sql, args...)
+	return session.Exec(sqlorArgs...)
 }
 
 // Query a raw sql and return records as []map[string][]byte

+ 1 - 1
interface.go

@@ -27,7 +27,7 @@ type Interface interface {
 	Delete(interface{}) (int64, error)
 	Distinct(columns ...string) *Session
 	DropIndexes(bean interface{}) error
-	Exec(string, ...interface{}) (sql.Result, error)
+	Exec(sqlOrAgrs ...interface{}) (sql.Result, error)
 	Exist(bean ...interface{}) (bool, error)
 	Find(interface{}, ...interface{}) error
 	FindAndCount(interface{}, ...interface{}) (int64, error)

+ 2 - 2
session_find.go

@@ -156,7 +156,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
 			if session.statement.JoinStr == "" {
 				if columnStr == "" {
 					if session.statement.GroupByStr != "" {
-						columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1))
+						columnStr = session.engine.quoteColumns(session.statement.GroupByStr)
 					} else {
 						columnStr = session.statement.genColumnStr()
 					}
@@ -164,7 +164,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
 			} else {
 				if columnStr == "" {
 					if session.statement.GroupByStr != "" {
-						columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1))
+						columnStr = session.engine.quoteColumns(session.statement.GroupByStr)
 					} else {
 						columnStr = "*"
 					}

+ 9 - 0
session_find_test.go

@@ -268,6 +268,15 @@ func TestOrder(t *testing.T) {
 	fmt.Println(users2)
 }
 
+func TestGroupBy(t *testing.T) {
+	assert.NoError(t, prepareEngine())
+	assertSync(t, new(Userinfo))
+
+	users := make([]Userinfo, 0)
+	err := testEngine.GroupBy("id, username").Find(&users)
+	assert.NoError(t, err)
+}
+
 func TestHaving(t *testing.T) {
 	assert.NoError(t, prepareEngine())
 	assertSync(t, new(Userinfo))

+ 3 - 13
session_query.go

@@ -17,17 +17,7 @@ import (
 
 func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interface{}, error) {
 	if len(sqlorArgs) > 0 {
-		switch sqlorArgs[0].(type) {
-		case string:
-			return sqlorArgs[0].(string), sqlorArgs[1:], nil
-		case *builder.Builder:
-			return sqlorArgs[0].(*builder.Builder).ToSQL()
-		case builder.Builder:
-			bd := sqlorArgs[0].(builder.Builder)
-			return bd.ToSQL()
-		default:
-			return "", nil, ErrUnSupportedType
-		}
+		return convertSQLOrArgs(sqlorArgs...)
 	}
 
 	if session.statement.RawSQL != "" {
@@ -65,7 +55,7 @@ func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interfa
 		if session.statement.JoinStr == "" {
 			if columnStr == "" {
 				if session.statement.GroupByStr != "" {
-					columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1))
+					columnStr = session.engine.quoteColumns(session.statement.GroupByStr)
 				} else {
 					columnStr = session.statement.genColumnStr()
 				}
@@ -73,7 +63,7 @@ func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interfa
 		} else {
 			if columnStr == "" {
 				if session.statement.GroupByStr != "" {
-					columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1))
+					columnStr = session.engine.quoteColumns(session.statement.GroupByStr)
 				} else {
 					columnStr = "*"
 				}

+ 25 - 1
session_raw.go

@@ -9,6 +9,7 @@ import (
 	"reflect"
 	"time"
 
+	"github.com/go-xorm/builder"
 	"github.com/xormplus/core"
 )
 
@@ -308,11 +309,34 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er
 	return session.DB().Exec(sqlStr, args...)
 }
 
+func convertSQLOrArgs(sqlorArgs ...interface{}) (string, []interface{}, error) {
+	switch sqlorArgs[0].(type) {
+	case string:
+		return sqlorArgs[0].(string), sqlorArgs[1:], nil
+	case *builder.Builder:
+		return sqlorArgs[0].(*builder.Builder).ToSQL()
+	case builder.Builder:
+		bd := sqlorArgs[0].(builder.Builder)
+		return bd.ToSQL()
+	}
+
+	return "", nil, ErrUnSupportedType
+}
+
 // Exec raw sql
-func (session *Session) Exec(sqlStr string, args ...interface{}) (sql.Result, error) {
+func (session *Session) Exec(sqlorArgs ...interface{}) (sql.Result, error) {
 	if session.isAutoClose {
 		defer session.Close()
 	}
 
+	if len(sqlorArgs) == 0 {
+		return nil, ErrUnSupportedType
+	}
+
+	sqlStr, args, err := convertSQLOrArgs(sqlorArgs...)
+	if err != nil {
+		return nil, err
+	}
+
 	return session.exec(sqlStr, args...)
 }

+ 2 - 2
statement.go

@@ -931,7 +931,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
 		if len(statement.JoinStr) == 0 {
 			if len(columnStr) == 0 {
 				if len(statement.GroupByStr) > 0 {
-					columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
+					columnStr = statement.Engine.quoteColumns(statement.GroupByStr)
 				} else {
 					columnStr = statement.genColumnStr()
 				}
@@ -939,7 +939,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
 		} else {
 			if len(columnStr) == 0 {
 				if len(statement.GroupByStr) > 0 {
-					columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
+					columnStr = statement.Engine.quoteColumns(statement.GroupByStr)
 				}
 			}
 		}