Browse Source

bug fixed for innerInsert

xormplus 10 years ago
parent
commit
30355e215f
2 changed files with 29 additions and 30 deletions
  1. 4 4
      engine.go
  2. 25 26
      session.go

+ 4 - 4
engine.go

@@ -32,10 +32,10 @@ type Engine struct {
 	TableMapper   core.IMapper
 	TagIdentifier string
 	Tables        map[reflect.Type]*core.Table
-	SqlMap  SqlMap
-	SqlTemplate  SqlTemplate
-	mutex  *sync.RWMutex
-	Cacher core.Cacher
+	SqlMap        SqlMap
+	SqlTemplate   SqlTemplate
+	mutex         *sync.RWMutex
+	Cacher        core.Cacher
 
 	ShowSQL bool
 

+ 25 - 26
session.go

@@ -3162,9 +3162,10 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
 
 	// for postgres, many of them didn't implement lastInsertId, so we should
 	// implemented it ourself.
+	if session.Engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 {
+		//assert table.AutoIncrement != ""
+		res, err := session.query("select seq_atable.currval from dual", args...)
 
-	if session.Engine.DriverName() != core.POSTGRES || table.AutoIncrement == "" {
-		res, err := session.exec(sqlStr, args...)
 		if err != nil {
 			return 0, err
 		} else {
@@ -3184,14 +3185,14 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
 			}
 		}
 
-		if table.AutoIncrement == "" {
-			return res.RowsAffected()
+		if len(res) < 1 {
+			return 0, errors.New("insert no error but not returned id")
 		}
 
-		var id int64 = 0
-		id, err = res.LastInsertId()
-		if err != nil || id <= 0 {
-			return res.RowsAffected()
+		idByte := res[0][table.AutoIncrement]
+		id, err := strconv.ParseInt(string(idByte), 10, 64)
+		if err != nil {
+			return 1, err
 		}
 
 		aiValue, err := table.AutoIncrColumn().ValueOf(bean)
@@ -3199,8 +3200,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
 			session.Engine.LogError(err)
 		}
 
-		if aiValue == nil || !aiValue.IsValid() /*|| aiValue.Int() != 0*/ || !aiValue.CanSet() {
-			return res.RowsAffected()
+		if aiValue == nil || !aiValue.IsValid() /*|| aiValue. != 0*/ || !aiValue.CanSet() {
+			return 1, nil
 		}
 
 		var v interface{} = id
@@ -3218,10 +3219,11 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
 		}
 		aiValue.Set(reflect.ValueOf(v))
 
-		return res.RowsAffected()
-	} else if session.Engine.DriverName() == core.ORACLE {
+		return 1, nil
+	} else if session.Engine.dialect.DBType() == core.POSTGRES && len(table.AutoIncrement) > 0 {
 		//assert table.AutoIncrement != ""
-		res, err := session.query("select seq_atable.currval from dual", args...)
+		sqlStr = sqlStr + " RETURNING " + session.Engine.Quote(table.AutoIncrement)
+		res, err := session.query(sqlStr, args...)
 
 		if err != nil {
 			return 0, err
@@ -3278,10 +3280,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
 
 		return 1, nil
 	} else {
-		//assert table.AutoIncrement != ""
-		sqlStr = sqlStr + " RETURNING " + session.Engine.Quote(table.AutoIncrement)
-		res, err := session.query(sqlStr, args...)
-
+		res, err := session.exec(sqlStr, args...)
 		if err != nil {
 			return 0, err
 		} else {
@@ -3301,14 +3300,14 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
 			}
 		}
 
-		if len(res) < 1 {
-			return 0, errors.New("insert no error but not returned id")
+		if table.AutoIncrement == "" {
+			return res.RowsAffected()
 		}
 
-		idByte := res[0][table.AutoIncrement]
-		id, err := strconv.ParseInt(string(idByte), 10, 64)
-		if err != nil {
-			return 1, err
+		var id int64 = 0
+		id, err = res.LastInsertId()
+		if err != nil || id <= 0 {
+			return res.RowsAffected()
 		}
 
 		aiValue, err := table.AutoIncrColumn().ValueOf(bean)
@@ -3316,8 +3315,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
 			session.Engine.LogError(err)
 		}
 
-		if aiValue == nil || !aiValue.IsValid() /*|| aiValue. != 0*/ || !aiValue.CanSet() {
-			return 1, nil
+		if aiValue == nil || !aiValue.IsValid() /*|| aiValue.Int() != 0*/ || !aiValue.CanSet() {
+			return res.RowsAffected()
 		}
 
 		var v interface{} = id
@@ -3335,7 +3334,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
 		}
 		aiValue.Set(reflect.ValueOf(v))
 
-		return 1, nil
+		return res.RowsAffected()
 	}
 }