sqlconn.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. package sqlx
  2. import (
  3. "database/sql"
  4. "git.i2edu.net/i2/go-zero/core/breaker"
  5. )
  6. // ErrNotFound is an alias of sql.ErrNoRows
  7. var ErrNotFound = sql.ErrNoRows
  8. type (
  9. // Session stands for raw connections or transaction sessions
  10. Session interface {
  11. Exec(query string, args ...interface{}) (sql.Result, error)
  12. Prepare(query string) (StmtSession, error)
  13. QueryRow(v interface{}, query string, args ...interface{}) error
  14. QueryRowPartial(v interface{}, query string, args ...interface{}) error
  15. QueryRows(v interface{}, query string, args ...interface{}) error
  16. QueryRowsPartial(v interface{}, query string, args ...interface{}) error
  17. }
  18. // SqlConn only stands for raw connections, so Transact method can be called.
  19. SqlConn interface {
  20. Session
  21. Transact(func(session Session) error) error
  22. }
  23. // SqlOption defines the method to customize a sql connection.
  24. SqlOption func(*commonSqlConn)
  25. // StmtSession interface represents a session that can be used to execute statements.
  26. StmtSession interface {
  27. Close() error
  28. Exec(args ...interface{}) (sql.Result, error)
  29. QueryRow(v interface{}, args ...interface{}) error
  30. QueryRowPartial(v interface{}, args ...interface{}) error
  31. QueryRows(v interface{}, args ...interface{}) error
  32. QueryRowsPartial(v interface{}, args ...interface{}) error
  33. }
  34. // thread-safe
  35. // Because CORBA doesn't support PREPARE, so we need to combine the
  36. // query arguments into one string and do underlying query without arguments
  37. commonSqlConn struct {
  38. driverName string
  39. datasource string
  40. beginTx beginnable
  41. brk breaker.Breaker
  42. accept func(error) bool
  43. }
  44. sessionConn interface {
  45. Exec(query string, args ...interface{}) (sql.Result, error)
  46. Query(query string, args ...interface{}) (*sql.Rows, error)
  47. }
  48. statement struct {
  49. query string
  50. stmt *sql.Stmt
  51. }
  52. stmtConn interface {
  53. Exec(args ...interface{}) (sql.Result, error)
  54. Query(args ...interface{}) (*sql.Rows, error)
  55. }
  56. )
  57. // NewSqlConn returns a SqlConn with given driver name and datasource.
  58. func NewSqlConn(driverName, datasource string, opts ...SqlOption) SqlConn {
  59. conn := &commonSqlConn{
  60. driverName: driverName,
  61. datasource: datasource,
  62. beginTx: begin,
  63. brk: breaker.NewBreaker(),
  64. }
  65. for _, opt := range opts {
  66. opt(conn)
  67. }
  68. return conn
  69. }
  70. func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) {
  71. err = db.brk.DoWithAcceptable(func() error {
  72. var conn *sql.DB
  73. conn, err = getSqlConn(db.driverName, db.datasource)
  74. if err != nil {
  75. logInstanceError(db.datasource, err)
  76. return err
  77. }
  78. result, err = exec(conn, q, args...)
  79. return err
  80. }, db.acceptable)
  81. return
  82. }
  83. func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
  84. err = db.brk.DoWithAcceptable(func() error {
  85. var conn *sql.DB
  86. conn, err = getSqlConn(db.driverName, db.datasource)
  87. if err != nil {
  88. logInstanceError(db.datasource, err)
  89. return err
  90. }
  91. st, err := conn.Prepare(query)
  92. if err != nil {
  93. return err
  94. }
  95. stmt = statement{
  96. query: query,
  97. stmt: st,
  98. }
  99. return nil
  100. }, db.acceptable)
  101. return
  102. }
  103. func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error {
  104. return db.queryRows(func(rows *sql.Rows) error {
  105. return unmarshalRow(v, rows, true)
  106. }, q, args...)
  107. }
  108. func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
  109. return db.queryRows(func(rows *sql.Rows) error {
  110. return unmarshalRow(v, rows, false)
  111. }, q, args...)
  112. }
  113. func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error {
  114. return db.queryRows(func(rows *sql.Rows) error {
  115. return unmarshalRows(v, rows, true)
  116. }, q, args...)
  117. }
  118. func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
  119. return db.queryRows(func(rows *sql.Rows) error {
  120. return unmarshalRows(v, rows, false)
  121. }, q, args...)
  122. }
  123. func (db *commonSqlConn) Transact(fn func(Session) error) error {
  124. return db.brk.DoWithAcceptable(func() error {
  125. return transact(db, db.beginTx, fn)
  126. }, db.acceptable)
  127. }
  128. func (db *commonSqlConn) acceptable(err error) bool {
  129. ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone
  130. if db.accept == nil {
  131. return ok
  132. }
  133. return ok || db.accept(err)
  134. }
  135. func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args ...interface{}) error {
  136. var qerr error
  137. return db.brk.DoWithAcceptable(func() error {
  138. conn, err := getSqlConn(db.driverName, db.datasource)
  139. if err != nil {
  140. logInstanceError(db.datasource, err)
  141. return err
  142. }
  143. return query(conn, func(rows *sql.Rows) error {
  144. qerr = scanner(rows)
  145. return qerr
  146. }, q, args...)
  147. }, func(err error) bool {
  148. return qerr == err || db.acceptable(err)
  149. })
  150. }
  151. func (s statement) Close() error {
  152. return s.stmt.Close()
  153. }
  154. func (s statement) Exec(args ...interface{}) (sql.Result, error) {
  155. return execStmt(s.stmt, s.query, args...)
  156. }
  157. func (s statement) QueryRow(v interface{}, args ...interface{}) error {
  158. return queryStmt(s.stmt, func(rows *sql.Rows) error {
  159. return unmarshalRow(v, rows, true)
  160. }, s.query, args...)
  161. }
  162. func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
  163. return queryStmt(s.stmt, func(rows *sql.Rows) error {
  164. return unmarshalRow(v, rows, false)
  165. }, s.query, args...)
  166. }
  167. func (s statement) QueryRows(v interface{}, args ...interface{}) error {
  168. return queryStmt(s.stmt, func(rows *sql.Rows) error {
  169. return unmarshalRows(v, rows, true)
  170. }, s.query, args...)
  171. }
  172. func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
  173. return queryStmt(s.stmt, func(rows *sql.Rows) error {
  174. return unmarshalRows(v, rows, false)
  175. }, s.query, args...)
  176. }