|
- package sqlx
- import (
- "database/sql"
- "errors"
- "testing"
- "time"
- "github.com/stretchr/testify/assert"
- )
- var errMockedPlaceholder = errors.New("placeholder")
- func TestStmt_exec(t *testing.T) {
- tests := []struct {
- name string
- args []interface{}
- delay bool
- formatError bool
- hasError bool
- err error
- lastInsertId int64
- rowsAffected int64
- }{
- {
- name: "normal",
- args: []interface{}{1},
- lastInsertId: 1,
- rowsAffected: 2,
- },
- {
- name: "wrong format",
- args: []interface{}{1, 2},
- formatError: true,
- hasError: true,
- },
- {
- name: "exec error",
- args: []interface{}{1},
- hasError: true,
- err: errors.New("exec"),
- },
- {
- name: "slowcall",
- args: []interface{}{1},
- delay: true,
- lastInsertId: 1,
- rowsAffected: 2,
- },
- }
- for _, test := range tests {
- test := test
- fns := []func(args ...interface{}) (sql.Result, error){
- func(args ...interface{}) (sql.Result, error) {
- return exec(&mockedSessionConn{
- lastInsertId: test.lastInsertId,
- rowsAffected: test.rowsAffected,
- err: test.err,
- delay: test.delay,
- }, "select user from users where id=?", args...)
- },
- func(args ...interface{}) (sql.Result, error) {
- return execStmt(&mockedStmtConn{
- lastInsertId: test.lastInsertId,
- rowsAffected: test.rowsAffected,
- err: test.err,
- delay: test.delay,
- }, args...)
- },
- }
- for i, fn := range fns {
- i := i
- fn := fn
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
- res, err := fn(test.args...)
- if i == 0 && test.formatError {
- assert.NotNil(t, err)
- return
- }
- if !test.formatError && test.hasError {
- assert.NotNil(t, err)
- return
- }
- assert.Nil(t, err)
- lastInsertId, err := res.LastInsertId()
- assert.Nil(t, err)
- assert.Equal(t, test.lastInsertId, lastInsertId)
- rowsAffected, err := res.RowsAffected()
- assert.Nil(t, err)
- assert.Equal(t, test.rowsAffected, rowsAffected)
- })
- }
- }
- }
- func TestStmt_query(t *testing.T) {
- tests := []struct {
- name string
- args []interface{}
- delay bool
- formatError bool
- hasError bool
- err error
- }{
- {
- name: "normal",
- args: []interface{}{1},
- },
- {
- name: "wrong format",
- args: []interface{}{1, 2},
- formatError: true,
- hasError: true,
- },
- {
- name: "query error",
- args: []interface{}{1},
- hasError: true,
- err: errors.New("exec"),
- },
- {
- name: "slowcall",
- args: []interface{}{1},
- delay: true,
- },
- }
- for _, test := range tests {
- test := test
- fns := []func(args ...interface{}) error{
- func(args ...interface{}) error {
- return query(&mockedSessionConn{
- err: test.err,
- delay: test.delay,
- }, func(rows *sql.Rows) error {
- return nil
- }, "select user from users where id=?", args...)
- },
- func(args ...interface{}) error {
- return queryStmt(&mockedStmtConn{
- err: test.err,
- delay: test.delay,
- }, func(rows *sql.Rows) error {
- return nil
- }, args...)
- },
- }
- for i, fn := range fns {
- i := i
- fn := fn
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
- err := fn(test.args...)
- if i == 0 && test.formatError {
- assert.NotNil(t, err)
- return
- }
- if !test.formatError && test.hasError {
- assert.NotNil(t, err)
- return
- }
- assert.Equal(t, errMockedPlaceholder, err)
- })
- }
- }
- }
- type mockedSessionConn struct {
- lastInsertId int64
- rowsAffected int64
- err error
- delay bool
- }
- func (m *mockedSessionConn) Exec(query string, args ...interface{}) (sql.Result, error) {
- if m.delay {
- time.Sleep(slowThreshold + time.Millisecond)
- }
- return mockedResult{
- lastInsertId: m.lastInsertId,
- rowsAffected: m.rowsAffected,
- }, m.err
- }
- func (m *mockedSessionConn) Query(query string, args ...interface{}) (*sql.Rows, error) {
- if m.delay {
- time.Sleep(slowThreshold + time.Millisecond)
- }
- err := errMockedPlaceholder
- if m.err != nil {
- err = m.err
- }
- return new(sql.Rows), err
- }
- type mockedStmtConn struct {
- lastInsertId int64
- rowsAffected int64
- err error
- delay bool
- }
- func (m *mockedStmtConn) Exec(args ...interface{}) (sql.Result, error) {
- if m.delay {
- time.Sleep(slowThreshold + time.Millisecond)
- }
- return mockedResult{
- lastInsertId: m.lastInsertId,
- rowsAffected: m.rowsAffected,
- }, m.err
- }
- func (m *mockedStmtConn) Query(args ...interface{}) (*sql.Rows, error) {
- if m.delay {
- time.Sleep(slowThreshold + time.Millisecond)
- }
- err := errMockedPlaceholder
- if m.err != nil {
- err = m.err
- }
- return new(sql.Rows), err
- }
- type mockedResult struct {
- lastInsertId int64
- rowsAffected int64
- }
- func (m mockedResult) LastInsertId() (int64, error) {
- return m.lastInsertId, nil
- }
- func (m mockedResult) RowsAffected() (int64, error) {
- return m.rowsAffected, nil
- }
|