Explorar o código

Return ErrBadConn for temporary Dial error (#867)

When `Dial()` returned error and it's `Timeout() == true`, return ErrBadConn to
database/sql retry new connection.
Tom Jenkinson %!s(int64=7) %!d(string=hai) anos
pai
achega
fd197cdcfa
Modificáronse 3 ficheiros con 54 adicións e 0 borrados
  1. 1 0
      AUTHORS
  2. 4 0
      driver.go
  3. 49 0
      driver_test.go

+ 1 - 0
AUTHORS

@@ -74,6 +74,7 @@ Soroush Pour <me at soroushjp.com>
 Stan Putrya <root.vagner at gmail.com>
 Stanley Gunawan <gunawan.stanley at gmail.com>
 Thomas Wodarek <wodarekwebpage at gmail.com>
+Tom Jenkinson <tom at tjenkinson.me>
 Xiangyu Hu <xiangyu.hu at outlook.com>
 Xiaobing Jiang <s7v7nislands at gmail.com>
 Xiuming Chen <cc at cxm.cc>

+ 4 - 0
driver.go

@@ -77,6 +77,10 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
 		mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr)
 	}
 	if err != nil {
+		if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
+			errLog.Print("net.Error from Dial()': ", nerr.Error())
+			return nil, driver.ErrBadConn
+		}
 		return nil, err
 	}
 

+ 49 - 0
driver_test.go

@@ -85,6 +85,23 @@ type DBTest struct {
 	db *sql.DB
 }
 
+type netErrorMock struct {
+	temporary bool
+	timeout   bool
+}
+
+func (e netErrorMock) Temporary() bool {
+	return e.temporary
+}
+
+func (e netErrorMock) Timeout() bool {
+	return e.timeout
+}
+
+func (e netErrorMock) Error() string {
+	return fmt.Sprintf("mock net error. Temporary: %v, Timeout %v", e.temporary, e.timeout)
+}
+
 func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
 	if !available {
 		t.Skipf("MySQL server not running on %s", netAddr)
@@ -1801,6 +1818,38 @@ func TestConcurrent(t *testing.T) {
 	})
 }
 
+func testDialError(t *testing.T, dialErr error, expectErr error) {
+	RegisterDial("mydial", func(addr string) (net.Conn, error) {
+		return nil, dialErr
+	})
+
+	db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
+	if err != nil {
+		t.Fatalf("error connecting: %s", err.Error())
+	}
+	defer db.Close()
+
+	_, err = db.Exec("DO 1")
+	if err != expectErr {
+		t.Fatalf("was expecting %s. Got: %s", dialErr, err)
+	}
+}
+
+func TestDialUnknownError(t *testing.T) {
+	testErr := fmt.Errorf("test")
+	testDialError(t, testErr, testErr)
+}
+
+func TestDialNonRetryableNetErr(t *testing.T) {
+	testErr := netErrorMock{}
+	testDialError(t, testErr, testErr)
+}
+
+func TestDialTemporaryNetErr(t *testing.T) {
+	testErr := netErrorMock{temporary: true}
+	testDialError(t, testErr, driver.ErrBadConn)
+}
+
 // Tests custom dial functions
 func TestCustomDial(t *testing.T) {
 	if !available {