Bläddra i källkod

no panic on closed connection reuse

Arne Hormann 12 år sedan
förälder
incheckning
32e5ceed8e
5 ändrade filer med 59 tillägg och 4 borttagningar
  1. 15 0
      connection.go
  2. 34 0
      driver_test.go
  3. 2 2
      rows.go
  4. 6 0
      statement.go
  5. 2 2
      transaction.go

+ 15 - 0
connection.go

@@ -99,6 +99,9 @@ func (mc *mysqlConn) handleParams() (err error) {
 }
 
 func (mc *mysqlConn) Begin() (driver.Tx, error) {
+	if mc.buf == nil {
+		return nil, errInvalidConn
+	}
 	err := mc.exec("START TRANSACTION")
 	if err == nil {
 		return &mysqlTx{mc}, err
@@ -108,6 +111,9 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
 }
 
 func (mc *mysqlConn) Close() (err error) {
+	if mc.buf == nil {
+		return errInvalidConn
+	}
 	// Makes Close idempotent
 	if mc.netConn != nil {
 		mc.writeCommandPacket(comQuit)
@@ -122,6 +128,9 @@ func (mc *mysqlConn) Close() (err error) {
 }
 
 func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
+	if mc.buf == nil {
+		return nil, errInvalidConn
+	}
 	// Send command
 	err := mc.writeCommandPacketStr(comStmtPrepare, query)
 	if err != nil {
@@ -150,6 +159,9 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 }
 
 func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
+	if mc.buf == nil {
+		return nil, errInvalidConn
+	}
 	if len(args) == 0 { // no args, fastpath
 		mc.affectedRows = 0
 		mc.insertId = 0
@@ -191,6 +203,9 @@ func (mc *mysqlConn) exec(query string) error {
 }
 
 func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
+	if mc.buf == nil {
+		return nil, errInvalidConn
+	}
 	if len(args) == 0 { // no args, fastpath
 		// Send command
 		err := mc.writeCommandPacketStr(comQuery, query)

+ 34 - 0
driver_test.go

@@ -108,6 +108,40 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows)
 	return rows
 }
 
+func TestClosedConnection(t *testing.T) {
+	// this test does not use sql.database, it uses the driver directly
+	if !available {
+		t.Skipf("MySQL-Server not running on %s", netAddr)
+	}
+	driver := &MySQLDriver{}
+	conn, err := driver.Open(dsn)
+	if err != nil {
+		t.Fatalf("Error connecting: %s", err.Error())
+	}
+	stmt, err := conn.Prepare("SET @tmpif := 1")
+	if err != nil {
+		t.Fatalf("Error preparing statement: %s", err.Error())
+	}
+	_, err = stmt.Exec(nil)
+	if err != nil {
+		t.Fatalf("Error executing statement: %s", err.Error())
+	}
+	err = conn.Close()
+	if err != nil {
+		t.Fatalf("Error closing connection: %s", err.Error())
+	}
+	defer func() {
+		if err := recover(); err != nil {
+			t.Errorf("Panic after reusing a closed connection: %v", err)
+		}
+	}()
+	_, err = stmt.Exec(nil)
+	if err != errInvalidConn {
+		t.Errorf("Unexpected error '%s', expected '%s'",
+			err.Error(), errInvalidConn.Error())
+	}
+}
+
 func TestCharset(t *testing.T) {
 	if !available {
 		t.Skipf("MySQL-Server not running on %s", netAddr)

+ 2 - 2
rows.go

@@ -37,7 +37,7 @@ func (rows *mysqlRows) Columns() []string {
 func (rows *mysqlRows) Close() (err error) {
 	// Remove unread packets from stream
 	if !rows.eof {
-		if rows.mc == nil || rows.mc.netConn == nil {
+		if rows.mc == nil || rows.mc.buf == nil {
 			return errInvalidConn
 		}
 
@@ -58,7 +58,7 @@ func (rows *mysqlRows) Next(dest []driver.Value) (err error) {
 		return io.EOF
 	}
 
-	if rows.mc == nil || rows.mc.netConn == nil {
+	if rows.mc == nil || rows.mc.buf == nil {
 		return errInvalidConn
 	}
 

+ 6 - 0
statement.go

@@ -34,6 +34,9 @@ func (stmt *mysqlStmt) NumInput() int {
 }
 
 func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
+	if stmt.mc.buf == nil {
+		return nil, errInvalidConn
+	}
 	// Send command
 	err := stmt.writeExecutePacket(args)
 	if err != nil {
@@ -70,6 +73,9 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
 }
 
 func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
+	if stmt.mc.buf == nil {
+		return nil, errInvalidConn
+	}
 	// Send command
 	err := stmt.writeExecutePacket(args)
 	if err != nil {

+ 2 - 2
transaction.go

@@ -13,7 +13,7 @@ type mysqlTx struct {
 }
 
 func (tx *mysqlTx) Commit() (err error) {
-	if tx.mc == nil || tx.mc.netConn == nil {
+	if tx.mc == nil || tx.mc.buf == nil {
 		return errInvalidConn
 	}
 	err = tx.mc.exec("COMMIT")
@@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) {
 }
 
 func (tx *mysqlTx) Rollback() (err error) {
-	if tx.mc == nil || tx.mc.netConn == nil {
+	if tx.mc == nil || tx.mc.buf == nil {
 		return errInvalidConn
 	}
 	err = tx.mc.exec("ROLLBACK")