Browse Source

fewer driver.ErrBadConn to prevent repeated queries (#302)

According to the database/sql/driver documentation, ErrBadConn should only
be used when the database was not affected. The driver restarts the same
query on a different connection, then.
The mysql driver did not follow this advice, so queries were repeated if
ErrBadConn is returned but a query succeeded.

This is fixed by changing most ErrBadConn errors to ErrInvalidConn.

The only valid returns of ErrBadConn are at the beginning of a database
interaction when no data was sent to the database yet.

Those valid cases are located the following funcs before attempting to write
to the network or if 0 bytes were written:

* Begin
* BeginTx
* Exec
* ExecContext
* Prepare
* PrepareContext
* Query
* QueryContext

Commit and Rollback could arguably also be on that list, but are left out as
some engines like MyISAM are not supporting transactions.

Tests in b/packets_test.go were changed because they simulate a read not
preceded by a write to the db. This cannot happen as the client has to send
the query first.
Arne Hormann 8 years ago
parent
commit
26471af196
5 changed files with 46 additions and 28 deletions
  1. 16 7
      connection.go
  2. 6 0
      errors.go
  3. 16 12
      packets.go
  4. 6 7
      packets_test.go
  5. 2 2
      statement.go

+ 16 - 7
connection.go

@@ -81,6 +81,16 @@ func (mc *mysqlConn) handleParams() (err error) {
 	return
 }
 
+func (mc *mysqlConn) markBadConn(err error) error {
+	if mc == nil {
+		return err
+	}
+	if err != errBadConnNoWrite {
+		return err
+	}
+	return driver.ErrBadConn
+}
+
 func (mc *mysqlConn) Begin() (driver.Tx, error) {
 	if mc.closed.IsSet() {
 		errLog.Print(ErrInvalidConn)
@@ -90,8 +100,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
 	if err == nil {
 		return &mysqlTx{mc}, err
 	}
-
-	return nil, err
+	return nil, mc.markBadConn(err)
 }
 
 func (mc *mysqlConn) Close() (err error) {
@@ -142,7 +151,7 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 	// Send command
 	err := mc.writeCommandPacketStr(comStmtPrepare, query)
 	if err != nil {
-		return nil, err
+		return nil, mc.markBadConn(err)
 	}
 
 	stmt := &mysqlStmt{
@@ -176,7 +185,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
 	if buf == nil {
 		// can not take the buffer. Something must be wrong with the connection
 		errLog.Print(ErrBusyBuffer)
-		return "", driver.ErrBadConn
+		return "", ErrInvalidConn
 	}
 	buf = buf[:0]
 	argPos := 0
@@ -314,14 +323,14 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
 			insertId:     int64(mc.insertId),
 		}, err
 	}
-	return nil, err
+	return nil, mc.markBadConn(err)
 }
 
 // Internal function to execute commands
 func (mc *mysqlConn) exec(query string) error {
 	// Send command
 	if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
-		return err
+		return mc.markBadConn(err)
 	}
 
 	// Read Result
@@ -390,7 +399,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
 			return rows, err
 		}
 	}
-	return nil, err
+	return nil, mc.markBadConn(err)
 }
 
 // Gets the value of the given MySQL System Variable

+ 6 - 0
errors.go

@@ -31,6 +31,12 @@ var (
 	ErrPktSyncMul        = errors.New("commands out of sync. Did you run multiple statements at once?")
 	ErrPktTooLarge       = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server")
 	ErrBusyBuffer        = errors.New("busy buffer")
+
+	// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet.
+	// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn
+	// to trigger a resend.
+	// See https://github.com/go-sql-driver/mysql/pull/302
+	errBadConnNoWrite = errors.New("bad connection")
 )
 
 var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile))

+ 16 - 12
packets.go

@@ -35,7 +35,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
 			}
 			errLog.Print(err)
 			mc.Close()
-			return nil, driver.ErrBadConn
+			return nil, ErrInvalidConn
 		}
 
 		// packet length [24 bit]
@@ -57,7 +57,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
 			if prevData == nil {
 				errLog.Print(ErrMalformPkt)
 				mc.Close()
-				return nil, driver.ErrBadConn
+				return nil, ErrInvalidConn
 			}
 
 			return prevData, nil
@@ -71,7 +71,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
 			}
 			errLog.Print(err)
 			mc.Close()
-			return nil, driver.ErrBadConn
+			return nil, ErrInvalidConn
 		}
 
 		// return data if this was the last packet
@@ -137,10 +137,14 @@ func (mc *mysqlConn) writePacket(data []byte) error {
 			if cerr := mc.canceled.Value(); cerr != nil {
 				return cerr
 			}
+			if n == 0 && pktLen == len(data)-4 {
+				// only for the first loop iteration when nothing was written yet
+				return errBadConnNoWrite
+			}
 			mc.cleanup()
 			errLog.Print(err)
 		}
