stmt_test.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. package sqlx
  2. import (
  3. "database/sql"
  4. "errors"
  5. "testing"
  6. "time"
  7. "github.com/stretchr/testify/assert"
  8. )
  9. var errMockedPlaceholder = errors.New("placeholder")
  10. func TestStmt_exec(t *testing.T) {
  11. tests := []struct {
  12. name string
  13. args []interface{}
  14. delay bool
  15. formatError bool
  16. hasError bool
  17. err error
  18. lastInsertId int64
  19. rowsAffected int64
  20. }{
  21. {
  22. name: "normal",
  23. args: []interface{}{1},
  24. lastInsertId: 1,
  25. rowsAffected: 2,
  26. },
  27. {
  28. name: "wrong format",
  29. args: []interface{}{1, 2},
  30. formatError: true,
  31. hasError: true,
  32. },
  33. {
  34. name: "exec error",
  35. args: []interface{}{1},
  36. hasError: true,
  37. err: errors.New("exec"),
  38. },
  39. {
  40. name: "slowcall",
  41. args: []interface{}{1},
  42. delay: true,
  43. lastInsertId: 1,
  44. rowsAffected: 2,
  45. },
  46. }
  47. for _, test := range tests {
  48. test := test
  49. fns := []func(args ...interface{}) (sql.Result, error){
  50. func(args ...interface{}) (sql.Result, error) {
  51. return exec(&mockedSessionConn{
  52. lastInsertId: test.lastInsertId,
  53. rowsAffected: test.rowsAffected,
  54. err: test.err,
  55. delay: test.delay,
  56. }, "select user from users where id=?", args...)
  57. },
  58. func(args ...interface{}) (sql.Result, error) {
  59. return execStmt(&mockedStmtConn{
  60. lastInsertId: test.lastInsertId,
  61. rowsAffected: test.rowsAffected,
  62. err: test.err,
  63. delay: test.delay,
  64. }, args...)
  65. },
  66. }
  67. for i, fn := range fns {
  68. i := i
  69. fn := fn
  70. t.Run(test.name, func(t *testing.T) {
  71. t.Parallel()
  72. res, err := fn(test.args...)
  73. if i == 0 && test.formatError {
  74. assert.NotNil(t, err)
  75. return
  76. }
  77. if !test.formatError && test.hasError {
  78. assert.NotNil(t, err)
  79. return
  80. }
  81. assert.Nil(t, err)
  82. lastInsertId, err := res.LastInsertId()
  83. assert.Nil(t, err)
  84. assert.Equal(t, test.lastInsertId, lastInsertId)
  85. rowsAffected, err := res.RowsAffected()
  86. assert.Nil(t, err)
  87. assert.Equal(t, test.rowsAffected, rowsAffected)
  88. })
  89. }
  90. }
  91. }
  92. func TestStmt_query(t *testing.T) {
  93. tests := []struct {
  94. name string
  95. args []interface{}
  96. delay bool
  97. formatError bool
  98. hasError bool
  99. err error
  100. }{
  101. {
  102. name: "normal",
  103. args: []interface{}{1},
  104. },
  105. {
  106. name: "wrong format",
  107. args: []interface{}{1, 2},
  108. formatError: true,
  109. hasError: true,
  110. },
  111. {
  112. name: "query error",
  113. args: []interface{}{1},
  114. hasError: true,
  115. err: errors.New("exec"),
  116. },
  117. {
  118. name: "slowcall",
  119. args: []interface{}{1},
  120. delay: true,
  121. },
  122. }
  123. for _, test := range tests {
  124. test := test
  125. fns := []func(args ...interface{}) error{
  126. func(args ...interface{}) error {
  127. return query(&mockedSessionConn{
  128. err: test.err,
  129. delay: test.delay,
  130. }, func(rows *sql.Rows) error {
  131. return nil
  132. }, "select user from users where id=?", args...)
  133. },
  134. func(args ...interface{}) error {
  135. return queryStmt(&mockedStmtConn{
  136. err: test.err,
  137. delay: test.delay,
  138. }, func(rows *sql.Rows) error {
  139. return nil
  140. }, args...)
  141. },
  142. }
  143. for i, fn := range fns {
  144. i := i
  145. fn := fn
  146. t.Run(test.name, func(t *testing.T) {
  147. t.Parallel()
  148. err := fn(test.args...)
  149. if i == 0 && test.formatError {
  150. assert.NotNil(t, err)
  151. return
  152. }
  153. if !test.formatError && test.hasError {
  154. assert.NotNil(t, err)
  155. return
  156. }
  157. assert.Equal(t, errMockedPlaceholder, err)
  158. })
  159. }
  160. }
  161. }
  162. type mockedSessionConn struct {
  163. lastInsertId int64
  164. rowsAffected int64
  165. err error
  166. delay bool
  167. }
  168. func (m *mockedSessionConn) Exec(query string, args ...interface{}) (sql.Result, error) {
  169. if m.delay {
  170. time.Sleep(slowThreshold + time.Millisecond)
  171. }
  172. return mockedResult{
  173. lastInsertId: m.lastInsertId,
  174. rowsAffected: m.rowsAffected,
  175. }, m.err
  176. }
  177. func (m *mockedSessionConn) Query(query string, args ...interface{}) (*sql.Rows, error) {
  178. if m.delay {
  179. time.Sleep(slowThreshold + time.Millisecond)
  180. }
  181. err := errMockedPlaceholder
  182. if m.err != nil {
  183. err = m.err
  184. }
  185. return new(sql.Rows), err
  186. }
  187. type mockedStmtConn struct {
  188. lastInsertId int64
  189. rowsAffected int64
  190. err error
  191. delay bool
  192. }
  193. func (m *mockedStmtConn) Exec(args ...interface{}) (sql.Result, error) {
  194. if m.delay {
  195. time.Sleep(slowThreshold + time.Millisecond)
  196. }
  197. return mockedResult{
  198. lastInsertId: m.lastInsertId,
  199. rowsAffected: m.rowsAffected,
  200. }, m.err
  201. }
  202. func (m *mockedStmtConn) Query(args ...interface{}) (*sql.Rows, error) {
  203. if m.delay {
  204. time.Sleep(slowThreshold + time.Millisecond)
  205. }
  206. err := errMockedPlaceholder
  207. if m.err != nil {
  208. err = m.err
  209. }
  210. return new(sql.Rows), err
  211. }
  212. type mockedResult struct {
  213. lastInsertId int64
  214. rowsAffected int64
  215. }
  216. func (m mockedResult) LastInsertId() (int64, error) {
  217. return m.lastInsertId, nil
  218. }
  219. func (m mockedResult) RowsAffected() (int64, error) {
  220. return m.rowsAffected, nil
  221. }