Browse Source

Require explicitly allowing old passwords

+ close connection if authentication fails
Julien Schmidt 12 years ago
parent
commit
83ed16b506
7 changed files with 67 additions and 46 deletions
  1. 1 0
      README.md
  2. 12 12
      connection.go
  3. 22 5
      driver.go
  4. 1 1
      errors.go
  5. 13 19
      packets.go
  6. 9 0
      utils.go
  7. 9 9
      utils_test.go

+ 1 - 0
README.md

@@ -107,6 +107,7 @@ For Unix domain sockets the address is the absolute path to the MySQL-Server-soc
 
 Possible Parameters are:
   * `allowAllFiles`: `allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files. [*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)
+  * `allowOldPasswords`: `allowAllFiles=true` allows the usage of the insecure old password method. This should be avoided, but is necessary in some cases. See also [the old_passwords wiki page](https://github.com/go-sql-driver/mysql/wiki/old_passwords).
   * `charset`: Sets the charset used for client-server interaction ("SET NAMES `value`"). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`).
   * `clientFoundRows`: `clientFoundRows=true` causes an UPDATE to return the number of matching rows instead of the number of rows changed.
   * `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details.

+ 12 - 12
connection.go

@@ -21,7 +21,6 @@ import (
 type mysqlConn struct {
 	cfg              *config
 	flags            clientFlag
-	cipher           []byte
 	netConn          net.Conn
 	buf              *buffer
 	protocol         uint8
@@ -35,17 +34,18 @@ type mysqlConn struct {
 }
 
 type config struct {
-	user            string
-	passwd          string
-	net             string
-	addr            string
-	dbname          string
-	params          map[string]string
-	loc             *time.Location
-	timeout         time.Duration
-	tls             *tls.Config
-	allowAllFiles   bool
-	clientFoundRows bool
+	user              string
+	passwd            string
+	net               string
+	addr              string
+	dbname            string
+	params            map[string]string
+	loc               *time.Location
+	timeout           time.Duration
+	tls               *tls.Config
+	allowAllFiles     bool
+	allowOldPasswords bool
+	clientFoundRows   bool
 }
 
 // Handles parameters set in DSN

+ 22 - 5
driver.go

@@ -52,26 +52,42 @@ func (d *MySQLDriver) Open(dsn string) (driver.Conn, error) {
 	mc.buf = newBuffer(mc.netConn)
 
 	// Reading Handshake Initialization Packet
-	err = mc.readInitPacket()
+	cipher, err := mc.readInitPacket()
 	if err != nil {
+		mc.Close()
 		return nil, err
 	}
 
 	// Send Client Authentication Packet
-	err = mc.writeAuthPacket()
-	if err != nil {
+	if err = mc.writeAuthPacket(cipher); err != nil {
+		mc.Close()
 		return nil, err
 	}
 
 	// Read Result Packet
 	err = mc.readResultOK()
 	if err != nil {
-		return nil, err
+		// Retry with old authentication method, if allowed
+		if mc.cfg.allowOldPasswords && err == errOldPassword {
+			if err = mc.writeOldAuthPacket(cipher); err != nil {
+				mc.Close()
+				return nil, err
+			}
+			if err = mc.readResultOK(); err != nil {
+				mc.Close()
+				return nil, err
+			}
+		} else {
+			mc.Close()
+			return nil, err
+		}
+
 	}
 
 	// Get max allowed packet size
 	maxap, err := mc.getSystemVar("max_allowed_packet")
 	if err != nil {
+		mc.Close()
 		return nil, err
 	}
 	mc.maxPacketAllowed = stringToInt(maxap) - 1
@@ -82,10 +98,11 @@ func (d *MySQLDriver) Open(dsn string) (driver.Conn, error) {
 	// Handle DSN Params
 	err = mc.handleParams()
 	if err != nil {
+		mc.Close()
 		return nil, err
 	}
 
-	return mc, err
+	return mc, nil
 }
 
 func init() {

+ 1 - 1
errors.go

@@ -20,7 +20,7 @@ var (
 	errInvalidConn = errors.New("Invalid Connection")
 	errMalformPkt  = errors.New("Malformed Packet")
 	errNoTLS       = errors.New("TLS encryption requested but server does not support TLS")
-	errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords")
+	errOldPassword = errors.New("This server only supports the insecure old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
 	errOldProtocol = errors.New("MySQL-Server does not support required Protocol 41+")
 	errPktSync     = errors.New("Commands out of sync. You can't run this command now")
 	errPktSyncMul  = errors.New("Commands out of sync. Did you run multiple statements at once?")

+ 13 - 19
packets.go

@@ -138,14 +138,14 @@ func (mc *mysqlConn) splitPacket(data []byte) (err error) {
 
 // Handshake Initialization Packet
 // http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::Handshake
-func (mc *mysqlConn) readInitPacket() (err error) {
+func (mc *mysqlConn) readInitPacket() (cipher []byte, err error) {
 	data, err := mc.readPacket()
 	if err != nil {
 		return
 	}
 
 	if data[0] == iERR {
-		return mc.handleErrorPacket(data)
+		return nil, mc.handleErrorPacket(data)
 	}
 
 	// protocol version [1 byte]
@@ -154,6 +154,7 @@ func (mc *mysqlConn) readInitPacket() (err error) {
 			"Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required",
 			data[0],
 			minProtocolVersion)
+		return
 	}
 
 	// server version [null terminated string]
@@ -161,7 +162,7 @@ func (mc *mysqlConn) readInitPacket() (err error) {
 	pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
 
 	// first part of the password cipher [8 bytes]
-	mc.cipher = append(mc.cipher, data[pos:pos+8]...)
+	cipher = data[pos : pos+8]
 
 	// (filler) always 0x00 [1 byte]
 	pos += 8 + 1
@@ -169,10 +170,10 @@ func (mc *mysqlConn) readInitPacket() (err error) {
 	// capability flags (lower 2 bytes) [2 bytes]
 	mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
 	if mc.flags&clientProtocol41 == 0 {
-		err = errOldProtocol
+		return nil, errOldProtocol
 	}
 	if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
-		return errNoTLS
+		return nil, errNoTLS
 	}
 	pos += 2
 
@@ -188,7 +189,7 @@ func (mc *mysqlConn) readInitPacket() (err error) {
 		// The documentation is ambiguous about the length.
 		// The official Python library uses the fixed length 12
 		// which is not documented but seems to work.
-		mc.cipher = append(mc.cipher, data[pos:pos+12]...)
+		cipher = append(cipher, data[pos:pos+12]...)
 
 		// TODO: Verify string termination
 		// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
@@ -206,7 +207,7 @@ func (mc *mysqlConn) readInitPacket() (err error) {
 
 // Client Authentication Packet
 // http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::HandshakeResponse
-func (mc *mysqlConn) writeAuthPacket() error {
+func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 	// Adjust client flags based on server support
 	clientFlags := clientProtocol41 |
 		clientSecureConn |
@@ -225,7 +226,7 @@ func (mc *mysqlConn) writeAuthPacket() error {
 	}
 
 	// User Password
-	scrambleBuff := scramblePassword(mc.cipher, []byte(mc.cfg.passwd))
+	scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.passwd))
 
 	pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff)
 
@@ -308,10 +309,9 @@ func (mc *mysqlConn) writeAuthPacket() error {
 
 //  Client old authentication packet
 // http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::AuthSwitchResponse
-func (mc *mysqlConn) writeOldAuthPacket() error {
+func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
 	// User password
-	scrambleBuff := scrambleOldPassword(mc.cipher, []byte(mc.cfg.passwd))
-	mc.cipher = nil
+	scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.passwd))
 
 	// Calculate the packet lenght and add a tailing 0
 	pktLen := len(scrambleBuff) + 1
@@ -323,7 +323,7 @@ func (mc *mysqlConn) writeOldAuthPacket() error {
 	data[2] = byte(pktLen >> 16)
 	data[3] = mc.sequence
 
-	// Add the scrambled password (it will be terminated by 0)
+	// Add the scrambled password [null terminated string]
 	copy(data[4:], scrambleBuff)
 
 	return mc.writePacket(data)
@@ -408,17 +408,11 @@ func (mc *mysqlConn) readResultOK() error {
 		switch data[0] {
 
 		case iOK:
-			// Remove the chipher in case of successfull authentication
-			mc.cipher = nil
 			return mc.handleOkPacket(data)
 
 		case iEOF:
 			// someone is using old_passwords
-			err = mc.writeOldAuthPacket()
-			if err != nil {
-				return err
-			}
-			return mc.readResultOK()
+			return errOldPassword
 
 		default: // Error otherwise
 			return mc.handleErrorPacket(data)

+ 9 - 0
utils.go

@@ -126,6 +126,15 @@ func parseDSN(dsn string) (cfg *config, err error) {
 						return
 					}
 
+				// Use old authentication mode (pre MySQL 4.1)
+				case "allowOldPasswords":
+					var isBool bool
+					cfg.allowOldPasswords, isBool = readBool(value)
+					if !isBool {
+						err = fmt.Errorf("Invalid Bool value: %s", value)
+						return
+					}
+
 				// Time Location
 				case "loc":
 					cfg.loc, err = time.LoadLocation(value)

+ 9 - 9
utils_test.go

@@ -21,15 +21,15 @@ func TestDSNParser(t *testing.T) {
 		out string
 		loc *time.Location
 	}{
-		{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
-		{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
-		{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
-		{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
-		{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil> allowAllFiles:true clientFoundRows:true}", time.UTC},
-		{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.Local},
-		{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
-		{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
-		{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
+		{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+		{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+		{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+		{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+		{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil> allowAllFiles:true allowOldPasswords:true clientFoundRows:true}", time.UTC},
+		{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.Local},
+		{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+		{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+		{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
 	}
 
 	var cfg *config