Browse Source

Support caching_sha2_password (#794)

Hajime Nakagami 7 years ago
parent
commit
f557730784
7 changed files with 152 additions and 18 deletions
  1. 1 0
      AUTHORS
  2. 6 0
      const.go
  3. 24 4
      driver.go
  4. 2 2
      driver_test.go
  5. 72 12
      packets.go
  6. 29 0
      utils.go
  7. 18 0
      utils_test.go

+ 1 - 0
AUTHORS

@@ -29,6 +29,7 @@ Egor Smolyakov <egorsmkv at gmail.com>
 Evan Shaw <evan at vendhq.com>
 Frederick Mayle <frederickmayle at gmail.com>
 Gustavo Kristic <gkristic at gmail.com>
+Hajime Nakagami <nakagami at gmail.com>
 Hanno Braun <mail at hannobraun.com>
 Henri Yandell <flamefew at gmail.com>
 Hirotaka Yamamoto <ymmt2005 at gmail.com>

+ 6 - 0
const.go

@@ -164,3 +164,9 @@ const (
 	statusInTransReadonly
 	statusSessionStateChanged
 )
+
+const (
+	cachingSha2PasswordRequestPublicKey          = 2
+	cachingSha2PasswordFastAuthSuccess           = 3
+	cachingSha2PasswordPerformFullAuthentication = 4
+)

+ 24 - 4
driver.go

@@ -107,20 +107,20 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
 	mc.writeTimeout = mc.cfg.WriteTimeout
 
 	// Reading Handshake Initialization Packet
-	cipher, err := mc.readInitPacket()
+	cipher, pluginName, err := mc.readInitPacket()
 	if err != nil {
 		mc.cleanup()
 		return nil, err
 	}
 
 	// Send Client Authentication Packet
-	if err = mc.writeAuthPacket(cipher); err != nil {
+	if err = mc.writeAuthPacket(cipher, pluginName); err != nil {
 		mc.cleanup()
 		return nil, err
 	}
 
 	// Handle response to auth packet, switch methods if possible
-	if err = handleAuthResult(mc, cipher); err != nil {
+	if err = handleAuthResult(mc, cipher, pluginName); err != nil {
 		// Authentication failed and MySQL has already closed the connection
 		// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
 		// Do not send COM_QUIT, just cleanup and return the error.
@@ -153,7 +153,27 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
 	return mc, nil
 }
 
-func handleAuthResult(mc *mysqlConn, oldCipher []byte) error {
+func handleAuthResult(mc *mysqlConn, oldCipher []byte, pluginName string) error {
+
+	// handle caching_sha2_password
+	if pluginName == "caching_sha2_password" {
+		auth, err := mc.readCachingSha2PasswordAuthResult()
+		if err != nil {
+			return err
+		}
+		if auth == cachingSha2PasswordPerformFullAuthentication {
+			if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
+				if err = mc.writeClearAuthPacket(); err != nil {
+					return err
+				}
+			} else {
+				if err = mc.writePublicKeyAuthPacket(oldCipher); err != nil {
+					return err
+				}
+			}
+		}
+	}
+
 	// Read Result Packet
 	cipher, err := mc.readResultOK()
 	if err == nil {

+ 2 - 2
driver_test.go

@@ -1842,7 +1842,7 @@ func TestSQLInjection(t *testing.T) {
 
 	dsns := []string{
 		dsn,
-		dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'",
+		dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'",
 	}
 	for _, testdsn := range dsns {
 		runTests(t, testdsn, createTest("1 OR 1=1"))
@@ -1872,7 +1872,7 @@ func TestInsertRetrieveEscapedData(t *testing.T) {
 
 	dsns := []string{
 		dsn,
-		dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'",
+		dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'",
 	}
 	for _, testdsn := range dsns {
 		runTests(t, testdsn, testData)

+ 72 - 12
packets.go

@@ -10,9 +10,14 @@ package mysql
 
 import (
 	"bytes"
+	"crypto/rand"
+	"crypto/rsa"
+	"crypto/sha1"
 	"crypto/tls"
+	"crypto/x509"
 	"database/sql/driver"
 	"encoding/binary"
+	"encoding/pem"
 	"errors"
 	"fmt"
 	"io"
@@ -154,24 +159,24 @@ func (mc *mysqlConn) writePacket(data []byte) error {
 
 // Handshake Initialization Packet
 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
-func (mc *mysqlConn) readInitPacket() ([]byte, error) {
+func (mc *mysqlConn) readInitPacket() ([]byte, string, error) {
 	data, err := mc.readPacket()
 	if err != nil {
 		// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
 		// in connection initialization we don't risk retrying non-idempotent actions.
 		if err == ErrInvalidConn {
-			return nil, driver.ErrBadConn
+			return nil, "", driver.ErrBadConn
 		}
-		return nil, err
+		return nil, "", err
 	}
 
 	if data[0] == iERR {
-		return nil, mc.handleErrorPacket(data)
+		return nil, "", mc.handleErrorPacket(data)
 	}
 
 	// protocol version [1 byte]
 	if data[0] < minProtocolVersion {
-		return nil, fmt.Errorf(
+		return nil, "", fmt.Errorf(
 			"unsupported protocol version %d. Version %d or higher is required",
 			data[0],
 			minProtocolVersion,
@@ -191,13 +196,14 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
 	// capability flags (lower 2 bytes) [2 bytes]
 	mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
 	if mc.flags&clientProtocol41 == 0 {
-		return nil, ErrOldProtocol
+		return nil, "", ErrOldProtocol
 	}
 	if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
-		return nil, ErrNoTLS
+		return nil, "", ErrNoTLS
 	}
 	pos += 2
 
+	pluginName := ""
 	if len(data) > pos {
 		// character set [1 byte]
 		// status flags [2 bytes]
@@ -219,6 +225,8 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
 		// The official Python library uses the fixed length 12
 		// which seems to work but technically could have a hidden bug.
 		cipher = append(cipher, data[pos:pos+12]...)
+		pos += 13
+		pluginName = string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)])
 
 		// TODO: Verify string termination
 		// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
@@ -232,18 +240,22 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
 		// make a memory safe copy of the cipher slice
 		var b [20]byte
 		copy(b[:], cipher)
-		return b[:], nil
+		return b[:], pluginName, nil
 	}
 
 	// make a memory safe copy of the cipher slice
 	var b [8]byte
 	copy(b[:], cipher)
-	return b[:], nil
+	return b[:], pluginName, nil
 }
 
 // Client Authentication Packet
 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
-func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
+func (mc *mysqlConn) writeAuthPacket(cipher []byte, pluginName string) error {
+	if pluginName != "mysql_native_password" && pluginName != "caching_sha2_password" {
+		return fmt.Errorf("unknown authentication plugin name '%s'", pluginName)
+	}
+
 	// Adjust client flags based on server support
 	clientFlags := clientProtocol41 |
 		clientSecureConn |
@@ -268,7 +280,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 	}
 
 	// User Password
-	scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd))
+	var scrambleBuff []byte
+	switch pluginName {
+	case "mysql_native_password":
+		scrambleBuff = scramblePassword(cipher, []byte(mc.cfg.Passwd))
+	case "caching_sha2_password":
+		scrambleBuff = scrambleCachingSha2Password(cipher, []byte(mc.cfg.Passwd))
+	}
 
 	pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1
 
@@ -350,7 +368,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 	}
 
 	// Assume native client during response
-	pos += copy(data[pos:], "mysql_native_password")
+	pos += copy(data[pos:], pluginName)
 	data[pos] = 0x00
 
 	// Send Auth packet
@@ -422,6 +440,38 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {
 	return mc.writePacket(data)
 }
 
+//  Caching sha2 authentication. Public key request and send encrypted password
+// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
+func (mc *mysqlConn) writePublicKeyAuthPacket(cipher []byte) error {
+	// request public key
+	data := mc.buf.takeSmallBuffer(4 + 1)
+	data[4] = cachingSha2PasswordRequestPublicKey
+	mc.writePacket(data)
+
+	data, err := mc.readPacket()
+	if err != nil {
+		return err
+	}
+
+	block, _ := pem.Decode(data[1:])
+	pub, err := x509.ParsePKIXPublicKey(block.Bytes)
+	if err != nil {
+		return err
+	}
+
+	plain := make([]byte, len(mc.cfg.Passwd)+1)
+	copy(plain, mc.cfg.Passwd)
+	for i := range plain {
+		j := i % len(cipher)
+		plain[i] ^= cipher[j]
+	}
+	sha1 := sha1.New()
+	enc, _ := rsa.EncryptOAEP(sha1, rand.Reader, pub.(*rsa.PublicKey), plain, nil)
+	data = mc.buf.takeSmallBuffer(4 + len(enc))
+	copy(data[4:], enc)
+	return mc.writePacket(data)
+}
+
 /******************************************************************************
 *                             Command Packets                                 *
 ******************************************************************************/
@@ -535,6 +585,16 @@ func (mc *mysqlConn) readResultOK() ([]byte, error) {
 	return nil, err
 }
 
+func (mc *mysqlConn) readCachingSha2PasswordAuthResult() (int, error) {
+	data, err := mc.readPacket()
+	if err == nil {
+		if data[0] != 1 {
+			return 0, ErrMalformPkt
+		}
+	}
+	return int(data[1]), err
+}
+
 // Result Set Header Packet
 // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
 func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {

+ 29 - 0
utils.go

@@ -10,6 +10,7 @@ package mysql
 
 import (
 	"crypto/sha1"
+	"crypto/sha256"
 	"crypto/tls"
 	"database/sql/driver"
 	"encoding/binary"
@@ -211,6 +212,34 @@ func scrambleOldPassword(scramble, password []byte) []byte {
 	return out[:]
 }
 
+// Encrypt password using 8.0 default method
+func scrambleCachingSha2Password(scramble, password []byte) []byte {
+	if len(password) == 0 {
+		return nil
+	}
+
+	// XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble))
+
+	crypt := sha256.New()
+	crypt.Write(password)
+	message1 := crypt.Sum(nil)
+
+	crypt.Reset()
+	crypt.Write(message1)
+	message1Hash := crypt.Sum(nil)
+
+	crypt.Reset()
+	crypt.Write(message1Hash)
+	crypt.Write(scramble)
+	message2 := crypt.Sum(nil)
+
+	for i := range message1 {
+		message1[i] ^= message2[i]
+	}
+
+	return message1
+}
+
 /******************************************************************************
 *                           Time related utils                                *
 ******************************************************************************/

+ 18 - 0
utils_test.go

@@ -112,6 +112,24 @@ func TestOldPass(t *testing.T) {
 	}
 }
 
+func TestCachingSha2Pass(t *testing.T) {
+	scramble := []byte{10, 47, 74, 111, 75, 73, 34, 48, 88, 76, 114, 74, 37, 13, 3, 80, 82, 2, 23, 21}
+	vectors := []struct {
+		pass string
+		out  string
+	}{
+		{"secret", "f490e76f66d9d86665ce54d98c78d0acfe2fb0b08b423da807144873d30b312c"},
+		{"secret2", "abc3934a012cf342e876071c8ee202de51785b430258a7a0138bc79c4d800bc6"},
+	}
+	for _, tuple := range vectors {
+		ours := scrambleCachingSha2Password(scramble, []byte(tuple.pass))
+		if tuple.out != fmt.Sprintf("%x", ours) {
+			t.Errorf("Failed caching sha2 password %q", tuple.pass)
+		}
+	}
+
+}
+
 func TestFormatBinaryDateTime(t *testing.T) {
 	rawDate := [11]byte{}
 	binary.LittleEndian.PutUint16(rawDate[:2], 1978)   // years