Переглянути джерело

fix bug and add custom SQL count support

xormplus 8 роки тому
батько
коміт
e6d18c7aeb
9 змінених файлів з 145 додано та 44 видалено
  1. 6 3
      rows.go
  2. 1 1
      session.go
  3. 4 2
      session_delete.go
  4. 9 3
      session_find.go
  5. 5 1
      session_get.go
  6. 24 9
      session_sum.go
  7. 22 0
      session_sum_test.go
  8. 21 5
      session_update.go
  9. 53 20
      statement.go

+ 6 - 3
rows.go

@@ -33,8 +33,9 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
 
 	var sqlStr string
 	var args []interface{}
+	var err error
 
-	if err := rows.session.Statement.setRefValue(rValue(bean)); err != nil {
+	if err = rows.session.Statement.setRefValue(rValue(bean)); err != nil {
 		return nil, err
 	}
 
@@ -43,7 +44,10 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
 	}
 
 	if rows.session.Statement.RawSQL == "" {
-		sqlStr, args = rows.session.Statement.genGetSQL(bean)
+		sqlStr, args, err = rows.session.Statement.genGetSQL(bean)
+		if err != nil {
+			return nil, err
+		}
 	} else {
 		sqlStr = rows.session.Statement.RawSQL
 		args = rows.session.Statement.RawParams
@@ -54,7 +58,6 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
 	}
 
 	rows.session.saveLastSQL(sqlStr, args...)
-	var err error
 	if rows.session.prepareStmt {
 		rows.stmt, err = rows.session.DB().Prepare(sqlStr)
 		if err != nil {

+ 1 - 1
session.go

@@ -626,7 +626,7 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
 						structInter := reflect.New(fieldValue.Type())
 						newsession := session.Engine.NewSession()
 						defer newsession.Close()
-						has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface())
+						has, err := newsession.ID(pk).NoCascade().Get(structInter.Interface())
 						if err != nil {
 							return nil, err
 						}

+ 4 - 2
session_delete.go

@@ -98,8 +98,10 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
 		processor.BeforeDelete()
 	}
 
-	// --
-	condSQL, condArgs, _ := session.Statement.genConds(bean)
+	condSQL, condArgs, err := session.Statement.genConds(bean)
+	if err != nil {
+		return 0, err
+	}
 	if len(condSQL) == 0 && session.Statement.LimitN == 0 {
 		return 0, ErrNeedDeletedCond
 	}

+ 9 - 3
session_find.go

@@ -91,6 +91,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
 
 	var sqlStr string
 	var args []interface{}
+	var err error
 	if session.Statement.RawSQL == "" {
 		if len(session.Statement.TableName()) <= 0 {
 			return ErrTableNotFound
@@ -122,10 +123,16 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
 			}
 		}
 
-		condSQL, condArgs, _ := builder.ToSQL(session.Statement.cond.And(autoCond))
+		condSQL, condArgs, err := builder.ToSQL(session.Statement.cond.And(autoCond))
+		if err != nil {
+			return err
+		}
 
 		args = append(session.Statement.joinArgs, condArgs...)
-		sqlStr = session.Statement.genSelectSQL(columnStr, condSQL)
+		sqlStr, err = session.Statement.genSelectSQL(columnStr, condSQL)
+		if err != nil {
+			return err
+		}
 		// for mssql and use limit
 		qs := strings.Count(sqlStr, "?")
 		if len(args)*2 == qs {
@@ -136,7 +143,6 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
 		args = session.Statement.RawParams
 	}
 
