sqlconn.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. package sqlx
  2. import (
  3. "database/sql"
  4. "github.com/tal-tech/go-zero/core/breaker"
  5. )
  6. var ErrNotFound = sql.ErrNoRows
  7. type (
  8. // Session stands for raw connections or transaction sessions
  9. Session interface {
  10. Exec(query string, args ...interface{}) (sql.Result, error)
  11. Prepare(query string) (StmtSession, error)
  12. QueryRow(v interface{}, query string, args ...interface{}) error
  13. QueryRowPartial(v interface{}, query string, args ...interface{}) error
  14. QueryRows(v interface{}, query string, args ...interface{}) error
  15. QueryRowsPartial(v interface{}, query string, args ...interface{}) error
  16. }
  17. // SqlConn only stands for raw connections, so Transact method can be called.
  18. SqlConn interface {
  19. Session
  20. Transact(func(session Session) error) error
  21. }
  22. SqlOption func(*commonSqlConn)
  23. StmtSession interface {
  24. Close() error
  25. Exec(args ...interface{}) (sql.Result, error)
  26. QueryRow(v interface{}, args ...interface{}) error
  27. QueryRowPartial(v interface{}, args ...interface{}) error
  28. QueryRows(v interface{}, args ...interface{}) error
  29. QueryRowsPartial(v interface{}, args ...interface{}) error
  30. }
  31. // thread-safe
  32. // Because CORBA doesn't support PREPARE, so we need to combine the
  33. // query arguments into one string and do underlying query without arguments
  34. commonSqlConn struct {
  35. driverName string
  36. datasource string
  37. beginTx beginnable
  38. brk breaker.Breaker
  39. accept func(error) bool
  40. }
  41. sessionConn interface {
  42. Exec(query string, args ...interface{}) (sql.Result, error)
  43. Query(query string, args ...interface{}) (*sql.Rows, error)
  44. }
  45. statement struct {
  46. stmt *sql.Stmt
  47. }
  48. stmtConn interface {
  49. Exec(args ...interface{}) (sql.Result, error)
  50. Query(args ...interface{}) (*sql.Rows, error)
  51. }
  52. )
  53. func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
  54. conn := &commonSqlConn{
  55. driverName: driverName,
  56. datasource: datasource,
  57. beginTx: begin,
  58. brk: breaker.NewBreaker(),
  59. }
  60. for _, opt := range opts {
  61. opt(conn)
  62. }
  63. return conn
  64. }
  65. func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) {
  66. err = db.brk.DoWithAcceptable(func() error {
  67. var conn *sql.DB
  68. conn, err = getSqlConn(db.driverName, db.datasource)
  69. if err != nil {
  70. logInstanceError(db.datasource, err)
  71. return err
  72. }
  73. result, err = exec(conn, q, args...)
  74. return err
  75. }, db.acceptable)
  76. return
  77. }
  78. func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
  79. err = db.brk.DoWithAcceptable(func() error {
  80. var conn *sql.DB
  81. conn, err = getSqlConn(db.driverName, db.datasource)
  82. if err != nil {
  83. logInstanceError(db.datasource, err)
  84. return err
  85. }
  86. st, err := conn.Prepare(query)
  87. if err != nil {
  88. return err
  89. }
  90. stmt = statement{
  91. stmt: st,
  92. }
  93. return nil
  94. }, db.acceptable)
  95. return
  96. }
  97. func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error {
  98. return db.queryRows(func(rows *sql.Rows) error {
  99. return unmarshalRow(v, rows, true)
  100. }, q, args...)
  101. }
  102. func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
  103. return db.queryRows(func(rows *sql.Rows) error {
  104. return unmarshalRow(v, rows, false)
  105. }, q, args...)
  106. }
  107. func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error {
  108. return db.queryRows(func(rows *sql.Rows) error {
  109. return unmarshalRows(v, rows, true)
  110. }, q, args...)
  111. }
  112. func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
  113. return db.queryRows(func(rows *sql.Rows) error {
  114. return unmarshalRows(v, rows, false)
  115. }, q, args...)
  116. }
  117. func (db *commonSqlConn) Transact(fn func(Session) error) error {
  118. return db.brk.DoWithAcceptable(func() error {
  119. return transact(db, db.beginTx, fn)
  120. }, db.acceptable)
  121. }
  122. func (db *commonSqlConn) acceptable(err error) bool {
  123. ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone
  124. if db.accept == nil {
  125. return ok
  126. }
  127. return ok || db.accept(err)
  128. }
  129. func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args ...interface{}) error {
  130. var qerr error
  131. return db.brk.DoWithAcceptable(func() error {
  132. conn, err := getSqlConn(db.driverName, db.datasource)
  133. if err != nil {
  134. logInstanceError(db.datasource, err)
  135. return err
  136. }
  137. return query(conn, func(rows *sql.Rows) error {
  138. qerr = scanner(rows)
  139. return qerr
  140. }, q, args...)
  141. }, func(err error) bool {
  142. return qerr == err || db.acceptable(err)
  143. })
  144. }
  145. func (s statement) Close() error {
  146. return s.stmt.Close()
  147. }
  148. func (s statement) Exec(args ...interface{}) (sql.Result, error) {
  149. return execStmt(s.stmt, args...)
  150. }
  151. func (s statement) QueryRow(v interface{}, args ...interface{}) error {
  152. return queryStmt(s.stmt, func(rows *sql.Rows) error {
  153. return unmarshalRow(v, rows, true)
  154. }, args...)
  155. }
  156. func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
  157. return queryStmt(s.stmt, func(rows *sql.Rows) error {
  158. return unmarshalRow(v, rows, false)
  159. }, args...)
  160. }
  161. func (s statement) QueryRows(v interface{}, args ...interface{}) error {
  162. return queryStmt(s.stmt, func(rows *sql.Rows) error {
  163. return unmarshalRows(v, rows, true)
  164. }, args...)
  165. }
  166. func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
  167. return queryStmt(s.stmt, func(rows *sql.Rows) error {
  168. return unmarshalRows(v, rows, false)
  169. }, args...)
  170. }