sqlconn.go 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. // copy from core/stores/sqlx/sqlconn.go
  2. package mocksql
  3. import (
  4. "database/sql"
  5. "github.com/tal-tech/go-zero/core/stores/sqlx"
  6. )
  7. type (
  8. MockConn struct {
  9. db *sql.DB
  10. }
  11. statement struct {
  12. stmt *sql.Stmt
  13. }
  14. )
  15. func NewMockConn(db *sql.DB) *MockConn {
  16. return &MockConn{db: db}
  17. }
  18. func (conn *MockConn) Exec(query string, args ...interface{}) (sql.Result, error) {
  19. return exec(conn.db, query, args...)
  20. }
  21. func (conn *MockConn) Prepare(query string) (sqlx.StmtSession, error) {
  22. st, err := conn.db.Prepare(query)
  23. return statement{stmt: st}, err
  24. }
  25. func (conn *MockConn) QueryRow(v interface{}, q string, args ...interface{}) error {
  26. return query(conn.db, func(rows *sql.Rows) error {
  27. return unmarshalRow(v, rows, true)
  28. }, q, args...)
  29. }
  30. func (conn *MockConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
  31. return query(conn.db, func(rows *sql.Rows) error {
  32. return unmarshalRow(v, rows, false)
  33. }, q, args...)
  34. }
  35. func (conn *MockConn) QueryRows(v interface{}, q string, args ...interface{}) error {
  36. return query(conn.db, func(rows *sql.Rows) error {
  37. return unmarshalRows(v, rows, true)
  38. }, q, args...)
  39. }
  40. func (conn *MockConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
  41. return query(conn.db, func(rows *sql.Rows) error {
  42. return unmarshalRows(v, rows, false)
  43. }, q, args...)
  44. }
  45. func (conn *MockConn) Transact(func(session sqlx.Session) error) error {
  46. return nil
  47. }
  48. func (s statement) Close() error {
  49. return s.stmt.Close()
  50. }
  51. func (s statement) Exec(args ...interface{}) (sql.Result, error) {
  52. return execStmt(s.stmt, args...)
  53. }
  54. func (s statement) QueryRow(v interface{}, args ...interface{}) error {
  55. return queryStmt(s.stmt, func(rows *sql.Rows) error {
  56. return unmarshalRow(v, rows, true)
  57. }, args...)
  58. }
  59. func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
  60. return queryStmt(s.stmt, func(rows *sql.Rows) error {
  61. return unmarshalRow(v, rows, false)
  62. }, args...)
  63. }
  64. func (s statement) QueryRows(v interface{}, args ...interface{}) error {
  65. return queryStmt(s.stmt, func(rows *sql.Rows) error {
  66. return unmarshalRows(v, rows, true)
  67. }, args...)
  68. }
  69. func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
  70. return queryStmt(s.stmt, func(rows *sql.Rows) error {
  71. return unmarshalRows(v, rows, false)
  72. }, args...)
  73. }