Przeglądaj źródła

新增批量执行CRUD操作,返回多个结果集功能

xormplus 9 lat temu
rodzic
commit
f04593cf0c
2 zmienionych plików z 112 dodań i 57 usunięć
  1. 7 29
      engineplus.go
  2. 105 28
      sessionplus.go

+ 7 - 29
engineplus.go

@@ -2,7 +2,6 @@ package xorm
 
 import (
 	"encoding/json"
-	"reflect"
 
 	"gopkg.in/flosch/pongo2.v3"
 )
@@ -92,42 +91,21 @@ func JSONString(v interface{}, IndentJSON bool) (string, error) {
 
 func (engine *Engine) Sqls(sqls interface{}, parmas ...interface{}) *SqlsExecutor {
 	session := engine.NewSession()
+	session.IsAutoClose = true
 	session.IsSqlFuc = true
 	return session.Sqls(sqls, parmas...)
 }
 
-func (engine *Engine) SqlMapsClient(sqls interface{}, parmas ...interface{}) *SqlMapsExecutor {
+func (engine *Engine) SqlMapsClient(sqlkeys interface{}, parmas ...interface{}) *SqlMapsExecutor {
 	session := engine.NewSession()
+	session.IsAutoClose = true
 	session.IsSqlFuc = true
-	return session.SqlMapsClient(sqls, parmas...)
+	return session.SqlMapsClient(sqlkeys, parmas...)
 }
 
-func (engine *Engine) SqlTemplatesClient(sqls interface{}, parmas ...interface{}) *SqlsExecutor {
+func (engine *Engine) SqlTemplatesClient(sqlkeys interface{}, parmas ...interface{}) *SqlTemplatesExecutor {
 	session := engine.NewSession()
+	session.IsAutoClose = true
 	session.IsSqlFuc = true
-	return session.Sqls(sqls, parmas...)
-}
-
-func (engine *Engine) BatchSql(sqls interface{}) *Session {
-	session := engine.NewSession()
-	types := reflect.TypeOf(sqls)
-	if types.Kind() == reflect.Map {
-		engine.logger.Info("sqls is Map")
-		engine.logger.Info(types.Elem())
-		engine.logger.Info(types.Elem().Kind())
-	}
-
-	if types.Kind() == reflect.Slice {
-		engine.logger.Info("sqls is Slice")
-		engine.logger.Info(types.Elem())
-		engine.logger.Info(types.Elem().Kind())
-	}
-
-	switch sqls.(type) {
-	case []string:
-		engine.logger.Info("sqls is []string")
-	case map[string]string:
-		engine.logger.Info("sqls is map[string]string")
-	}
-	return session.BatchSql()
+	return session.SqlTemplatesClient(sqlkeys, parmas...)
 }

+ 105 - 28
sessionplus.go

@@ -11,11 +11,10 @@ import (
 	"fmt"
 	"reflect"
 	"regexp"
+
 	"strconv"
 	"strings"
 	"time"
-	//	"unsafe"
-	"runtime"
 
 	"github.com/Chronokeeper/anyxml"
 	"github.com/xormplus/core"
@@ -63,7 +62,7 @@ func (resultBean ResultBean) Xml() (bool, string, error) {
 	if !has {
 		return has, "", nil
 	}
-	var anydata = []byte(result) //str2byte(result)
+	var anydata = []byte(result)
 	var i interface{}
 	err = json.Unmarshal(anydata, &i)
 	if err != nil {
@@ -77,12 +76,6 @@ func (resultBean ResultBean) Xml() (bool, string, error) {
 	return resultBean.Has, string(resultByte), err
 }
 
-//func str2byte(s string) []byte {
-//	x := (*[2]uintptr)(unsafe.Pointer(&s))
-//	h := [3]uintptr{x[0], x[1], x[1]}
-//	return *(*[]byte)((unsafe.Pointer(&h)))
-//}
-
 func (resultBean ResultBean) XmlIndent(prefix string, indent string, recordTag string) (bool, string, error) {
 	if resultBean.Error != nil {
 		return false, "", resultBean.Error
@@ -116,10 +109,42 @@ type ResultMap struct {
 	Error   error
 }
 
-func (resultMap ResultMap) GetResults() ([]map[string]interface{}, error) {
+func (resultMap ResultMap) List() ([]map[string]interface{}, error) {
 	return resultMap.Results, resultMap.Error
 }
 
+func (resultMap ResultMap) Count() (int, error) {
+	if resultMap.Error != nil {
+		return 0, resultMap.Error
+	}
+	if resultMap.Results == nil {
+		return 0, nil
+	}
+	return len(resultMap.Results), nil
+}
+
+func (resultMap ResultMap) ListPage(firstResult int, maxResults int) ([]map[string]interface{}, error) {
+	if resultMap.Error != nil {
+		return nil, resultMap.Error
+	}
+	if resultMap.Results == nil {
+		return nil, nil
+	}
+	if firstResult >= maxResults {
+		return nil, ErrParamsFormat
+	}
+	if firstResult < 0 {
+		return nil, ErrParamsFormat
+	}
+	if maxResults < 0 {
+		return nil, ErrParamsFormat
+	}
+	if maxResults > len(resultMap.Results) {
+		return nil, ErrParamsFormat
+	}
+	return resultMap.Results[(firstResult - 1):maxResults], resultMap.Error
+}
+
 func (resultMap ResultMap) Json() (string, error) {
 
 	if resultMap.Error != nil {
@@ -235,8 +260,12 @@ func (session *Session) Search(rowsSlicePtr interface{}, condiBean ...interface{
 	return r
 }
 
-// Exec a raw sql and return records as []map[string]interface{}
+// Exec a raw sql and return records as ResultMap
 func (session *Session) Query() ResultMap {
+	defer session.resetStatement()
+	if session.IsAutoClose {
+		defer session.Close()
+	}
 	sql := session.Statement.RawSQL
 	params := session.Statement.RawParams
 	i := len(params)
@@ -257,8 +286,12 @@ func (session *Session) Query() ResultMap {
 	return r
 }
 
-// Exec a raw sql and return records as []map[string]interface{}
+// Exec a raw sql and return records as ResultMap
 func (session *Session) QueryWithDateFormat(dateFormat string) ResultMap {
+	defer session.resetStatement()
+	if session.IsAutoClose {
+		defer session.Close()
+	}
 	sql := session.Statement.RawSQL
 	params := session.Statement.RawParams
 	i := len(params)
@@ -280,13 +313,14 @@ func (session *Session) QueryWithDateFormat(dateFormat string) ResultMap {
 
 // Execute raw sql
 func (session *Session) Execute() (sql.Result, error) {
-	sqlStr := session.Statement.RawSQL
-	params := session.Statement.RawParams
 	defer session.resetStatement()
 	if session.IsAutoClose {
 		defer session.Close()
 	}
 
+	sqlStr := session.Statement.RawSQL
+	params := session.Statement.RawParams
+
 	i := len(params)
 	if i == 1 {
 		vv := reflect.ValueOf(params[0])
@@ -918,18 +952,6 @@ func (session *Session) queryPreprocessByMap(sqlStr *string, paramMap interface{
 	session.Engine.logSQL(*sqlStr, paramMap)
 }
 
-func (session *Session) BatchSql() *Session {
-	return session
-}
-
-func (session *Session) BatchExecute() interface{} {
-
-	session.Engine.logger.Info(runtime.Caller(0))
-	session.Engine.logger.Info(runtime.Caller(1))
-	session.Engine.logger.Info(runtime.Caller(2))
-	return []string{"GetSqlMap", "sql_2_1", "sql1"}
-}
-
 func (session *Session) Sqls(sqls interface{}, parmas ...interface{}) *SqlsExecutor {
 
 	sqlsExecutor := new(SqlsExecutor)
@@ -940,6 +962,9 @@ func (session *Session) Sqls(sqls interface{}, parmas ...interface{}) *SqlsExecu
 		sqlsExecutor.sqls = sqls.([]string)
 	case map[string]string:
 		sqlsExecutor.sqls = sqls.(map[string]string)
+	default:
+		sqlsExecutor.sqls = nil
+		sqlsExecutor.err = ErrParamsType
 	}
 
 	if len(parmas) == 0 {
@@ -961,7 +986,9 @@ func (session *Session) Sqls(sqls interface{}, parmas ...interface{}) *SqlsExecu
 
 		case map[string]map[string]interface{}:
 			sqlsExecutor.parmas = parmas[0].(map[string]map[string]interface{})
-
+		default:
+			sqlsExecutor.parmas = nil
+			sqlsExecutor.err = ErrParamsType
 		}
 	}
 
@@ -980,6 +1007,9 @@ func (session *Session) SqlMapsClient(sqlkeys interface{}, parmas ...interface{}
 		sqlMapsExecutor.sqlkeys = sqlkeys.([]string)
 	case map[string]string:
 		sqlMapsExecutor.sqlkeys = sqlkeys.(map[string]string)
+	default:
+		sqlMapsExecutor.sqlkeys = nil
+		sqlMapsExecutor.err = ErrParamsType
 	}
 
 	if len(parmas) == 0 {
@@ -1001,7 +1031,9 @@ func (session *Session) SqlMapsClient(sqlkeys interface{}, parmas ...interface{}
 
 		case map[string]map[string]interface{}:
 			sqlMapsExecutor.parmas = parmas[0].(map[string]map[string]interface{})
-
+		default:
+			sqlMapsExecutor.parmas = nil
+			sqlMapsExecutor.err = ErrParamsType
 		}
 	}
 
@@ -1009,3 +1041,48 @@ func (session *Session) SqlMapsClient(sqlkeys interface{}, parmas ...interface{}
 
 	return sqlMapsExecutor
 }
+
+func (session *Session) SqlTemplatesClient(sqlkeys interface{}, parmas ...interface{}) *SqlTemplatesExecutor {
+	sqlTemplatesExecutor := new(SqlTemplatesExecutor)
+
+	switch sqlkeys.(type) {
+	case string:
+		sqlTemplatesExecutor.sqlkeys = sqlkeys.(string)
+	case []string:
+		sqlTemplatesExecutor.sqlkeys = sqlkeys.([]string)
+	case map[string]string:
+		sqlTemplatesExecutor.sqlkeys = sqlkeys.(map[string]string)
+	default:
+		sqlTemplatesExecutor.sqlkeys = nil
+		sqlTemplatesExecutor.err = ErrParamsType
+	}
+
+	if len(parmas) == 0 {
+		sqlTemplatesExecutor.parmas = nil
+	}
+
+	if len(parmas) > 1 {
+		sqlTemplatesExecutor.parmas = nil
+		sqlTemplatesExecutor.err = ErrParamsType
+	}
+
+	if len(parmas) == 1 {
+		switch parmas[0].(type) {
+		case map[string]interface{}:
+			sqlTemplatesExecutor.parmas = parmas[0].(map[string]interface{})
+
+		case []map[string]interface{}:
+			sqlTemplatesExecutor.parmas = parmas[0].([]map[string]interface{})
+
+		case map[string]map[string]interface{}:
+			sqlTemplatesExecutor.parmas = parmas[0].(map[string]map[string]interface{})
+		default:
+			sqlTemplatesExecutor.parmas = nil
+			sqlTemplatesExecutor.err = ErrParamsType
+		}
+	}
+
+	sqlTemplatesExecutor.session = session
+
+	return sqlTemplatesExecutor
+}