-		return driver.ErrBadConn
+		return ErrInvalidConn
 	}
 }
 
@@ -274,7 +278,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 	if data == nil {
 		// can not take the buffer. Something must be wrong with the connection
 		errLog.Print(ErrBusyBuffer)
-		return driver.ErrBadConn
+		return errBadConnNoWrite
 	}
 
 	// ClientFlags [32 bit]
@@ -362,7 +366,7 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
 	if data == nil {
 		// can not take the buffer. Something must be wrong with the connection
 		errLog.Print(ErrBusyBuffer)
-		return driver.ErrBadConn
+		return errBadConnNoWrite
 	}
 
 	// Add the scrambled password [null terminated string]
@@ -381,7 +385,7 @@ func (mc *mysqlConn) writeClearAuthPacket() error {
 	if data == nil {
 		// can not take the buffer. Something must be wrong with the connection
 		errLog.Print(ErrBusyBuffer)
-		return driver.ErrBadConn
+		return errBadConnNoWrite
 	}
 
 	// Add the clear password [null terminated string]
@@ -404,7 +408,7 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {
 	if data == nil {
 		// can not take the buffer. Something must be wrong with the connection
 		errLog.Print(ErrBusyBuffer)
-		return driver.ErrBadConn
+		return errBadConnNoWrite
 	}
 
 	// Add the scramble
@@ -425,7 +429,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
 	if data == nil {
 		// can not take the buffer. Something must be wrong with the connection
 		errLog.Print(ErrBusyBuffer)
-		return driver.ErrBadConn
+		return errBadConnNoWrite
 	}
 
 	// Add command byte
@@ -444,7 +448,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
 	if data == nil {
 		// can not take the buffer. Something must be wrong with the connection
 		errLog.Print(ErrBusyBuffer)
-		return driver.ErrBadConn
+		return errBadConnNoWrite
 	}
 
 	// Add command byte
@@ -465,7 +469,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
 	if data == nil {
 		// can not take the buffer. Something must be wrong with the connection
 		errLog.Print(ErrBusyBuffer)
-		return driver.ErrBadConn
+		return errBadConnNoWrite
 	}
 
 	// Add command byte
@@ -931,7 +935,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 	if data == nil {
 		// can not take the buffer. Something must be wrong with the connection
 		errLog.Print(ErrBusyBuffer)
-		return driver.ErrBadConn
+		return errBadConnNoWrite
 	}
 
 	// command [1 byte]

+ 6 - 7
packets_test.go

@@ -9,7 +9,6 @@
 package mysql
 
 import (
-	"database/sql/driver"
 	"errors"
 	"net"
 	"testing"
@@ -252,8 +251,8 @@ func TestReadPacketFail(t *testing.T) {
 	conn.data = []byte{0x00, 0x00, 0x00, 0x00}
 	conn.maxReads = 1
 	_, err := mc.readPacket()
-	if err != driver.ErrBadConn {
-		t.Errorf("expected ErrBadConn, got %v", err)
+	if err != ErrInvalidConn {
+		t.Errorf("expected ErrInvalidConn, got %v", err)
 	}
 
 	// reset
@@ -264,8 +263,8 @@ func TestReadPacketFail(t *testing.T) {
 	// fail to read header
 	conn.closed = true
 	_, err = mc.readPacket()
-	if err != driver.ErrBadConn {
-		t.Errorf("expected ErrBadConn, got %v", err)
+	if err != ErrInvalidConn {
+		t.Errorf("expected ErrInvalidConn, got %v", err)
 	}
 
 	// reset
@@ -277,7 +276,7 @@ func TestReadPacketFail(t *testing.T) {
 	// fail to read body
 	conn.maxReads = 1
 	_, err = mc.readPacket()
-	if err != driver.ErrBadConn {
-		t.Errorf("expected ErrBadConn, got %v", err)
+	if err != ErrInvalidConn {
+		t.Errorf("expected ErrInvalidConn, got %v", err)
 	}
 }

+ 2 - 2
statement.go

@@ -52,7 +52,7 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
 	// Send command
 	err := stmt.writeExecutePacket(args)
 	if err != nil {
-		return nil, err
+		return nil, stmt.mc.markBadConn(err)
 	}
 
 	mc := stmt.mc
@@ -100,7 +100,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
 	// Send command
 	err := stmt.writeExecutePacket(args)
 	if err != nil {
-		return nil, err
+		return nil, stmt.mc.markBadConn(err)
 	}
 
 	mc := stmt.mc