Browse Source

Auth refactoring and bug fixes (#807)

* log missing auth plugin name

* refactor auth handling

* auth: fix AllowNativePasswords

* auth: remove plugin name print

* packets: attempt to fix writePublicKeyAuthPacket

* packets: do not NUL-terminate auth switch packets

* move handleAuthResult to auth

* add old_password auth tests

* auth: add empty old_password test

* auth: add cleartext auth tests

* auth: add native auth tests

* auth: add caching_sha2 tests

* rename init and auth packets to documented names

* auth: fix plugin name for switched auth methods

* buffer: optimize default branches

* auth: add tests for switch to caching sha2

* auth: add tests for switch to cleartext password

* auth: add tests for switch to native password

* auth: sync NUL termination with official connectors

* packets: handle missing NUL bytes in AuthSwitchRequests

Updates #795
Julien Schmidt 7 years ago
parent
commit
affd4c9396
12 changed files with 1294 additions and 461 deletions
  1. 2 1
      AUTHORS
  2. 309 0
      auth.go
  3. 853 0
      auth_test.go
  4. 6 6
      buffer.go
  5. 6 9
      connection_go18.go
  6. 1 0
      const.go
  7. 14 92
      driver.go
  8. 1 2
      infile.go
  9. 71 158
      packets.go
  10. 31 12
      packets_test.go
  11. 0 143
      utils.go
  12. 0 38
      utils_test.go

+ 2 - 1
AUTHORS

@@ -20,6 +20,7 @@ Asta Xie <xiemengjun at gmail.com>
 Bulat Gaifullin <gaifullinbf at gmail.com>
 Carlos Nieto <jose.carlos at menteslibres.net>
 Chris Moos <chris at tech9computers.com>
+Craig Wilson <craiggwilson at gmail.com>
 Daniel Montoya <dsmontoyam at gmail.com>
 Daniel Nichter <nil at codenode.com>
 Daniël van Eeden <git at myname.nl>
@@ -55,7 +56,7 @@ Lion Yang <lion at aosc.xyz>
 Luca Looz <luca.looz92 at gmail.com>
 Lucas Liu <extrafliu at gmail.com>
 Luke Scott <luke at webconnex.com>
-Maciej Zimnoch <maciej.zimnoch@codilime.com>
+Maciej Zimnoch <maciej.zimnoch at codilime.com>
 Michael Woolnough <michael.woolnough at gmail.com>
 Nicola Peduzzi <thenikso at gmail.com>
 Olivier Mengué <dolmen at cpan.org>

+ 309 - 0
auth.go

@@ -0,0 +1,309 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+	"crypto/rand"
+	"crypto/rsa"
+	"crypto/sha1"
+	"crypto/sha256"
+	"crypto/x509"
+	"encoding/pem"
+)
+
+// Hash password using pre 4.1 (old password) method
+// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
+type myRnd struct {
+	seed1, seed2 uint32
+}
+
+const myRndMaxVal = 0x3FFFFFFF
+
+// Pseudo random number generator
+func newMyRnd(seed1, seed2 uint32) *myRnd {
+	return &myRnd{
+		seed1: seed1 % myRndMaxVal,
+		seed2: seed2 % myRndMaxVal,
+	}
+}
+
+// Tested to be equivalent to MariaDB's floating point variant
+// http://play.golang.org/p/QHvhd4qved
+// http://play.golang.org/p/RG0q4ElWDx
+func (r *myRnd) NextByte() byte {
+	r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal
+	r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal
+
+	return byte(uint64(r.seed1) * 31 / myRndMaxVal)
+}
+
+// Generate binary hash from byte string using insecure pre 4.1 method
+func pwHash(password []byte) (result [2]uint32) {
+	var add uint32 = 7
+	var tmp uint32
+
+	result[0] = 1345345333
+	result[1] = 0x12345671
+
+	for _, c := range password {
+		// skip spaces and tabs in password
+		if c == ' ' || c == '\t' {
+			continue
+		}
+
+		tmp = uint32(c)
+		result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8)
+		result[1] += (result[1] << 8) ^ result[0]
+		add += tmp
+	}
+
+	// Remove sign bit (1<<31)-1)
+	result[0] &= 0x7FFFFFFF
+	result[1] &= 0x7FFFFFFF
+
+	return
+}
+
+// Hash password using insecure pre 4.1 method
+func scrambleOldPassword(scramble []byte, password string) []byte {
+	if len(password) == 0 {
+		return nil
+	}
+
+	scramble = scramble[:8]
+
+	hashPw := pwHash([]byte(password))
+	hashSc := pwHash(scramble)
+
+	r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1])
+
+	var out [8]byte
+	for i := range out {
+		out[i] = r.NextByte() + 64
+	}
+
+	mask := r.NextByte()
+	for i := range out {
+		out[i] ^= mask
+	}
+
+	return out[:]
+}
+
+// Hash password using 4.1+ method (SHA1)
+func scramblePassword(scramble []byte, password string) []byte {
+	if len(password) == 0 {
+		return nil
+	}
+
+	// stage1Hash = SHA1(password)
+	crypt := sha1.New()
+	crypt.Write([]byte(password))
+	stage1 := crypt.Sum(nil)
+
+	// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
+	// inner Hash
+	crypt.Reset()
+	crypt.Write(stage1)
+	hash := crypt.Sum(nil)
+
+	// outer Hash
+	crypt.Reset()
+	crypt.Write(scramble)
+	crypt.Write(hash)
+	scramble = crypt.Sum(nil)
+
+	// token = scrambleHash XOR stage1Hash
+	for i := range scramble {
+		scramble[i] ^= stage1[i]
+	}
+	return scramble
+}
+
+// Hash password using MySQL 8+ method (SHA256)
+func scrambleSHA256Password(scramble []byte, password string) []byte {
+	if len(password) == 0 {
+		return nil
+	}
+
+	// XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble))
+
+	crypt := sha256.New()
+	crypt.Write([]byte(password))
+	message1 := crypt.Sum(nil)
+
+	crypt.Reset()
+	crypt.Write(message1)
+	message1Hash := crypt.Sum(nil)
+
+	crypt.Reset()
+	crypt.Write(message1Hash)
+	crypt.Write(scramble)
+	message2 := crypt.Sum(nil)
+
+	for i := range message1 {
+		message1[i] ^= message2[i]
+	}
+
+	return message1
+}
+
+func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) {
+	switch plugin {
+	case "caching_sha2_password":
+		authResp := scrambleSHA256Password(authData, mc.cfg.Passwd)
+		return authResp, (authResp == nil), nil
+
+	case "mysql_old_password":
+		if !mc.cfg.AllowOldPasswords {
+			return nil, false, ErrOldPassword
+		}
+		// Note: there are edge cases where this should work but doesn't;
+		// this is currently "wontfix":
+		// https://github.com/go-sql-driver/mysql/issues/184
+		authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd)
+		return authResp, true, nil
+
+	case "mysql_clear_password":
+		if !mc.cfg.AllowCleartextPasswords {
+			return nil, false, ErrCleartextPassword
+		}
+		// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
+		// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
+		return []byte(mc.cfg.Passwd), true, nil
+
+	case "mysql_native_password":
+		if !mc.cfg.AllowNativePasswords {
+			return nil, false, ErrNativePassword
+		}
+		// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
+		// Native password authentication only need and will need 20-byte challenge.
+		authResp := scramblePassword(authData[:20], mc.cfg.Passwd)
+		return authResp, false, nil
+
+	default:
+		errLog.Print("unknown auth plugin:", plugin)
+		return nil, false, ErrUnknownPlugin
+	}
+}
+
+func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
+	// Read Result Packet
+	authData, newPlugin, err := mc.readAuthResult()
+	if err != nil {
+		return err
+	}
+
+	// handle auth plugin switch, if requested
+	if newPlugin != "" {
+		// If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is
+		// sent and we have to keep using the cipher sent in the init packet.
+		if authData == nil {
+			authData = oldAuthData
+		}
+
+		plugin = newPlugin
+
+		authResp, addNUL, err := mc.auth(authData, plugin)
+		if err != nil {
+			return err
+		}
+		if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil {
+			return err
+		}
+
+		// Read Result Packet
+		authData, newPlugin, err = mc.readAuthResult()
+		if err != nil {
+			return err
+		}
+		// Do not allow to change the auth plugin more than once
+		if newPlugin != "" {
+			return ErrMalformPkt
+		}
+	}
+
+	switch plugin {
+
+	// https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/
+	case "caching_sha2_password":
+		switch len(authData) {
+		case 0:
+			return nil // auth successful
+		case 1:
+			switch authData[0] {
+			case cachingSha2PasswordFastAuthSuccess:
+				if err = mc.readResultOK(); err == nil {
+					return nil // auth successful
+				}
+
+			case cachingSha2PasswordPerformFullAuthentication:
+				if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
+					// write cleartext auth packet
+					err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true)
+					if err != nil {
+						return err
+					}
+				} else {
+					seed := oldAuthData
+
+					// TODO: allow to specify a local file with the pub key via
+					// the DSN
+
+					// request public key
+					data := mc.buf.takeSmallBuffer(4 + 1)
+					data[4] = cachingSha2PasswordRequestPublicKey
+					mc.writePacket(data)
+
+					// parse public key
+					data, err := mc.readPacket()
+					if err != nil {
+						return err
+					}
+
+					block, _ := pem.Decode(data[1:])
+					pub, err := x509.ParsePKIXPublicKey(block.Bytes)
+					if err != nil {
+						return err
+					}
+
+					// send encrypted password
+					plain := make([]byte, len(mc.cfg.Passwd)+1)
+					copy(plain, mc.cfg.Passwd)
+					for i := range plain {
+						j := i % len(seed)
+						plain[i] ^= seed[j]
+					}
+					sha1 := sha1.New()
+					enc, err := rsa.EncryptOAEP(sha1, rand.Reader, pub.(*rsa.PublicKey), plain, nil)
+					if err != nil {
+						return err
+					}
+
+					if err = mc.writeAuthSwitchPacket(enc, false); err != nil {
+						return err
+					}
+				}
+				if err = mc.readResultOK(); err == nil {
+					return nil // auth successful
+				}
+
+			default:
+				return ErrMalformPkt
+			}
+		default:
+			return ErrMalformPkt
+		}
+
+	default:
+		return nil // auth successful
+	}
+
+	return err
+}