-	var err error
 	if session.canCache() {
 		if cacher := session.Engine.getCacher2(table); cacher != nil &&
 			!session.Statement.IsDistinct &&

+ 5 - 1
session_get.go

@@ -33,13 +33,17 @@ func (session *Session) Get(bean interface{}) (bool, error) {
 
 	var sqlStr string
 	var args []interface{}
+	var err error
 
 	if session.Statement.RawSQL == "" {
 		if len(session.Statement.TableName()) <= 0 {
 			return false, ErrTableNotFound
 		}
 		session.Statement.Limit(1)
-		sqlStr, args = session.Statement.genGetSQL(bean)
+		sqlStr, args, err = session.Statement.genGetSQL(bean)
+		if err != nil {
+			return false, err
+		}
 	} else {
 		sqlStr = session.Statement.RawSQL
 		args = session.Statement.RawParams

+ 24 - 9
session_sum.go

@@ -8,7 +8,7 @@ import "database/sql"
 
 // Count counts the records. bean's non-empty fields
 // are conditions.
-func (session *Session) Count(bean interface{}) (int64, error) {
+func (session *Session) Count(bean ...interface{}) (int64, error) {
 	defer session.resetStatement()
 	if session.IsAutoClose {
 		defer session.Close()
@@ -16,8 +16,15 @@ func (session *Session) Count(bean interface{}) (int64, error) {
 
 	var sqlStr string
 	var args []interface{}
+	var err error
 	if session.Statement.RawSQL == "" {
-		sqlStr, args = session.Statement.genCountSQL(bean)
+		if len(bean) == 0 {
+			return 0, ErrTableNotFound
+		}
+		sqlStr, args, err = session.Statement.genCountSQL(bean[0])
+		if err != nil {
+			return 0, err
+		}
 	} else {
 		sqlStr = session.Statement.RawSQL
 		args = session.Statement.RawParams
@@ -25,7 +32,6 @@ func (session *Session) Count(bean interface{}) (int64, error) {
 
 	session.queryPreprocess(&sqlStr, args...)
 
-	var err error
 	var total int64
 	if session.IsAutoCommit {
 		err = session.DB().QueryRow(sqlStr, args...).Scan(&total)
@@ -49,8 +55,12 @@ func (session *Session) Sum(bean interface{}, columnName string) (float64, error
 
 	var sqlStr string
 	var args []interface{}
+	var err error
 	if len(session.Statement.RawSQL) == 0 {
-		sqlStr, args = session.Statement.genSumSQL(bean, columnName)
+		sqlStr, args, err = session.Statement.genSumSQL(bean, columnName)
+		if err != nil {
+			return 0, err
+		}
 	} else {
 		sqlStr = session.Statement.RawSQL
 		args = session.Statement.RawParams
@@ -58,7 +68,6 @@ func (session *Session) Sum(bean interface{}, columnName string) (float64, error
 
 	session.queryPreprocess(&sqlStr, args...)
 
-	var err error
 	var res float64
 	if session.IsAutoCommit {
 		err = session.DB().QueryRow(sqlStr, args...).Scan(&res)
@@ -81,8 +90,12 @@ func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64
 
 	var sqlStr string
 	var args []interface{}
+	var err error
 	if len(session.Statement.RawSQL) == 0 {
-		sqlStr, args = session.Statement.genSumSQL(bean, columnNames...)
+		sqlStr, args, err = session.Statement.genSumSQL(bean, columnNames...)
+		if err != nil {
+			return nil, err
+		}
 	} else {
 		sqlStr = session.Statement.RawSQL
 		args = session.Statement.RawParams
@@ -90,7 +103,6 @@ func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64
 
 	session.queryPreprocess(&sqlStr, args...)
 
-	var err error
 	var res = make([]float64, len(columnNames), len(columnNames))
 	if session.IsAutoCommit {
 		err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res)
@@ -113,8 +125,12 @@ func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int6
 
 	var sqlStr string
 	var args []interface{}
+	var err error
 	if len(session.Statement.RawSQL) == 0 {
-		sqlStr, args = session.Statement.genSumSQL(bean, columnNames...)
+		sqlStr, args, err = session.Statement.genSumSQL(bean, columnNames...)
+		if err != nil {
+			return nil, err
+		}
 	} else {
 		sqlStr = session.Statement.RawSQL
 		args = session.Statement.RawParams
@@ -122,7 +138,6 @@ func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int6
 
 	session.queryPreprocess(&sqlStr, args...)
 
-	var err error
 	var res = make([]int64, len(columnNames), len(columnNames))
 	if session.IsAutoCommit {
 		err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res)

+ 22 - 0
session_sum_test.go

@@ -128,3 +128,25 @@ func TestCount(t *testing.T) {
 	assert.NoError(t, err)
 	assert.EqualValues(t, 1, total)
 }
+
+func TestSQLCount(t *testing.T) {
+	assert.NoError(t, prepareEngine())
+
+	type UserinfoCount2 struct {
+		Id         int64
+		Departname string
+	}
+
+	type UserinfoBooks struct {
+		Id     int64
+		Pid    int64
+		IsOpen bool
+	}
+
+	assertSync(t, new(UserinfoCount2), new(UserinfoBooks))
+
+	total, err := testEngine.SQL("SELECT count(id) FROM userinfo_count2").
+		Count()
+	assert.NoError(t, err)
+	assert.EqualValues(t, 0, total)
+}

+ 21 - 5
session_update.go

@@ -236,7 +236,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
 		colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+v.expr)
 	}
 
-	session.Statement.processIDParam()
+	if err = session.Statement.processIDParam(); err != nil {
+		return 0, err
+	}
 
 	var autoCond builder.Cond
 	if !session.Statement.noAutoCondition && len(condiBean) > 0 {
@@ -267,7 +269,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
 		colNames = append(colNames, session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1")
 	}
 
-	condSQL, condArgs, _ = builder.ToSQL(cond)
+	condSQL, condArgs, err = builder.ToSQL(cond)
+	if err != nil {
+		return 0, err
+	}
+
 	if len(condSQL) > 0 {
 		condSQL = "WHERE " + condSQL
 	}
@@ -285,7 +291,10 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
 			tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
 			cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
 				session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...))
-			condSQL, condArgs, _ = builder.ToSQL(cond)
+			condSQL, condArgs, err = builder.ToSQL(cond)
+			if err != nil {
+				return 0, err
+			}
 			if len(condSQL) > 0 {
 				condSQL = "WHERE " + condSQL
 			}
@@ -293,7 +302,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
 			tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
 			cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
 				session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...))
-			condSQL, condArgs, _ = builder.ToSQL(cond)
+			condSQL, condArgs, err = builder.ToSQL(cond)
+			if err != nil {
+				return 0, err
+			}
+
 			if len(condSQL) > 0 {
 				condSQL = "WHERE " + condSQL
 			}
@@ -304,7 +317,10 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
 					table.PrimaryKeys[0], st.LimitN, table.PrimaryKeys[0],
 					session.Engine.Quote(session.Statement.TableName()), condSQL), condArgs...)
 
-				condSQL, condArgs, _ = builder.ToSQL(cond)
+				condSQL, condArgs, err = builder.ToSQL(cond)
+				if err != nil {
+					return 0, err
+				}
 				if len(condSQL) > 0 {
 					condSQL = "WHERE " + condSQL
 				}

+ 53 - 20
statement.go

@@ -1118,12 +1118,14 @@ func (statement *Statement) genConds(bean interface{}) (string, []interface{}, e
 		statement.cond = statement.cond.And(autoCond)
 	}
 
-	statement.processIDParam()
+	if err := statement.processIDParam(); err != nil {
+		return "", nil, err
+	}
 
 	return builder.ToSQL(statement.cond)
 }
 
-func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) {
+func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
 	v := rValue(bean)
 	isStruct := v.Kind() == reflect.Struct
 	if isStruct {
@@ -1158,19 +1160,31 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{})
 
 	var condSQL string
 	var condArgs []interface{}
+	var err error
 	if isStruct {
-		condSQL, condArgs, _ = statement.genConds(bean)
+		condSQL, condArgs, err = statement.genConds(bean)
 	} else {
-		condSQL, condArgs, _ = builder.ToSQL(statement.cond)
+		condSQL, condArgs, err = builder.ToSQL(statement.cond)
+	}
+	if err != nil {
+		return "", nil, err
+	}
+
+	sqlStr, err := statement.genSelectSQL(columnStr, condSQL)
+	if err != nil {
+		return "", nil, err
 	}
 
-	return statement.genSelectSQL(columnStr, condSQL), append(statement.joinArgs, condArgs...)
+	return sqlStr, append(statement.joinArgs, condArgs...), nil
 }
 
-func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}) {
+func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}, error) {
 	statement.setRefValue(rValue(bean))
 
-	condSQL, condArgs, _ := statement.genConds(bean)
+	condSQL, condArgs, err := statement.genConds(bean)
+	if err != nil {
+		return "", nil, err
+	}
 
 	var selectSQL = statement.selectStr
 	if len(selectSQL) <= 0 {
@@ -1180,10 +1194,15 @@ func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}
 			selectSQL = "count(*)"
 		}
 	}
-	return statement.genSelectSQL(selectSQL, condSQL), append(statement.joinArgs, condArgs...)
+	sqlStr, err := statement.genSelectSQL(selectSQL, condSQL)
+	if err != nil {
+		return "", nil, err
+	}
+
+	return sqlStr, append(statement.joinArgs, condArgs...), nil
 }
 
-func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}) {
+func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
 	statement.setRefValue(rValue(bean))
 
 	var sumStrs = make([]string, 0, len(columns))
@@ -1195,12 +1214,20 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
 	}
 	sumSelect := strings.Join(sumStrs, ", ")
 
