浏览代码

Add insert select where support

Unknown 6 年之前
父节点
当前提交
8cb1132511
共有 2 个文件被更改,包括 152 次插入20 次删除
  1. 93 20
      session_insert.go
  2. 59 0
      session_insert_test.go

+ 93 - 20
session_insert.go

@@ -12,6 +12,7 @@ import (
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
 
 
+	"github.com/xormplus/builder"
 	"github.com/xormplus/core"
 	"github.com/xormplus/core"
 )
 )
 
 
@@ -346,7 +347,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
 	for _, v := range exprColumns {
 	for _, v := range exprColumns {
 		// remove the expr columns
 		// remove the expr columns
 		for i, colName := range colNames {
 		for i, colName := range colNames {
-			if colName == v.colName {
+			if colName == strings.Trim(v.colName, "`") {
 				colNames = append(colNames[:i], colNames[i+1:]...)
 				colNames = append(colNames[:i], colNames[i+1:]...)
 				args = append(args[:i], args[i+1:]...)
 				args = append(args[:i], args[i+1:]...)
 			}
 			}
@@ -372,12 +373,30 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
 	if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 {
 	if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 {
 		output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)
 		output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)
 	}
 	}
+
 	if len(colPlaces) > 0 {
 	if len(colPlaces) > 0 {
-		sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)",
-			session.engine.Quote(tableName),
-			quoteColumns(colNames, session.engine.Quote, ","),
-			output,
-			colPlaces)
+		if session.statement.cond.IsValid() {
+			condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
+			if err != nil {
+				return 0, err
+			}
+
+			sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s SELECT %v FROM %v WHERE %v",
+				session.engine.Quote(tableName),
+				quoteColumns(colNames, session.engine.Quote, ","),
+				output,
+				colPlaces,
+				session.engine.Quote(tableName),
+				condSQL,
+			)
+			args = append(args, condArgs...)
+		} else {
+			sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)",
+				session.engine.Quote(tableName),
+				quoteColumns(colNames, session.engine.Quote, ","),
+				output,
+				colPlaces)
+		}
 	} else {
 	} else {
 		if session.engine.dialect.DBType() == core.MYSQL {
 		if session.engine.dialect.DBType() == core.MYSQL {
 			sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.engine.Quote(tableName))
 			sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.engine.Quote(tableName))
@@ -665,6 +684,11 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err
 		return 0, ErrParamsType
 		return 0, ErrParamsType
 	}
 	}
 
 
+	tableName := session.statement.TableName()
+	if len(tableName) <= 0 {
+		return 0, ErrTableNotFound
+	}
+
 	var columns = make([]string, 0, len(m))
 	var columns = make([]string, 0, len(m))
 	for k := range m {
 	for k := range m {
 		columns = append(columns, k)
 		columns = append(columns, k)
@@ -672,19 +696,40 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err
 	sort.Strings(columns)
 	sort.Strings(columns)
 
 
 	qm := strings.Repeat("?,", len(columns))
 	qm := strings.Repeat("?,", len(columns))
-	qm = "(" + qm[:len(qm)-1] + ")"
 
 
-	tableName := session.statement.TableName()
-	if len(tableName) <= 0 {
-		return 0, ErrTableNotFound
-	}
-
-	var sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES %s", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)
 	var args = make([]interface{}, 0, len(m))
 	var args = make([]interface{}, 0, len(m))
 	for _, colName := range columns {
 	for _, colName := range columns {
 		args = append(args, m[colName])
 		args = append(args, m[colName])
 	}
 	}
 
 
+	// insert expr columns, override if exists
+	exprColumns := session.statement.getExpr()
+	for _, col := range exprColumns {
+		columns = append(columns, strings.Trim(col.colName, "`"))
+		qm = qm + col.expr + ","
+	}
+
+	qm = qm[:len(qm)-1]
+
+	var sql string
+
+	if session.statement.cond.IsValid() {
+		condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
+		if err != nil {
+			return 0, err
+		}
+		sql = fmt.Sprintf("INSERT INTO %s (`%s`) SELECT %s FROM %s WHERE %s",
+			session.engine.Quote(tableName),
+			strings.Join(columns, "`,`"),
+			qm,
+			session.engine.Quote(tableName),
+			condSQL,
+		)
+		args = append(args, condArgs...)
+	} else {
+		sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)
+	}
+
 	if err := session.cacheInsert(tableName); err != nil {
 	if err := session.cacheInsert(tableName); err != nil {
 		return 0, err
 		return 0, err
 	}
 	}
@@ -706,24 +751,52 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
 		return 0, ErrParamsType
 		return 0, ErrParamsType
 	}
 	}
 
 
