Browse Source

Merge pull request #122 from go-sql-driver/badconn_close

Explicitly close connection on ErrBadConn
Julien Schmidt 12 years ago
parent
commit
04572b49b1
5 changed files with 30 additions and 7 deletions
  1. 8 3
      connection.go
  2. 6 2
      packets.go
  3. 6 2
      rows.go
  4. 4 0
      statement.go
  5. 6 0
      transaction.go

+ 8 - 3
connection.go

@@ -108,11 +108,16 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
 }
 
 func (mc *mysqlConn) Close() (err error) {
-	mc.writeCommandPacket(comQuit)
+	// Makes Close idempotent
+	if mc.netConn != nil {
+		mc.writeCommandPacket(comQuit)
+		mc.netConn.Close()
+		mc.netConn = nil
+	}
+
 	mc.cfg = nil
 	mc.buf = nil
-	mc.netConn.Close()
-	mc.netConn = nil
+
 	return
 }
 

+ 6 - 2
packets.go

@@ -28,6 +28,7 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
 	data, err = mc.buf.readNext(4)
 	if err != nil {
 		errLog.Print(err.Error())
+		mc.Close()
 		return nil, driver.ErrBadConn
 	}
 
@@ -36,6 +37,7 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
 
 	if pktLen < 1 {
 		errLog.Print(errMalformPkt.Error())
+		mc.Close()
 		return nil, driver.ErrBadConn
 	}
 
@@ -50,8 +52,7 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
 	mc.sequence++
 
 	// Read packet body [pktLen bytes]
-	data, err = mc.buf.readNext(pktLen)
-	if err == nil {
+	if data, err = mc.buf.readNext(pktLen); err == nil {
 		if pktLen < maxPacketSize {
 			return data, nil
 		}
@@ -65,6 +66,9 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
 			return append(buf, data...), nil
 		}
 	}
+
+	// err case
+	mc.Close()
 	errLog.Print(err.Error())
 	return nil, driver.ErrBadConn
 }

+ 6 - 2
rows.go

@@ -37,11 +37,15 @@ func (rows *mysqlRows) Columns() (columns []string) {
 func (rows *mysqlRows) Close() (err error) {
 	// Remove unread packets from stream
 	if !rows.eof {
-		if rows.mc == nil {
+		if rows.mc == nil || rows.mc.netConn == nil {
 			return errInvalidConn
 		}
 
 		err = rows.mc.readUntilEOF()
+
+		// explicitly set because readUntilEOF might return early in case of an
+		// error
+		rows.eof = true
 	}
 
 	rows.mc = nil
@@ -54,7 +58,7 @@ func (rows *mysqlRows) Next(dest []driver.Value) (err error) {
 		return io.EOF
 	}
 
-	if rows.mc == nil {
+	if rows.mc == nil || rows.mc.netConn == nil {
 		return errInvalidConn
 	}
 

+ 4 - 0
statement.go

@@ -20,6 +20,10 @@ type mysqlStmt struct {
 }
 
 func (stmt *mysqlStmt) Close() (err error) {
+	if stmt.mc == nil || stmt.mc.netConn == nil {
+		return errInvalidConn
+	}
+
 	err = stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
 	stmt.mc = nil
 	return

+ 6 - 0
transaction.go

@@ -13,12 +13,18 @@ type mysqlTx struct {
 }
 
 func (tx *mysqlTx) Commit() (err error) {
+	if tx.mc == nil || tx.mc.netConn == nil {
+		return errInvalidConn
+	}
 	err = tx.mc.exec("COMMIT")
 	tx.mc = nil
 	return
 }
 
 func (tx *mysqlTx) Rollback() (err error) {
+	if tx.mc == nil || tx.mc.netConn == nil {
+		return errInvalidConn
+	}
 	err = tx.mc.exec("ROLLBACK")
 	tx.mc = nil
 	return