Browse Source

Support Go 1.10 Connector interface (#941)

Vicent Martí 6 years ago
parent
commit
89ec2a9ec8
8 changed files with 415 additions and 115 deletions
  1. 6 1
      appengine.go
  2. 143 0
      connector.go
  3. 27 111
      driver.go
  4. 37 0
      driver_go110.go
  5. 137 0
      driver_go110_test.go
  6. 4 3
      driver_test.go
  7. 21 0
      dsn.go
  8. 40 0
      dsn_test.go

+ 6 - 1
appengine.go

@@ -11,9 +11,14 @@
 package mysql
 
 import (
+	"context"
+
 	"google.golang.org/appengine/cloudsql"
 )
 
 func init() {
-	RegisterDial("cloudsql", cloudsql.Dial)
+	RegisterDialContext("cloudsql", func(_ context.Context, instance addr) (net.Conn, error) {
+		// XXX: the cloudsql driver still does not export a Context-aware dialer.
+		return cloudsql.Dial(instance)
+	})
 }

+ 143 - 0
connector.go

@@ -0,0 +1,143 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+	"context"
+	"database/sql/driver"
+	"net"
+)
+
+type connector struct {
+	cfg *Config // immutable private copy.
+}
+
+// Connect implements driver.Connector interface.
+// Connect returns a connection to the database.
+func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
+	var err error
+
+	// New mysqlConn
+	mc := &mysqlConn{
+		maxAllowedPacket: maxPacketSize,
+		maxWriteSize:     maxPacketSize - 1,
+		closech:          make(chan struct{}),
+		cfg:              c.cfg,
+	}
+	mc.parseTime = mc.cfg.ParseTime
+
+	// Connect to Server
+	dialsLock.RLock()
+	dial, ok := dials[mc.cfg.Net]
+	dialsLock.RUnlock()
+	if ok {
+		mc.netConn, err = dial(ctx, mc.cfg.Addr)
+	} else {
+		nd := net.Dialer{Timeout: mc.cfg.Timeout}
+		mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
+	}
+
+	if err != nil {
+		if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
+			errLog.Print("net.Error from Dial()': ", nerr.Error())
+			return nil, driver.ErrBadConn
+		}
+		return nil, err
+	}
+
+	// Enable TCP Keepalives on TCP connections
+	if tc, ok := mc.netConn.(*net.TCPConn); ok {
+		if err := tc.SetKeepAlive(true); err != nil {
+			// Don't send COM_QUIT before handshake.
+			mc.netConn.Close()
+			mc.netConn = nil
+			return nil, err
+		}
+	}
+
+	// Call startWatcher for context support (From Go 1.8)
+	mc.startWatcher()
+	if err := mc.watchCancel(ctx); err != nil {
+		return nil, err
+	}
+	defer mc.finish()
+
+	mc.buf = newBuffer(mc.netConn)
+
+	// Set I/O timeouts
+	mc.buf.timeout = mc.cfg.ReadTimeout
+	mc.writeTimeout = mc.cfg.WriteTimeout
+
+	// Reading Handshake Initialization Packet
+	authData, plugin, err := mc.readHandshakePacket()
+	if err != nil {
+		mc.cleanup()
+		return nil, err
+	}
+
+	if plugin == "" {
+		plugin = defaultAuthPlugin
+	}
+
+	// Send Client Authentication Packet
+	authResp, err := mc.auth(authData, plugin)
+	if err != nil {
+		// try the default auth plugin, if using the requested plugin failed
+		errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
+		plugin = defaultAuthPlugin
+		authResp, err = mc.auth(authData, plugin)
+		if err != nil {
+			mc.cleanup()
+			return nil, err
+		}
+	}
+	if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
+		mc.cleanup()
+		return nil, err
+	}
+
+	// Handle response to auth packet, switch methods if possible
+	if err = mc.handleAuthResult(authData, plugin); err != nil {
+		// Authentication failed and MySQL has already closed the connection
+		// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
+		// Do not send COM_QUIT, just cleanup and return the error.
+		mc.cleanup()
+		return nil, err
+	}
+
+	if mc.cfg.MaxAllowedPacket > 0 {
+		mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
+	} else {
+		// Get max allowed packet size
+		maxap, err := mc.getSystemVar("max_allowed_packet")
+		if err != nil {
+			mc.Close()
+			return nil, err
+		}
+		mc.maxAllowedPacket = stringToInt(maxap) - 1
+	}
+	if mc.maxAllowedPacket < maxPacketSize {
+		mc.maxWriteSize = mc.maxAllowedPacket
+	}
+
+	// Handle DSN Params
+	err = mc.handleParams()
+	if err != nil {
+		mc.Close()
+		return nil, err
+	}
+
+	return mc, nil
+}
+
+// Driver implements driver.Connector interface.
+// Driver returns &MySQLDriver{}.
+func (c *connector) Driver() driver.Driver {
+	return &MySQLDriver{}
+}

