Explorar el Código

Add insert map support

* add insert map

* fix insert map bug when cache enabled
xormplus hace 6 años
padre
commit
0aa454323c
Se han modificado 1 ficheros con 119 adiciones y 19 borrados
  1. 119 19
      session_insert.go

+ 119 - 19
session_insert.go

@@ -8,6 +8,7 @@ import (
 	"errors"
 	"fmt"
 	"reflect"
+	"sort"
 	"strconv"
 	"strings"
 
@@ -24,32 +25,67 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
 	}
 
 	for _, bean := range beans {
-		sliceValue := reflect.Indirect(reflect.ValueOf(bean))
-		if sliceValue.Kind() == reflect.Slice {
-			size := sliceValue.Len()
-			if size > 0 {
-				if session.engine.SupportInsertMany() {
-					cnt, err := session.innerInsertMulti(bean)
-					if err != nil {
-						return affected, err
-					}
-					affected += cnt
-				} else {
-					for i := 0; i < size; i++ {
-						cnt, err := session.innerInsert(sliceValue.Index(i).Interface())
+		switch bean.(type) {
+		case map[string]interface{}:
+			cnt, err := session.insertMapInterface(bean.(map[string]interface{}))
+			if err != nil {
+				return affected, err
+			}
+			affected += cnt
+		case []map[string]interface{}:
+			s := bean.([]map[string]interface{})
+			session.autoResetStatement = false
+			for i := 0; i < len(s); i++ {
+				cnt, err := session.insertMapInterface(s[i])
+				if err != nil {
+					return affected, err
+				}
+				affected += cnt
+			}
+		case map[string]string:
+			cnt, err := session.insertMapString(bean.(map[string]string))
+			if err != nil {
+				return affected, err
+			}
+			affected += cnt
+		case []map[string]string:
+			s := bean.([]map[string]string)
+			session.autoResetStatement = false
+			for i := 0; i < len(s); i++ {
+				cnt, err := session.insertMapString(s[i])
+				if err != nil {
+					return affected, err
+				}
+				affected += cnt
+			}
+		default:
+			sliceValue := reflect.Indirect(reflect.ValueOf(bean))
+			if sliceValue.Kind() == reflect.Slice {
+				size := sliceValue.Len()
+				if size > 0 {
+					if session.engine.SupportInsertMany() {
+						cnt, err := session.innerInsertMulti(bean)
 						if err != nil {
 							return affected, err
 						}
 						affected += cnt
+					} else {
+						for i := 0; i < size; i++ {
+							cnt, err := session.innerInsert(sliceValue.Index(i).Interface())
+							if err != nil {
+								return affected, err
+							}
+							affected += cnt
+						}
 					}
 				}
+			} else {
+				cnt, err := session.innerInsert(bean)
+				if err != nil {
+					return affected, err
+				}
+				affected += cnt
 			}
-		} else {
-			cnt, err := session.innerInsert(bean)
-			if err != nil {
-				return affected, err
-			}
-			affected += cnt
 		}
 	}
 
@@ -622,3 +658,67 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac
 	}
 	return colNames, args, nil
 }
+
+func (session *Session) insertMapInterface(m map[string]interface{}) (int64, error) {
+	var columns = make([]string, 0, len(m))
+	for k := range m {
+		columns = append(columns, k)
+	}
+	sort.Strings(columns)
+
+	qm := strings.Repeat("?,", len(columns))
+	qm = "(" + qm[:len(qm)-1] + ")"
+
+	tableName := session.statement.AltTableName
+	var sql = "INSERT INTO `" + tableName + "` (`" + strings.Join(columns, "`,`") + "`) VALUES " + qm
+	var args = make([]interface{}, 0, len(m))
+	for _, colName := range columns {
+		args = append(args, m[colName])
+	}
+
+	if err := session.cacheInsert(tableName); err != nil {
+		return 0, err
+	}
+
+	res, err := session.exec(sql, args...)
+	if err != nil {
+		return 0, err
+	}
+	affected, err := res.RowsAffected()
+	if err != nil {
+		return 0, err
+	}
+	return affected, nil
+}
+
+func (session *Session) insertMapString(m map[string]string) (int64, error) {
+	var columns = make([]string, 0, len(m))
+	for k := range m {
+		columns = append(columns, k)
+	}
+	sort.Strings(columns)
+
+	qm := strings.Repeat("?,", len(columns))
+	qm = "(" + qm[:len(qm)-1] + ")"
+
+	tableName := session.statement.AltTableName
+	var sql = "INSERT INTO `" + tableName + "` (`" + strings.Join(columns, "`,`") + "`) VALUES " + qm
+	var args = make([]interface{}, 0, len(m))
+	for _, colName := range columns {
+		args = append(args, m[colName])
+	}
+
+	if err := session.cacheInsert(tableName); err != nil {
+		return 0, err
+	}
+
+	res, err := session.exec(sql, args...)
+	if err != nil {
+		return 0, err
+	}
+	affected, err := res.RowsAffected()
+	if err != nil {
+		return 0, err
+	}
+	return affected, nil
+}