Ver código fonte

Return ErrBadConn for invalid connections

This helps with #142
The database/sql package only retries if it encounters an ErrBadConn.
With errInvalidConn it returns immediately
Julien Schmidt 12 anos atrás
pai
commit
4d3764bbcb
3 arquivos alterados com 19 adições e 11 exclusões
  1. 8 4
      connection.go
  2. 5 4
      driver_test.go
  3. 6 3
      statement.go

+ 8 - 4
connection.go

@@ -100,7 +100,8 @@ func (mc *mysqlConn) handleParams() (err error) {
 
 func (mc *mysqlConn) Begin() (driver.Tx, error) {
 	if mc.netConn == nil {
-		return nil, errInvalidConn
+		errLog.Print(errInvalidConn)
+		return nil, driver.ErrBadConn
 	}
 	err := mc.exec("START TRANSACTION")
 	if err == nil {
@@ -126,7 +127,8 @@ func (mc *mysqlConn) Close() (err error) {
 
 func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 	if mc.netConn == nil {
-		return nil, errInvalidConn
+		errLog.Print(errInvalidConn)
+		return nil, driver.ErrBadConn
 	}
 	// Send command
 	err := mc.writeCommandPacketStr(comStmtPrepare, query)
@@ -157,7 +159,8 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 
 func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
 	if mc.netConn == nil {
-		return nil, errInvalidConn
+		errLog.Print(errInvalidConn)
+		return nil, driver.ErrBadConn
 	}
 	if len(args) == 0 { // no args, fastpath
 		mc.affectedRows = 0
@@ -201,7 +204,8 @@ func (mc *mysqlConn) exec(query string) error {
 
 func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
 	if mc.netConn == nil {
-		return nil, errInvalidConn
+		errLog.Print(errInvalidConn)
+		return nil, driver.ErrBadConn
 	}
 	if len(args) == 0 { // no args, fastpath
 		// Send command

+ 5 - 4
driver_test.go

@@ -11,6 +11,7 @@ package mysql
 import (
 	"crypto/tls"
 	"database/sql"
+	"database/sql/driver"
 	"fmt"
 	"io"
 	"io/ioutil"
@@ -860,8 +861,8 @@ func TestReuseClosedConnection(t *testing.T) {
 		t.Skipf("MySQL-Server not running on %s", netAddr)
 	}
 
-	driver := &MySQLDriver{}
-	conn, err := driver.Open(dsn)
+	md := &MySQLDriver{}
+	conn, err := md.Open(dsn)
 	if err != nil {
 		t.Fatalf("Error connecting: %s", err.Error())
 	}
@@ -884,9 +885,9 @@ func TestReuseClosedConnection(t *testing.T) {
 		}
 	}()
 	_, err = stmt.Exec(nil)
-	if err != nil && err != errInvalidConn {
+	if err != nil && err != driver.ErrBadConn {
 		t.Errorf("Unexpected error '%s', expected '%s'",
-			err.Error(), errInvalidConn.Error())
+			err.Error(), driver.ErrBadConn.Error())
 	}
 }
 

+ 6 - 3
statement.go

@@ -21,7 +21,8 @@ type mysqlStmt struct {
 
 func (stmt *mysqlStmt) Close() error {
 	if stmt.mc == nil || stmt.mc.netConn == nil {
-		return errInvalidConn
+		errLog.Print(errInvalidConn)
+		return driver.ErrBadConn
 	}
 
 	err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
@@ -35,7 +36,8 @@ func (stmt *mysqlStmt) NumInput() int {
 
 func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
 	if stmt.mc.netConn == nil {
-		return nil, errInvalidConn
+		errLog.Print(errInvalidConn)
+		return nil, driver.ErrBadConn
 	}
 	// Send command
 	err := stmt.writeExecutePacket(args)
@@ -74,7 +76,8 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
 
 func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
 	if stmt.mc.netConn == nil {
-		return nil, errInvalidConn
+		errLog.Print(errInvalidConn)
+		return nil, driver.ErrBadConn
 	}
 	// Send command
 	err := stmt.writeExecutePacket(args)