|
|
@@ -26,6 +26,11 @@ import (
|
|
|
)
|
|
|
|
|
|
var (
|
|
|
+ user string
|
|
|
+ pass string
|
|
|
+ prot string
|
|
|
+ addr string
|
|
|
+ dbname string
|
|
|
dsn string
|
|
|
netAddr string
|
|
|
available bool
|
|
|
@@ -43,17 +48,18 @@ var (
|
|
|
|
|
|
// See https://github.com/go-sql-driver/mysql/wiki/Testing
|
|
|
func init() {
|
|
|
+ // get environment variables
|
|
|
env := func(key, defaultValue string) string {
|
|
|
if value := os.Getenv(key); value != "" {
|
|
|
return value
|
|
|
}
|
|
|
return defaultValue
|
|
|
}
|
|
|
- user := env("MYSQL_TEST_USER", "root")
|
|
|
- pass := env("MYSQL_TEST_PASS", "")
|
|
|
- prot := env("MYSQL_TEST_PROT", "tcp")
|
|
|
- addr := env("MYSQL_TEST_ADDR", "localhost:3306")
|
|
|
- dbname := env("MYSQL_TEST_DBNAME", "gotest")
|
|
|
+ user = env("MYSQL_TEST_USER", "root")
|
|
|
+ pass = env("MYSQL_TEST_PASS", "")
|
|
|
+ prot = env("MYSQL_TEST_PROT", "tcp")
|
|
|
+ addr = env("MYSQL_TEST_ADDR", "localhost:3306")
|
|
|
+ dbname = env("MYSQL_TEST_DBNAME", "gotest")
|
|
|
netAddr = fmt.Sprintf("%s(%s)", prot, addr)
|
|
|
dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true", user, pass, netAddr, dbname)
|
|
|
c, err := net.Dial(prot, addr)
|
|
|
@@ -1340,3 +1346,25 @@ func TestConcurrent(t *testing.T) {
|
|
|
dbt.Logf("Reached %d concurrent connections\r\n", succeeded)
|
|
|
})
|
|
|
}
|
|
|
+
|
|
|
+// Tests custom dial functions
|
|
|
+func TestCustomDial(t *testing.T) {
|
|
|
+ if !available {
|
|
|
+ t.Skipf("MySQL-Server not running on %s", netAddr)
|
|
|
+ }
|
|
|
+
|
|
|
+ // our custom dial function which justs wraps net.Dial here
|
|
|
+ RegisterDial("mydial", func(addr string) (net.Conn, error) {
|
|
|
+ return net.Dial(prot, addr)
|
|
|
+ })
|
|
|
+
|
|
|
+ db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s&strict=true", user, pass, addr, dbname))
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("Error connecting: %s", err.Error())
|
|
|
+ }
|
|
|
+ defer db.Close()
|
|
|
+
|
|
|
+ if _, err = db.Exec("DO 1"); err != nil {
|
|
|
+ t.Fatalf("Connection failed: %s", err.Error())
|
|
|
+ }
|
|
|
+}
|