+ 853 - 0
auth_test.go

@@ -0,0 +1,853 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+	"bytes"
+	"crypto/tls"
+	"fmt"
+	"testing"
+)
+
+var serverPubKey = []byte{1, 45, 45, 45, 45, 45, 66, 69, 71, 73, 78, 32, 80, 85,
+	66, 76, 73, 67, 32, 75, 69, 89, 45, 45, 45, 45, 45, 10, 77, 73, 73, 66, 73,
+	106, 65, 78, 66, 103, 107, 113, 104, 107, 105, 71, 57, 119, 48, 66, 65, 81,
+	69, 70, 65, 65, 79, 67, 65, 81, 56, 65, 77, 73, 73, 66, 67, 103, 75, 67, 65,
+	81, 69, 65, 51, 72, 115, 120, 83, 53, 80, 47, 72, 97, 88, 80, 118, 109, 51,
+	109, 50, 65, 68, 110, 10, 98, 117, 54, 71, 81, 102, 112, 83, 71, 111, 55,
+	104, 50, 103, 104, 56, 49, 112, 109, 97, 120, 107, 67, 110, 68, 67, 119,
+	102, 54, 109, 109, 101, 72, 55, 76, 75, 104, 115, 110, 89, 110, 78, 52, 81,
+	48, 99, 122, 49, 81, 69, 47, 98, 104, 100, 80, 117, 54, 106, 115, 43, 86,
+	97, 89, 52, 10, 67, 99, 77, 117, 98, 80, 78, 49, 103, 79, 75, 97, 89, 118,
+	78, 99, 103, 69, 87, 112, 116, 73, 67, 105, 50, 88, 84, 116, 116, 66, 55,
+	117, 104, 43, 118, 67, 77, 106, 76, 118, 106, 65, 77, 100, 54, 47, 68, 109,
+	120, 100, 98, 85, 66, 48, 122, 80, 71, 113, 68, 79, 103, 105, 76, 68, 10,
+	75, 82, 79, 79, 53, 113, 100, 55, 115, 104, 98, 55, 49, 82, 47, 88, 74, 69,
+	70, 118, 76, 120, 71, 88, 69, 70, 48, 90, 116, 104, 72, 101, 78, 111, 57,
+	102, 69, 118, 120, 70, 81, 111, 109, 98, 49, 107, 90, 57, 74, 56, 110, 66,
+	119, 116, 101, 53, 83, 70, 53, 89, 108, 113, 86, 50, 10, 66, 66, 53, 113,
+	108, 97, 122, 43, 51, 81, 83, 78, 118, 109, 67, 49, 105, 87, 102, 108, 106,
+	88, 98, 89, 53, 107, 51, 47, 97, 54, 109, 107, 77, 47, 76, 97, 87, 104, 97,
+	117, 78, 53, 80, 82, 51, 115, 67, 120, 53, 85, 117, 49, 77, 102, 100, 115,
+	86, 105, 107, 53, 102, 88, 77, 77, 10, 100, 120, 107, 102, 70, 43, 88, 51,
+	99, 104, 107, 65, 110, 119, 73, 51, 70, 117, 119, 119, 50, 87, 71, 109, 87,
+	79, 71, 98, 75, 116, 109, 73, 101, 85, 109, 51, 98, 73, 82, 109, 100, 70,
+	85, 113, 97, 108, 81, 105, 70, 104, 113, 101, 90, 50, 105, 107, 106, 104,
+	103, 86, 73, 57, 112, 76, 10, 119, 81, 73, 68, 65, 81, 65, 66, 10, 45, 45,
+	45, 45, 45, 69, 78, 68, 32, 80, 85, 66, 76, 73, 67, 32, 75, 69, 89, 45, 45,
+	45, 45, 45, 10}
+
+func TestScrambleOldPass(t *testing.T) {
+	scramble := []byte{9, 8, 7, 6, 5, 4, 3, 2}
+	vectors := []struct {
+		pass string
+		out  string
+	}{
+		{" pass", "47575c5a435b4251"},
+		{"pass ", "47575c5a435b4251"},
+		{"123\t456", "575c47505b5b5559"},
+		{"C0mpl!ca ted#PASS123", "5d5d554849584a45"},
+	}
+	for _, tuple := range vectors {
+		ours := scrambleOldPassword(scramble, tuple.pass)
+		if tuple.out != fmt.Sprintf("%x", ours) {
+			t.Errorf("Failed old password %q", tuple.pass)
+		}
+	}
+}
+
+func TestScrambleSHA256Pass(t *testing.T) {
+	scramble := []byte{10, 47, 74, 111, 75, 73, 34, 48, 88, 76, 114, 74, 37, 13, 3, 80, 82, 2, 23, 21}
+	vectors := []struct {
+		pass string
+		out  string
+	}{
+		{"secret", "f490e76f66d9d86665ce54d98c78d0acfe2fb0b08b423da807144873d30b312c"},
+		{"secret2", "abc3934a012cf342e876071c8ee202de51785b430258a7a0138bc79c4d800bc6"},
+	}
+	for _, tuple := range vectors {
+		ours := scrambleSHA256Password(scramble, tuple.pass)
+		if tuple.out != fmt.Sprintf("%x", ours) {
+			t.Errorf("Failed SHA256 password %q", tuple.pass)
+		}
+	}
+}
+
+func TestAuthFastCachingSHA256PasswordCached(t *testing.T) {
+	conn, mc := newRWMockConn(1)
+	mc.cfg.User = "root"
+	mc.cfg.Passwd = "secret"
+
+	authData := []byte{90, 105, 74, 126, 30, 48, 37, 56, 3, 23, 115, 127, 69,
+		22, 41, 84, 32, 123, 43, 118}
+	plugin := "caching_sha2_password"
+
+	// Send Client Authentication Packet
+	authResp, addNUL, err := mc.auth(authData, plugin)
+	if err != nil {
+		t.Fatal(err)
+	}
+	err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// check written auth response
+	authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
+	authRespEnd := authRespStart + 1 + len(authResp)
+	writtenAuthRespLen := conn.written[authRespStart]
+	writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
+	expectedAuthResp := []byte{102, 32, 5, 35, 143, 161, 140, 241, 171, 232, 56,
+		139, 43, 14, 107, 196, 249, 170, 147, 60, 220, 204, 120, 178, 214, 15,
+		184, 150, 26, 61, 57, 235}
+	if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
+		t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
+	}
+	conn.written = nil
+
+	// auth response
+	conn.data = []byte{
+		2, 0, 0, 2, 1, 3, // Fast Auth Success
+		7, 0, 0, 3, 0, 0, 0, 2, 0, 0, 0, // OK
+	}
+	conn.maxReads = 1
+
+	// Handle response to auth packet
+	if err := mc.handleAuthResult(authData, plugin); err != nil {
+		t.Errorf("got error: %v", err)
+	}
+}
+
+func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) {
+	conn, mc := newRWMockConn(1)
+	mc.cfg.User = "root"
+	mc.cfg.Passwd = ""
+
+	authData := []byte{90, 105, 74, 126, 30, 48, 37, 56, 3, 23, 115, 127, 69,
+		22, 41, 84, 32, 123, 43, 118}
+	plugin := "caching_sha2_password"
+
+	// Send Client Authentication Packet
+	authResp, addNUL, err := mc.auth(authData, plugin)
+	if err != nil {
+		t.Fatal(err)
+	}
+	err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// check written auth response
+	authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
+	authRespEnd := authRespStart + 1 + len(authResp)
+	writtenAuthRespLen := conn.written[authRespStart]
+	writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
+	if writtenAuthRespLen != 0 {
+		t.Fatalf("unexpected written auth response (%d bytes): %v",
+			writtenAuthRespLen, writtenAuthResp)
+	}
+	conn.written = nil
+
+	// auth response
+	conn.data = []byte{
+		7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK
+	}
+	conn.maxReads = 1
+
+	// Handle response to auth packet
+	if err := mc.handleAuthResult(authData, plugin); err != nil {
+		t.Errorf("got error: %v", err)
+	}
+}
+
+func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) {
+	conn, mc := newRWMockConn(1)
+	mc.cfg.User = "root"
+	mc.cfg.Passwd = "secret"
+
+	authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81,
+		62, 94, 83, 80, 52, 85}
+	plugin := "caching_sha2_password"
+
+	// Send Client Authentication Packet
+	authResp, addNUL, err := mc.auth(authData, plugin)
+	if err != nil {
+		t.Fatal(err)
+	}
+	err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// check written auth response
+	authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
+	authRespEnd := authRespStart + 1 + len(authResp)
+	writtenAuthRespLen := conn.written[authRespStart]
+	writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
+	expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165,
+		49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70,
+		110, 40, 139, 124, 41}
+	if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
+		t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
+	}
+	conn.written = nil
+
+	// auth response
+	conn.data = []byte{
+		2, 0, 0, 2, 1, 4, // Perform Full Authentication
+	}
+	conn.queuedReplies = [][]byte{
+		// pub key response
+		append([]byte{byte(len(serverPubKey)), 1, 0, 4}, serverPubKey...),
+
+		// OK
+		{7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0},
+	}
+	conn.maxReads = 3
+
+	// Handle response to auth packet
+	if err := mc.handleAuthResult(authData, plugin); err != nil {
+		t.Errorf("got error: %v", err)
+	}
+
+	if !bytes.HasPrefix(conn.written, []byte{1, 0, 0, 3, 2, 0, 1, 0, 5}) {
+		t.Errorf("unexpected written data: %v", conn.written)
+	}
+}
+
+func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) {
+	conn, mc := newRWMockConn(1)
+	mc.cfg.User = "root"
+	mc.cfg.Passwd = "secret"
+
+	authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81,
+		62, 94, 83, 80, 52, 85}
+	plugin := "caching_sha2_password"
+
+	// Send Client Authentication Packet
+	authResp, addNUL, err := mc.auth(authData, plugin)
+	if err != nil {
+		t.Fatal(err)
+	}
+	err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// Hack to make the caching_sha2_password plugin believe that the connection
+	// is secure
+	mc.cfg.tls = &tls.Config{InsecureSkipVerify: true}
+
+	// check written auth response
+	authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
+	authRespEnd := authRespStart + 1 + len(authResp)
+	writtenAuthRespLen := conn.written[authRespStart]
+	writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
+	expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165,
+		49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70,
+		110, 40, 139, 124, 41}
+	if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
+		t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
+	}
+	conn.written = nil
+
+	// auth response
+	conn.data = []byte{
+		2, 0, 0, 2, 1, 4, // Perform Full Authentication
+	}
+	conn.queuedReplies = [][]byte{
+		// OK
+		{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0},
+	}
+	conn.maxReads = 3
+
+	// Handle response to auth packet
+	if err := mc.handleAuthResult(authData, plugin); err != nil {
+		t.Errorf("got error: %v", err)
+	}
+
+	if !bytes.Equal(conn.written, []byte{7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0}) {
+		t.Errorf("unexpected written data: %v", conn.written)
+	}
+}
+
+func TestAuthFastCleartextPasswordNotAllowed(t *testing.T) {
+	_, mc := newRWMockConn(1)
+	mc.cfg.User = "root"
+	mc.cfg.Passwd = "secret"
+
+	authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126,
+		103, 26, 95, 81, 17, 24, 21}
+	plugin := "mysql_clear_password"
+
+	// Send Client Authentication Packet
+	_, _, err := mc.auth(authData, plugin)
+	if err != ErrCleartextPassword {
+		t.Errorf("expected ErrCleartextPassword, got %v", err)
+	}
+}
+
+func TestAuthFastCleartextPassword(t *testing.T) {
+	conn, mc := newRWMockConn(1)
+	mc.cfg.User = "root"
+	mc.cfg.Passwd = "secret"
+	mc.cfg.AllowCleartextPasswords = true
+
+	authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126,
+		103, 26, 95, 81, 17, 24, 21}
+	plugin := "mysql_clear_password"
+
+	// Send Client Authentication Packet
+	authResp, addNUL, err := mc.auth(authData, plugin)
+	if err != nil {
+		t.Fatal(err)
+	}
+	err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// check written auth response
+	authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
+	authRespEnd := authRespStart + 1 + len(authResp)
+	writtenAuthRespLen := conn.written[authRespStart]
+	writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
+	expectedAuthResp := []byte{115, 101, 99, 114, 101, 116}
+	if writtenAuthRespLen != 6 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
+		t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
+	}
+	conn.written = nil
+
+	// auth response
+	conn.data = []byte{
+		7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK
+	}
+	conn.maxReads = 1
+
+	// Handle response to auth packet
+	if err := mc.handleAuthResult(authData, plugin); err != nil {
+		t.Errorf("got error: %v", err)
+	}
+}
+
+func TestAuthFastCleartextPasswordEmpty(t *testing.T) {
+	conn, mc := newRWMockConn(1)
+	mc.cfg.User = "root"
+	mc.cfg.Passwd = ""
+	mc.cfg.AllowCleartextPasswords = true
+
+	authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126,
+		103, 26, 95, 81, 17, 24, 21}
+	plugin := "mysql_clear_password"
+
+	// Send Client Authentication Packet
+	authResp, addNUL, err := mc.auth(authData, plugin)
+	if err != nil {
+		t.Fatal(err)
+	}
+	err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// check written auth response
+	authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
+	authRespEnd := authRespStart + 1 + len(authResp)
+	writtenAuthRespLen := conn.written[authRespStart]
+	writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
+	if writtenAuthRespLen != 0 {
+		t.Fatalf("unexpected written auth response (%d bytes): %v",
+			writtenAuthRespLen, writtenAuthResp)
+	}
+	conn.written = nil
+
+	// auth response
+	conn.data = []byte{
+		7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK
+	}
+	conn.maxReads = 1
+
+	// Handle response to auth packet
+	if err := mc.handleAuthResult(authData, plugin); err != nil {
+		t.Errorf("got error: %v", err)
+	}
+}
+
+func TestAuthFastNativePasswordNotAllowed(t *testing.T) {
+	_, mc := newRWMockConn(1)
+	mc.cfg.User = "root"
+	mc.cfg.Passwd = "secret"
+	mc.cfg.AllowNativePasswords = false
+
+	authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126,
+		103, 26, 95, 81, 17, 24, 21}
+	plugin := "mysql_native_password"
+
+	// Send Client Authentication Packet
+	_, _, err := mc.auth(authData, plugin)
+	if err != ErrNativePassword {
+		t.Errorf("expected ErrNativePassword, got %v", err)
+	}
+}
+
+func TestAuthFastNativePassword(t *testing.T) {
+	conn, mc := newRWMockConn(1)
+	mc.cfg.User = "root"
+	mc.cfg.Passwd = "secret"
+
+	authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126,
+		103, 26, 95, 81, 17, 24, 21}
+	plugin := "mysql_native_password"
+
+	// Send Client Authentication Packet
+	authResp, addNUL, err := mc.auth(authData, plugin)
+	if err != nil {
+		t.Fatal(err)
+	}
+	err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// check written auth response
+	authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
+	authRespEnd := authRespStart + 1 + len(authResp)
+	writtenAuthRespLen := conn.written[authRespStart]
+	writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
+	expectedAuthResp := []byte{53, 177, 140, 159, 251, 189, 127, 53, 109, 252,
+		172, 50, 211, 192, 240, 164, 26, 48, 207, 45}
+	if writtenAuthRespLen != 20 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
+		t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
+	}
+	conn.written = nil
+
+	// auth response
+	conn.data = []byte{
+		7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK
+	}
+	conn.maxReads = 1
+
+	// Handle response to auth packet
+	if err := mc.handleAuthResult(authData, plugin); err != nil {
+		t.Errorf("got error: %v", err)
+	}
+}
+
+func TestAuthFastNativePasswordEmpty(t *testing.T) {
+	conn, mc := newRWMockConn(1)
+	mc.cfg.User = "root"
+	mc.cfg.Passwd = ""
+
+	authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126,
+		103, 26, 95, 81, 17, 24, 21}
+	plugin := "mysql_native_password"
+
+	// Send Client Authentication Packet
+	authResp, addNUL, err := mc.auth(authData, plugin)
+	if err != nil {
+		t.Fatal(err)
+	}
+	err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// check written auth response
+	authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
+	authRespEnd := authRespStart + 1 + len(authResp)
+	writtenAuthRespLen := conn.written[authRespStart]
+	writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
+	if writtenAuthRespLen != 0 {
+		t.Fatalf("unexpected written auth response (%d bytes): %v",
+			writtenAuthRespLen, writtenAuthResp)
+	}
+	conn.written = nil
+
+	// auth response
+	conn.data = []byte{
+		7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK
+	}
+	conn.maxReads = 1
+
+	// Handle response to auth packet
+	if err := mc.handleAuthResult(authData, plugin); err != nil {
+		t.Errorf("got error: %v", err)
+	}
+}
+
+func TestAuthSwitchCachingSHA256PasswordCached(t *testing.T) {
+	conn, mc := newRWMockConn(2)
+	mc.cfg.Passwd = "secret"
+
+	// auth switch request
+	conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95,
+		115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101,
+		11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84,
+		50, 0}
+
+	// auth response
+	conn.queuedReplies = [][]byte{
+		{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, // OK
+	}
+	conn.maxReads = 3
+
+	authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19,
+		47, 43, 9, 41, 112, 67, 110}
+	plugin := "mysql_native_password"
+
+	if err := mc.handleAuthResult(authData, plugin); err != nil {
+		t.Errorf("got error: %v", err)
+	}
+
+	expectedReply := []byte{
+		// 1. Packet: Hash
+		32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128,
+		54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58,
+		153, 9, 130,
+	}
+	if !bytes.Equal(conn.written, expectedReply) {
+		t.Errorf("got unexpected data: %v", conn.written)
+	}
+}
+
+func TestAuthSwitchCachingSHA256PasswordEmpty(t *testing.T) {
+	conn, mc := newRWMockConn(2)
+	mc.cfg.Passwd = ""
+
+	// auth switch request
+	conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95,
+		115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101,
+		11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84,
+		50, 0}
+
+	// auth response
+	conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}}
+	conn.maxReads = 2
+
+	authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19,
+		47, 43, 9, 41, 112, 67, 110}
+	plugin := "mysql_native_password"
+
+	if err := mc.handleAuthResult(authData, plugin); err != nil {
+		t.Errorf("got error: %v", err)
+	}
+
+	expectedReply := []byte{1, 0, 0, 3, 0}
+	if !bytes.Equal(conn.written, expectedReply) {
+		t.Errorf("got unexpected data: %v", conn.written)
+	}
+}
+
+func TestAuthSwitchCachingSHA256PasswordFullRSA(t *testing.T) {
+	conn, mc := newRWMockConn(2)
+	mc.cfg.Passwd = "secret"
+
+	// auth switch request
+	conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95,
+		115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101,
+		11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84,
+		50, 0}
+
+	conn.queuedReplies = [][]byte{
+		// Perform Full Authentication
+		{2, 0, 0, 4, 1, 4},
+
+		// Pub Key Response
+		append([]byte{byte(len(serverPubKey)), 1, 0, 6}, serverPubKey...),
+
+		// OK
+		{7, 0, 0, 8, 0, 0, 0, 2, 0, 0, 0},
+	}
+	conn.maxReads = 4
+
+	authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19,
+		47, 43, 9, 41, 112, 67, 110}
+	plugin := "mysql_native_password"
+
+	if err := mc.handleAuthResult(authData, plugin); err != nil {
+		t.Errorf("got error: %v", err)
+	}
+
+	expectedReplyPrefix := []byte{
+		// 1. Packet: Hash
+		32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128,
+		54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58,
+		153, 9, 130,
+
+		// 2. Packet: Pub Key Request
+		1, 0, 0, 5, 2,
+
+		// 3. Packet: Encrypted Password
+		0, 1, 0, 7, // [changing bytes]
+	}
+	if !bytes.HasPrefix(conn.written, expectedReplyPrefix) {
+		t.Errorf("got unexpected data: %v", conn.written)
+	}
+}
+
+func TestAuthSwitchCachingSHA256PasswordFullSecure(t *testing.T) {
+	conn, mc := newRWMockConn(2)
+	mc.cfg.Passwd = "secret"
+
+	// Hack to make the caching_sha2_password plugin believe that the connection
+	// is secure
+	mc.cfg.tls = &tls.Config{InsecureSkipVerify: true}
+
+	// auth switch request
+	conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95,
+		115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101,
+		11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84,
+		50, 0}
+
+	// auth response
+	conn.queuedReplies = [][]byte{
+		{2, 0, 0, 4, 1, 4},                // Perform Full Authentication
+		{7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, // OK
+	}
+	conn.maxReads = 3
+
+	authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19,
+		47, 43, 9, 41, 112, 67, 110}
+	plugin := "mysql_native_password"
+
+	if err := mc.handleAuthResult(authData, plugin); err != nil {
+		t.Errorf("got error: %v", err)
+	}
+
+	expectedReply := []byte{
+		// 1. Packet: Hash
+		32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128,
+		54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58,
+		153, 9, 130,
+
+		// 2. Packet: Cleartext password
+		7, 0, 0, 5, 115, 101, 99, 114, 101, 116, 0,
+	}
+	if !bytes.Equal(conn.written, expectedReply) {
+		t.Errorf("got unexpected data: %v", conn.written)
+	}
+}
+
+func TestAuthSwitchCleartextPasswordNotAllowed(t *testing.T) {
+	conn, mc := newRWMockConn(2)
+
+	conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108,
+		101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0}
+	conn.maxReads = 1
+	authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19,
+		47, 43, 9, 41, 112, 67, 110}
+	plugin := "mysql_native_password"
+	err := mc.handleAuthResult(authData, plugin)
+	if err != ErrCleartextPassword {
+		t.Errorf("expected ErrCleartextPassword, got %v", err)
+	}
+}
+
+func TestAuthSwitchCleartextPassword(t *testing.T) {
+	conn, mc := newRWMockConn(2)
+	mc.cfg.AllowCleartextPasswords = true
+	mc.cfg.Passwd = "secret"
+
+	// auth switch request
+	conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108,
+		101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0}
+
+	// auth response
+	conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}}
+	conn.maxReads = 2
+
+	authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19,
+		47, 43, 9, 41, 112, 67, 110}
+	plugin := "mysql_native_password"
+
+	if err := mc.handleAuthResult(authData, plugin); err != nil {
+		t.Errorf("got error: %v", err)
+	}
+
+	expectedReply := []byte{7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0}
+	if !bytes.Equal(conn.written, expectedReply) {
+		t.Errorf("got unexpected data: %v", conn.written)
+	}
+}
+
+func TestAuthSwitchCleartextPasswordEmpty(t *testing.T) {
+	conn, mc := newRWMockConn(2)
+	mc.cfg.AllowCleartextPasswords = true
+	mc.cfg.Passwd = ""
+
+	// auth switch request
+	conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108,
+		101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0}
+
+	// auth response
+	conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}}
+	conn.maxReads = 2
+
+	authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19,
+		47, 43, 9, 41, 112, 67, 110}
+	plugin := "mysql_native_password"
+
+	if err := mc.handleAuthResult(authData, plugin); err != nil {
+		t.Errorf("got error: %v", err)
+	}
+
+	expectedReply := []byte{1, 0, 0, 3, 0}
+	if !bytes.Equal(conn.written, expectedReply) {
+		t.Errorf("got unexpected data: %v", conn.written)
+	}
+}
+
+func TestAuthSwitchNativePasswordNotAllowed(t *testing.T) {
+	conn, mc := newRWMockConn(2)
+	mc.cfg.AllowNativePasswords = false
+
+	conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97,
+		116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96,
+		71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55,
+		31, 0}
+	conn.maxReads = 1
+	authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31,
+		48, 31, 89, 39, 55, 31}
+	plugin := "caching_sha2_password"
+	err := mc.handleAuthResult(authData, plugin)
+	if err != ErrNativePassword {
+		t.Errorf("expected ErrNativePassword, got %v", err)
+	}
+}
+
+func TestAuthSwitchNativePassword(t *testing.T) {
+	conn, mc := newRWMockConn(2)
+	mc.cfg.AllowNativePasswords = true
+	mc.cfg.Passwd = "secret"
+
+	// auth switch request
+	conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97,
+		116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96,
+		71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55,
+		31, 0}
+
+	// auth response
+	conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}}
+	conn.maxReads = 2
+
+	authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31,
+		48, 31, 89, 39, 55, 31}
+	plugin := "caching_sha2_password"
+
+	if err := mc.handleAuthResult(authData, plugin); err != nil {
+		t.Errorf("got error: %v", err)
+	}
+
+	expectedReply := []byte{20, 0, 0, 3, 202, 41, 195, 164, 34, 226, 49, 103,
+		21, 211, 167, 199, 227, 116, 8, 48, 57, 71, 149, 146}
+	if !bytes.Equal(conn.written, expectedReply) {
+		t.Errorf("got unexpected data: %v", conn.written)
+	}
+}
+
+func TestAuthSwitchNativePasswordEmpty(t *testing.T) {
+	conn, mc := newRWMockConn(2)
+	mc.cfg.AllowNativePasswords = true
+	mc.cfg.Passwd = ""
+
+	// auth switch request
+	conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97,
+		116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96,
+		71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55,
+		31, 0}
+
+	// auth response
+	conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}}
+	conn.maxReads = 2
+
+	authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31,
+		48, 31, 89, 39, 55, 31}
+	plugin := "caching_sha2_password"
+
+	if err := mc.handleAuthResult(authData, plugin); err != nil {
+		t.Errorf("got error: %v", err)
+	}
+
+	expectedReply := []byte{0, 0, 0, 3}
+	if !bytes.Equal(conn.written, expectedReply) {
+		t.Errorf("got unexpected data: %v", conn.written)
+	}
+}
+
+func TestAuthSwitchOldPasswordNotAllowed(t *testing.T) {
+	conn, mc := newRWMockConn(2)
+
+	conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108,
+		100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61,
+		49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0}
+	conn.maxReads = 1
+	authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35,
+		84, 96, 101, 92, 123, 121, 107}
+	plugin := "mysql_native_password"
+	err := mc.handleAuthResult(authData, plugin)
+	if err != ErrOldPassword {
+		t.Errorf("expected ErrOldPassword, got %v", err)
+	}
+}
+
+func TestAuthSwitchOldPassword(t *testing.T) {
+	conn, mc := newRWMockConn(2)
+	mc.cfg.AllowOldPasswords = true
+	mc.cfg.Passwd = "secret"
+
+	// auth switch request
+	conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108,
+		100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61,
+		49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0}
+
+	// auth response
+	conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}}
+	conn.maxReads = 2
+
+	authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35,
+		84, 96, 101, 92, 123, 121, 107}
+	plugin := "mysql_native_password"
+
+	if err := mc.handleAuthResult(authData, plugin); err != nil {
+		t.Errorf("got error: %v", err)
+	}
+
+	expectedReply := []byte{9, 0, 0, 3, 86, 83, 83, 79, 74, 78, 65, 66, 0}
+	if !bytes.Equal(conn.written, expectedReply) {
+		t.Errorf("got unexpected data: %v", conn.written)
+	}
+}
+
+func TestAuthSwitchOldPasswordEmpty(t *testing.T) {
+	conn, mc := newRWMockConn(2)
+	mc.cfg.AllowOldPasswords = true
+	mc.cfg.Passwd = ""
+
+	// auth switch request
+	conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108,
+		100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61,
+		49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0}
+
+	// auth response
+	conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}}
+	conn.maxReads = 2
+
+	authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35,
+		84, 96, 101, 92, 123, 121, 107}
+	plugin := "mysql_native_password"
+
+	if err := mc.handleAuthResult(authData, plugin); err != nil {
+		t.Errorf("got error: %v", err)
+	}
+
+	expectedReply := []byte{1, 0, 0, 3, 0}
+	if !bytes.Equal(conn.written, expectedReply) {
+		t.Errorf("got unexpected data: %v", conn.written)
+	}
+}

