Sfoglia il codice sorgente

add an optional connecton timeout

Julien Schmidt 12 anni fa
parent
commit
a19c21848c
4 ha cambiato i file con 16 aggiunte e 2 eliminazioni
  1. 1 0
      README.md
  2. 4 0
      connection.go
  3. 10 1
      driver.go
  4. 1 1
      driver_test.go

+ 1 - 0
README.md

@@ -100,6 +100,7 @@ For Unix-sockets the address is the absolute path to the MySQL-Server-socket, e.
 **Parameters are case-sensitive!**
 
 Possible Parameters are:
+  * `timeout`: **Driver** side connection timeout. The value must be a string of decimal numbers, each with optional fraction and a unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*. To set a server side timeout, use the parameter [`wait_timeout`](http://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html#sysvar_wait_timeout).
   * `charset`: *"SET NAMES `value`"*. If multiple charsets are set (seperated by a comma), the following charset is used if setting the charset failes. This enables support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers.
   * _(pending)_ <s>`tls`</s>: will enable SSL/TLS-Encryption
   * _(pending)_ <s>`compress`</s>: will enable Compression

+ 4 - 0
connection.go

@@ -52,6 +52,10 @@ func (mc *mysqlConn) handleParams() (err error) {
 				}
 			}
 
+		// Timeout - already handled on connecting
+		case "timeout":
+			continue
+
 		// TLS-Encryption
 		case "tls":
 			err = errors.New("TLS-Encryption not implemented yet")

+ 10 - 1
driver.go

@@ -12,6 +12,7 @@ import (
 	"database/sql"
 	"database/sql/driver"
 	"net"
+	"time"
 )
 
 type mysqlDriver struct{}
@@ -27,7 +28,15 @@ func (d *mysqlDriver) Open(dsn string) (driver.Conn, error) {
 	mc.cfg = parseDSN(dsn)
 
 	// Connect to Server
-	mc.netConn, err = net.Dial(mc.cfg.net, mc.cfg.addr)
+	if _, ok := mc.cfg.params["timeout"]; ok { // with timeout
+		var timeout time.Duration
+		timeout, err = time.ParseDuration(mc.cfg.params["timeout"])
+		if err == nil {
+			mc.netConn, err = net.DialTimeout(mc.cfg.net, mc.cfg.addr, timeout)
+		}
+	} else { // no timeout
+		mc.netConn, err = net.Dial(mc.cfg.net, mc.cfg.addr)
+	}
 	if err != nil {
 		return nil, err
 	}

+ 1 - 1
driver_test.go

@@ -42,7 +42,7 @@ func getEnv() bool {
 		}
 
 		netAddr = fmt.Sprintf("%s(%s)", prot, addr)
-		dsn = fmt.Sprintf("%s:%s@%s/%s?charset=utf8", user, pass, netAddr, dbname)
+		dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&charset=utf8", user, pass, netAddr, dbname)
 
 		c, err := net.Dial(prot, addr)
 		if err == nil {