Просмотр исходного кода

1。fix oracle insert multiple records
2。remove unused method
3。remove unused field of rows

xormplus 8 лет назад
Родитель
Сommit
59e5b46fc6
3 измененных файлов с 43 добавлено и 36 удалено
  1. 7 9
      rows.go
  2. 27 11
      session_insert.go
  3. 9 16
      statement.go

+ 7 - 9
rows.go

@@ -16,13 +16,12 @@ import (
 type Rows struct {
 	NoTypeCheck bool
 
-	session     *Session
-	stmt        *core.Stmt
-	rows        *core.Rows
-	fields      []string
-	fieldsCount int
-	beanType    reflect.Type
-	lastError   error
+	session   *Session
+	stmt      *core.Stmt
+	rows      *core.Rows
+	fields    []string
+	beanType  reflect.Type
+	lastError error
 }
 
 func newRows(session *Session, bean interface{}) (*Rows, error) {
@@ -82,7 +81,6 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
 		rows.Close()
 		return nil, err
 	}
-	rows.fieldsCount = len(rows.fields)
 
 	return rows, nil
 }
@@ -114,7 +112,7 @@ func (rows *Rows) Scan(bean interface{}) error {
 		return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType)
 	}
 
-	_, err := rows.session.row2Bean(rows.rows, rows.fields, rows.fieldsCount, bean)
+	_, err := rows.session.row2Bean(rows.rows, rows.fields, len(rows.fields), bean)
 	return err
 }
 

+ 27 - 11
session_insert.go

@@ -210,13 +210,29 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
 	}
 	cleanupProcessorsClosures(&session.beforeClosures)
 
-	statement := fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)",
-		session.Engine.Quote(session.Statement.TableName()),
-		session.Engine.QuoteStr(),
-		strings.Join(colNames, session.Engine.QuoteStr()+", "+session.Engine.QuoteStr()),
-		session.Engine.QuoteStr(),
-		strings.Join(colMultiPlaces, "),("))
-
+	var sql = "INSERT INTO %s (%v%v%v) VALUES (%v)"
+	var statement string
+	if session.Engine.dialect.DBType() == core.ORACLE {
+		sql = "INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL"
+		temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (",
+			session.Engine.Quote(session.Statement.TableName()),
+			session.Engine.QuoteStr(),
+			strings.Join(colNames, session.Engine.QuoteStr() + ", " + session.Engine.QuoteStr()),
+			session.Engine.QuoteStr())
+		statement = fmt.Sprintf(sql,
+			session.Engine.Quote(session.Statement.TableName()),
+			session.Engine.QuoteStr(),
+			strings.Join(colNames, session.Engine.QuoteStr() + ", " + session.Engine.QuoteStr()),
+			session.Engine.QuoteStr(),
+			strings.Join(colMultiPlaces, temp))
+	} else {
+		statement = fmt.Sprintf(sql,
+			session.Engine.Quote(session.Statement.TableName()),
+			session.Engine.QuoteStr(),
+			strings.Join(colNames, session.Engine.QuoteStr() + ", " + session.Engine.QuoteStr()),
+			session.Engine.QuoteStr(),
+			strings.Join(colMultiPlaces, "),("))
+	}
 	res, err := session.exec(statement, args...)
 	if err != nil {
 		return 0, err
@@ -309,8 +325,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
 		// remove the expr columns
 		for i, colName := range colNames {
 			if colName == v.colName {
-				colNames = append(colNames[:i], colNames[i+1:]...)
-				args = append(args[:i], args[i+1:]...)
+				colNames = append(colNames[:i], colNames[i + 1:]...)
+				args = append(args[:i], args[i + 1:]...)
 			}
 		}
 
@@ -319,11 +335,11 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
 		exprColVals = append(exprColVals, v.expr)
 	}
 
-	colPlaces := strings.Repeat("?, ", len(colNames)-len(exprColumns))
+	colPlaces := strings.Repeat("?, ", len(colNames) - len(exprColumns))
 	if len(exprColVals) > 0 {
 		colPlaces = colPlaces + strings.Join(exprColVals, ", ")
 	} else {
-		colPlaces = colPlaces[0 : len(colPlaces)-2]
+		colPlaces = colPlaces[0 : len(colPlaces) - 2]
 	}
 
 	sqlStr := fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)",

+ 9 - 16
statement.go

@@ -39,7 +39,7 @@ type Statement struct {
 	Engine          *Engine
 	Start           int
 	LimitN          int
-	IdParam         *core.PK
+	idParam         *core.PK
 	OrderStr        string
 	JoinStr         string
 	joinArgs        []interface{}
@@ -91,7 +91,7 @@ func (statement *Statement) Init() {
 	statement.columnMap = make(map[string]bool)
 	statement.AltTableName = ""
 	statement.tableName = ""
-	statement.IdParam = nil
+	statement.idParam = nil
 	statement.RawSQL = ""
 	statement.RawParams = make([]interface{}, 0)
 	statement.UseCache = true
@@ -698,13 +698,6 @@ func (statement *Statement) TableName() string {
 	return statement.tableName
 }
 
-// Id generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
-//
-// Deprecated: use ID instead
-func (statement *Statement) Id(id interface{}) *Statement {
-	return statement.ID(id)
-}
-
 // ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
 func (statement *Statement) ID(id interface{}) *Statement {
 	idValue := reflect.ValueOf(id)
@@ -713,23 +706,23 @@ func (statement *Statement) ID(id interface{}) *Statement {
 	switch idType {
 	case ptrPkType:
 		if pkPtr, ok := (id).(*core.PK); ok {
-			statement.IdParam = pkPtr
+			statement.idParam = pkPtr
 			return statement
 		}
 	case pkType:
 		if pk, ok := (id).(core.PK); ok {
-			statement.IdParam = &pk
+			statement.idParam = &pk
 			return statement
 		}
 	}
 
 	switch idType.Kind() {
 	case reflect.String:
-		statement.IdParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
+		statement.idParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
 		return statement
 	}
 
-	statement.IdParam = &core.PK{id}
+	statement.idParam = &core.PK{id}
 	return statement
 }
 
@@ -1281,14 +1274,14 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
 }
 
 func (statement *Statement) processIDParam() {
-	if statement.IdParam == nil {
+	if statement.idParam == nil {
 		return
 	}
 
 	for i, col := range statement.RefTable.PKColumns() {
 		var colName = statement.colName(col, statement.TableName())
-		if i < len(*(statement.IdParam)) {
-			statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.IdParam))[i]})
+		if i < len(*(statement.idParam)) {
+			statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
 		} else {
 			statement.cond = statement.cond.And(builder.Eq{colName: ""})
 		}