Browse Source

支持子事务回滚,整个嵌套事务回滚

xormplus 9 years ago
parent
commit
f4836b2a74
1 changed files with 30 additions and 13 deletions
  1. 30 13
      transaction.go

+ 30 - 13
transaction.go

@@ -231,11 +231,17 @@ func (transaction *Transaction) Commit() error {
 		if !transaction.IsExistingTransaction() {
 			return ErrNotInTransaction
 		}
+
 		if !transaction.isNested {
 			err := transaction.txSession.commit()
 			if err != nil {
 				return err
 			}
+		} else if transaction.txSession.rollbackSavePointID == transaction.savePointID {
+			if err := transaction.RollbackToSavePoint(transaction.savePointID); err != nil {
+				transaction.txSession.rollbackSavePointID = ""
+				return err
+			}
 		}
 		return nil
 	default:
@@ -250,16 +256,25 @@ func (transaction *Transaction) Rollback() error {
 		if !transaction.IsExistingTransaction() {
 			return ErrNotInTransaction
 		}
-		err := transaction.txSession.rollback()
-		if err != nil {
-			return err
+		if transaction.savePointID == "" {
+			err := transaction.txSession.rollback()
+			if err != nil {
+				return err
+			}
+		} else {
+			transaction.txSession.rollbackSavePointID = transaction.savePointID
 		}
+
 		return nil
 	case PROPAGATION_SUPPORTS:
 		if transaction.IsExistingTransaction() {
-			err := transaction.txSession.rollback()
-			if err != nil {
-				return err
+			if transaction.savePointID == "" {
+				err := transaction.txSession.rollback()
+				if err != nil {
+					return err
+				}
+			} else {
+				transaction.txSession.rollbackSavePointID = transaction.savePointID
 			}
 			return nil
 		}
@@ -268,18 +283,15 @@ func (transaction *Transaction) Rollback() error {
 		if !transaction.IsExistingTransaction() {
 			return ErrNotInTransaction
 		}
-		if transaction.savePointID != "" {
-			if err := transaction.RollbackToSavePoint(transaction.savePointID); err != nil {
-				return err
-			}
-			return nil
-		} else {
+		if transaction.savePointID == "" {
 			err := transaction.txSession.rollback()
 			if err != nil {
 				return err
 			}
-			return nil
+		} else {
+			transaction.txSession.rollbackSavePointID = transaction.savePointID
 		}
+		return nil
 
 	case PROPAGATION_REQUIRES_NEW:
 		if !transaction.IsExistingTransaction() {
@@ -304,6 +316,11 @@ func (transaction *Transaction) Rollback() error {
 		if !transaction.IsExistingTransaction() {
 			return ErrNotInTransaction
 		}
+
+		if transaction.txSession.rollbackSavePointID == transaction.savePointID {
+			return nil
+		}
+
 		if transaction.isNested {
 			if err := transaction.RollbackToSavePoint(transaction.savePointID); err != nil {
 				return err