Browse Source

added Sum, Sums, SumsInt methods

xormplus 9 years ago
parent
commit
df8f4aa522
5 changed files with 152 additions and 14 deletions
  1. 21 0
      engine.go
  2. 106 11
      session.go
  3. 3 2
      sessionplus.go
  4. 21 0
      statement.go
  5. 1 1
      xorm.go

+ 21 - 0
engine.go

@@ -1555,6 +1555,27 @@ func (engine *Engine) Count(bean interface{}) (int64, error) {
 	return session.Count(bean)
 }
 
+// Sum sum the records by some column. bean's non-empty fields are conditions.
+func (engine *Engine) Sum(bean interface{}, colName string) (float64, error) {
+	session := engine.NewSession()
+	defer session.Close()
+	return session.Sum(bean, colName)
+}
+
+// Sums sum the records by some columns. bean's non-empty fields are conditions.
+func (engine *Engine) Sums(bean interface{}, colNames ...string) ([]float64, error) {
+	session := engine.NewSession()
+	defer session.Close()
+	return session.Sums(bean, colNames...)
+}
+
+// SumsInt like Sums but return slice of int64 instead of float64.
+func (engine *Engine) SumsInt(bean interface{}, colNames ...string) ([]int64, error) {
+	session := engine.NewSession()
+	defer session.Close()
+	return session.SumsInt(bean, colNames...)
+}
+
 // Import SQL DDL file
 func (engine *Engine) ImportFile(ddlPath string) ([]sql.Result, error) {
 	file, err := os.Open(ddlPath)

+ 106 - 11
session.go

@@ -1075,21 +1075,115 @@ func (session *Session) Count(bean interface{}) (int64, error) {
 		args = session.Statement.RawParams
 	}
 
-	resultsSlice, err := session.query(sqlStr, args...)
+	session.queryPreprocess(&sqlStr, args...)
+
+	var err error
+	var total int64
+	if session.IsAutoCommit {
+		err = session.DB().QueryRow(sqlStr, args...).Scan(&total)
+	} else {
+		err = session.Tx.QueryRow(sqlStr, args...).Scan(&total)
+	}
 	if err != nil {
 		return 0, err
 	}
 
-	var total int64
-	if len(resultsSlice) > 0 {
-		results := resultsSlice[0]
-		for _, value := range results {
-			total, err = strconv.ParseInt(string(value), 10, 64)
-			break
-		}
+	return total, nil
+}
+
+// Sum call sum some column. bean's non-empty fields are conditions.
+func (session *Session) Sum(bean interface{}, columnName string) (float64, error) {
+	defer session.resetStatement()
+	if session.IsAutoClose {
+		defer session.Close()
 	}
 
-	return int64(total), err
+	var sqlStr string
+	var args []interface{}
+	if len(session.Statement.RawSQL) == 0 {
+		sqlStr, args = session.Statement.genSumSql(bean, columnName)
+	} else {
+		sqlStr = session.Statement.RawSQL
+		args = session.Statement.RawParams
+	}
+
+	session.queryPreprocess(&sqlStr, args...)
+
+	var err error
+	var res float64
+	if session.IsAutoCommit {
+		err = session.DB().QueryRow(sqlStr, args...).Scan(&res)
+	} else {
+		err = session.Tx.QueryRow(sqlStr, args...).Scan(&res)
+	}
+	if err != nil {
+		return 0, err
+	}
+
+	return res, nil
+}
+
+// Sums call sum some columns. bean's non-empty fields are conditions.
+func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64, error) {
+	defer session.resetStatement()
+	if session.IsAutoClose {
+		defer session.Close()
+	}
+
+	var sqlStr string
+	var args []interface{}
+	if len(session.Statement.RawSQL) == 0 {
+		sqlStr, args = session.Statement.genSumSql(bean, columnNames...)
+	} else {
+		sqlStr = session.Statement.RawSQL
+		args = session.Statement.RawParams
+	}
+
+	session.queryPreprocess(&sqlStr, args...)
+
+	var err error
+	var res = make([]float64, len(columnNames), len(columnNames))
+	if session.IsAutoCommit {
+		err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res)
+	} else {
+		err = session.Tx.QueryRow(sqlStr, args...).ScanSlice(&res)
+	}
+	if err != nil {
+		return nil, err
+	}
+
+	return res, nil
+}
+
+func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int64, error) {
+	defer session.resetStatement()
+	if session.IsAutoClose {
+		defer session.Close()
+	}
+
+	var sqlStr string
+	var args []interface{}
+	if len(session.Statement.RawSQL) == 0 {
+		sqlStr, args = session.Statement.genSumSql(bean, columnNames...)
+	} else {
+		sqlStr = session.Statement.RawSQL
+		args = session.Statement.RawParams
+	}
+
+	session.queryPreprocess(&sqlStr, args...)
+
+	var err error
+	var res = make([]int64, 0, len(columnNames))
+	if session.IsAutoCommit {
+		err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res)
+	} else {
+		err = session.Tx.QueryRow(sqlStr, args...).ScanSlice(&res)
+	}
+	if err != nil {
+		return nil, err
+	}
+
+	return res, nil
 }
 
 // Find retrieve records from table, condiBeans's non-empty fields
