stmt_test.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. package sqlx
  2. import (
  3. "database/sql"
  4. "errors"
  5. "testing"
  6. "time"
  7. "github.com/stretchr/testify/assert"
  8. )
  9. var errMockedPlaceholder = errors.New("placeholder")
  10. func TestStmt_exec(t *testing.T) {
  11. tests := []struct {
  12. name string
  13. args []interface{}
  14. delay bool
  15. hasError bool
  16. err error
  17. lastInsertId int64
  18. rowsAffected int64
  19. }{
  20. {
  21. name: "normal",
  22. args: []interface{}{1},
  23. lastInsertId: 1,
  24. rowsAffected: 2,
  25. },
  26. {
  27. name: "exec error",
  28. args: []interface{}{1},
  29. hasError: true,
  30. err: errors.New("exec"),
  31. },
  32. {
  33. name: "slowcall",
  34. args: []interface{}{1},
  35. delay: true,
  36. lastInsertId: 1,
  37. rowsAffected: 2,
  38. },
  39. }
  40. for _, test := range tests {
  41. test := test
  42. fns := []func(args ...interface{}) (sql.Result, error){
  43. func(args ...interface{}) (sql.Result, error) {
  44. return exec(&mockedSessionConn{
  45. lastInsertId: test.lastInsertId,
  46. rowsAffected: test.rowsAffected,
  47. err: test.err,
  48. delay: test.delay,
  49. }, "select user from users where id=?", args...)
  50. },
  51. func(args ...interface{}) (sql.Result, error) {
  52. return execStmt(&mockedStmtConn{
  53. lastInsertId: test.lastInsertId,
  54. rowsAffected: test.rowsAffected,
  55. err: test.err,
  56. delay: test.delay,
  57. }, args...)
  58. },
  59. }
  60. for _, fn := range fns {
  61. fn := fn
  62. t.Run(test.name, func(t *testing.T) {
  63. t.Parallel()
  64. res, err := fn(test.args...)
  65. if test.hasError {
  66. assert.NotNil(t, err)
  67. return
  68. }
  69. assert.Nil(t, err)
  70. lastInsertId, err := res.LastInsertId()
  71. assert.Nil(t, err)
  72. assert.Equal(t, test.lastInsertId, lastInsertId)
  73. rowsAffected, err := res.RowsAffected()
  74. assert.Nil(t, err)
  75. assert.Equal(t, test.rowsAffected, rowsAffected)
  76. })
  77. }
  78. }
  79. }
  80. func TestStmt_query(t *testing.T) {
  81. tests := []struct {
  82. name string
  83. args []interface{}
  84. delay bool
  85. hasError bool
  86. err error
  87. }{
  88. {
  89. name: "normal",
  90. args: []interface{}{1},
  91. },
  92. {
  93. name: "query error",
  94. args: []interface{}{1},
  95. hasError: true,
  96. err: errors.New("exec"),
  97. },
  98. {
  99. name: "slowcall",
  100. args: []interface{}{1},
  101. delay: true,
  102. },
  103. }
  104. for _, test := range tests {
  105. test := test
  106. fns := []func(args ...interface{}) error{
  107. func(args ...interface{}) error {
  108. return query(&mockedSessionConn{
  109. err: test.err,
  110. delay: test.delay,
  111. }, func(rows *sql.Rows) error {
  112. return nil
  113. }, "select user from users where id=?", args...)
  114. },
  115. func(args ...interface{}) error {
  116. return queryStmt(&mockedStmtConn{
  117. err: test.err,
  118. delay: test.delay,
  119. }, func(rows *sql.Rows) error {
  120. return nil
  121. }, args...)
  122. },
  123. }
  124. for _, fn := range fns {
  125. fn := fn
  126. t.Run(test.name, func(t *testing.T) {
  127. t.Parallel()
  128. err := fn(test.args...)
  129. if test.hasError {
  130. assert.NotNil(t, err)
  131. return
  132. }
  133. assert.Equal(t, errMockedPlaceholder, err)
  134. })
  135. }
  136. }
  137. }
  138. type mockedSessionConn struct {
  139. lastInsertId int64
  140. rowsAffected int64
  141. err error
  142. delay bool
  143. }
  144. func (m *mockedSessionConn) Exec(query string, args ...interface{}) (sql.Result, error) {
  145. if m.delay {
  146. time.Sleep(slowThreshold + time.Millisecond)
  147. }
  148. return mockedResult{
  149. lastInsertId: m.lastInsertId,
  150. rowsAffected: m.rowsAffected,
  151. }, m.err
  152. }
  153. func (m *mockedSessionConn) Query(query string, args ...interface{}) (*sql.Rows, error) {
  154. if m.delay {
  155. time.Sleep(slowThreshold + time.Millisecond)
  156. }
  157. err := errMockedPlaceholder
  158. if m.err != nil {
  159. err = m.err
  160. }
  161. return new(sql.Rows), err
  162. }
  163. type mockedStmtConn struct {
  164. lastInsertId int64
  165. rowsAffected int64
  166. err error
  167. delay bool
  168. }
  169. func (m *mockedStmtConn) Exec(args ...interface{}) (sql.Result, error) {
  170. if m.delay {
  171. time.Sleep(slowThreshold + time.Millisecond)
  172. }
  173. return mockedResult{
  174. lastInsertId: m.lastInsertId,
  175. rowsAffected: m.rowsAffected,
  176. }, m.err
  177. }
  178. func (m *mockedStmtConn) Query(args ...interface{}) (*sql.Rows, error) {
  179. if m.delay {
  180. time.Sleep(slowThreshold + time.Millisecond)
  181. }
  182. err := errMockedPlaceholder
  183. if m.err != nil {
  184. err = m.err
  185. }
  186. return new(sql.Rows), err
  187. }
  188. type mockedResult struct {
  189. lastInsertId int64
  190. rowsAffected int64
  191. }
  192. func (m mockedResult) LastInsertId() (int64, error) {
  193. return m.lastInsertId, nil
  194. }
  195. func (m mockedResult) RowsAffected() (int64, error) {
  196. return m.rowsAffected, nil
  197. }