Browse Source

Add TLS-Support

Fixes #25
Julien Schmidt 12 years ago
parent
commit
e288006499
7 changed files with 159 additions and 61 deletions
  1. 12 13
      connection.go
  2. 2 5
      driver.go
  3. 47 11
      driver_test.go
  4. 3 1
      errors.go
  5. 49 20
      packets.go
  6. 34 2
      utils.go
  7. 12 9
      utils_test.go

+ 12 - 13
connection.go

@@ -10,6 +10,7 @@
 package mysql
 package mysql
 
 
 import (
 import (
+	"crypto/tls"
 	"database/sql/driver"
 	"database/sql/driver"
 	"errors"
 	"errors"
 	"net"
 	"net"
@@ -35,13 +36,15 @@ type mysqlConn struct {
 }
 }
 
 
 type config struct {
 type config struct {
-	user   string
-	passwd string
-	net    string
-	addr   string
-	dbname string
-	params map[string]string
-	loc    *time.Location
+	user    string
+	passwd  string
+	net     string
+	addr    string
+	dbname  string
+	params  map[string]string
+	loc     *time.Location
+	timeout time.Duration
+	tls     *tls.Config
 }
 }
 
 
 // Handles parameters set in DSN
 // Handles parameters set in DSN
@@ -63,7 +66,7 @@ func (mc *mysqlConn) handleParams() (err error) {
 			}
 			}
 
 
 		// handled elsewhere
 		// handled elsewhere
-		case "timeout", "allowAllFiles", "loc", "clientFoundRows":
+		case "allowAllFiles", "clientFoundRows":
 			continue
 			continue
 
 
 		// time.Time parsing
 		// time.Time parsing
@@ -74,14 +77,10 @@ func (mc *mysqlConn) handleParams() (err error) {
 		case "strict":
 		case "strict":
 			mc.strict = readBool(val)
 			mc.strict = readBool(val)
 
 
-		// TLS-Encryption
-		case "tls":
-			err = errors.New("TLS-Encryption not implemented yet")
-			return
-
 		// Compression
 		// Compression
 		case "compress":
 		case "compress":
 			err = errors.New("Compression not implemented yet")
 			err = errors.New("Compression not implemented yet")
+			return
 
 
 		// System Vars
 		// System Vars
 		default:
 		default:

+ 2 - 5
driver.go

@@ -12,7 +12,6 @@ import (
 	"database/sql"
 	"database/sql"
 	"database/sql/driver"
 	"database/sql/driver"
 	"net"
 	"net"
-	"time"
 )
 )
 
 
 type mysqlDriver struct{}
 type mysqlDriver struct{}
@@ -34,11 +33,9 @@ func (d *mysqlDriver) Open(dsn string) (driver.Conn, error) {
 	}
 	}
 
 
 	// Connect to Server
 	// Connect to Server
-	if _, ok := mc.cfg.params["timeout"]; ok { // with timeout
-		var timeout time.Duration
-		timeout, err = time.ParseDuration(mc.cfg.params["timeout"])
+	if mc.cfg.timeout > 0 { // with timeout
 		if err == nil {
 		if err == nil {
-			mc.netConn, err = net.DialTimeout(mc.cfg.net, mc.cfg.addr, timeout)
+			mc.netConn, err = net.DialTimeout(mc.cfg.net, mc.cfg.addr, mc.cfg.timeout)
 		}
 		}
 	} else { // no timeout
 	} else { // no timeout
 		mc.netConn, err = net.Dial(mc.cfg.net, mc.cfg.addr)
 		mc.netConn, err = net.Dial(mc.cfg.net, mc.cfg.addr)

+ 47 - 11
driver_test.go

@@ -807,6 +807,42 @@ func TestStrict(t *testing.T) {
 	})
 	})
 }
 }
 
 
