Pārlūkot izejas kodu

Merge pull request #89 from go-sql-driver/tls

Add TLS-Support
Julien Schmidt 12 gadi atpakaļ
vecāks
revīzija
55a708b5fe
8 mainītis faili ar 164 papildinājumiem un 65 dzēšanām
  1. 5 4
      README.md
  2. 12 13
      connection.go
  3. 2 5
      driver.go
  4. 47 11
      driver_test.go
  5. 3 1
      errors.go
  6. 49 20
      packets.go
  7. 34 2
      utils.go
  8. 12 9
      utils_test.go

+ 5 - 4
README.md

@@ -105,13 +105,14 @@ For Unix domain sockets the address is the absolute path to the MySQL-Server-soc
 ***Parameters are case-sensitive!***
 
 Possible Parameters are:
-  * `timeout`: **Driver** side connection timeout. The value must be a string of decimal numbers, each with optional fraction and a unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*. To set a server side timeout, use the parameter [`wait_timeout`](http://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html#sysvar_wait_timeout).
-  * `charset`: Sets the charset used for client-server interaction ("SET NAMES `value`"). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`).
   * `allowAllFiles`: `allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files. *Might be insecure!*
-  * `parseTime`: `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string`
+  * `charset`: Sets the charset used for client-server interaction ("SET NAMES `value`"). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`).
+  * `clientFoundRows`: `clientFoundRows=true` causes an UPDATE to return the number of matching rows instead of the number of rows changed.
   * `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details.
+  * `parseTime`: `parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string`
   * `strict`: Enable strict mode. MySQL warnings are treated as errors.
-  * `clientFoundRows`: `clientFoundRows=true` causes an UPDATE to return the number of matching rows instead of the number of rows changed.
+  * `timeout`: **Driver** side connection timeout. The value must be a string of decimal numbers, each with optional fraction and a unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*. To set a server side timeout, use the parameter [`wait_timeout`](http://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html#sysvar_wait_timeout).
+  * `tls`: `true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side)
 
 All other parameters are interpreted as system variables:
   * `autocommit`: *"SET autocommit=`value`"*

+ 12 - 13
connection.go

@@ -10,6 +10,7 @@
 package mysql
 
 import (
+	"crypto/tls"
 	"database/sql/driver"
 	"errors"
 	"net"
@@ -35,13 +36,15 @@ type mysqlConn struct {
 }
 
 type config struct {
-	user   string
-	passwd string
-	net    string
-	addr   string
-	dbname string
-	params map[string]string
-	loc    *time.Location
+	user    string
+	passwd  string
+	net     string
+	addr    string
+	dbname  string
+	params  map[string]string
+	loc     *time.Location
+	timeout time.Duration
+	tls     *tls.Config
 }
 
 // Handles parameters set in DSN
@@ -63,7 +66,7 @@ func (mc *mysqlConn) handleParams() (err error) {
 			}
 
 		// handled elsewhere
-		case "timeout", "allowAllFiles", "loc", "clientFoundRows":
+		case "allowAllFiles", "clientFoundRows":
 			continue
 
 		// time.Time parsing
@@ -74,14 +77,10 @@ func (mc *mysqlConn) handleParams() (err error) {
 		case "strict":
 			mc.strict = readBool(val)
 
-		// TLS-Encryption
-		case "tls":
-			err = errors.New("TLS-Encryption not implemented yet")
-			return
-
 		// Compression
 		case "compress":
 			err = errors.New("Compression not implemented yet")
+			return
 
 		// System Vars
 		default:

+ 2 - 5
driver.go

@@ -12,7 +12,6 @@ import (
 	"database/sql"
 	"database/sql/driver"
 	"net"
-	"time"
 )
 
 type mysqlDriver struct{}
@@ -34,11 +33,9 @@ func (d *mysqlDriver) Open(dsn string) (driver.Conn, error) {
 	}
 
 	// Connect to Server
-	if _, ok := mc.cfg.params["timeout"]; ok { // with timeout
-		var timeout time.Duration
-		timeout, err = time.ParseDuration(mc.cfg.params["timeout"])
+	if mc.cfg.timeout > 0 { // with timeout
 		if err == nil {
-			mc.netConn, err = net.DialTimeout(mc.cfg.net, mc.cfg.addr, timeout)
+			mc.netConn, err = net.DialTimeout(mc.cfg.net, mc.cfg.addr, mc.cfg.timeout)
 		}
 	} else { // no timeout
 		mc.netConn, err = net.Dial(mc.cfg.net, mc.cfg.addr)

+ 47 - 11
driver_test.go

@@ -807,6 +807,42 @@ func TestStrict(t *testing.T) {
 	})
 }
 
+func TestTLS(t *testing.T) {
+	runTests(t, "TestTLS", dsn+"&tls=skip-verify", func(dbt *DBTest) {
+		/* TODO: GO 1.1 API */
+		/*if err := dbt.db.Ping(); err != nil {
+		    if err == errNoTLS {
+		        dbt.Skip("Server does not support TLS. Skipping TestTLS")
+		    } else {
+		        dbt.Fatalf("Error on Ping: %s", err.Error())
+		    }
+		}*/
+
+		/* GO 1.0 API */
+		if _, err := dbt.db.Exec("DO 1"); err != nil {
+			if err == errNoTLS {
+				dbt.Log("Server does not support TLS. Skipping TestTLS")
+				return
+			} else {
+				dbt.Fatalf("Error on Ping: %s", err.Error())
+			}
+		}
+
+		rows := dbt.mustQuery("SHOW STATUS LIKE 'Ssl_cipher'")
+
+		var variable, value *sql.RawBytes
+		for rows.Next() {
+			if err := rows.Scan(&variable, &value); err != nil {
+				dbt.Fatal(err.Error())
+			}
+
+			if value == nil {
+				dbt.Fatal("No Cipher")
+			}
+		}
+	})
+}
+
 // Special cases
 
 func TestRowsClose(t *testing.T) {
@@ -1040,41 +1076,41 @@ func TestFoundRows(t *testing.T) {
 	runTests(t, "TestFoundRows1", dsn, func(dbt *DBTest) {
 		dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
 		dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
-		
+
 		res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
 		count, err := res.RowsAffected()
 		if err != nil {
-				dbt.Fatalf("res.RowsAffected() returned error: %v", err)
-			}
+			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
+		}
 		if count != 2 {
 			dbt.Fatalf("Expected 2 affected rows, got %d", count)
 		}
 		res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
 		count, err = res.RowsAffected()
 		if err != nil {
-				dbt.Fatalf("res.RowsAffected() returned error: %v", err)
-			}
+			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
+		}
 		if count != 2 {
 			dbt.Fatalf("Expected 2 affected rows, got %d", count)
 		}
 	})
-	runTests(t, "TestFoundRows2", dsn + "&clientFoundRows=true", func(dbt *DBTest) {
+	runTests(t, "TestFoundRows2", dsn+"&clientFoundRows=true", func(dbt *DBTest) {
 		dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
 		dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
-		
+
 		res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
 		count, err := res.RowsAffected()
 		if err != nil {
-				dbt.Fatalf("res.RowsAffected() returned error: %v", err)
-			}
+			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
+		}
 		if count != 2 {
 			dbt.Fatalf("Expected 2 matched rows, got %d", count)
 		}
 		res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
 		count, err = res.RowsAffected()
 		if err != nil {
-				dbt.Fatalf("res.RowsAffected() returned error: %v", err)
-			}
+			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
+		}
 		if count != 3 {
 			dbt.Fatalf("Expected 3 matched rows, got %d", count)
 		}

+ 3 - 1
errors.go

@@ -18,9 +18,11 @@ import (
 
 var (
 	errMalformPkt  = errors.New("Malformed Packet")
+	errNoTLS       = errors.New("TLS encryption requested but server does not support TLS")
+	errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords")
+	errOldProtocol = errors.New("MySQL-Server does not support required Protocol 41+")
 	errPktSync     = errors.New("Commands out of sync. You can't run this command now")
 	errPktSyncMul  = errors.New("Commands out of sync. Did you run multiple statements at once?")
-	errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords")
 	errPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.")
 )
 

+ 49 - 20
packets.go

@@ -11,9 +11,9 @@ package mysql
 
 import (
 	"bytes"
+	"crypto/tls"
 	"database/sql/driver"
 	"encoding/binary"
-	"errors"
 	"fmt"
 	"io"
 	"math"
@@ -167,7 +167,10 @@ func (mc *mysqlConn) readInitPacket() (err error) {
 	// capability flags (lower 2 bytes) [2 bytes]
 	mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
 	if mc.flags&clientProtocol41 == 0 {
-		err = errors.New("MySQL-Server does not support required Protocol 41+")
+		err = errOldProtocol
+	}
+	if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
+		return errNoTLS
 	}
 	pos += 2
 
@@ -205,19 +208,22 @@ func (mc *mysqlConn) readInitPacket() (err error) {
 // http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::HandshakeResponse
 func (mc *mysqlConn) writeAuthPacket() error {
 	// Adjust client flags based on server support
-	clientFlags := uint32(
-		clientProtocol41 |
-			clientSecureConn |
-			clientLongPassword |
-			clientTransactions |
-			clientLocalFiles,
-	)
-	if mc.flags&clientLongFlag > 0 {
-		clientFlags |= uint32(clientLongFlag)
-	}
+	clientFlags := clientProtocol41 |
+		clientSecureConn |
+		clientLongPassword |
+		clientTransactions |
+		clientLocalFiles |
+		mc.flags&clientLongFlag
+
 	if _, ok := mc.cfg.params["clientFoundRows"]; ok {
-		clientFlags |= uint32(clientFoundRows)
+		clientFlags |= clientFoundRows
 	}
+
+	// To enable TLS / SSL
+	if mc.cfg.tls != nil {
+		clientFlags |= clientSSL
+	}
+
 	// User Password
 	scrambleBuff := scramblePassword(mc.cipher, []byte(mc.cfg.passwd))
 	mc.cipher = nil
@@ -226,19 +232,13 @@ func (mc *mysqlConn) writeAuthPacket() error {
 
 	// To specify a db name
 	if len(mc.cfg.dbname) > 0 {
-		clientFlags |= uint32(clientConnectWithDB)
+		clientFlags |= clientConnectWithDB
 		pktLen += len(mc.cfg.dbname) + 1
 	}
 
 	// Calculate packet length and make buffer with that size
 	data := make([]byte, pktLen+4)
 
-	// Add the packet header  [24bit length + 1 byte sequence]
-	data[0] = byte(pktLen)
-	data[1] = byte(pktLen >> 8)
-	data[2] = byte(pktLen >> 16)
-	data[3] = mc.sequence
-
 	// ClientFlags [32 bit]
 	data[4] = byte(clientFlags)
 	data[5] = byte(clientFlags >> 8)
@@ -254,6 +254,35 @@ func (mc *mysqlConn) writeAuthPacket() error {
 	// Charset [1 byte]
 	data[12] = mc.charset
 
+	// SSL Connection Request Packet
+	// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::SSLRequest
+	if mc.cfg.tls != nil {
+		// Packet header  [24bit length + 1 byte sequence]
+		data[0] = byte((4 + 4 + 1 + 23))
+		data[1] = byte((4 + 4 + 1 + 23) >> 8)
+		data[2] = byte((4 + 4 + 1 + 23) >> 16)
+		data[3] = mc.sequence
+
+		// Send TLS / SSL request packet
+		if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
+			return err
+		}
+
+		// Switch to TLS
+		tlsConn := tls.Client(mc.netConn, mc.cfg.tls)
+		if err := tlsConn.Handshake(); err != nil {
+			return err
+		}
+		mc.netConn = tlsConn
+		mc.buf.rd = tlsConn
+	}
+
+	// Add the packet header  [24bit length + 1 byte sequence]
+	data[0] = byte(pktLen)
+	data[1] = byte(pktLen >> 8)
+	data[2] = byte(pktLen >> 16)
+	data[3] = mc.sequence
+
 	// Filler [23 bytes] (all 0x00)
 	pos := 13 + 23
 

+ 34 - 2
utils.go

@@ -11,6 +11,7 @@ package mysql
 
 import (
 	"crypto/sha1"
+	"crypto/tls"
 	"database/sql/driver"
 	"encoding/binary"
 	"fmt"
@@ -119,7 +120,35 @@ func parseDSN(dsn string) (cfg *config, err error) {
 				if len(param) != 2 {
 					continue
 				}
-				cfg.params[param[0]] = param[1]
+
+				// cfg params
+				switch value := param[1]; param[0] {
+
+				// Time Location
+				case "loc":
+					cfg.loc, err = time.LoadLocation(value)
+					if err != nil {
+						return
+					}
+
+				// Dial Timeout
+				case "timeout":
+					cfg.timeout, err = time.ParseDuration(value)
+					if err != nil {
+						return
+					}
+
+				// TLS-Encryption
+				case "tls":
+					if readBool(value) {
+						cfg.tls = &tls.Config{}
+					} else if strings.ToLower(value) == "skip-verify" {
+						cfg.tls = &tls.Config{InsecureSkipVerify: true}
+					}
+
+				default:
+					cfg.params[param[0]] = value
+				}
 			}
 		}
 	}
@@ -134,7 +163,10 @@ func parseDSN(dsn string) (cfg *config, err error) {
 		cfg.addr = "127.0.0.1:3306"
 	}
 
-	cfg.loc, err = time.LoadLocation(cfg.params["loc"])
+	// Set default location if not set
+	if cfg.loc == nil {
+		cfg.loc = time.UTC
+	}
 
 	return
 }

+ 12 - 9
utils_test.go

@@ -21,15 +21,15 @@ func TestDSNParser(t *testing.T) {
 		out string
 		loc *time.Location
 	}{
-		{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p}", time.UTC},
-		{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p}", time.UTC},
-		{"user:password@tcp(localhost:5555)/dbname?charset=utf8", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p}", time.UTC},
-		{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p}", time.UTC},
-		{"user:password@/dbname?loc=UTC", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[loc:UTC] loc:%p}", time.UTC},
-		{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[loc:Local] loc:%p}", time.Local},
-		{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p}", time.UTC},
-		{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p}", time.UTC},
-		{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p}", time.UTC},
+		{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil>}", time.UTC},
+		{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil>}", time.UTC},
+		{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil>}", time.UTC},
+		{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil>}", time.UTC},
+		{"user:password@/dbname?loc=UTC&timeout=30s", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil>}", time.UTC},
+		{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil>}", time.Local},
+		{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil>}", time.UTC},
+		{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil>}", time.UTC},
+		{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil>}", time.UTC},
 	}
 
 	var cfg *config
@@ -42,6 +42,9 @@ func TestDSNParser(t *testing.T) {
 			t.Error(err.Error())
 		}
 
+		// pointer not static
+		cfg.tls = nil
+
 		res = fmt.Sprintf("%+v", cfg)
 		if res != fmt.Sprintf(tst.out, tst.loc) {
 			t.Errorf("%d. parseDSN(%q) => %q, want %q", i, tst.in, res, fmt.Sprintf(tst.out, tst.loc))