Browse Source

Merge pull request #124 from go-sql-driver/old_passwords_support

old_passwords support
Julien Schmidt 12 years ago
parent
commit
f4bf8e8e0a
7 changed files with 179 additions and 47 deletions
  1. 1 0
      README.md
  2. 12 12
      connection.go
  3. 22 5
      driver.go
  4. 1 1
      errors.go
  5. 33 10
      packets.go
  6. 101 10
      utils.go
  7. 9 9
      utils_test.go

+ 1 - 0
README.md

@@ -107,6 +107,7 @@ For Unix domain sockets the address is the absolute path to the MySQL-Server-soc
 
 Possible Parameters are:
   * `allowAllFiles`: `allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files. [*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)
+  * `allowOldPasswords`: `allowAllFiles=true` allows the usage of the insecure old password method. This should be avoided, but is necessary in some cases. See also [the old_passwords wiki page](https://github.com/go-sql-driver/mysql/wiki/old_passwords).
   * `charset`: Sets the charset used for client-server interaction ("SET NAMES `value`"). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`).
   * `clientFoundRows`: `clientFoundRows=true` causes an UPDATE to return the number of matching rows instead of the number of rows changed.
   * `loc`: Sets the location for time.Time values (when using `parseTime=true`). The default is `UTC`. *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details.

+ 12 - 12
connection.go

