Bladeren bron

Call markBadConn in Ping method (#875)

* Call markBadConn in Ping method

* Add myself to AUTHORS
Ilia Cimpoes 7 jaren geleden
bovenliggende
commit
0f257fc7d3
3 gewijzigde bestanden met toevoegingen van 49 en 1 verwijderingen
  1. 1 0
      AUTHORS
  2. 1 1
      connection.go
  3. 47 0
      connection_test.go

+ 1 - 0
AUTHORS

@@ -35,6 +35,7 @@ Hanno Braun <mail at hannobraun.com>
 Henri Yandell <flamefew at gmail.com>
 Hirotaka Yamamoto <ymmt2005 at gmail.com>
 ICHINOSE Shogo <shogo82148 at gmail.com>
+Ilia Cimpoes <ichimpoesh at gmail.com>
 INADA Naoki <songofacandy at gmail.com>
 Jacek Szwec <szwec.jacek at gmail.com>
 James Harr <james.harr at gmail.com>

+ 1 - 1
connection.go

@@ -475,7 +475,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
 	defer mc.finish()
 
 	if err = mc.writeCommandPacket(comPing); err != nil {
-		return
+		return mc.markBadConn(err)
 	}
 
 	return mc.readResultOK()

+ 47 - 0
connection_test.go

@@ -11,6 +11,8 @@ package mysql
 import (
 	"context"
 	"database/sql/driver"
+	"errors"
+	"net"
 	"testing"
 )
 
@@ -108,3 +110,48 @@ func TestCleanCancel(t *testing.T) {
 		}
 	}
 }
+
+func TestPingMarkBadConnection(t *testing.T) {
+	nc := badConnection{err: errors.New("boom")}
+	ms := &mysqlConn{
+		netConn:          nc,
+		buf:              newBuffer(nc),
+		maxAllowedPacket: defaultMaxAllowedPacket,
+	}
+
+	err := ms.Ping(context.Background())
+
+	if err != driver.ErrBadConn {
+		t.Errorf("expected driver.ErrBadConn, got  %#v", err)
+	}
+}
+
+func TestPingErrInvalidConn(t *testing.T) {
+	nc := badConnection{err: errors.New("failed to write"), n: 10}
+	ms := &mysqlConn{
+		netConn:          nc,
+		buf:              newBuffer(nc),
+		maxAllowedPacket: defaultMaxAllowedPacket,
+		closech:          make(chan struct{}),
+	}
+
+	err := ms.Ping(context.Background())
+
+	if err != ErrInvalidConn {
+		t.Errorf("expected ErrInvalidConn, got  %#v", err)
+	}
+}
+
+type badConnection struct {
+	n   int
+	err error
+	net.Conn
+}
+
+func (bc badConnection) Write(b []byte) (n int, err error) {
+	return bc.n, bc.err
+}
+
+func (bc badConnection) Close() error {
+	return nil
+}