+ 27 - 111
driver.go

@@ -17,6 +17,7 @@
 package mysql
 
 import (
+	"context"
 	"database/sql"
 	"database/sql/driver"
 	"net"
@@ -29,139 +30,54 @@ type MySQLDriver struct{}
 
 // DialFunc is a function which can be used to establish the network connection.
 // Custom dial functions must be registered with RegisterDial
+//
+// Deprecated: users should register a DialContextFunc instead
 type DialFunc func(addr string) (net.Conn, error)
 
+// DialContextFunc is a function which can be used to establish the network connection.
+// Custom dial functions must be registered with RegisterDialContext
+type DialContextFunc func(ctx context.Context, addr string) (net.Conn, error)
+
 var (
 	dialsLock sync.RWMutex
-	dials     map[string]DialFunc
+	dials     map[string]DialContextFunc
 )
 
-// RegisterDial registers a custom dial function. It can then be used by the
+// RegisterDialContext registers a custom dial function. It can then be used by the
 // network address mynet(addr), where mynet is the registered new network.
-// addr is passed as a parameter to the dial function.
-func RegisterDial(net string, dial DialFunc) {
+// The current context for the connection and its address is passed to the dial function.
+func RegisterDialContext(net string, dial DialContextFunc) {
 	dialsLock.Lock()
 	defer dialsLock.Unlock()
 	if dials == nil {
-		dials = make(map[string]DialFunc)
+		dials = make(map[string]DialContextFunc)
 	}
 	dials[net] = dial
 }
 
+// RegisterDial registers a custom dial function. It can then be used by the
+// network address mynet(addr), where mynet is the registered new network.
+// addr is passed as a parameter to the dial function.
+//
+// Deprecated: users should call RegisterDialContext instead
+func RegisterDial(network string, dial DialFunc) {
+	RegisterDialContext(network, func(_ context.Context, addr string) (net.Conn, error) {
+		return dial(addr)
+	})
+}
+
 // Open new Connection.
 // See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
 // the DSN string is formatted
 func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
-	var err error
-
-	// New mysqlConn
-	mc := &mysqlConn{
-		maxAllowedPacket: maxPacketSize,
-		maxWriteSize:     maxPacketSize - 1,
-		closech:          make(chan struct{}),
-	}
-	mc.cfg, err = ParseDSN(dsn)
-	if err != nil {
-		return nil, err
-	}
-	mc.parseTime = mc.cfg.ParseTime
-
-	// Connect to Server
-	dialsLock.RLock()
-	dial, ok := dials[mc.cfg.Net]
-	dialsLock.RUnlock()
-	if ok {
-		mc.netConn, err = dial(mc.cfg.Addr)
-	} else {
-		nd := net.Dialer{Timeout: mc.cfg.Timeout}
-		mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr)
-	}
-	if err != nil {
-		if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
-			errLog.Print("net.Error from Dial()': ", nerr.Error())
-			return nil, driver.ErrBadConn
-		}
-		return nil, err
-	}
-
-	// Enable TCP Keepalives on TCP connections
-	if tc, ok := mc.netConn.(*net.TCPConn); ok {
-		if err := tc.SetKeepAlive(true); err != nil {
-			// Don't send COM_QUIT before handshake.
-			mc.netConn.Close()
-			mc.netConn = nil
-			return nil, err
-		}
-	}
-
-	// Call startWatcher for context support (From Go 1.8)
-	mc.startWatcher()
-
-	mc.buf = newBuffer(mc.netConn)
-
-	// Set I/O timeouts
-	mc.buf.timeout = mc.cfg.ReadTimeout
-	mc.writeTimeout = mc.cfg.WriteTimeout
-
-	// Reading Handshake Initialization Packet
-	authData, plugin, err := mc.readHandshakePacket()
+	cfg, err := ParseDSN(dsn)
 	if err != nil {
-		mc.cleanup()
 		return nil, err
 	}
