|
|
@@ -11,6 +11,8 @@ package mysql
|
|
|
import (
|
|
|
"context"
|
|
|
"database/sql/driver"
|
|
|
+ "errors"
|
|
|
+ "net"
|
|
|
"testing"
|
|
|
)
|
|
|
|
|
|
@@ -108,3 +110,48 @@ func TestCleanCancel(t *testing.T) {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+func TestPingMarkBadConnection(t *testing.T) {
|
|
|
+ nc := badConnection{err: errors.New("boom")}
|
|
|
+ ms := &mysqlConn{
|
|
|
+ netConn: nc,
|
|
|
+ buf: newBuffer(nc),
|
|
|
+ maxAllowedPacket: defaultMaxAllowedPacket,
|
|
|
+ }
|
|
|
+
|
|
|
+ err := ms.Ping(context.Background())
|
|
|
+
|
|
|
+ if err != driver.ErrBadConn {
|
|
|
+ t.Errorf("expected driver.ErrBadConn, got %#v", err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestPingErrInvalidConn(t *testing.T) {
|
|
|
+ nc := badConnection{err: errors.New("failed to write"), n: 10}
|
|
|
+ ms := &mysqlConn{
|
|
|
+ netConn: nc,
|
|
|
+ buf: newBuffer(nc),
|
|
|
+ maxAllowedPacket: defaultMaxAllowedPacket,
|
|
|
+ closech: make(chan struct{}),
|
|
|
+ }
|
|
|
+
|
|
|
+ err := ms.Ping(context.Background())
|
|
|
+
|
|
|
+ if err != ErrInvalidConn {
|
|
|
+ t.Errorf("expected ErrInvalidConn, got %#v", err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+type badConnection struct {
|
|
|
+ n int
|
|
|
+ err error
|
|
|
+ net.Conn
|
|
|
+}
|
|
|
+
|
|
|
+func (bc badConnection) Write(b []byte) (n int, err error) {
|
|
|
+ return bc.n, bc.err
|
|
|
+}
|
|
|
+
|
|
|
+func (bc badConnection) Close() error {
|
|
|
+ return nil
|
|
|
+}
|