|
|
@@ -10,9 +10,14 @@ package mysql
|
|
|
|
|
|
import (
|
|
|
"bytes"
|
|
|
+ "crypto/rand"
|
|
|
+ "crypto/rsa"
|
|
|
+ "crypto/sha1"
|
|
|
"crypto/tls"
|
|
|
+ "crypto/x509"
|
|
|
"database/sql/driver"
|
|
|
"encoding/binary"
|
|
|
+ "encoding/pem"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
@@ -154,24 +159,24 @@ func (mc *mysqlConn) writePacket(data []byte) error {
|
|
|
|
|
|
// Handshake Initialization Packet
|
|
|
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
|
|
|
-func (mc *mysqlConn) readInitPacket() ([]byte, error) {
|
|
|
+func (mc *mysqlConn) readInitPacket() ([]byte, string, error) {
|
|
|
data, err := mc.readPacket()
|
|
|
if err != nil {
|
|
|
// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
|
|
|
// in connection initialization we don't risk retrying non-idempotent actions.
|
|
|
if err == ErrInvalidConn {
|
|
|
- return nil, driver.ErrBadConn
|
|
|
+ return nil, "", driver.ErrBadConn
|
|
|
}
|
|
|
- return nil, err
|
|
|
+ return nil, "", err
|
|
|
}
|
|
|
|
|
|
if data[0] == iERR {
|
|
|
- return nil, mc.handleErrorPacket(data)
|
|
|
+ return nil, "", mc.handleErrorPacket(data)
|
|
|
}
|
|
|
|
|
|
// protocol version [1 byte]
|
|
|
if data[0] < minProtocolVersion {
|
|
|
- return nil, fmt.Errorf(
|
|
|
+ return nil, "", fmt.Errorf(
|
|
|
"unsupported protocol version %d. Version %d or higher is required",
|
|
|
data[0],
|
|
|
minProtocolVersion,
|
|
|
@@ -191,13 +196,14 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
|
|
|
// capability flags (lower 2 bytes) [2 bytes]
|
|
|
mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
|
|
if mc.flags&clientProtocol41 == 0 {
|
|
|
- return nil, ErrOldProtocol
|
|
|
+ return nil, "", ErrOldProtocol
|
|
|
}
|
|
|
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
|
|
|
- return nil, ErrNoTLS
|
|
|
+ return nil, "", ErrNoTLS
|
|
|
}
|
|
|
pos += 2
|
|
|
|
|
|
+ pluginName := ""
|
|
|
if len(data) > pos {
|
|
|
// character set [1 byte]
|
|
|
// status flags [2 bytes]
|
|
|
@@ -219,6 +225,8 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
|
|
|
// The official Python library uses the fixed length 12
|
|
|
// which seems to work but technically could have a hidden bug.
|
|
|
cipher = append(cipher, data[pos:pos+12]...)
|
|
|
+ pos += 13
|
|
|
+ pluginName = string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)])
|
|
|
|
|
|
// TODO: Verify string termination
|
|
|
// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
|
|
|
@@ -232,18 +240,22 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
|
|
|
// make a memory safe copy of the cipher slice
|
|
|
var b [20]byte
|
|
|
copy(b[:], cipher)
|
|
|
- return b[:], nil
|
|
|
+ return b[:], pluginName, nil
|
|
|
}
|
|
|
|
|
|
// make a memory safe copy of the cipher slice
|
|
|
var b [8]byte
|
|
|
copy(b[:], cipher)
|
|
|
- return b[:], nil
|
|
|
+ return b[:], pluginName, nil
|
|
|
}
|
|
|
|
|
|
// Client Authentication Packet
|
|
|
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
|
|
|
-func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
|
|
|
+func (mc *mysqlConn) writeAuthPacket(cipher []byte, pluginName string) error {
|
|
|
+ if pluginName != "mysql_native_password" && pluginName != "caching_sha2_password" {
|
|
|
+ return fmt.Errorf("unknown authentication plugin name '%s'", pluginName)
|
|
|
+ }
|
|
|
+
|
|
|
// Adjust client flags based on server support
|
|
|
clientFlags := clientProtocol41 |
|
|
|
clientSecureConn |
|
|
|
@@ -268,7 +280,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
|
|
|
}
|
|
|
|
|
|
// User Password
|
|
|
- scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd))
|
|
|
+ var scrambleBuff []byte
|
|
|
+ switch pluginName {
|
|
|
+ case "mysql_native_password":
|
|
|
+ scrambleBuff = scramblePassword(cipher, []byte(mc.cfg.Passwd))
|
|
|
+ case "caching_sha2_password":
|
|
|
+ scrambleBuff = scrambleCachingSha2Password(cipher, []byte(mc.cfg.Passwd))
|
|
|
+ }
|
|
|
|
|
|
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1
|
|
|
|
|
|
@@ -350,7 +368,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
|
|
|
}
|
|
|
|
|
|
// Assume native client during response
|
|
|
- pos += copy(data[pos:], "mysql_native_password")
|
|
|
+ pos += copy(data[pos:], pluginName)
|
|
|
data[pos] = 0x00
|
|
|
|
|
|
// Send Auth packet
|
|
|
@@ -422,6 +440,38 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {
|
|
|
return mc.writePacket(data)
|
|
|
}
|
|
|
|
|
|
+// Caching sha2 authentication. Public key request and send encrypted password
|
|
|
+// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
|
|
|
+func (mc *mysqlConn) writePublicKeyAuthPacket(cipher []byte) error {
|
|
|
+ // request public key
|
|
|
+ data := mc.buf.takeSmallBuffer(4 + 1)
|
|
|
+ data[4] = cachingSha2PasswordRequestPublicKey
|
|
|
+ mc.writePacket(data)
|
|
|
+
|
|
|
+ data, err := mc.readPacket()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ block, _ := pem.Decode(data[1:])
|
|
|
+ pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ plain := make([]byte, len(mc.cfg.Passwd)+1)
|
|
|
+ copy(plain, mc.cfg.Passwd)
|
|
|
+ for i := range plain {
|
|
|
+ j := i % len(cipher)
|
|
|
+ plain[i] ^= cipher[j]
|
|
|
+ }
|
|
|
+ sha1 := sha1.New()
|
|
|
+ enc, _ := rsa.EncryptOAEP(sha1, rand.Reader, pub.(*rsa.PublicKey), plain, nil)
|
|
|
+ data = mc.buf.takeSmallBuffer(4 + len(enc))
|
|
|
+ copy(data[4:], enc)
|
|
|
+ return mc.writePacket(data)
|
|
|
+}
|
|
|
+
|
|
|
/******************************************************************************
|
|
|
* Command Packets *
|
|
|
******************************************************************************/
|
|
|
@@ -535,6 +585,16 @@ func (mc *mysqlConn) readResultOK() ([]byte, error) {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
+func (mc *mysqlConn) readCachingSha2PasswordAuthResult() (int, error) {
|
|
|
+ data, err := mc.readPacket()
|
|
|
+ if err == nil {
|
|
|
+ if data[0] != 1 {
|
|
|
+ return 0, ErrMalformPkt
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return int(data[1]), err
|
|
|
+}
|
|
|
+
|
|
|
// Result Set Header Packet
|
|
|
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
|
|
|
func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
|