Prechádzať zdrojové kódy

Merge pull request #143 from arnehormann/no-panic

no panic on closed connection reuse
Arne Hormann 12 rokov pred
rodič
commit
2a93d21064
3 zmenil súbory, kde vykonal 52 pridanie a 0 odobranie
  1. 12 0
      connection.go
  2. 34 0
      driver_test.go
  3. 6 0
      statement.go

+ 12 - 0
connection.go

@@ -99,6 +99,9 @@ func (mc *mysqlConn) handleParams() (err error) {
 }
 
 func (mc *mysqlConn) Begin() (driver.Tx, error) {
+	if mc.netConn == nil {
+		return nil, errInvalidConn
+	}
 	err := mc.exec("START TRANSACTION")
 	if err == nil {
 		return &mysqlTx{mc}, err
@@ -122,6 +125,9 @@ func (mc *mysqlConn) Close() (err error) {
 }
 
 func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
+	if mc.netConn == nil {
+		return nil, errInvalidConn
+	}
 	// Send command
 	err := mc.writeCommandPacketStr(comStmtPrepare, query)
 	if err != nil {
@@ -150,6 +156,9 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 }
 
 func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
+	if mc.netConn == nil {
+		return nil, errInvalidConn
+	}
 	if len(args) == 0 { // no args, fastpath
 		mc.affectedRows = 0
 		mc.insertId = 0
@@ -191,6 +200,9 @@ func (mc *mysqlConn) exec(query string) error {
 }
 
 func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
+	if mc.netConn == nil {
+		return nil, errInvalidConn
+	}
 	if len(args) == 0 { // no args, fastpath
 		// Send command
 		err := mc.writeCommandPacketStr(comQuery, query)

+ 34 - 0
driver_test.go

@@ -108,6 +108,40 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows)
 	return rows
 }
 
+func TestReuseClosedConnection(t *testing.T) {
+	// this test does not use sql.database, it uses the driver directly
+	if !available {
+		t.Skipf("MySQL-Server not running on %s", netAddr)
+	}
+	driver := &MySQLDriver{}
+	conn, err := driver.Open(dsn)
+	if err != nil {
+		t.Fatalf("Error connecting: %s", err.Error())
+	}
+	stmt, err := conn.Prepare("DO 1")
+	if err != nil {
+		t.Fatalf("Error preparing statement: %s", err.Error())
+	}
+	_, err = stmt.Exec(nil)
+	if err != nil {
+		t.Fatalf("Error executing statement: %s", err.Error())
+	}
+	err = conn.Close()
+	if err != nil {
+		t.Fatalf("Error closing connection: %s", err.Error())
+	}
+	defer func() {
+		if err := recover(); err != nil {
+			t.Errorf("Panic after reusing a closed connection: %v", err)
+		}
+	}()
+	_, err = stmt.Exec(nil)
+	if err != nil && err != errInvalidConn {
+		t.Errorf("Unexpected error '%s', expected '%s'",
+			err.Error(), errInvalidConn.Error())
+	}
+}
+
 func TestCharset(t *testing.T) {
 	if !available {
 		t.Skipf("MySQL-Server not running on %s", netAddr)

+ 6 - 0
statement.go

@@ -34,6 +34,9 @@ func (stmt *mysqlStmt) NumInput() int {
 }
 
 func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
+	if stmt.mc.netConn == nil {
+		return nil, errInvalidConn
+	}
 	// Send command
 	err := stmt.writeExecutePacket(args)
 	if err != nil {
@@ -70,6 +73,9 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
 }
 
 func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
+	if stmt.mc.netConn == nil {
+		return nil, errInvalidConn
+	}
 	// Send command
 	err := stmt.writeExecutePacket(args)
 	if err != nil {