-	if plugin == "" {
-		plugin = defaultAuthPlugin
+	c := &connector{
+		cfg: cfg,
 	}
-
-	// Send Client Authentication Packet
-	authResp, err := mc.auth(authData, plugin)
-	if err != nil {
-		// try the default auth plugin, if using the requested plugin failed
-		errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
-		plugin = defaultAuthPlugin
-		authResp, err = mc.auth(authData, plugin)
-		if err != nil {
-			mc.cleanup()
-			return nil, err
-		}
-	}
-	if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
-		mc.cleanup()
-		return nil, err
-	}
-
-	// Handle response to auth packet, switch methods if possible
-	if err = mc.handleAuthResult(authData, plugin); err != nil {
-		// Authentication failed and MySQL has already closed the connection
-		// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
-		// Do not send COM_QUIT, just cleanup and return the error.
-		mc.cleanup()
-		return nil, err
-	}
-
-	if mc.cfg.MaxAllowedPacket > 0 {
-		mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
-	} else {
-		// Get max allowed packet size
-		maxap, err := mc.getSystemVar("max_allowed_packet")
-		if err != nil {
-			mc.Close()
-			return nil, err
-		}
-		mc.maxAllowedPacket = stringToInt(maxap) - 1
-	}
-	if mc.maxAllowedPacket < maxPacketSize {
-		mc.maxWriteSize = mc.maxAllowedPacket
-	}
-
-	// Handle DSN Params
-	err = mc.handleParams()
-	if err != nil {
-		mc.Close()
-		return nil, err
-	}
-
-	return mc, nil
+	return c.Connect(context.Background())
 }
 
 func init() {

+ 37 - 0
driver_go110.go

@@ -0,0 +1,37 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+// +build go1.10
+
+package mysql
+
+import (
+	"database/sql/driver"
+)
+
+// NewConnector returns new driver.Connector.
+func NewConnector(cfg *Config) (driver.Connector, error) {
+	cfg = cfg.Clone()
+	// normalize the contents of cfg so calls to NewConnector have the same
+	// behavior as MySQLDriver.OpenConnector
+	if err := cfg.normalize(); err != nil {
+		return nil, err
+	}
+	return &connector{cfg: cfg}, nil
+}
+
+// OpenConnector implements driver.DriverContext.
+func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) {
+	cfg, err := ParseDSN(dsn)
+	if err != nil {
+		return nil, err
+	}
+	return &connector{
+		cfg: cfg,
+	}, nil
+}

+ 137 - 0
driver_go110_test.go

@@ -0,0 +1,137 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+// +build go1.10
+
+package mysql
+
+import (
+	"context"
+	"database/sql"
+	"database/sql/driver"
+	"fmt"
+	"net"
+	"testing"
+	"time"
+)
+
+var _ driver.DriverContext = &MySQLDriver{}
+
+type dialCtxKey struct{}
+
+func TestConnectorObeysDialTimeouts(t *testing.T) {
+	if !available {
+		t.Skipf("MySQL server not running on %s", netAddr)
+	}
+
+	RegisterDialContext("dialctxtest", func(ctx context.Context, addr string) (net.Conn, error) {
+		var d net.Dialer
+		if !ctx.Value(dialCtxKey{}).(bool) {
+			return nil, fmt.Errorf("test error: query context is not propagated to our dialer")
+		}
+		return d.DialContext(ctx, prot, addr)
+	})
+
+	db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname))
+	if err != nil {
+		t.Fatalf("error connecting: %s", err.Error())
+	}
+	defer db.Close()
+
+	ctx := context.WithValue(context.Background(), dialCtxKey{}, true)
+
+	_, err = db.ExecContext(ctx, "DO 1")
+	if err != nil {
+		t.Fatal(err)
+	}
+}
+
+func configForTests(t *testing.T) *Config {
+	if !available {
+		t.Skipf("MySQL server not running on %s", netAddr)
+	}
+
+	mycnf := NewConfig()
+	mycnf.User = user
+	mycnf.Passwd = pass
+	mycnf.Addr = addr
+	mycnf.Net = prot
+	mycnf.DBName = dbname
+	return mycnf
+}
+
+func TestNewConnector(t *testing.T) {
+	mycnf := configForTests(t)
+	conn, err := NewConnector(mycnf)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	db := sql.OpenDB(conn)
+	defer db.Close()
+
+	if err := db.Ping(); err != nil {
+		t.Fatal(err)
+	}
+}
+
+type slowConnection struct {
+	net.Conn
+	slowdown time.Duration
+}
+
+func (sc *slowConnection) Read(b []byte) (int, error) {
+	time.Sleep(sc.slowdown)
+	return sc.Conn.Read(b)
+}
+
+type connectorHijack struct {
+	driver.Connector
+	connErr error
+}
+
+func (cw *connectorHijack) Connect(ctx context.Context) (driver.Conn, error) {
+	var conn driver.Conn
+	conn, cw.connErr = cw.Connector.Connect(ctx)
+	return conn, cw.connErr
+}
+
+func TestConnectorTimeoutsDuringOpen(t *testing.T) {
+	RegisterDialContext("slowconn", func(ctx context.Context, addr string) (net.Conn, error) {
+		var d net.Dialer
+		conn, err := d.DialContext(ctx, prot, addr)
+		if err != nil {
+			return nil, err
+		}
+		return &slowConnection{Conn: conn, slowdown: 100 * time.Millisecond}, nil
+	})
+
+	mycnf := configForTests(t)
+	mycnf.Net = "slowconn"
+
+	conn, err := NewConnector(mycnf)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	hijack := &connectorHijack{Connector: conn}
+
+	db := sql.OpenDB(hijack)
+	defer db.Close()
+
+	ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
+	defer cancel()
+
+	_, err = db.ExecContext(ctx, "DO 1")
+	if err != context.DeadlineExceeded {
+		t.Fatalf("ExecContext should have timed out")
+	}
+	if hijack.connErr != context.DeadlineExceeded {
+		t.Fatalf("(*Connector).Connect should have timed out")
+	}
+}

