tx.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. package sqlx
  2. import (
  3. "database/sql"
  4. "fmt"
  5. )
  6. type (
  7. beginnable func(*sql.DB) (trans, error)
  8. trans interface {
  9. Session
  10. Commit() error
  11. Rollback() error
  12. }
  13. txSession struct {
  14. *sql.Tx
  15. }
  16. )
  17. func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) {
  18. return exec(t.Tx, q, args...)
  19. }
  20. func (t txSession) Prepare(q string) (StmtSession, error) {
  21. stmt, err := t.Tx.Prepare(q)
  22. if err != nil {
  23. return nil, err
  24. }
  25. return statement{
  26. stmt: stmt,
  27. }, nil
  28. }
  29. func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error {
  30. return query(t.Tx, func(rows *sql.Rows) error {
  31. return unmarshalRow(v, rows, true)
  32. }, q, args...)
  33. }
  34. func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
  35. return query(t.Tx, func(rows *sql.Rows) error {
  36. return unmarshalRow(v, rows, false)
  37. }, q, args...)
  38. }
  39. func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error {
  40. return query(t.Tx, func(rows *sql.Rows) error {
  41. return unmarshalRows(v, rows, true)
  42. }, q, args...)
  43. }
  44. func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
  45. return query(t.Tx, func(rows *sql.Rows) error {
  46. return unmarshalRows(v, rows, false)
  47. }, q, args...)
  48. }
  49. func begin(db *sql.DB) (trans, error) {
  50. tx, err := db.Begin()
  51. if err != nil {
  52. return nil, err
  53. }
  54. return txSession{
  55. Tx: tx,
  56. }, nil
  57. }
  58. func transact(db *commonSqlConn, b beginnable, fn func(Session) error) (err error) {
  59. conn, err := getSqlConn(db.driverName, db.datasource)
  60. if err != nil {
  61. logInstanceError(db.datasource, err)
  62. return err
  63. }
  64. return transactOnConn(conn, b, fn)
  65. }
  66. func transactOnConn(conn *sql.DB, b beginnable, fn func(Session) error) (err error) {
  67. var tx trans
  68. tx, err = b(conn)
  69. if err != nil {
  70. return
  71. }
  72. defer func() {
  73. if p := recover(); p != nil {
  74. if e := tx.Rollback(); e != nil {
  75. err = fmt.Errorf("recover from %#v, rollback failed: %s", p, e)
  76. } else {
  77. err = fmt.Errorf("recoveer from %#v", p)
  78. }
  79. } else if err != nil {
  80. if e := tx.Rollback(); e != nil {
  81. err = fmt.Errorf("transaction failed: %s, rollback failed: %s", err, e)
  82. }
  83. } else {
  84. err = tx.Commit()
  85. }
  86. }()
  87. return fn(tx)
  88. }