+ 6 - 6
buffer.go

@@ -130,18 +130,18 @@ func (b *buffer) takeBuffer(length int) []byte {
 // smaller than defaultBufSize
 // Only one buffer (total) can be used at a time.
 func (b *buffer) takeSmallBuffer(length int) []byte {
-	if b.length == 0 {
-		return b.buf[:length]
+	if b.length > 0 {
+		return nil
 	}
-	return nil
+	return b.buf[:length]
 }
 
 // takeCompleteBuffer returns the complete existing buffer.
 // This can be used if the necessary buffer size is unknown.
 // Only one buffer (total) can be used at a time.
 func (b *buffer) takeCompleteBuffer() []byte {
-	if b.length == 0 {
-		return b.buf
+	if b.length > 0 {
+		return nil
 	}
-	return nil
+	return b.buf
 }

+ 6 - 9
connection_go18.go

@@ -17,25 +17,22 @@ import (
 )
 
 // Ping implements driver.Pinger interface
-func (mc *mysqlConn) Ping(ctx context.Context) error {
+func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
 	if mc.closed.IsSet() {
 		errLog.Print(ErrInvalidConn)
 		return driver.ErrBadConn
 	}
 
-	if err := mc.watchCancel(ctx); err != nil {
-		return err
+	if err = mc.watchCancel(ctx); err != nil {
+		return
 	}
 	defer mc.finish()
 
-	if err := mc.writeCommandPacket(comPing); err != nil {
-		return err
-	}
-	if _, err := mc.readResultOK(); err != nil {
-		return err
+	if err = mc.writeCommandPacket(comPing); err != nil {
+		return
 	}
 
-	return nil
+	return mc.readResultOK()
 }
 
 // BeginTx implements driver.ConnBeginTx interface

+ 1 - 0
const.go

@@ -9,6 +9,7 @@
 package mysql
 
 const (
+	defaultAuthPlugin       = "mysql_native_password"
 	defaultMaxAllowedPacket = 4 << 20 // 4 MiB
 	minProtocolVersion      = 10
 	maxPacketSize           = 1<<24 - 1

+ 14 - 92
driver.go

@@ -107,20 +107,31 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
 	mc.writeTimeout = mc.cfg.WriteTimeout
 
 	// Reading Handshake Initialization Packet
-	cipher, pluginName, err := mc.readInitPacket()
+	authData, plugin, err := mc.readHandshakePacket()
 	if err != nil {
 		mc.cleanup()
 		return nil, err
 	}
 
 	// Send Client Authentication Packet
-	if err = mc.writeAuthPacket(cipher, pluginName); err != nil {
+	authResp, addNUL, err := mc.auth(authData, plugin)
+	if err != nil {
+		// try the default auth plugin, if using the requested plugin failed
+		errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
+		plugin = defaultAuthPlugin
+		authResp, addNUL, err = mc.auth(authData, plugin)
+		if err != nil {
+			mc.cleanup()
+			return nil, err
+		}
+	}
+	if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil {
 		mc.cleanup()
 		return nil, err
 	}
 
 	// Handle response to auth packet, switch methods if possible
-	if err = handleAuthResult(mc, cipher, pluginName); err != nil {
+	if err = mc.handleAuthResult(authData, plugin); err != nil {
 		// Authentication failed and MySQL has already closed the connection
 		// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
 		// Do not send COM_QUIT, just cleanup and return the error.
@@ -153,95 +164,6 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
 	return mc, nil
 }
 
-func handleAuthResult(mc *mysqlConn, oldCipher []byte, pluginName string) error {
-	// Read Result Packet
-	cipher, err := mc.readResultOK()
-	if err == nil {
-		// handle caching_sha2_password
-		// https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/
-		if pluginName == "caching_sha2_password" {
-			if len(cipher) == 1 {
-				switch cipher[0] {
-				case cachingSha2PasswordFastAuthSuccess:
-					cipher, err = mc.readResultOK()
-					if err == nil {
-						return nil // auth successful
-					}
-
-				case cachingSha2PasswordPerformFullAuthentication:
-					if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
-						if err = mc.writeClearAuthPacket(); err != nil {
-							return err
-						}
-					} else {
-						if err = mc.writePublicKeyAuthPacket(oldCipher); err != nil {
-							return err
-						}
-					}
-					cipher, err = mc.readResultOK()
-					if err == nil {
-						return nil // auth successful
-					}
-
-				default:
-					return ErrMalformPkt
-				}
-			} else {
-				return ErrMalformPkt
-			}
-
-		} else {
-			return nil // auth successful
-		}
-	}
-
-	if mc.cfg == nil {
-		return err // auth failed and retry not possible
-	}
-
-	// Retry auth if configured to do so
-	switch err {
-	case ErrCleartextPassword:
-		if mc.cfg.AllowCleartextPasswords {
-			// Retry with clear text password for
-			// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
-			// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
-			if err = mc.writeClearAuthPacket(); err != nil {
-				return err
-			}
-			_, err = mc.readResultOK()
-		}
-
-	case ErrNativePassword:
-		if mc.cfg.AllowNativePasswords {
-			if err = mc.writeNativeAuthPacket(cipher); err != nil {
-				return err
-			}
-			_, err = mc.readResultOK()
-		}
-
-	case ErrOldPassword:
-		if mc.cfg.AllowOldPasswords {
-			// Retry with old authentication method. Note: there are edge cases
-			// where this should work but doesn't; this is currently "wontfix":
-			// https://github.com/go-sql-driver/mysql/issues/184
-
-			// If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is
-			// sent and we have to keep using the cipher sent in the init packet.
-			if cipher == nil {
-				cipher = oldCipher
-			}
-
-			if err = mc.writeOldAuthPacket(cipher); err != nil {
-				return err
-			}
-			_, err = mc.readResultOK()
-		}
-	}
-
-	return err
-}
-
 func init() {
 	sql.Register("mysql", &MySQLDriver{})
 }

+ 1 - 2
infile.go

@@ -174,8 +174,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
 
 	// read OK packet
 	if err == nil {
-		_, err = mc.readResultOK()
-		return err
+		return mc.readResultOK()
 	}
 
 	mc.readPacket()

+ 71 - 158
packets.go

@@ -10,14 +10,9 @@ package mysql
 
 import (
 	"bytes"
-	"crypto/rand"
-	"crypto/rsa"
-	"crypto/sha1"
 	"crypto/tls"
-	"crypto/x509"
 	"database/sql/driver"
 	"encoding/binary"
-	"encoding/pem"
 	"errors"
 	"fmt"
 	"io"
@@ -154,12 +149,12 @@ func (mc *mysqlConn) writePacket(data []byte) error {
 }
 
 /******************************************************************************
-*                           Initialisation Process                            *
+*                           Initialization Process                            *
 ******************************************************************************/
 
 // Handshake Initialization Packet
 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
-func (mc *mysqlConn) readInitPacket() ([]byte, string, error) {
+func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) {
 	data, err := mc.readPacket()
 	if err != nil {
 		// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
@@ -188,7 +183,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, string, error) {
 	pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
 
 	// first part of the password cipher [8 bytes]
-	cipher := data[pos : pos+8]
+	authData := data[pos : pos+8]
 
 	// (filler) always 0x00 [1 byte]
 	pos += 8 + 1
@@ -203,7 +198,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, string, error) {
 	}
 	pos += 2
 
-	pluginName := "mysql_native_password"
+	plugin := ""
 	if len(data) > pos {
 		// character set [1 byte]
 		// status flags [2 bytes]
@@ -224,36 +219,34 @@ func (mc *mysqlConn) readInitPacket() ([]byte, string, error) {
 		//
 		// The official Python library uses the fixed length 12
 		// which seems to work but technically could have a hidden bug.
-		cipher = append(cipher, data[pos:pos+12]...)
+		authData = append(authData, data[pos:pos+12]...)
 		pos += 13
 
 		// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
 		// \NUL otherwise
 		if end := bytes.IndexByte(data[pos:], 0x00); end != -1 {
-			pluginName = string(data[pos : pos+end])
+			plugin = string(data[pos : pos+end])
 		} else {
-			pluginName = string(data[pos:])
+			plugin = string(data[pos:])
 		}
 
 		// make a memory safe copy of the cipher slice
 		var b [20]byte
-		copy(b[:], cipher)
-		return b[:], pluginName, nil
+		copy(b[:], authData)
+		return b[:], plugin, nil
 	}
 
+	plugin = defaultAuthPlugin
+
 	// make a memory safe copy of the cipher slice
 	var b [8]byte
-	copy(b[:], cipher)
-	return b[:], pluginName, nil
+	copy(b[:], authData)
+	return b[:], plugin, nil
 }
 
 // Client Authentication Packet
 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
-func (mc *mysqlConn) writeAuthPacket(cipher []byte, pluginName string) error {
-	if pluginName != "mysql_native_password" && pluginName != "caching_sha2_password" {
-		return fmt.Errorf("unknown authentication plugin name '%s'", pluginName)
-	}
-
+func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error {
 	// Adjust client flags based on server support
 	clientFlags := clientProtocol41 |
 		clientSecureConn |
@@ -277,17 +270,11 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte, pluginName string) error {
 		clientFlags |= clientMultiStatements
 	}
 
-	// User Password
-	var scrambleBuff []byte
-	switch pluginName {
-	case "mysql_native_password":
-		scrambleBuff = scramblePassword(cipher, []byte(mc.cfg.Passwd))
-	case "caching_sha2_password":
-		scrambleBuff = scrambleCachingSha2Password(cipher, []byte(mc.cfg.Passwd))
+	pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(authResp) + 21 + 1
+	if addNUL {
+		pktLen++
 	}
 
-	pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1
-
 	// To specify a db name
 	if n := len(mc.cfg.DBName); n > 0 {
 		clientFlags |= clientConnectWithDB
@@ -297,7 +284,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte, pluginName string) error {
 	// Calculate packet length and get buffer with that size
 	data := mc.buf.takeSmallBuffer(pktLen + 4)
 	if data == nil {
-		// can not take the buffer. Something must be wrong with the connection
+		// cannot take the buffer. Something must be wrong with the connection
 		errLog.Print(ErrBusyBuffer)
 		return errBadConnNoWrite
 	}
@@ -354,9 +341,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte, pluginName string) error {
 	data[pos] = 0x00
 	pos++
 
-	// ScrambleBuffer [length encoded integer]
-	data[pos] = byte(len(scrambleBuff))
-	pos += 1 + copy(data[pos+1:], scrambleBuff)
+	// Auth Data [length encoded integer]
+	data[pos] = byte(len(authResp))
+	pos += 1 + copy(data[pos+1:], authResp)
+	if addNUL {
+		data[pos] = 0x00
+		pos++
+	}
 
 	// Databasename [null terminated string]
 	if len(mc.cfg.DBName) > 0 {
@@ -365,107 +356,32 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte, pluginName string) error {
 		pos++
 	}
 
-	pos += copy(data[pos:], pluginName)
+	pos += copy(data[pos:], plugin)
 	data[pos] = 0x00
 
 	// Send Auth packet
 	return mc.writePacket(data)
 }
 
-//  Client old authentication packet
 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
-func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
-	// User password
-	// https://dev.mysql.com/doc/internals/en/old-password-authentication.html
-	// Old password authentication only need and will need 8-byte challenge.
-	scrambleBuff := scrambleOldPassword(cipher[:8], []byte(mc.cfg.Passwd))
-
-	// Calculate the packet length and add a tailing 0
-	pktLen := len(scrambleBuff) + 1
-	data := mc.buf.takeSmallBuffer(4 + pktLen)
-	if data == nil {
-		// can not take the buffer. Something must be wrong with the connection
-		errLog.Print(ErrBusyBuffer)
-		return errBadConnNoWrite
+func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error {
+	pktLen := 4 + len(authData)
+	if addNUL {
+		pktLen++
 	}
-
-	// Add the scrambled password [null terminated string]
-	copy(data[4:], scrambleBuff)
-	data[4+pktLen-1] = 0x00
-
-	return mc.writePacket(data)
-}
-
-//  Client clear text authentication packet
-// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
-func (mc *mysqlConn) writeClearAuthPacket() error {
-	// Calculate the packet length and add a tailing 0
-	pktLen := len(mc.cfg.Passwd) + 1
-	data := mc.buf.takeSmallBuffer(4 + pktLen)
+	data := mc.buf.takeSmallBuffer(pktLen)
 	if data == nil {
-		// can not take the buffer. Something must be wrong with the connection
+		// cannot take the buffer. Something must be wrong with the connection
 		errLog.Print(ErrBusyBuffer)
 		return errBadConnNoWrite
 	}
 
-	// Add the clear password [null terminated string]
-	copy(data[4:], mc.cfg.Passwd)
-	data[4+pktLen-1] = 0x00
-
-	return mc.writePacket(data)
-}
-
-//  Native password authentication method
-// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
-func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {
-	// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
-	// Native password authentication only need and will need 20-byte challenge.
-	scrambleBuff := scramblePassword(cipher[0:20], []byte(mc.cfg.Passwd))
-
-	// Calculate the packet length and add a tailing 0
-	pktLen := len(scrambleBuff)
-	data := mc.buf.takeSmallBuffer(4 + pktLen)
-	if data == nil {
-		// can not take the buffer. Something must be wrong with the connection
-		errLog.Print(ErrBusyBuffer)
-		return errBadConnNoWrite
+	// Add the auth data [EOF]
+	copy(data[4:], authData)
+	if addNUL {
+		data[pktLen-1] = 0x00
 	}
 
-	// Add the scramble
-	copy(data[4:], scrambleBuff)
-
-	return mc.writePacket(data)
-}
-
-//  Caching sha2 authentication. Public key request and send encrypted password
-// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
-func (mc *mysqlConn) writePublicKeyAuthPacket(cipher []byte) error {
-	// request public key
-	data := mc.buf.takeSmallBuffer(4 + 1)
-	data[4] = cachingSha2PasswordRequestPublicKey
-	mc.writePacket(data)
-
-	data, err := mc.readPacket()
-	if err != nil {
-		return err
-	}
-
-	block, _ := pem.Decode(data[1:])
-	pub, err := x509.ParsePKIXPublicKey(block.Bytes)
-	if err != nil {
-		return err
-	}
-
-	plain := make([]byte, len(mc.cfg.Passwd)+1)
-	copy(plain, mc.cfg.Passwd)
-	for i := range plain {
-		j := i % len(cipher)
-		plain[i] ^= cipher[j]
-	}
-	sha1 := sha1.New()
-	enc, _ := rsa.EncryptOAEP(sha1, rand.Reader, pub.(*rsa.PublicKey), plain, nil)
-	data = mc.buf.takeSmallBuffer(4 + len(enc))
-	copy(data[4:], enc)
 	return mc.writePacket(data)
 }
 
@@ -479,7 +395,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
 
 	data := mc.buf.takeSmallBuffer(4 + 1)
 	if data == nil {
-		// can not take the buffer. Something must be wrong with the connection
+		// cannot take the buffer. Something must be wrong with the connection
 		errLog.Print(ErrBusyBuffer)
 		return errBadConnNoWrite
 	}
@@ -498,7 +414,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
 	pktLen := 1 + len(arg)
 	data := mc.buf.takeBuffer(pktLen + 4)
 	if data == nil {
-		// can not take the buffer. Something must be wrong with the connection
+		// cannot take the buffer. Something must be wrong with the connection
 		errLog.Print(ErrBusyBuffer)
 		return errBadConnNoWrite
 	}
@@ -519,7 +435,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
 
 	data := mc.buf.takeSmallBuffer(4 + 1 + 4)
 	if data == nil {
-		// can not take the buffer. Something must be wrong with the connection
+		// cannot take the buffer. Something must be wrong with the connection
 		errLog.Print(ErrBusyBuffer)
 		return errBadConnNoWrite
 	}
@@ -541,53 +457,50 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
 *                              Result Packets                                 *
 ******************************************************************************/
 
-func readAuthSwitch(data []byte) ([]byte, error) {
-	if len(data) > 1 {
-		pluginEndIndex := bytes.IndexByte(data, 0x00)
-		plugin := string(data[1:pluginEndIndex])
-		cipher := data[pluginEndIndex+1:]
-
-		switch plugin {
-		case "mysql_old_password":
-			// using old_passwords
-			return cipher, ErrOldPassword
-		case "mysql_clear_password":
-			// using clear text password
-			return cipher, ErrCleartextPassword
-		case "mysql_native_password":
-			// using mysql default authentication method
-			return cipher, ErrNativePassword
-		default:
-			return cipher, ErrUnknownPlugin
-		}
-	}
-
-	// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
-	return nil, ErrOldPassword
-}
-
-// Returns error if Packet is not an 'Result OK'-Packet
-func (mc *mysqlConn) readResultOK() ([]byte, error) {
+func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
 	data, err := mc.readPacket()
 	if err != nil {
-		return nil, err
+		return nil, "", err
 	}
 
 	// packet indicator
 	switch data[0] {
 
 	case iOK:
-		return nil, mc.handleOkPacket(data)
+		return nil, "", mc.handleOkPacket(data)
 
 	case iAuthMoreData:
-		return data[1:], nil
+		return data[1:], "", err
 
 	case iEOF:
-		return readAuthSwitch(data)
+		if len(data) < 1 {
+			// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
+			return nil, "mysql_old_password", nil
+		}
+		pluginEndIndex := bytes.IndexByte(data, 0x00)
+		if pluginEndIndex < 0 {
+			return nil, "", ErrMalformPkt
+		}
+		plugin := string(data[1:pluginEndIndex])
+		authData := data[pluginEndIndex+1:]
+		return authData, plugin, nil
 
 	default: // Error otherwise
-		return nil, mc.handleErrorPacket(data)
+		return nil, "", mc.handleErrorPacket(data)
+	}
+}
+
+// Returns error if Packet is not an 'Result OK'-Packet
+func (mc *mysqlConn) readResultOK() error {
+	data, err := mc.readPacket()
+	if err != nil {
+		return err
+	}
+
+	if data[0] == iOK {
+		return mc.handleOkPacket(data)
 	}
+	return mc.handleErrorPacket(data)
 }
 
 // Result Set Header Packet
@@ -921,7 +834,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
 	// 2 bytes paramID
 	const dataOffset = 1 + 4 + 2
 
-	// Can not use the write buffer since
+	// Cannot use the write buffer since
 	// a) the buffer is too small
 	// b) it is in use
 	data := make([]byte, 4+1+4+2+len(arg))
@@ -993,7 +906,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 		data = mc.buf.takeCompleteBuffer()
 	}
 	if data == nil {
-		// can not take the buffer. Something must be wrong with the connection
+		// cannot take the buffer. Something must be wrong with the connection
 		errLog.Print(ErrBusyBuffer)
 		return errBadConnNoWrite
 	}
@@ -1161,7 +1074,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 				paramValues = append(paramValues, b...)
 
 			default:
-				return fmt.Errorf("can not convert type: %T", arg)
+				return fmt.Errorf("cannot convert type: %T", arg)
 			}
 		}
 

+ 31 - 12
packets_test.go

@@ -24,16 +24,17 @@ var (
 
 // struct to mock a net.Conn for testing purposes
 type mockConn struct {
-	laddr     net.Addr
-	raddr     net.Addr
-	data      []byte
-	closed    bool
-	read      int
-	written   int
-	reads     int
-	writes    int
-	maxReads  int
-	maxWrites int
+	laddr         net.Addr
+	raddr         net.Addr
+	data          []byte
+	written       []byte
+	queuedReplies [][]byte
+	closed        bool
+	read          int
+	reads         int
+	writes        int
+	maxReads      int
+	maxWrites     int
 }
 
 func (m *mockConn) Read(b []byte) (n int, err error) {
@@ -62,7 +63,12 @@ func (m *mockConn) Write(b []byte) (n int, err error) {
 	}
 
 	n = len(b)
-	m.written += n
+	m.written = append(m.written, b...)
+
+	if n > 0 && len(m.queuedReplies) > 0 {
+		m.data = m.queuedReplies[0]
+		m.queuedReplies = m.queuedReplies[1:]
+	}
 	return
 }
 func (m *mockConn) Close() error {
@@ -88,6 +94,19 @@ func (m *mockConn) SetWriteDeadline(t time.Time) error {
 // make sure mockConn implements the net.Conn interface
 var _ net.Conn = new(mockConn)
 
+func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) {
+	conn := new(mockConn)
+	mc := &mysqlConn{
+		buf:              newBuffer(conn),
+		cfg:              NewConfig(),
+		netConn:          conn,
+		closech:          make(chan struct{}),
+		maxAllowedPacket: defaultMaxAllowedPacket,
+		sequence:         sequence,
+	}
+	return conn, mc
+}
+
 func TestReadPacketSingleByte(t *testing.T) {
 	conn := new(mockConn)
 	mc := &mysqlConn{
@@ -300,7 +319,7 @@ func TestRegression801(t *testing.T) {
 		112, 97, 115, 115, 119, 111, 114, 100}
 	conn.maxReads = 1
 
-	authData, pluginName, err := mc.readInitPacket()
+	authData, pluginName, err := mc.readHandshakePacket()
 	if err != nil {
 		t.Fatalf("got error: %v", err)
 	}

+ 0 - 143
utils.go

@@ -9,8 +9,6 @@
 package mysql
 
 import (
-	"crypto/sha1"
-	"crypto/sha256"
 	"crypto/tls"
 	"database/sql/driver"
 	"encoding/binary"
@@ -99,147 +97,6 @@ func readBool(input string) (value bool, valid bool) {
 	return
 }
 
-/******************************************************************************
-*                             Authentication                                  *
-******************************************************************************/
-
-// Encrypt password using 4.1+ method
-func scramblePassword(scramble, password []byte) []byte {
-	if len(password) == 0 {
-		return nil
-	}
-
-	// stage1Hash = SHA1(password)
-	crypt := sha1.New()
-	crypt.Write(password)
-	stage1 := crypt.Sum(nil)
-
-	// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
-	// inner Hash
-	crypt.Reset()
-	crypt.Write(stage1)
-	hash := crypt.Sum(nil)
-
-	// outer Hash
-	crypt.Reset()
-	crypt.Write(scramble)
-	crypt.Write(hash)
-	scramble = crypt.Sum(nil)
-
-	// token = scrambleHash XOR stage1Hash
-	for i := range scramble {
-		scramble[i] ^= stage1[i]
-	}
-	return scramble
-}
-
-// Encrypt password using pre 4.1 (old password) method
-// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
-type myRnd struct {
-	seed1, seed2 uint32
-}
-
-const myRndMaxVal = 0x3FFFFFFF
-
-// Pseudo random number generator
-func newMyRnd(seed1, seed2 uint32) *myRnd {
-	return &myRnd{
-		seed1: seed1 % myRndMaxVal,
-		seed2: seed2 % myRndMaxVal,
-	}
-}
-
-// Tested to be equivalent to MariaDB's floating point variant
-// http://play.golang.org/p/QHvhd4qved
-// http://play.golang.org/p/RG0q4ElWDx
-func (r *myRnd) NextByte() byte {
-	r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal
-	r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal
-
-	return byte(uint64(r.seed1) * 31 / myRndMaxVal)
-}
-
-// Generate binary hash from byte string using insecure pre 4.1 method
-func pwHash(password []byte) (result [2]uint32) {
-	var add uint32 = 7
-	var tmp uint32
-
-	result[0] = 1345345333
-	result[1] = 0x12345671
-
-	for _, c := range password {
-		// skip spaces and tabs in password
-		if c == ' ' || c == '\t' {
-			continue
-		}
-
-		tmp = uint32(c)
-		result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8)
-		result[1] += (result[1] << 8) ^ result[0]
-		add += tmp
-	}
-
-	// Remove sign bit (1<<31)-1)
-	result[0] &= 0x7FFFFFFF
-	result[1] &= 0x7FFFFFFF
-
-	return
-}
-
-// Encrypt password using insecure pre 4.1 method
-func scrambleOldPassword(scramble, password []byte) []byte {
-	if len(password) == 0 {
-		return nil
-	}
-
-	scramble = scramble[:8]
-
-	hashPw := pwHash(password)
-	hashSc := pwHash(scramble)
-
-	r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1])
-
-	var out [8]byte
-	for i := range out {
-		out[i] = r.NextByte() + 64
-	}
-
-	mask := r.NextByte()
-	for i := range out {
-		out[i] ^= mask
-	}
-
-	return out[:]
-}
-
-// Encrypt password using 8.0 default method
-func scrambleCachingSha2Password(scramble, password []byte) []byte {
-	if len(password) == 0 {
-		return nil
-	}
-
-	// XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble))
-
-	crypt := sha256.New()
-	crypt.Write(password)
-	message1 := crypt.Sum(nil)
-
-	crypt.Reset()
-	crypt.Write(message1)
-	message1Hash := crypt.Sum(nil)
-
-	crypt.Reset()
-	crypt.Write(message1Hash)
-	crypt.Write(scramble)
-	message2 := crypt.Sum(nil)
-
-	for i := range message1 {
-		message1[i] ^= message2[i]
-	}
-
-	return message1
-}
-
 /******************************************************************************
 *                           Time related utils                                *
 ******************************************************************************/

+ 0 - 38
utils_test.go

@@ -11,7 +11,6 @@ package mysql
 import (
 	"bytes"
 	"encoding/binary"
-	"fmt"
 	"testing"
 	"time"
 )
@@ -93,43 +92,6 @@ func TestLengthEncodedInteger(t *testing.T) {
 	}
 }
 
-func TestOldPass(t *testing.T) {
-	scramble := []byte{9, 8, 7, 6, 5, 4, 3, 2}
-	vectors := []struct {
-		pass string
-		out  string
-	}{
-		{" pass", "47575c5a435b4251"},
-		{"pass ", "47575c5a435b4251"},
-		{"123\t456", "575c47505b5b5559"},
-		{"C0mpl!ca ted#PASS123", "5d5d554849584a45"},
-	}
-	for _, tuple := range vectors {
-		ours := scrambleOldPassword(scramble, []byte(tuple.pass))
-		if tuple.out != fmt.Sprintf("%x", ours) {
-			t.Errorf("Failed old password %q", tuple.pass)
-		}
-	}
-}
-
-func TestCachingSha2Pass(t *testing.T) {
-	scramble := []byte{10, 47, 74, 111, 75, 73, 34, 48, 88, 76, 114, 74, 37, 13, 3, 80, 82, 2, 23, 21}
-	vectors := []struct {
-		pass string
-		out  string
-	}{
-		{"secret", "f490e76f66d9d86665ce54d98c78d0acfe2fb0b08b423da807144873d30b312c"},
-		{"secret2", "abc3934a012cf342e876071c8ee202de51785b430258a7a0138bc79c4d800bc6"},
-	}
-	for _, tuple := range vectors {
-		ours := scrambleCachingSha2Password(scramble, []byte(tuple.pass))
-		if tuple.out != fmt.Sprintf("%x", ours) {
-			t.Errorf("Failed caching sha2 password %q", tuple.pass)
-		}
-	}
-
-}
-
 func TestFormatBinaryDateTime(t *testing.T) {
 	rawDate := [11]byte{}
 	binary.LittleEndian.PutUint16(rawDate[:2], 1978)   // years