+ 4 - 3
driver_test.go

@@ -1846,7 +1846,7 @@ func TestConcurrent(t *testing.T) {
 }
 
 func testDialError(t *testing.T, dialErr error, expectErr error) {
-	RegisterDial("mydial", func(addr string) (net.Conn, error) {
+	RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) {
 		return nil, dialErr
 	})
 
@@ -1884,8 +1884,9 @@ func TestCustomDial(t *testing.T) {
 	}
 
 	// our custom dial function which justs wraps net.Dial here
-	RegisterDial("mydial", func(addr string) (net.Conn, error) {
-		return net.Dial(prot, addr)
+	RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) {
+		var d net.Dialer
+		return d.DialContext(ctx, prot, addr)
 	})
 
 	db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))

+ 21 - 0
dsn.go

@@ -14,6 +14,7 @@ import (
 	"crypto/tls"
 	"errors"
 	"fmt"
+	"math/big"
 	"net"
 	"net/url"
 	"sort"
@@ -72,6 +73,26 @@ func NewConfig() *Config {
 	}
 }
 
+func (cfg *Config) Clone() *Config {
+	cp := *cfg
+	if cp.tls != nil {
+		cp.tls = cfg.tls.Clone()
+	}
+	if len(cp.Params) > 0 {
+		cp.Params = make(map[string]string, len(cfg.Params))
+		for k, v := range cfg.Params {
+			cp.Params[k] = v
+		}
+	}
+	if cfg.pubKey != nil {
+		cp.pubKey = &rsa.PublicKey{
+			N: new(big.Int).Set(cfg.pubKey.N),
+			E: cfg.pubKey.E,
+		}
+	}
+	return &cp
+}
+
 func (cfg *Config) normalize() error {
 	if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
 		return errInvalidDSNUnsafeCollation

+ 40 - 0
dsn_test.go

@@ -318,6 +318,46 @@ func TestParamsAreSorted(t *testing.T) {
 	}
 }
 
+func TestCloneConfig(t *testing.T) {
+	RegisterServerPubKey("testKey", testPubKeyRSA)
+	defer DeregisterServerPubKey("testKey")
+
+	expectedServerName := "example.com"
+	dsn := "tcp(example.com:1234)/?tls=true&foobar=baz&serverPubKey=testKey"
+	cfg, err := ParseDSN(dsn)
+	if err != nil {
+		t.Fatal(err.Error())
+	}
+
+	cfg2 := cfg.Clone()
+	if cfg == cfg2 {
+		t.Errorf("Config.Clone did not create a separate config struct")
+	}
+
+	if cfg2.tls.ServerName != expectedServerName {
+		t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName)
+	}
+
+	cfg2.tls.ServerName = "example2.com"
+	if cfg.tls.ServerName == cfg2.tls.ServerName {
+		t.Errorf("changed cfg.tls.Server name should not propagate to original Config")
+	}
+
+	if _, ok := cfg2.Params["foobar"]; !ok {
+		t.Errorf("cloned Config is missing custom params")
+	}
+
+	delete(cfg2.Params, "foobar")
+
+	if _, ok := cfg.Params["foobar"]; !ok {
+		t.Errorf("custom params in cloned Config should not propagate to original Config")
+	}
+
+	if !reflect.DeepEqual(cfg.pubKey, cfg2.pubKey) {
+		t.Errorf("public key in Config should be identical")
+	}
+}
+
 func BenchmarkParseDSN(b *testing.B) {
 	b.ReportAllocs()