-	condSQL, condArgs, _ := statement.genConds(bean)
+	condSQL, condArgs, err := statement.genConds(bean)
+	if err != nil {
+		return "", nil, err
+	}
+
+	sqlStr, err := statement.genSelectSQL(sumSelect, condSQL)
+	if err != nil {
+		return "", nil, err
+	}
 
-	return statement.genSelectSQL(sumSelect, condSQL), append(statement.joinArgs, condArgs...)
+	return sqlStr, append(statement.joinArgs, condArgs...), nil
 }
 
-func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
+func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) {
 	var distinct string
 	if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
 		distinct = "DISTINCT "
@@ -1211,7 +1238,9 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
 	var top string
 	var mssqlCondi string
 
-	statement.processIDParam()
+	if err := statement.processIDParam(); err != nil {
+		return "", err
+	}
 
 	var buf bytes.Buffer
 	if len(condSQL) > 0 {
@@ -1314,19 +1343,23 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
 	return
 }
 
-func (statement *Statement) processIDParam() {
+func (statement *Statement) processIDParam() error {
 	if statement.idParam == nil {
-		return
+		return nil
+	}
+
+	if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) {
+		return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
+			len(statement.RefTable.PrimaryKeys),
+			len(*statement.idParam),
+		)
 	}
 
 	for i, col := range statement.RefTable.PKColumns() {
 		var colName = statement.colName(col, statement.TableName())
-		if i < len(*(statement.idParam)) {
-			statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
-		} else {
-			statement.cond = statement.cond.And(builder.Eq{colName: ""})
-		}
+		statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
 	}
+	return nil
 }
 
 func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {