Pārlūkot izejas kodu

Add support for context.Context (#608)

* Add supports to context.Context

* add authors related context.Context support

* fix comment of mysqlContext.

- s/from/for/
- and start the comment with a space please

* closed is now just bool flag.

* drop read-only transactions support

* remove unused methods from mysqlContext interface.

* moved checking canceled logic into method of connection.

* add a section about context.Context to the README.

* use atomic variable for closed.

* short circuit for context.Background()

* fix illegal watching state

* set rows.finish directly.

* move namedValueToValue to utils_go18.go

* move static interface implementation checks to the top of the file

* add the new section about `context.Context` to the table of contents, and fix the section.

* mark unsupported features with TODO comments

* rename watcher to starter.

* use mc.error() instead of duplicated logics.
Ichinose Shogo 8 gadi atpakaļ
vecāks
revīzija
56226343bd
14 mainītis faili ar 738 papildinājumiem un 28 dzēšanām
  1. 4 0
      AUTHORS
  2. 4 0
      README.md
  3. 93 0
      benchmark_go18_test.go
  4. 84 11
      connection.go
  5. 207 0
      connection_go18.go
  6. 11 0
      driver.go
  7. 272 0
      driver_go18_test.go
  8. 8 0
      driver_test.go
  9. 11 0
      packets.go
  10. 2 1
      packets_test.go
  11. 16 10
      rows.go
  12. 7 3
      statement.go
  13. 2 2
      transaction.go
  14. 17 1
      utils_go18.go

+ 4 - 0
AUTHORS

@@ -14,6 +14,7 @@
 Aaron Hopkins <go-sql-driver at die.net>
 Arne Hormann <arnehormann at gmail.com>
 Asta Xie <xiemengjun at gmail.com>
+Bulat Gaifullin <gaifullinbf at gmail.com>
 Carlos Nieto <jose.carlos at menteslibres.net>
 Chris Moos <chris at tech9computers.com>
 Daniel Nichter <nil at codenode.com>
@@ -21,11 +22,13 @@ Daniël van Eeden <git at myname.nl>
 Dave Protasowski <dprotaso at gmail.com>
 DisposaBoy <disposaboy at dby.me>
 Egor Smolyakov <egorsmkv at gmail.com>
+Evan Shaw <evan at vendhq.com>
 Frederick Mayle <frederickmayle at gmail.com>
 Gustavo Kristic <gkristic at gmail.com>
 Hanno Braun <mail at hannobraun.com>
 Henri Yandell <flamefew at gmail.com>
 Hirotaka Yamamoto <ymmt2005 at gmail.com>
+ICHINOSE Shogo <shogo82148 at gmail.com>
 INADA Naoki <songofacandy at gmail.com>
 Jacek Szwec <szwec.jacek at gmail.com>
 James Harr <james.harr at gmail.com>
@@ -45,6 +48,7 @@ Luke Scott <luke at webconnex.com>
 Michael Woolnough <michael.woolnough at gmail.com>
 Nicola Peduzzi <thenikso at gmail.com>
 Olivier Mengué <dolmen at cpan.org>
+oscarzhao <oscarzhaosl at gmail.com>
 Paul Bonser <misterpib at gmail.com>
 Peter Schultz <peter.schultz at classmarkets.com>
 Rebecca Chin <rchin at pivotal.io>

+ 4 - 0
README.md

@@ -19,6 +19,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac
     * [LOAD DATA LOCAL INFILE support](#load-data-local-infile-support)
     * [time.Time support](#timetime-support)
     * [Unicode support](#unicode-support)
+    * [context.Context Support](#contextcontext-support)
   * [Testing / Development](#testing--development)
   * [License](#license)
 
@@ -443,6 +444,9 @@ Version 1.0 of the driver recommended adding `&charset=utf8` (alias for `SET NAM
 
 See http://dev.mysql.com/doc/refman/5.7/en/charset-unicode.html for more details on MySQL's Unicode support.
 
+## `context.Context` Support
+Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts.
+See [context support in the database/sql package](https://golang.org/doc/go1.8#database_sql) for more details.
 
 ## Testing / Development
 To run the driver tests you may need to adjust the configuration. See the [Testing Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details.

+ 93 - 0
benchmark_go18_test.go

@@ -0,0 +1,93 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+// +build go1.8
+
+package mysql
+
+import (
+	"context"
+	"database/sql"
+	"fmt"
+	"runtime"
+	"testing"
+)
+
+func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) {
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+	db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0))
+
+	tb := (*TB)(b)
+	stmt := tb.checkStmt(db.PrepareContext(ctx, "SELECT val FROM foo WHERE id=?"))
+	defer stmt.Close()
+
+	b.SetParallelism(p)
+	b.ReportAllocs()
+	b.ResetTimer()
+	b.RunParallel(func(pb *testing.PB) {
+		var got string
+		for pb.Next() {
+			tb.check(stmt.QueryRow(1).Scan(&got))
+			if got != "one" {
+				b.Fatalf("query = %q; want one", got)
+			}
+		}
+	})
+}
+
+func BenchmarkQueryContext(b *testing.B) {
+	db := initDB(b,
+		"DROP TABLE IF EXISTS foo",
+		"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
+		`INSERT INTO foo VALUES (1, "one")`,
+		`INSERT INTO foo VALUES (2, "two")`,
+	)
+	defer db.Close()
+	for _, p := range []int{1, 2, 3, 4} {
+		b.Run(fmt.Sprintf("%d", p), func(b *testing.B) {
+			benchmarkQueryContext(b, db, p)
+		})
+	}
+}
+
+func benchmarkExecContext(b *testing.B, db *sql.DB, p int) {
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+	db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0))
+
+	tb := (*TB)(b)
+	stmt := tb.checkStmt(db.PrepareContext(ctx, "DO 1"))
+	defer stmt.Close()
+
+	b.SetParallelism(p)
+	b.ReportAllocs()
+	b.ResetTimer()
+	b.RunParallel(func(pb *testing.PB) {
+		for pb.Next() {
+			if _, err := stmt.ExecContext(ctx); err != nil {
+				b.Fatal(err)
+			}
+		}
+	})
+}
+
+func BenchmarkExecContext(b *testing.B) {
+	db := initDB(b,
+		"DROP TABLE IF EXISTS foo",
+		"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
+		`INSERT INTO foo VALUES (1, "one")`,
+		`INSERT INTO foo VALUES (2, "two")`,
+	)
+	defer db.Close()
+	for _, p := range []int{1, 2, 3, 4} {
+		b.Run(fmt.Sprintf("%d", p), func(b *testing.B) {
+			benchmarkQueryContext(b, db, p)
+		})
+	}
+}

+ 84 - 11
connection.go

@@ -14,9 +14,21 @@ import (
 	"net"
 	"strconv"
 	"strings"
+	"sync"
+	"sync/atomic"
 	"time"
 )
 
+// a copy of context.Context for Go 1.7 and later.
+type mysqlContext interface {
+	Done() <-chan struct{}
+	Err() error
+
+	// They are defined in context.Context, but go-mysql-driver does not use them.
+	// Deadline() (deadline time.Time, ok bool)
+	// Value(key interface{}) interface{}
+}
+
 type mysqlConn struct {
 	buf              buffer
 	netConn          net.Conn
@@ -31,6 +43,19 @@ type mysqlConn struct {
 	sequence         uint8
 	parseTime        bool
 	strict           bool
+
+	// for context support (From Go 1.8)
+	watching bool
+	watcher  chan<- mysqlContext
+	closech  chan struct{}
+	finished chan<- struct{}
+
+	// set non-zero when conn is closed, before closech is closed.
+	// accessed atomically.
+	closed int32
+
+	mu          sync.Mutex // guards following fields
+	canceledErr error      // set non-nil if conn is canceled
 }
 
 // Handles parameters set in DSN after the connection is established
@@ -64,7 +89,7 @@ func (mc *mysqlConn) handleParams() (err error) {
 }
 
 func (mc *mysqlConn) Begin() (driver.Tx, error) {
-	if mc.netConn == nil {
+	if mc.isBroken() {
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 	}
@@ -78,7 +103,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
 
 func (mc *mysqlConn) Close() (err error) {
 	// Makes Close idempotent
-	if mc.netConn != nil {
+	if !mc.isBroken() {
 		err = mc.writeCommandPacket(comQuit)
 	}
 
@@ -92,19 +117,36 @@ func (mc *mysqlConn) Close() (err error) {
 // is called before auth or on auth failure because MySQL will have already
 // closed the network connection.
 func (mc *mysqlConn) cleanup() {
+	if atomic.SwapInt32(&mc.closed, 1) != 0 {
+		return
+	}
+
 	// Makes cleanup idempotent
-	if mc.netConn != nil {
-		if err := mc.netConn.Close(); err != nil {
-			errLog.Print(err)
+	close(mc.closech)
+	if mc.netConn == nil {
+		return
+	}
+	if err := mc.netConn.Close(); err != nil {
+		errLog.Print(err)
+	}
+}
+
+func (mc *mysqlConn) isBroken() bool {
+	return atomic.LoadInt32(&mc.closed) != 0
+}
+
+func (mc *mysqlConn) error() error {
+	if mc.isBroken() {
+		if err := mc.canceled(); err != nil {
+			return err
 		}
-		mc.netConn = nil
+		return ErrInvalidConn
 	}
-	mc.cfg = nil
-	mc.buf.nc = nil
+	return nil
 }
 
 func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
-	if mc.netConn == nil {
+	if mc.isBroken() {
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 	}
@@ -258,7 +300,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
 }
 
 func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
-	if mc.netConn == nil {
+	if mc.isBroken() {
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 	}
@@ -315,7 +357,11 @@ func (mc *mysqlConn) exec(query string) error {
 }
 
 func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
-	if mc.netConn == nil {
+	return mc.query(query, args)
+}
+
+func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
+	if mc.isBroken() {
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 	}
@@ -387,3 +433,30 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
 	}
 	return nil, err
 }
+
+// finish is called when the query has canceled.
+func (mc *mysqlConn) cancel(err error) {
+	mc.mu.Lock()
+	mc.canceledErr = err
+	mc.mu.Unlock()
+	mc.cleanup()
+}
+
+// canceled returns non-nil if the connection was closed due to context cancelation.
+func (mc *mysqlConn) canceled() error {
+	mc.mu.Lock()
+	defer mc.mu.Unlock()
+	return mc.canceledErr
+}
+
+// finish is called when the query has succeeded.
+func (mc *mysqlConn) finish() {
+	if !mc.watching || mc.finished == nil {
+		return
+	}
+	select {
+	case mc.finished <- struct{}{}:
+		mc.watching = false
+	case <-mc.closech:
+	}
+}

+ 207 - 0
connection_go18.go

@@ -0,0 +1,207 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+// +build go1.8
+
+package mysql
+
+import (
+	"context"
+	"database/sql"
+	"database/sql/driver"
+	"errors"
+)
+
+// Ping implements driver.Pinger interface
+func (mc *mysqlConn) Ping(ctx context.Context) error {
+	if mc.isBroken() {
+		errLog.Print(ErrInvalidConn)
+		return driver.ErrBadConn
+	}
+
+	if err := mc.watchCancel(ctx); err != nil {
+		return err
+	}
+	defer mc.finish()
+
+	if err := mc.writeCommandPacket(comPing); err != nil {
+		return err
+	}
+	if _, err := mc.readResultOK(); err != nil {
+		return err
+	}
+
+	return nil
+}
+
+// BeginTx implements driver.ConnBeginTx interface
+func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
+	if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
+		// TODO: support isolation levels
+		return nil, errors.New("mysql: isolation levels not supported")
+	}
+	if opts.ReadOnly {
+		// TODO: support read-only transactions
+		return nil, errors.New("mysql: read-only transactions not supported")
+	}
+
+	if err := mc.watchCancel(ctx); err != nil {
+		return nil, err
+	}
+
+	tx, err := mc.Begin()
+	mc.finish()
+	if err != nil {
+		return nil, err
+	}
+
+	select {
+	default:
+	case <-ctx.Done():
+		tx.Rollback()
+		return nil, ctx.Err()
+	}
+	return tx, err
+}
+
+func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
+	dargs, err := namedValueToValue(args)
+	if err != nil {
+		return nil, err
+	}
+
+	if err := mc.watchCancel(ctx); err != nil {
+		return nil, err
+	}
+
+	rows, err := mc.query(query, dargs)
+	if err != nil {
+		mc.finish()
+		return nil, err
+	}
+	rows.finish = mc.finish
+	return rows, err
+}
+
+func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
+	dargs, err := namedValueToValue(args)
+	if err != nil {
+		return nil, err
+	}
+
+	if err := mc.watchCancel(ctx); err != nil {
+		return nil, err
+	}
+	defer mc.finish()
+
+	return mc.Exec(query, dargs)
+}
+
+func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
+	if err := mc.watchCancel(ctx); err != nil {
+		return nil, err
+	}
+
+	stmt, err := mc.Prepare(query)
+	mc.finish()
+	if err != nil {
+		return nil, err
+	}
+
+	select {
+	default:
+	case <-ctx.Done():
+		stmt.Close()
+		return nil, ctx.Err()
+	}
+	return stmt, nil
+}
+
+func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
+	dargs, err := namedValueToValue(args)
+	if err != nil {
+		return nil, err
+	}
+
+	if err := stmt.mc.watchCancel(ctx); err != nil {
+		return nil, err
+	}
+
+	rows, err := stmt.query(dargs)
+	if err != nil {
+		stmt.mc.finish()
+		return nil, err
+	}
+	rows.finish = stmt.mc.finish
+	return rows, err
+}
+
+func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
+	dargs, err := namedValueToValue(args)
+	if err != nil {
+		return nil, err
+	}
+
+	if err := stmt.mc.watchCancel(ctx); err != nil {
+		return nil, err
+	}
+	defer stmt.mc.finish()
+
+	return stmt.Exec(dargs)
+}
+
+func (mc *mysqlConn) watchCancel(ctx context.Context) error {
+	if mc.watching {
+		// Reach here if canceled,
+		// so the connection is already invalid
+		mc.cleanup()
+		return nil
+	}
+	if ctx.Done() == nil {
+		return nil
+	}
+
+	mc.watching = true
+	select {
+	default:
+	case <-ctx.Done():
+		return ctx.Err()
+	}
+	if mc.watcher == nil {
+		return nil
+	}
+
+	mc.watcher <- ctx
+
+	return nil
+}
+
+func (mc *mysqlConn) startWatcher() {
+	watcher := make(chan mysqlContext, 1)
+	mc.watcher = watcher
+	finished := make(chan struct{})
+	mc.finished = finished
+	go func() {
+		for {
+			var ctx mysqlContext
+			select {
+			case ctx = <-watcher:
+			case <-mc.closech:
+				return
+			}
+
+			select {
+			case <-ctx.Done():
+				mc.cancel(ctx.Err())
+			case <-finished:
+			case <-mc.closech:
+				return
+			}
+		}
+	}()
+}

+ 11 - 0
driver.go

@@ -22,6 +22,11 @@ import (
 	"net"
 )
 
+// watcher interface is used for context support (From Go 1.8)
+type watcher interface {
+	startWatcher()
+}
+
 // MySQLDriver is exported to make the driver directly accessible.
 // In general the driver is used via the database/sql package.
 type MySQLDriver struct{}
@@ -52,6 +57,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
 	mc := &mysqlConn{
 		maxAllowedPacket: maxPacketSize,
 		maxWriteSize:     maxPacketSize - 1,
+		closech:          make(chan struct{}),
 	}
 	mc.cfg, err = ParseDSN(dsn)
 	if err != nil {
@@ -60,6 +66,11 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
 	mc.parseTime = mc.cfg.ParseTime
 	mc.strict = mc.cfg.Strict
 
+	// Call startWatcher for context support (From Go 1.8)
+	if s, ok := interface{}(mc).(watcher); ok {
+		s.startWatcher()
+	}
+
 	// Connect to Server
 	if dial, ok := dials[mc.cfg.Net]; ok {
 		mc.netConn, err = dial(mc.cfg.Addr)

+ 272 - 0
driver_go18_test.go

@@ -11,10 +11,28 @@
 package mysql
 
 import (
+	"context"
 	"database/sql"
+	"database/sql/driver"
 	"fmt"
 	"reflect"
 	"testing"
+	"time"
+)
+
+// static interface implementation checks of mysqlConn
+var (
+	_ driver.ConnBeginTx        = &mysqlConn{}
+	_ driver.ConnPrepareContext = &mysqlConn{}
+	_ driver.ExecerContext      = &mysqlConn{}
+	_ driver.Pinger             = &mysqlConn{}
+	_ driver.QueryerContext     = &mysqlConn{}
+)
+
+// static interface implementation checks of mysqlStmt
+var (
+	_ driver.StmtExecContext  = &mysqlStmt{}
+	_ driver.StmtQueryContext = &mysqlStmt{}
 )
 
 func TestMultiResultSet(t *testing.T) {
@@ -196,3 +214,257 @@ func TestSkipResults(t *testing.T) {
 		}
 	})
 }
+
+func TestPingContext(t *testing.T) {
+	runTests(t, dsn, func(dbt *DBTest) {
+		ctx, cancel := context.WithCancel(context.Background())
+		cancel()
+		if err := dbt.db.PingContext(ctx); err != context.Canceled {
+			dbt.Errorf("expected context.Canceled, got %v", err)
+		}
+	})
+}
+
+func TestContextCancelExec(t *testing.T) {
+	runTests(t, dsn, func(dbt *DBTest) {
+		dbt.mustExec("CREATE TABLE test (v INTEGER)")
+		ctx, cancel := context.WithCancel(context.Background())
+
+		// Delay execution for just a bit until db.ExecContext has begun.
+		defer time.AfterFunc(100*time.Millisecond, cancel).Stop()
+
+		// This query will be canceled.
+		startTime := time.Now()
+		if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled {
+			dbt.Errorf("expected context.Canceled, got %v", err)
+		}
+		if d := time.Since(startTime); d > 500*time.Millisecond {
+			dbt.Errorf("too long execution time: %s", d)
+		}
+
+		// Wait for the INSERT query has done.
+		time.Sleep(time.Second)
+
+		// Check how many times the query is executed.
+		var v int
+		if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
+			dbt.Fatalf("%s", err.Error())
+		}
+		if v != 1 { // TODO: need to kill the query, and v should be 0.
+			dbt.Errorf("expected val to be 1, got %d", v)
+		}
+
+		// Context is already canceled, so error should come before execution.
+		if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (1)"); err == nil {
+			dbt.Error("expected error")
+		} else if err.Error() != "context canceled" {
+			dbt.Fatalf("unexpected error: %s", err)
+		}
+
+		// The second insert query will fail, so the table has no changes.
+		if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
+			dbt.Fatalf("%s", err.Error())
+		}
+		if v != 1 {
+			dbt.Errorf("expected val to be 1, got %d", v)
+		}
+	})
+}
+
+func TestContextCancelQuery(t *testing.T) {
+	runTests(t, dsn, func(dbt *DBTest) {
+		dbt.mustExec("CREATE TABLE test (v INTEGER)")
+		ctx, cancel := context.WithCancel(context.Background())
+
+		// Delay execution for just a bit until db.ExecContext has begun.
+		defer time.AfterFunc(100*time.Millisecond, cancel).Stop()
+
+		// This query will be canceled.
+		startTime := time.Now()
+		if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled {
+			dbt.Errorf("expected context.Canceled, got %v", err)
+		}
+		if d := time.Since(startTime); d > 500*time.Millisecond {
+			dbt.Errorf("too long execution time: %s", d)
+		}
+
+		// Wait for the INSERT query has done.
+		time.Sleep(time.Second)
+
+		// Check how many times the query is executed.
+		var v int
+		if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
+			dbt.Fatalf("%s", err.Error())
+		}
+		if v != 1 { // TODO: need to kill the query, and v should be 0.
+			dbt.Errorf("expected val to be 1, got %d", v)
+		}
+
+		// Context is already canceled, so error should come before execution.
+		if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (1)"); err != context.Canceled {
+			dbt.Errorf("expected context.Canceled, got %v", err)
+		}
+
+		// The second insert query will fail, so the table has no changes.
+		if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
+			dbt.Fatalf("%s", err.Error())
+		}
+		if v != 1 {
+			dbt.Errorf("expected val to be 1, got %d", v)
+		}
+	})
+}
+
+func TestContextCancelQueryRow(t *testing.T) {
+	runTests(t, dsn, func(dbt *DBTest) {
+		dbt.mustExec("CREATE TABLE test (v INTEGER)")
+		dbt.mustExec("INSERT INTO test VALUES (1), (2), (3)")
+		ctx, cancel := context.WithCancel(context.Background())
+
+		rows, err := dbt.db.QueryContext(ctx, "SELECT v FROM test")
+		if err != nil {
+			dbt.Fatalf("%s", err.Error())
+		}
+
+		// the first row will be succeed.
+		var v int
+		if !rows.Next() {
+			dbt.Fatalf("unexpected end")
+		}
+		if err := rows.Scan(&v); err != nil {
+			dbt.Fatalf("%s", err.Error())
+		}
+
+		cancel()
+		// make sure the driver recieve cancel request.
+		time.Sleep(100 * time.Millisecond)
+
+		if rows.Next() {
+			dbt.Errorf("expected end, but not")
+		}
+		if err := rows.Err(); err != context.Canceled {
+			dbt.Errorf("expected context.Canceled, got %v", err)
+		}
+	})
+}
+
+func TestContextCancelPrepare(t *testing.T) {
+	runTests(t, dsn, func(dbt *DBTest) {
+		ctx, cancel := context.WithCancel(context.Background())
+		cancel()
+		if _, err := dbt.db.PrepareContext(ctx, "SELECT 1"); err != context.Canceled {
+			dbt.Errorf("expected context.Canceled, got %v", err)
+		}
+	})
+}
+
+func TestContextCancelStmtExec(t *testing.T) {
+	runTests(t, dsn, func(dbt *DBTest) {
+		dbt.mustExec("CREATE TABLE test (v INTEGER)")
+		ctx, cancel := context.WithCancel(context.Background())
+		stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))")
+		if err != nil {
+			dbt.Fatalf("unexpected error: %v", err)
+		}
+
+		// Delay execution for just a bit until db.ExecContext has begun.
+		defer time.AfterFunc(100*time.Millisecond, cancel).Stop()
+
+		// This query will be canceled.
+		startTime := time.Now()
+		if _, err := stmt.ExecContext(ctx); err != context.Canceled {
+			dbt.Errorf("expected context.Canceled, got %v", err)
+		}
+		if d := time.Since(startTime); d > 500*time.Millisecond {
+			dbt.Errorf("too long execution time: %s", d)
+		}
+
+		// Wait for the INSERT query has done.
+		time.Sleep(time.Second)
+
+		// Check how many times the query is executed.
+		var v int
+		if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
+			dbt.Fatalf("%s", err.Error())
+		}
+		if v != 1 { // TODO: need to kill the query, and v should be 0.
+			dbt.Errorf("expected val to be 1, got %d", v)
+		}
+	})
+}
+
+func TestContextCancelStmtQuery(t *testing.T) {
+	runTests(t, dsn, func(dbt *DBTest) {
+		dbt.mustExec("CREATE TABLE test (v INTEGER)")
+		ctx, cancel := context.WithCancel(context.Background())
+		stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))")
+		if err != nil {
+			dbt.Fatalf("unexpected error: %v", err)
+		}
+
+		// Delay execution for just a bit until db.ExecContext has begun.
+		defer time.AfterFunc(100*time.Millisecond, cancel).Stop()
+
+		// This query will be canceled.
+		startTime := time.Now()
+		if _, err := stmt.QueryContext(ctx); err != context.Canceled {
+			dbt.Errorf("expected context.Canceled, got %v", err)
+		}
+		if d := time.Since(startTime); d > 500*time.Millisecond {
+			dbt.Errorf("too long execution time: %s", d)
+		}
+
+		// Wait for the INSERT query has done.
+		time.Sleep(time.Second)
+
+		// Check how many times the query is executed.
+		var v int
+		if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
+			dbt.Fatalf("%s", err.Error())
+		}
+		if v != 1 { // TODO: need to kill the query, and v should be 0.
+			dbt.Errorf("expected val to be 1, got %d", v)
+		}
+	})
+}
+
+func TestContextCancelBegin(t *testing.T) {
+	runTests(t, dsn, func(dbt *DBTest) {
+		dbt.mustExec("CREATE TABLE test (v INTEGER)")
+		ctx, cancel := context.WithCancel(context.Background())
+		tx, err := dbt.db.BeginTx(ctx, nil)
+		if err != nil {
+			dbt.Fatal(err)
+		}
+
+		// Delay execution for just a bit until db.ExecContext has begun.
+		defer time.AfterFunc(100*time.Millisecond, cancel).Stop()
+
+		// This query will be canceled.
+		startTime := time.Now()
+		if _, err := tx.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled {
+			dbt.Errorf("expected context.Canceled, got %v", err)
+		}
+		if d := time.Since(startTime); d > 500*time.Millisecond {
+			dbt.Errorf("too long execution time: %s", d)
+		}
+
+		// Transaction is canceled, so expect an error.
+		switch err := tx.Commit(); err {
+		case sql.ErrTxDone:
+			// because the transaction has already been rollbacked.
+			// the database/sql package watches ctx
+			// and rollbacks when ctx is canceled.
+		case context.Canceled:
+			// the database/sql package rollbacks on another goroutine,
+			// so the transaction may not be rollbacked depending on goroutine scheduling.
+		default:
+			dbt.Errorf("expected sql.ErrTxDone or context.Canceled, got %v", err)
+		}
+
+		// Context is canceled, so cannot begin a transaction.
+		if _, err := dbt.db.BeginTx(ctx, nil); err != context.Canceled {
+			dbt.Errorf("expected context.Canceled, got %v", err)
+		}
+	})
+}

+ 8 - 0
driver_test.go

@@ -1991,3 +1991,11 @@ func TestRejectReadOnly(t *testing.T) {
 		dbt.mustExec("DROP TABLE test")
 	})
 }
+
+func TestPing(t *testing.T) {
+	runTests(t, dsn, func(dbt *DBTest) {
+		if err := dbt.db.Ping(); err != nil {
+			dbt.fail("Ping", "Ping", err)
+		}
+	})
+}

+ 11 - 0
packets.go

@@ -30,6 +30,9 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
 		// read packet header
 		data, err := mc.buf.readNext(4)
 		if err != nil {
+			if cerr := mc.canceled(); cerr != nil {
+				return nil, cerr
+			}
 			errLog.Print(err)
 			mc.Close()
 			return nil, driver.ErrBadConn
@@ -63,6 +66,9 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
 		// read packet body [pktLen bytes]
 		data, err = mc.buf.readNext(pktLen)
 		if err != nil {
+			if cerr := mc.canceled(); cerr != nil {
+				return nil, cerr
+			}
 			errLog.Print(err)
 			mc.Close()
 			return nil, driver.ErrBadConn
@@ -125,8 +131,13 @@ func (mc *mysqlConn) writePacket(data []byte) error {
 
 		// Handle error
 		if err == nil { // n != len(data)
+			mc.cleanup()
 			errLog.Print(ErrMalformPkt)
 		} else {
+			if cerr := mc.canceled(); cerr != nil {
+				return cerr
+			}
+			mc.cleanup()
 			errLog.Print(err)
 		}
 		return driver.ErrBadConn

+ 2 - 1
packets_test.go

@@ -244,7 +244,8 @@ func TestReadPacketSplit(t *testing.T) {
 func TestReadPacketFail(t *testing.T) {
 	conn := new(mockConn)
 	mc := &mysqlConn{
-		buf: newBuffer(conn),
+		buf:     newBuffer(conn),
+		closech: make(chan struct{}),
 	}
 
 	// illegal empty (stand-alone) packet

+ 16 - 10
rows.go

@@ -28,8 +28,9 @@ type resultSet struct {
 }
 
 type mysqlRows struct {
-	mc *mysqlConn
-	rs resultSet
+	mc     *mysqlConn
+	rs     resultSet
+	finish func()
 }
 
 type binaryRows struct {
@@ -65,12 +66,17 @@ func (rows *mysqlRows) Columns() []string {
 }
 
 func (rows *mysqlRows) Close() (err error) {
+	if f := rows.finish; f != nil {
+		f()
+		rows.finish = nil
+	}
+
 	mc := rows.mc
 	if mc == nil {
 		return nil
 	}
-	if mc.netConn == nil {
-		return ErrInvalidConn
+	if err := mc.error(); err != nil {
+		return err
 	}
 
 	// Remove unread packets from stream
@@ -98,8 +104,8 @@ func (rows *mysqlRows) nextResultSet() (int, error) {
 	if rows.mc == nil {
 		return 0, io.EOF
 	}
-	if rows.mc.netConn == nil {
-		return 0, ErrInvalidConn
+	if err := rows.mc.error(); err != nil {
+		return 0, err
 	}
 
 	// Remove unread packets from stream
@@ -145,8 +151,8 @@ func (rows *binaryRows) NextResultSet() error {
 
 func (rows *binaryRows) Next(dest []driver.Value) error {
 	if mc := rows.mc; mc != nil {
-		if mc.netConn == nil {
-			return ErrInvalidConn
+		if err := mc.error(); err != nil {
+			return err
 		}
 
 		// Fetch next row from stream
@@ -167,8 +173,8 @@ func (rows *textRows) NextResultSet() (err error) {
 
 func (rows *textRows) Next(dest []driver.Value) error {
 	if mc := rows.mc; mc != nil {
-		if mc.netConn == nil {
-			return ErrInvalidConn
+		if err := mc.error(); err != nil {
+			return err
 		}
 
 		// Fetch next row from stream

+ 7 - 3
statement.go

@@ -23,7 +23,7 @@ type mysqlStmt struct {
 }
 
 func (stmt *mysqlStmt) Close() error {
-	if stmt.mc == nil || stmt.mc.netConn == nil {
+	if stmt.mc == nil || stmt.mc.isBroken() {
 		// driver.Stmt.Close can be called more than once, thus this function
 		// has to be idempotent.
 		// See also Issue #450 and golang/go#16019.
@@ -45,7 +45,7 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
 }
 
 func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
-	if stmt.mc.netConn == nil {
+	if stmt.mc.isBroken() {
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 	}
@@ -89,7 +89,11 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
 }
 
 func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
-	if stmt.mc.netConn == nil {
+	return stmt.query(args)
+}
+
+func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
+	if stmt.mc.isBroken() {
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 	}

+ 2 - 2
transaction.go

@@ -13,7 +13,7 @@ type mysqlTx struct {
 }
 
 func (tx *mysqlTx) Commit() (err error) {
-	if tx.mc == nil || tx.mc.netConn == nil {
+	if tx.mc == nil || tx.mc.isBroken() {
 		return ErrInvalidConn
 	}
 	err = tx.mc.exec("COMMIT")
@@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) {
 }
 
 func (tx *mysqlTx) Rollback() (err error) {
-	if tx.mc == nil || tx.mc.netConn == nil {
+	if tx.mc == nil || tx.mc.isBroken() {
 		return ErrInvalidConn
 	}
 	err = tx.mc.exec("ROLLBACK")

+ 17 - 1
utils_go18.go

@@ -10,8 +10,24 @@
 
 package mysql
 
-import "crypto/tls"
+import (
+	"crypto/tls"
+	"database/sql/driver"
+	"errors"
+)
 
 func cloneTLSConfig(c *tls.Config) *tls.Config {
 	return c.Clone()
 }
+
+func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
+	dargs := make([]driver.Value, len(named))
+	for n, param := range named {
+		if len(param.Name) > 0 {
+			// TODO: support the use of Named Parameters #561
+			return nil, errors.New("mysql: driver does not support the use of Named Parameters")
+		}
+		dargs[n] = param.Value
+	}
+	return dargs, nil
+}