|
|
@@ -85,6 +85,23 @@ type DBTest struct {
|
|
|
db *sql.DB
|
|
|
}
|
|
|
|
|
|
+type netErrorMock struct {
|
|
|
+ temporary bool
|
|
|
+ timeout bool
|
|
|
+}
|
|
|
+
|
|
|
+func (e netErrorMock) Temporary() bool {
|
|
|
+ return e.temporary
|
|
|
+}
|
|
|
+
|
|
|
+func (e netErrorMock) Timeout() bool {
|
|
|
+ return e.timeout
|
|
|
+}
|
|
|
+
|
|
|
+func (e netErrorMock) Error() string {
|
|
|
+ return fmt.Sprintf("mock net error. Temporary: %v, Timeout %v", e.temporary, e.timeout)
|
|
|
+}
|
|
|
+
|
|
|
func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
|
|
|
if !available {
|
|
|
t.Skipf("MySQL server not running on %s", netAddr)
|
|
|
@@ -1801,6 +1818,38 @@ func TestConcurrent(t *testing.T) {
|
|
|
})
|
|
|
}
|
|
|
|
|
|
+func testDialError(t *testing.T, dialErr error, expectErr error) {
|
|
|
+ RegisterDial("mydial", func(addr string) (net.Conn, error) {
|
|
|
+ return nil, dialErr
|
|
|
+ })
|
|
|
+
|
|
|
+ db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("error connecting: %s", err.Error())
|
|
|
+ }
|
|
|
+ defer db.Close()
|
|
|
+
|
|
|
+ _, err = db.Exec("DO 1")
|
|
|
+ if err != expectErr {
|
|
|
+ t.Fatalf("was expecting %s. Got: %s", dialErr, err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestDialUnknownError(t *testing.T) {
|
|
|
+ testErr := fmt.Errorf("test")
|
|
|
+ testDialError(t, testErr, testErr)
|
|
|
+}
|
|
|
+
|
|
|
+func TestDialNonRetryableNetErr(t *testing.T) {
|
|
|
+ testErr := netErrorMock{}
|
|
|
+ testDialError(t, testErr, testErr)
|
|
|
+}
|
|
|
+
|
|
|
+func TestDialTemporaryNetErr(t *testing.T) {
|
|
|
+ testErr := netErrorMock{temporary: true}
|
|
|
+ testDialError(t, testErr, driver.ErrBadConn)
|
|
|
+}
|
|
|
+
|
|
|
// Tests custom dial functions
|
|
|
func TestCustomDial(t *testing.T) {
|
|
|
if !available {
|