Ver código fonte

no panic on closed connection reuse

Arne Hormann 12 anos atrás
pai
commit
cd41d9f206
6 arquivos alterados com 73 adições e 3 exclusões
  1. 15 0
      connection.go
  2. 35 0
      driver_test.go
  3. 1 0
      errors.go
  4. 2 3
      rows.go
  5. 12 0
      statement.go
  6. 8 0
      transaction.go

+ 15 - 0
connection.go

@@ -96,6 +96,10 @@ func (mc *mysqlConn) handleParams() (err error) {
 }
 
 func (mc *mysqlConn) Begin() (driver.Tx, error) {
+	if mc.netConn == nil {
+		return nil, errInvalidConn
+	}
+
 	err := mc.exec("START TRANSACTION")
 	if err == nil {
 		return &mysqlTx{mc}, err
@@ -119,6 +123,10 @@ func (mc *mysqlConn) Close() (err error) {
 }
 
 func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
+	if mc.netConn == nil {
+		return nil, errInvalidConn
+	}
+
 	// Send command
 	err := mc.writeCommandPacketStr(comStmtPrepare, query)
 	if err != nil {
@@ -148,6 +156,10 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 }
 
 func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
+	if mc.netConn == nil {
+		return nil, errInvalidConn
+	}
+
 	if len(args) == 0 { // no args, fastpath
 		mc.affectedRows = 0
 		mc.insertId = 0
@@ -191,6 +203,9 @@ func (mc *mysqlConn) exec(query string) (err error) {
 }
 
 func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
+	if mc.netConn == nil {
+		return nil, errInvalidConn
+	}
 	if len(args) == 0 { // no args, fastpath
 		// Send command
 		err := mc.writeCommandPacketStr(comQuery, query)

+ 35 - 0
driver_test.go

@@ -101,6 +101,41 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows)
 	return rows
 }
 
+func TestReuseClosedConnection(t *testing.T) {
+	// this test does not use sql.database, it uses the driver directly
+	if !available {
+		t.Logf("MySQL-Server not running on %s. Skipping TestReuseClosedConnection", netAddr)
+		return
+	}
+	driver := &mysqlDriver{}
+	conn, err := driver.Open(dsn)
+	if err != nil {
+		t.Fatalf("Error connecting: %s", err.Error())
+	}
+	stmt, err := conn.Prepare("DO 1")
+	if err != nil {
+		t.Fatalf("Error preparing statement: %s", err.Error())
+	}
+	_, err = stmt.Exec(nil)
+	if err != nil {
+		t.Fatalf("Error executing statement: %s", err.Error())
+	}
+	err = conn.Close()
+	if err != nil {
+		t.Fatalf("Error closing connection: %s", err.Error())
+	}
+	defer func() {
+		if err := recover(); err != nil {
+			t.Errorf("Panic after reusing a closed connection: %v", err)
+		}
+	}()
+	_, err = stmt.Exec(nil)
+	if err != nil && err != errInvalidConn {
+		t.Errorf("Unexpected error '%s', expected '%s'",
+			err.Error(), errInvalidConn.Error())
+	}
+}
+
 func TestCharset(t *testing.T) {
 	mustSetCharset := func(charsetParam, expected string) {
 		db, err := sql.Open("mysql", strings.Replace(dsn, charset, charsetParam, 1))

+ 1 - 0
errors.go

@@ -17,6 +17,7 @@ import (
 )
 
 var (
+	errInvalidConn = errors.New("Invalid Connection")
 	errMalformPkt  = errors.New("Malformed Packet")
 	errPktSync     = errors.New("Commands out of sync. You can't run this command now")
 	errPktSyncMul  = errors.New("Commands out of sync. Did you run multiple statements at once?")

+ 2 - 3
rows.go

@@ -11,7 +11,6 @@ package mysql
 
 import (
 	"database/sql/driver"
-	"errors"
 	"io"
 )
 
@@ -44,7 +43,7 @@ func (rows *mysqlRows) Close() (err error) {
 	// Remove unread packets from stream
 	if !rows.eof {
 		if rows.mc == nil || rows.mc.netConn == nil {
-			return errors.New("Invalid Connection")
+			return errInvalidConn
 		}
 
 		err = rows.mc.readUntilEOF()
@@ -63,7 +62,7 @@ func (rows *mysqlRows) Next(dest []driver.Value) error {
 	}
 
 	if rows.mc == nil || rows.mc.netConn == nil {
-		return errors.New("Invalid Connection")
+		return errInvalidConn
 	}
 
 	// Fetch next row from stream

+ 12 - 0
statement.go

@@ -21,6 +21,10 @@ type mysqlStmt struct {
 }
 
 func (stmt *mysqlStmt) Close() (err error) {
+	if stmt.mc == nil || stmt.mc.netConn == nil {
+		return errInvalidConn
+	}
+
 	err = stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
 	stmt.mc = nil
 	return
@@ -31,6 +35,10 @@ func (stmt *mysqlStmt) NumInput() int {
 }
 
 func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
+	if stmt.mc.netConn == nil {
+		return nil, errInvalidConn
+	}
+
 	stmt.mc.affectedRows = 0
 	stmt.mc.insertId = 0
 
@@ -66,6 +74,10 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
 }
 
 func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
+	if stmt.mc.netConn == nil {
+		return nil, errInvalidConn
+	}
+
 	// Send command
 	err := stmt.writeExecutePacket(args)
 	if err != nil {

+ 8 - 0
transaction.go

@@ -14,12 +14,20 @@ type mysqlTx struct {
 }
 
 func (tx *mysqlTx) Commit() (err error) {
+	if tx.mc == nil || tx.mc.netConn == nil {
+		return errInvalidConn
+	}
+
 	err = tx.mc.exec("COMMIT")
 	tx.mc = nil
 	return
 }
 
 func (tx *mysqlTx) Rollback() (err error) {
+	if tx.mc == nil || tx.mc.netConn == nil {
+		return errInvalidConn
+	}
+
 	err = tx.mc.exec("ROLLBACK")
 	tx.mc = nil
 	return