+	tableName := session.statement.TableName()
+	if len(tableName) <= 0 {
+		return 0, ErrTableNotFound
+	}
+
 	var columns = make([]string, 0, len(m))
 	var columns = make([]string, 0, len(m))
 	for k := range m {
 	for k := range m {
 		columns = append(columns, k)
 		columns = append(columns, k)
 	}
 	}
 	sort.Strings(columns)
 	sort.Strings(columns)
 
 
+	var args = make([]interface{}, 0, len(m))
+	for _, colName := range columns {
+		args = append(args, m[colName])
+	}
+
 	qm := strings.Repeat("?,", len(columns))
 	qm := strings.Repeat("?,", len(columns))
 	qm = "(" + qm[:len(qm)-1] + ")"
 	qm = "(" + qm[:len(qm)-1] + ")"
 
 
-	tableName := session.statement.TableName()
-	if len(tableName) <= 0 {
-		return 0, ErrTableNotFound
+	// insert expr columns, override if exists
+	exprColumns := session.statement.getExpr()
+	for _, col := range exprColumns {
+		columns = append(columns, strings.Trim(col.colName, "`"))
+		qm = qm + col.expr + ","
 	}
 	}
 
 
-	var sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES %s", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)
-	var args = make([]interface{}, 0, len(m))
-	for _, colName := range columns {
-		args = append(args, m[colName])
+	qm = qm[:len(qm)-1]
+
+	var sql string
+
+	if session.statement.cond.IsValid() {
+		qm = "(" + qm[:len(qm)-1] + ")"
+		condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
+		if err != nil {
+			return 0, err
+		}
+		sql = fmt.Sprintf("INSERT INTO %s (`%s`) SELECT %s FROM %s WHERE %s",
+			session.engine.Quote(tableName),
+			strings.Join(columns, "`,`"),
+			qm,
+			session.engine.Quote(tableName),
+			condSQL,
+		)
+		args = append(args, condArgs...)
+	} else {
+		sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)
 	}
 	}
 
 
 	if err := session.cacheInsert(tableName); err != nil {
 	if err := session.cacheInsert(tableName); err != nil {

+ 59 - 0
session_insert_test.go

@@ -850,3 +850,62 @@ func TestInsertMap(t *testing.T) {
 	assert.EqualValues(t, 10, ims[3].Height)
 	assert.EqualValues(t, 10, ims[3].Height)
 	assert.EqualValues(t, "lunny", ims[3].Name)
 	assert.EqualValues(t, "lunny", ims[3].Name)
 }
 }
+
+/*INSERT INTO `issue` (`repo_id`, `poster_id`, ... ,`name`, `content`, ... ,`index`)
+SELECT $1, $2, ..., $14, $15, ..., MAX(`index`) + 1 FROM `issue` WHERE `repo_id` = $1;
+*/
+func TestInsertWhere(t *testing.T) {
+	type InsertWhere struct {
+		Id     int64
+		Index  int   `xorm:"unique(s) notnull"`
+		RepoId int64 `xorm:"unique(s)"`
+		Width  uint32
+		Height uint32
+		Name   string
+	}
+
+	assert.NoError(t, prepareEngine())
+	assertSync(t, new(InsertWhere))
+
+	var i = InsertWhere{
+		RepoId: 1,
+		Width:  10,
+		Height: 20,
+		Name:   "trest",
+	}
+
+	inserted, err := testEngine.SetExpr("`index`", "coalesce(MAX(`index`),0)+1").
+		Where("repo_id=?", 1).
+		Insert(&i)
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, inserted)
+	assert.EqualValues(t, 1, i.Id)
+
+	var j InsertWhere
+	has, err := testEngine.ID(i.Id).Get(&j)
+	assert.NoError(t, err)
+	assert.True(t, has)
+	i.Index = 1
+	assert.EqualValues(t, i, j)
+
+	inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1).
+		SetExpr("`index`", "coalesce(MAX(`index`),0)+1").
+		Insert(map[string]interface{}{
+			"repo_id": 1,
+			"width":   20,
+			"height":  40,
+			"name":    "trest2",
+		})
+	assert.NoError(t, err)
+	assert.EqualValues(t, 1, inserted)
+
+	var j2 InsertWhere
+	has, err = testEngine.ID(2).Get(&j2)
+	assert.NoError(t, err)
+	assert.True(t, has)
+	assert.EqualValues(t, 1, j2.RepoId)
+	assert.EqualValues(t, 20, j2.Width)
+	assert.EqualValues(t, 40, j2.Height)
+	assert.EqualValues(t, "trest2", j2.Name)
+	assert.EqualValues(t, 2, j2.Index)
+}