@@ -1861,8 +1955,9 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
 								return err
 							}
 							if has {
-								v := structInter.Elem().Interface()
-								fieldValue.Set(reflect.ValueOf(v))
+								//v := structInter.Elem().Interface()
+								//fieldValue.Set(reflect.ValueOf(v))
+								fieldValue.Set(structInter.Elem())
 							} else {
 								return errors.New("cascade obj is not exist!")
 							}

+ 3 - 2
sessionplus.go

@@ -928,8 +928,9 @@ func (session *Session) _row2BeanWithDateFormat(dateFormat string, rows *core.Ro
 								return err
 							}
 							if has {
-								v := structInter.Elem().Interface()
-								fieldValue.Set(reflect.ValueOf(v))
+								//v := structInter.Elem().Interface()
+								//fieldValue.Set(reflect.ValueOf(v))
+								fieldValue.Set(structInter.Elem())
 							} else {
 								return errors.New("cascade obj is not exist!")
 							}

+ 21 - 0
statement.go

@@ -1220,6 +1220,27 @@ func (statement *Statement) genCountSql(bean interface{}) (string, []interface{}
 	return statement.genSelectSQL(fmt.Sprintf("count(%v)", id)), append(append(statement.joinArgs, statement.Params...), statement.BeanArgs...)
 }
 
+func (statement *Statement) genSumSql(bean interface{}, columns ...string) (string, []interface{}) {
+	table := statement.Engine.TableInfo(bean)
+	statement.RefTable = table
+
+	var addedTableName = (len(statement.JoinStr) > 0)
+
+	if !statement.noAutoCondition {
+		colNames, args := statement.buildConditions(table, bean, true, true, false, true, addedTableName)
+
+		statement.ConditionStr = strings.Join(colNames, " "+statement.Engine.Dialect().AndStr()+" ")
+		statement.BeanArgs = args
+	}
+
+	statement.attachInSql()
+	var sumStrs = make([]string, 0, len(columns))
+	for _, colName := range columns {
+		sumStrs = append(sumStrs, fmt.Sprintf("sum(%s)", colName))
+	}
+	return statement.genSelectSQL(strings.Join(sumStrs, ", ")), append(append(statement.joinArgs, statement.Params...), statement.BeanArgs...)
+}
+
 func (statement *Statement) genSelectSQL(columnStr string) (a string) {
 	var distinct string
 	if statement.IsDistinct {

+ 1 - 1
xorm.go

@@ -17,7 +17,7 @@ import (
 
 const (
 	// Version show the xorm's version
-	Version string = "0.5.4.0630"
+	Version string = "0.5.5.0707"
 )
 
 func regDrvsNDialects() bool {