xormplus 8 jaren geleden
bovenliggende
commit
efd026b06b
2 gewijzigde bestanden met toevoegingen van 41 en 3 verwijderingen
  1. 28 0
      session_sum_test.go
  2. 13 3
      statement.go

+ 28 - 0
session_sum_test.go

@@ -71,3 +71,31 @@ func TestSum(t *testing.T) {
 	assert.EqualValues(t, 1, len(sumsInt))
 	assert.EqualValues(t, i, int(sumsInt[0]))
 }
+
+func TestSumCustomColumn(t *testing.T) {
+	assert.NoError(t, prepareEngine())
+
+	type SumStruct struct {
+		Int   int
+		Float float32
+	}
+
+	var (
+		cases = []SumStruct{
+			{1, 6.2},
+			{2, 5.3},
+			{92, -0.2},
+		}
+	)
+
+	assert.NoError(t, testEngine.Sync2(new(SumStruct)))
+
+	cnt, err := testEngine.Insert(cases)
+	assert.NoError(t, err)
+	assert.EqualValues(t, 3, cnt)
+
+	sumInt, err := testEngine.Sum(new(SumStruct),
+		"CASE WHEN `int` <= 2 THEN `int` ELSE 0 END")
+	assert.NoError(t, err)
+	assert.EqualValues(t, 3, int(sumInt))
+}

+ 13 - 3
statement.go

@@ -1188,12 +1188,16 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
 
 	var sumStrs = make([]string, 0, len(columns))
 	for _, colName := range columns {
-		sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", statement.Engine.Quote(colName)))
+		if !strings.Contains(colName, " ") && strings.Contains(colName, "(") {
+			colName = statement.Engine.Quote(colName)
+		}
+		sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
 	}
+	sumSelect := strings.Join(sumStrs, ", ")
 
 	condSQL, condArgs, _ := statement.genConds(bean)
 
-	return statement.genSelectSQL(strings.Join(sumStrs, ", "), condSQL), append(statement.joinArgs, condArgs...)
+	return statement.genSelectSQL(sumSelect, condSQL), append(statement.joinArgs, condArgs...)
 }
 
 func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
@@ -1214,8 +1218,14 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
 		fmt.Fprintf(&buf, " WHERE %v", condSQL)
 	}
 	var whereStr = buf.String()
+	var fromStr = " FROM "
+
+	if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") {
+		fromStr += statement.TableName()
+	} else {
+		fromStr += quote(statement.TableName())
+	}
 
-	var fromStr = " FROM " + quote(statement.TableName())
 	if statement.TableAlias != "" {
 		if dialect.DBType() == core.ORACLE {
 			fromStr += " " + quote(statement.TableAlias)