+func TestTLS(t *testing.T) {
+	runTests(t, "TestTLS", dsn+"&tls=skip-verify", func(dbt *DBTest) {
+		/* TODO: GO 1.1 API */
+		/*if err := dbt.db.Ping(); err != nil {
+		    if err == errNoTLS {
+		        dbt.Skip("Server does not support TLS. Skipping TestTLS")
+		    } else {
+		        dbt.Fatalf("Error on Ping: %s", err.Error())
+		    }
+		}*/
+
+		/* GO 1.0 API */
+		if _, err := dbt.db.Exec("DO 1"); err != nil {
+			if err == errNoTLS {
+				dbt.Log("Server does not support TLS. Skipping TestTLS")
+				return
+			} else {
+				dbt.Fatalf("Error on Ping: %s", err.Error())
+			}
+		}
+
+		rows := dbt.mustQuery("SHOW STATUS LIKE 'Ssl_cipher'")
+
+		var variable, value *sql.RawBytes
+		for rows.Next() {
+			if err := rows.Scan(&variable, &value); err != nil {
+				dbt.Fatal(err.Error())
+			}
+
+			if value == nil {
+				dbt.Fatal("No Cipher")
+			}
+		}
+	})
+}
+
 // Special cases
 // Special cases
 
 
 func TestRowsClose(t *testing.T) {
 func TestRowsClose(t *testing.T) {
@@ -1040,41 +1076,41 @@ func TestFoundRows(t *testing.T) {
 	runTests(t, "TestFoundRows1", dsn, func(dbt *DBTest) {
 	runTests(t, "TestFoundRows1", dsn, func(dbt *DBTest) {
 		dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
 		dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
 		dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
 		dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
-		
+
 		res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
 		res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
 		count, err := res.RowsAffected()
 		count, err := res.RowsAffected()
 		if err != nil {
 		if err != nil {
-				dbt.Fatalf("res.RowsAffected() returned error: %v", err)
-			}
+			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
+		}
 		if count != 2 {
 		if count != 2 {
 			dbt.Fatalf("Expected 2 affected rows, got %d", count)
 			dbt.Fatalf("Expected 2 affected rows, got %d", count)
 		}
 		}
 		res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
 		res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
 		count, err = res.RowsAffected()
 		count, err = res.RowsAffected()
 		if err != nil {
 		if err != nil {
-				dbt.Fatalf("res.RowsAffected() returned error: %v", err)
-			}
+			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
+		}
 		if count != 2 {
 		if count != 2 {
 			dbt.Fatalf("Expected 2 affected rows, got %d", count)
 			dbt.Fatalf("Expected 2 affected rows, got %d", count)
 		}
 		}
 	})
 	})
-	runTests(t, "TestFoundRows2", dsn + "&clientFoundRows=true", func(dbt *DBTest) {
+	runTests(t, "TestFoundRows2", dsn+"&clientFoundRows=true", func(dbt *DBTest) {
 		dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
 		dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)")
 		dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
 		dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)")
-		
+
 		res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
 		res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0")
 		count, err := res.RowsAffected()
 		count, err := res.RowsAffected()
 		if err != nil {
 		if err != nil {
-				dbt.Fatalf("res.RowsAffected() returned error: %v", err)
-			}
+			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
+		}
 		if count != 2 {
 		if count != 2 {
 			dbt.Fatalf("Expected 2 matched rows, got %d", count)
 			dbt.Fatalf("Expected 2 matched rows, got %d", count)
 		}
 		}
 		res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
 		res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1")
 		count, err = res.RowsAffected()
 		count, err = res.RowsAffected()
 		if err != nil {
 		if err != nil {
-				dbt.Fatalf("res.RowsAffected() returned error: %v", err)
-			}
+			dbt.Fatalf("res.RowsAffected() returned error: %v", err)
+		}
 		if count != 3 {
 		if count != 3 {
 			dbt.Fatalf("Expected 3 matched rows, got %d", count)
 			dbt.Fatalf("Expected 3 matched rows, got %d", count)
 		}
 		}

+ 3 - 1
errors.go

@@ -18,9 +18,11 @@ import (
 
 
 var (
 var (
 	errMalformPkt  = errors.New("Malformed Packet")
 	errMalformPkt  = errors.New("Malformed Packet")
+	errNoTLS       = errors.New("TLS encryption requested but server does not support TLS")
+	errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords")
+	errOldProtocol = errors.New("MySQL-Server does not support required Protocol 41+")
 	errPktSync     = errors.New("Commands out of sync. You can't run this command now")
 	errPktSync     = errors.New("Commands out of sync. You can't run this command now")
 	errPktSyncMul  = errors.New("Commands out of sync. Did you run multiple statements at once?")
 	errPktSyncMul  = errors.New("Commands out of sync. Did you run multiple statements at once?")
-	errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords")
 	errPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.")
 	errPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.")
 )
 )
 
 

+ 49 - 20
packets.go

@@ -11,9 +11,9 @@ package mysql
 
 
 import (
 import (
 	"bytes"
 	"bytes"
+	"crypto/tls"
 	"database/sql/driver"
 	"database/sql/driver"
 	"encoding/binary"
 	"encoding/binary"
-	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"math"
 	"math"
@@ -167,7 +167,10 @@ func (mc *mysqlConn) readInitPacket() (err error) {
 	// capability flags (lower 2 bytes) [2 bytes]
 	// capability flags (lower 2 bytes) [2 bytes]
 	mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
 	mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
 	if mc.flags&clientProtocol41 == 0 {
 	if mc.flags&clientProtocol41 == 0 {
-		err = errors.New("MySQL-Server does not support required Protocol 41+")
+		err = errOldProtocol
+	}
+	if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
+		return errNoTLS
 	}
 	}
 	pos += 2
 	pos += 2
 
 
@@ -205,19 +208,22 @@ func (mc *mysqlConn) readInitPacket() (err error) {
 // http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::HandshakeResponse
 // http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::HandshakeResponse
 func (mc *mysqlConn) writeAuthPacket() error {
 func (mc *mysqlConn) writeAuthPacket() error {
 	// Adjust client flags based on server support
 	// Adjust client flags based on server support
-	clientFlags := uint32(
-		clientProtocol41 |
-			clientSecureConn |
-			clientLongPassword |
-			clientTransactions |
-			clientLocalFiles,
-	)
-	if mc.flags&clientLongFlag > 0 {
-		clientFlags |= uint32(clientLongFlag)
-	}
+	clientFlags := clientProtocol41 |
+		clientSecureConn |
+		clientLongPassword |
+		clientTransactions |
+		clientLocalFiles |
+		mc.flags&clientLongFlag
+
 	if _, ok := mc.cfg.params["clientFoundRows"]; ok {
 	if _, ok := mc.cfg.params["clientFoundRows"]; ok {
-		clientFlags |= uint32(clientFoundRows)
+		clientFlags |= clientFoundRows
 	}
 	}
+
+	// To enable TLS / SSL
+	if mc.cfg.tls != nil {
+		clientFlags |= clientSSL
+	}
+
 	// User Password
 	// User Password
 	scrambleBuff := scramblePassword(mc.cipher, []byte(mc.cfg.passwd))
 	scrambleBuff := scramblePassword(mc.cipher, []byte(mc.cfg.passwd))
 	mc.cipher = nil
 	mc.cipher = nil
@@ -226,19 +232,13 @@ func (mc *mysqlConn) writeAuthPacket() error {
 
 
 	// To specify a db name
 	// To specify a db name
 	if len(mc.cfg.dbname) > 0 {
 	if len(mc.cfg.dbname) > 0 {
-		clientFlags |= uint32(clientConnectWithDB)
+		clientFlags |= clientConnectWithDB
 		pktLen += len(mc.cfg.dbname) + 1
 		pktLen += len(mc.cfg.dbname) + 1
 	}
 	}
 
 
 	// Calculate packet length and make buffer with that size
 	// Calculate packet length and make buffer with that size
 	data := make([]byte, pktLen+4)
 	data := make([]byte, pktLen+4)
 
 
-	// Add the packet header  [24bit length + 1 byte sequence]
-	data[0] = byte(pktLen)
-	data[1] = byte(pktLen >> 8)
-	data[2] = byte(pktLen >> 16)
-	data[3] = mc.sequence
-
 	// ClientFlags [32 bit]
 	// ClientFlags [32 bit]
 	data[4] = byte(clientFlags)
 	data[4] = byte(clientFlags)
 	data[5] = byte(clientFlags >> 8)
 	data[5] = byte(clientFlags >> 8)
@@ -254,6 +254,35 @@ func (mc *mysqlConn) writeAuthPacket() error {
 	// Charset [1 byte]
 	// Charset [1 byte]
 	data[12] = mc.charset
 	data[12] = mc.charset
 
 
+	// SSL Connection Request Packet
+	// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::SSLRequest
+	if mc.cfg.tls != nil {
+		// Packet header  [24bit length + 1 byte sequence]
+		data[0] = byte((4 + 4 + 1 + 23))
+		data[1] = byte((4 + 4 + 1 + 23) >> 8)
+		data[2] = byte((4 + 4 + 1 + 23) >> 16)
+		data[3] = mc.sequence
+
+		// Send TLS / SSL request packet
+		if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
+			return err
+		}
+
+		// Switch to TLS
+		tlsConn := tls.Client(mc.netConn, mc.cfg.tls)
+		if err := tlsConn.Handshake(); err != nil {
+			return err
+		}
+		mc.netConn = tlsConn
+		mc.buf.rd = tlsConn
+	}
+
+	// Add the packet header  [24bit length + 1 byte sequence]
+	data[0] = byte(pktLen)
+	data[1] = byte(pktLen >> 8)
+	data[2] = byte(pktLen >> 16)
+	data[3] = mc.sequence
+
 	// Filler [23 bytes] (all 0x00)
 	// Filler [23 bytes] (all 0x00)
 	pos := 13 + 23
 	pos := 13 + 23
 
 

+ 34 - 2
utils.go

@@ -11,6 +11,7 @@ package mysql
 
 
 import (
 import (
 	"crypto/sha1"
 	"crypto/sha1"
+	"crypto/tls"
 	"database/sql/driver"
 	"database/sql/driver"
 	"encoding/binary"
 	"encoding/binary"
 	"fmt"
 	"fmt"
@@ -119,7 +120,35 @@ func parseDSN(dsn string) (cfg *config, err error) {
 				if len(param) != 2 {
 				if len(param) != 2 {
 					continue
 					continue
 				}
 				}
-				cfg.params[param[0]] = param[1]
+
+				// cfg params
+				switch value := param[1]; param[0] {
+
+				// Time Location
+				case "loc":
+					cfg.loc, err = time.LoadLocation(value)
+					if err != nil {
+						return
+					}
+
+				// Dial Timeout
+				case "timeout":
+					cfg.timeout, err = time.ParseDuration(value)
+					if err != nil {
+						return
+					}
+
+				// TLS-Encryption
+				case "tls":
+					if readBool(value) {
+						cfg.tls = &tls.Config{}
+					} else if strings.ToLower(value) == "skip-verify" {
+						cfg.tls = &tls.Config{InsecureSkipVerify: true}
+					}
+
+				default:
+					cfg.params[param[0]] = value
+				}
 			}
 			}
 		}
 		}
 	}
 	}
@@ -134,7 +163,10 @@ func parseDSN(dsn string) (cfg *config, err error) {
 		cfg.addr = "127.0.0.1:3306"
 		cfg.addr = "127.0.0.1:3306"
 	}
 	}
 
 
-	cfg.loc, err = time.LoadLocation(cfg.params["loc"])
+	// Set default location if not set
+	if cfg.loc == nil {
+		cfg.loc = time.UTC
+	}
 
 
 	return
 	return
 }
 }

+ 12 - 9
utils_test.go

@@ -21,15 +21,15 @@ func TestDSNParser(t *testing.T) {
 		out string
 		out string
 		loc *time.Location
 		loc *time.Location
 	}{
 	}{
-		{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p}", time.UTC},
-		{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p}", time.UTC},
-		{"user:password@tcp(localhost:5555)/dbname?charset=utf8", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p}", time.UTC},
-		{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p}", time.UTC},
-		{"user:password@/dbname?loc=UTC", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[loc:UTC] loc:%p}", time.UTC},
-		{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[loc:Local] loc:%p}", time.Local},
-		{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p}", time.UTC},
-		{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p}", time.UTC},
-		{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p}", time.UTC},
+		{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil>}", time.UTC},
+		{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil>}", time.UTC},
+		{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil>}", time.UTC},
+		{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil>}", time.UTC},
+		{"user:password@/dbname?loc=UTC&timeout=30s", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil>}", time.UTC},
+		{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil>}", time.Local},
+		{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil>}", time.UTC},
+		{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil>}", time.UTC},
+		{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil>}", time.UTC},
 	}
 	}
 
 
 	var cfg *config
 	var cfg *config
@@ -42,6 +42,9 @@ func TestDSNParser(t *testing.T) {
 			t.Error(err.Error())
 			t.Error(err.Error())
 		}
 		}
 
 
+		// pointer not static
+		cfg.tls = nil
+
 		res = fmt.Sprintf("%+v", cfg)
 		res = fmt.Sprintf("%+v", cfg)
 		if res != fmt.Sprintf(tst.out, tst.loc) {
 		if res != fmt.Sprintf(tst.out, tst.loc) {
 			t.Errorf("%d. parseDSN(%q) => %q, want %q", i, tst.in, res, fmt.Sprintf(tst.out, tst.loc))
 			t.Errorf("%d. parseDSN(%q) => %q, want %q", i, tst.in, res, fmt.Sprintf(tst.out, tst.loc))