浏览代码

Merge pull request #148 from go-sql-driver/v1.0.3

Release v1.0.3
Julien Schmidt 12 年之前
父节点
当前提交
dfcba35e8d
共有 9 个文件被更改,包括 103 次插入16 次删除
  1. 1 1
      README.md
  2. 23 3
      connection.go
  3. 45 0
      driver_test.go
  4. 1 0
      errors.go
  5. 4 6
      packets.go
  6. 8 5
      rows.go
  7. 12 0
      statement.go
  8. 8 0
      transaction.go
  9. 1 1
      utils.go

+ 1 - 1
README.md

@@ -4,7 +4,7 @@ A MySQL-Driver for Go's [database/sql](http://golang.org/pkg/database/sql) packa
 
 ![Go-MySQL-Driver logo](https://raw.github.com/wiki/go-sql-driver/mysql/go-mysql-driver_m.jpg "Golang Gopher transporting the MySQL Dolphin in a wheelbarrow")
 
-**Current tagged Release:** June 03, 2013 (Version 1.0.1)
+**Current tagged Release:** November 01, 2013 (Version 1.0.3)
 
 ---------------------------------------
   * [Features](#features)

+ 23 - 3
connection.go

@@ -96,6 +96,10 @@ func (mc *mysqlConn) handleParams() (err error) {
 }
 
 func (mc *mysqlConn) Begin() (driver.Tx, error) {
+	if mc.netConn == nil {
+		return nil, errInvalidConn
+	}
+
 	err := mc.exec("START TRANSACTION")
 	if err == nil {
 		return &mysqlTx{mc}, err
@@ -105,15 +109,24 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
 }
 
 func (mc *mysqlConn) Close() (err error) {
-	mc.writeCommandPacket(comQuit)
+	// Makes Close idempotent
+	if mc.netConn != nil {
+		mc.writeCommandPacket(comQuit)
+		mc.netConn.Close()
+		mc.netConn = nil
+	}
+
 	mc.cfg = nil
 	mc.buf = nil
-	mc.netConn.Close()
-	mc.netConn = nil
+
 	return
 }
 
 func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
+	if mc.netConn == nil {
+		return nil, errInvalidConn
+	}
+
 	// Send command
 	err := mc.writeCommandPacketStr(comStmtPrepare, query)
 	if err != nil {
@@ -143,6 +156,10 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 }
 
 func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
+	if mc.netConn == nil {
+		return nil, errInvalidConn
+	}
+
 	if len(args) == 0 { // no args, fastpath
 		mc.affectedRows = 0
 		mc.insertId = 0
@@ -186,6 +203,9 @@ func (mc *mysqlConn) exec(query string) (err error) {
 }
 
 func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
+	if mc.netConn == nil {
+		return nil, errInvalidConn
+	}
 	if len(args) == 0 { // no args, fastpath
 		// Send command
 		err := mc.writeCommandPacketStr(comQuery, query)

+ 45 - 0
driver_test.go

@@ -101,6 +101,41 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows)
 	return rows
 }
 
+func TestReuseClosedConnection(t *testing.T) {
+	// this test does not use sql.database, it uses the driver directly
+	if !available {
+		t.Logf("MySQL-Server not running on %s. Skipping TestReuseClosedConnection", netAddr)
+		return
+	}
+	driver := &mysqlDriver{}
+	conn, err := driver.Open(dsn)
+	if err != nil {
+		t.Fatalf("Error connecting: %s", err.Error())
+	}
+	stmt, err := conn.Prepare("DO 1")
+	if err != nil {
+		t.Fatalf("Error preparing statement: %s", err.Error())
+	}
+	_, err = stmt.Exec(nil)
+	if err != nil {
+		t.Fatalf("Error executing statement: %s", err.Error())
+	}
+	err = conn.Close()
+	if err != nil {
+		t.Fatalf("Error closing connection: %s", err.Error())
+	}
+	defer func() {
+		if err := recover(); err != nil {
+			t.Errorf("Panic after reusing a closed connection: %v", err)
+		}
+	}()
+	_, err = stmt.Exec(nil)
+	if err != nil && err != errInvalidConn {
+		t.Errorf("Unexpected error '%s', expected '%s'",
+			err.Error(), errInvalidConn.Error())
+	}
+}
+
 func TestCharset(t *testing.T) {
 	mustSetCharset := func(charsetParam, expected string) {
 		db, err := sql.Open("mysql", strings.Replace(dsn, charset, charsetParam, 1))
@@ -578,6 +613,16 @@ func TestNULL(t *testing.T) {
 			dbt.Error("Unexpected NullString value:" + ns.String + " (should be `1`)")
 		}
 
+		// bytes
+		// Check input==output with input!=nil
+		b := []byte("")
+		if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil {
+			dbt.Fatal(err)
+		}
+		if b == nil {
+			dbt.Error("nil echo from non-nil input")
+		}
+
 		// Insert NULL
 		dbt.mustExec("CREATE TABLE test (dummmy1 int, value int, dummy2 int)")
 

+ 1 - 0
errors.go

@@ -17,6 +17,7 @@ import (
 )
 
 var (
+	errInvalidConn = errors.New("Invalid Connection")
 	errMalformPkt  = errors.New("Malformed Packet")
 	errPktSync     = errors.New("Commands out of sync. You can't run this command now")
 	errPktSyncMul  = errors.New("Commands out of sync. Did you run multiple statements at once?")

+ 4 - 6
packets.go

@@ -1014,16 +1014,15 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 				}
 			}
 
-			var sign byte
+			var sign string
 			if data[pos] == 1 {
-				sign = byte('-')
+				sign = "-"
 			}
 
 			switch num {
 			case 8:
 				dest[i] = []byte(fmt.Sprintf(
-					"%c%02d:%02d:%02d",
-					sign,
+					sign+"%02d:%02d:%02d",
 					uint16(data[pos+1])*24+uint16(data[pos+5]),
 					data[pos+6],
 					data[pos+7],
@@ -1032,8 +1031,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 				continue
 			case 12:
 				dest[i] = []byte(fmt.Sprintf(
-					"%c%02d:%02d:%02d.%06d",
-					sign,
+					sign+"%02d:%02d:%02d.%06d",
 					uint16(data[pos+1])*24+uint16(data[pos+5]),
 					data[pos+6],
 					data[pos+7],

+ 8 - 5
rows.go

@@ -11,7 +11,6 @@ package mysql
 
 import (
 	"database/sql/driver"
-	"errors"
 	"io"
 )
 
@@ -43,11 +42,15 @@ func (rows *mysqlRows) Close() (err error) {
 
 	// Remove unread packets from stream
 	if !rows.eof {
-		if rows.mc == nil {
-			return errors.New("Invalid Connection")
+		if rows.mc == nil || rows.mc.netConn == nil {
+			return errInvalidConn
 		}
 
 		err = rows.mc.readUntilEOF()
+
+		// explicitly set because readUntilEOF might return early in case of an
+		// error
+		rows.eof = true
 	}
 
 	return
@@ -58,8 +61,8 @@ func (rows *mysqlRows) Next(dest []driver.Value) error {
 		return io.EOF
 	}
 
-	if rows.mc == nil {
-		return errors.New("Invalid Connection")
+	if rows.mc == nil || rows.mc.netConn == nil {
+		return errInvalidConn
 	}
 
 	// Fetch next row from stream

+ 12 - 0
statement.go

@@ -21,6 +21,10 @@ type mysqlStmt struct {
 }
 
 func (stmt *mysqlStmt) Close() (err error) {
+	if stmt.mc == nil || stmt.mc.netConn == nil {
+		return errInvalidConn
+	}
+
 	err = stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
 	stmt.mc = nil
 	return
@@ -31,6 +35,10 @@ func (stmt *mysqlStmt) NumInput() int {
 }
 
 func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
+	if stmt.mc.netConn == nil {
+		return nil, errInvalidConn
+	}
+
 	stmt.mc.affectedRows = 0
 	stmt.mc.insertId = 0
 
@@ -66,6 +74,10 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
 }
 
 func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
+	if stmt.mc.netConn == nil {
+		return nil, errInvalidConn
+	}
+
 	// Send command
 	err := stmt.writeExecutePacket(args)
 	if err != nil {

+ 8 - 0
transaction.go

@@ -14,12 +14,20 @@ type mysqlTx struct {
 }
 
 func (tx *mysqlTx) Commit() (err error) {
+	if tx.mc == nil || tx.mc.netConn == nil {
+		return errInvalidConn
+	}
+
 	err = tx.mc.exec("COMMIT")
 	tx.mc = nil
 	return
 }
 
 func (tx *mysqlTx) Rollback() (err error) {
+	if tx.mc == nil || tx.mc.netConn == nil {
+		return errInvalidConn
+	}
+
 	err = tx.mc.exec("ROLLBACK")
 	tx.mc = nil
 	return

+ 1 - 1
utils.go

@@ -349,7 +349,7 @@ func readLengthEnodedString(b []byte) ([]byte, bool, int, error) {
 	// Get length
 	num, isNull, n := readLengthEncodedInteger(b)
 	if num < 1 {
-		return nil, isNull, n, nil
+		return b[n:n], isNull, n, nil
 	}
 
 	n += int(num)