@@ -20,7 +20,6 @@ import (
 type mysqlConn struct {
 	cfg              *config
 	flags            clientFlag
-	cipher           []byte
 	netConn          net.Conn
 	buf              *buffer
 	protocol         uint8
@@ -34,17 +33,18 @@ type mysqlConn struct {
 }
 
 type config struct {
-	user            string
-	passwd          string
-	net             string
-	addr            string
-	dbname          string
-	params          map[string]string
-	loc             *time.Location
-	timeout         time.Duration
-	tls             *tls.Config
-	allowAllFiles   bool
-	clientFoundRows bool
+	user              string
+	passwd            string
+	net               string
+	addr              string
+	dbname            string
+	params            map[string]string
+	loc               *time.Location
+	timeout           time.Duration
+	tls               *tls.Config
+	allowAllFiles     bool
+	allowOldPasswords bool
+	clientFoundRows   bool
 }
 
 // Handles parameters set in DSN

+ 22 - 5
driver.go

@@ -51,26 +51,42 @@ func (d *MySQLDriver) Open(dsn string) (driver.Conn, error) {
 	mc.buf = newBuffer(mc.netConn)
 
 	// Reading Handshake Initialization Packet
-	err = mc.readInitPacket()
+	cipher, err := mc.readInitPacket()
 	if err != nil {
+		mc.Close()
 		return nil, err
 	}
 
 	// Send Client Authentication Packet
-	err = mc.writeAuthPacket()
-	if err != nil {
+	if err = mc.writeAuthPacket(cipher); err != nil {
+		mc.Close()
 		return nil, err
 	}
 
 	// Read Result Packet
 	err = mc.readResultOK()
 	if err != nil {
-		return nil, err
+		// Retry with old authentication method, if allowed
+		if mc.cfg.allowOldPasswords && err == errOldPassword {
+			if err = mc.writeOldAuthPacket(cipher); err != nil {
+				mc.Close()
+				return nil, err
+			}
+			if err = mc.readResultOK(); err != nil {
+				mc.Close()
+				return nil, err
+			}
+		} else {
+			mc.Close()
+			return nil, err
+		}
+
 	}
 
 	// Get max allowed packet size
 	maxap, err := mc.getSystemVar("max_allowed_packet")
 	if err != nil {
+		mc.Close()
 		return nil, err
 	}
 	mc.maxPacketAllowed = stringToInt(maxap) - 1
@@ -81,10 +97,11 @@ func (d *MySQLDriver) Open(dsn string) (driver.Conn, error) {
 	// Handle DSN Params
 	err = mc.handleParams()
 	if err != nil {
+		mc.Close()
 		return nil, err
 	}
 
-	return mc, err
+	return mc, nil
 }
 
 func init() {

+ 1 - 1
errors.go

@@ -19,7 +19,7 @@ var (
 	errInvalidConn = errors.New("Invalid Connection")
 	errMalformPkt  = errors.New("Malformed Packet")
 	errNoTLS       = errors.New("TLS encryption requested but server does not support TLS")
-	errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords")
+	errOldPassword = errors.New("This server only supports the insecure old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
 	errOldProtocol = errors.New("MySQL-Server does not support required Protocol 41+")
 	errPktSync     = errors.New("Commands out of sync. You can't run this command now")
 	errPktSyncMul  = errors.New("Commands out of sync. Did you run multiple statements at once?")

+ 33 - 10
packets.go

@@ -137,14 +137,14 @@ func (mc *mysqlConn) splitPacket(data []byte) (err error) {
 
 // Handshake Initialization Packet
 // http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::Handshake
-func (mc *mysqlConn) readInitPacket() (err error) {
+func (mc *mysqlConn) readInitPacket() (cipher []byte, err error) {
 	data, err := mc.readPacket()
 	if err != nil {
 		return
 	}
 
 	if data[0] == iERR {
-		return mc.handleErrorPacket(data)
+		return nil, mc.handleErrorPacket(data)
 	}
 
 	// protocol version [1 byte]
@@ -153,6 +153,7 @@ func (mc *mysqlConn) readInitPacket() (err error) {
 			"Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required",
 			data[0],
 			minProtocolVersion)
+		return
 	}
 
 	// server version [null terminated string]
@@ -160,7 +161,7 @@ func (mc *mysqlConn) readInitPacket() (err error) {
 	pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
 
 	// first part of the password cipher [8 bytes]
-	mc.cipher = append(mc.cipher, data[pos:pos+8]...)
+	cipher = data[pos : pos+8]
 
 	// (filler) always 0x00 [1 byte]
 	pos += 8 + 1
@@ -168,10 +169,10 @@ func (mc *mysqlConn) readInitPacket() (err error) {
 	// capability flags (lower 2 bytes) [2 bytes]
 	mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
 	if mc.flags&clientProtocol41 == 0 {
-		err = errOldProtocol
+		return nil, errOldProtocol
 	}
 	if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
-		return errNoTLS
+		return nil, errNoTLS
 	}
 	pos += 2
 
@@ -187,7 +188,7 @@ func (mc *mysqlConn) readInitPacket() (err error) {
 		// The documentation is ambiguous about the length.
 		// The official Python library uses the fixed length 12
 		// which is not documented but seems to work.
-		mc.cipher = append(mc.cipher, data[pos:pos+12]...)
+		cipher = append(cipher, data[pos:pos+12]...)
 
 		// TODO: Verify string termination
 		// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
@@ -205,7 +206,7 @@ func (mc *mysqlConn) readInitPacket() (err error) {
 
 // Client Authentication Packet
 // http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::HandshakeResponse
-func (mc *mysqlConn) writeAuthPacket() error {
+func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 	// Adjust client flags based on server support
 	clientFlags := clientProtocol41 |
 		clientSecureConn |
@@ -224,8 +225,7 @@ func (mc *mysqlConn) writeAuthPacket() error {
 	}
 
 	// User Password
-	scrambleBuff := scramblePassword(mc.cipher, []byte(mc.cfg.passwd))
-	mc.cipher = nil
+	scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.passwd))
 
 	pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff)
 
@@ -306,6 +306,28 @@ func (mc *mysqlConn) writeAuthPacket() error {
 	return mc.writePacket(data)
 }
 
+//  Client old authentication packet
+// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::AuthSwitchResponse
+func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
+	// User password
+	scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.passwd))
+
+	// Calculate the packet lenght and add a tailing 0
+	pktLen := len(scrambleBuff) + 1
+	data := make([]byte, pktLen+4)
+
+	// Add the packet header  [24bit length + 1 byte sequence]
+	data[0] = byte(pktLen)
+	data[1] = byte(pktLen >> 8)
+	data[2] = byte(pktLen >> 16)
+	data[3] = mc.sequence
+
+	// Add the scrambled password [null terminated string]
+	copy(data[4:], scrambleBuff)
+
+	return mc.writePacket(data)
+}
+
 /******************************************************************************
 *                             Command Packets                                 *
 ******************************************************************************/
@@ -387,7 +409,8 @@ func (mc *mysqlConn) readResultOK() error {
 		case iOK:
 			return mc.handleOkPacket(data)
 
-		case iEOF: // someone is using old_passwords
+		case iEOF:
+			// someone is using old_passwords
 			return errOldPassword
 
 		default: // Error otherwise

+ 101 - 10
utils.go

@@ -124,6 +124,15 @@ func parseDSN(dsn string) (cfg *config, err error) {
 						return
 					}
 
+				// Use old authentication mode (pre MySQL 4.1)
+				case "allowOldPasswords":
+					var isBool bool
+					cfg.allowOldPasswords, isBool = readBool(value)
+					if !isBool {
+						err = fmt.Errorf("Invalid Bool value: %s", value)
+						return
+					}
+
 				// Time Location
 				case "loc":
 					cfg.loc, err = time.LoadLocation(value)
@@ -181,8 +190,25 @@ func parseDSN(dsn string) (cfg *config, err error) {
 	return
 }
 
+// Returns the bool value of the input.
+// The 2nd return value indicates if the input was a valid bool value
+func readBool(input string) (value bool, valid bool) {
+	switch input {
+	case "1", "true", "TRUE", "True":
+		return true, true
+	case "0", "false", "FALSE", "False":
+		return false, true
+	}
+
+	// Not a valid bool value
+	return
+}
+
+/******************************************************************************
+*                             Authentication                                  *
+******************************************************************************/
+
 // Encrypt password using 4.1+ method
-// http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol#4.1_and_later
 func scramblePassword(scramble, password []byte) []byte {
 	if len(password) == 0 {
 		return nil
@@ -212,20 +238,85 @@ func scramblePassword(scramble, password []byte) []byte {
 	return scramble
 }
 
-// Returns the bool value of the input.
-// The 2nd return value indicates if the input was a valid bool value
-func readBool(input string) (value bool, valid bool) {
-	switch input {
-	case "1", "true", "TRUE", "True":
-		return true, true
-	case "0", "false", "FALSE", "False":
-		return false, true
+// Encrypt password using pre 4.1 (old password) method
+// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
+type myRnd struct {
+	seed1, seed2 uint32
+}
+
+const myRndMaxVal = 0x3FFFFFFF
+
+// Pseudo random number generator
+func newMyRnd(seed1, seed2 uint32) *myRnd {
+	return &myRnd{
+		seed1: seed1 % myRndMaxVal,
+		seed2: seed2 % myRndMaxVal,
 	}
+}
+
+// Tested to be equivalent to MariaDB's floating point variant
+// http://play.golang.org/p/QHvhd4qved
+// http://play.golang.org/p/RG0q4ElWDx
+func (r *myRnd) NextByte() byte {
+	r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal
+	r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal
+
+	return byte(uint64(r.seed1) * 31 / myRndMaxVal)
+}
+
+// Generate binary hash from byte string using insecure pre 4.1 method
+func pwHash(password []byte) (result [2]uint32) {
+	var add uint32 = 7
+	var tmp uint32
+
+	result[0] = 1345345333
+	result[1] = 0x12345671
+
+	for _, c := range password {
+		// skip spaces and tabs in password
+		if c == ' ' || c == '\t' {
+			continue
+		}
+
+		tmp = uint32(c)
+		result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8)
+		result[1] += (result[1] << 8) ^ result[0]
+		add += tmp
+	}
+
+	// Remove sign bit (1<<31)-1)
+	result[0] &= 0x7FFFFFFF
+	result[1] &= 0x7FFFFFFF
 
-	// Not a valid bool value
 	return
 }
 
+// Encrypt password using insecure pre 4.1 method
+func scrambleOldPassword(scramble, password []byte) []byte {
+	if len(password) == 0 {
+		return nil
+	}
+
+	scramble = scramble[:8]
+
+	hashPw := pwHash(password)
+	hashSc := pwHash(scramble)
+
+	r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1])
+
+	var out [8]byte
+	for i := range out {
+		out[i] = r.NextByte() + 64
+	}
+
+	mask := r.NextByte()
+	for i := range out {
+		out[i] ^= mask
+	}
+
+	return out[:]
+}
+
 /******************************************************************************
 *                           Time related utils                                *
 ******************************************************************************/

+ 9 - 9
utils_test.go

@@ -20,15 +20,15 @@ func TestDSNParser(t *testing.T) {
 		out string
 		loc *time.Location
 	}{
-		{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
-		{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
-		{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
-		{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
-		{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil> allowAllFiles:true clientFoundRows:true}", time.UTC},
-		{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.Local},
-		{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
-		{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
-		{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false clientFoundRows:false}", time.UTC},
+		{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+		{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+		{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+		{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+		{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil> allowAllFiles:true allowOldPasswords:true clientFoundRows:true}", time.UTC},
+		{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.Local},
+		{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+		{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+		{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
 	}
 
 	var cfg *config