瀏覽代碼

Merge pull request #245 from go-sql-driver/dial

Registration of custom dial functions
Julien Schmidt 11 年之前
父節點
當前提交
0183433227
共有 4 個文件被更改,包括 51 次插入15 次删除
  1. 1 0
      CHANGELOG.md
  2. 1 6
      appengine.go
  3. 16 4
      driver.go
  4. 33 5
      driver_test.go

+ 1 - 0
CHANGELOG.md

@@ -16,6 +16,7 @@ Changes:
 
 New Features:
 
+ - `RegisterDial` allows the usage of a custom dial function to establish the network connection
  - Setting the connection collation is possible with the `collation` DSN parameter. This parameter should be preferred over the `charset` parameter
  - Logging of critical errors is configurable with `SetLogger`
  - Google CloudSQL support

+ 1 - 6
appengine.go

@@ -16,10 +16,5 @@ import (
 )
 
 func init() {
-	if dials == nil {
-		dials = make(map[string]dialFunc)
-	}
-	dials["cloudsql"] = func(cfg *config) (net.Conn, error) {
-		return cloudsql.Dial(cfg.addr)
-	}
+	RegisterDial("cloudsql", cloudsql.Dial)
 }

+ 16 - 4
driver.go

@@ -26,9 +26,21 @@ import (
 // In general the driver is used via the database/sql package.
 type MySQLDriver struct{}
 
-type dialFunc func(*config) (net.Conn, error)
-
-var dials map[string]dialFunc
+// DialFunc is a function which can be used to establish the network connection.
+// Custom dial functions must be registered with RegisterDial
+type DialFunc func(addr string) (net.Conn, error)
+
+var dials map[string]DialFunc
+
+// RegisterDial registers a custom dial function. It can then be used by the
+// network address mynet(addr), where mynet is the registered new network.
+// addr is passed as a parameter to the dial function.
+func RegisterDial(net string, dial DialFunc) {
+	if dials == nil {
+		dials = make(map[string]DialFunc)
+	}
+	dials[net] = dial
+}
 
 // Open new Connection.
 // See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
@@ -48,7 +60,7 @@ func (d *MySQLDriver) Open(dsn string) (driver.Conn, error) {
 
 	// Connect to Server
 	if dial, ok := dials[mc.cfg.net]; ok {
-		mc.netConn, err = dial(mc.cfg)
+		mc.netConn, err = dial(mc.cfg.addr)
 	} else {
 		nd := net.Dialer{Timeout: mc.cfg.timeout}
 		mc.netConn, err = nd.Dial(mc.cfg.net, mc.cfg.addr)

+ 33 - 5
driver_test.go

@@ -26,6 +26,11 @@ import (
 )
 
 var (
+	user      string
+	pass      string
+	prot      string
+	addr      string
+	dbname    string
 	dsn       string
 	netAddr   string
 	available bool
@@ -43,17 +48,18 @@ var (
 
 // See https://github.com/go-sql-driver/mysql/wiki/Testing
 func init() {
+	// get environment variables
 	env := func(key, defaultValue string) string {
 		if value := os.Getenv(key); value != "" {
 			return value
 		}
 		return defaultValue
 	}
-	user := env("MYSQL_TEST_USER", "root")
-	pass := env("MYSQL_TEST_PASS", "")
-	prot := env("MYSQL_TEST_PROT", "tcp")
-	addr := env("MYSQL_TEST_ADDR", "localhost:3306")
-	dbname := env("MYSQL_TEST_DBNAME", "gotest")
+	user = env("MYSQL_TEST_USER", "root")
+	pass = env("MYSQL_TEST_PASS", "")
+	prot = env("MYSQL_TEST_PROT", "tcp")
+	addr = env("MYSQL_TEST_ADDR", "localhost:3306")
+	dbname = env("MYSQL_TEST_DBNAME", "gotest")
 	netAddr = fmt.Sprintf("%s(%s)", prot, addr)
 	dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true", user, pass, netAddr, dbname)
 	c, err := net.Dial(prot, addr)
@@ -1340,3 +1346,25 @@ func TestConcurrent(t *testing.T) {
 		dbt.Logf("Reached %d concurrent connections\r\n", succeeded)
 	})
 }
+
+// Tests custom dial functions
+func TestCustomDial(t *testing.T) {
+	if !available {
+		t.Skipf("MySQL-Server not running on %s", netAddr)
+	}
+
+	// our custom dial function which justs wraps net.Dial here
+	RegisterDial("mydial", func(addr string) (net.Conn, error) {
+		return net.Dial(prot, addr)
+	})
+
+	db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s&strict=true", user, pass, addr, dbname))
+	if err != nil {
+		t.Fatalf("Error connecting: %s", err.Error())
+	}
+	defer db.Close()
+
+	if _, err = db.Exec("DO 1"); err != nil {
+		t.Fatalf("Connection failed: %s", err.Error())
+	}
+}