فهرست منبع

Fix Auth Resnponse packet for cleartext password (#887)

Trailing NUL char should be in `string[n] auth-response`.
But NUL was after auth-response.
INADA Naoki 7 سال پیش
والد
کامیت
fb9c42fa52
4فایلهای تغییر یافته به همراه64 افزوده شده و 75 حذف شده
  1. 18 18
      auth.go
  2. 37 36
      auth_test.go
  3. 3 3
      driver.go
  4. 6 18
      packets.go

+ 18 - 18
auth.go

@@ -234,64 +234,64 @@ func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) erro
 	if err != nil {
 		return err
 	}
-	return mc.writeAuthSwitchPacket(enc, false)
+	return mc.writeAuthSwitchPacket(enc)
 }
 
-func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) {
+func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
 	switch plugin {
 	case "caching_sha2_password":
 		authResp := scrambleSHA256Password(authData, mc.cfg.Passwd)
-		return authResp, false, nil
+		return authResp, nil
 
 	case "mysql_old_password":
 		if !mc.cfg.AllowOldPasswords {
-			return nil, false, ErrOldPassword
+			return nil, ErrOldPassword
 		}
 		// Note: there are edge cases where this should work but doesn't;
 		// this is currently "wontfix":
 		// https://github.com/go-sql-driver/mysql/issues/184
-		authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd)
-		return authResp, true, nil
+		authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0)
+		return authResp, nil
 
 	case "mysql_clear_password":
 		if !mc.cfg.AllowCleartextPasswords {
-			return nil, false, ErrCleartextPassword
+			return nil, ErrCleartextPassword
 		}
 		// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
 		// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
-		return []byte(mc.cfg.Passwd), true, nil
+		return append([]byte(mc.cfg.Passwd), 0), nil
 
 	case "mysql_native_password":
 		if !mc.cfg.AllowNativePasswords {
-			return nil, false, ErrNativePassword
+			return nil, ErrNativePassword
 		}
 		// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
 		// Native password authentication only need and will need 20-byte challenge.
 		authResp := scramblePassword(authData[:20], mc.cfg.Passwd)
-		return authResp, false, nil
+		return authResp, nil
 
 	case "sha256_password":
 		if len(mc.cfg.Passwd) == 0 {
-			return nil, true, nil
+			return []byte{0}, nil
 		}
 		if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
 			// write cleartext auth packet
-			return []byte(mc.cfg.Passwd), true, nil
+			return append([]byte(mc.cfg.Passwd), 0), nil
 		}
 
 		pubKey := mc.cfg.pubKey
 		if pubKey == nil {
 			// request public key from server
-			return []byte{1}, false, nil
+			return []byte{1}, nil
 		}
 
 		// encrypted password
 		enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey)
-		return enc, false, err
+		return enc, err
 
 	default:
 		errLog.Print("unknown auth plugin:", plugin)
-		return nil, false, ErrUnknownPlugin
+		return nil, ErrUnknownPlugin
 	}
 }
 
@@ -315,11 +315,11 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
 
 		plugin = newPlugin
 
-		authResp, addNUL, err := mc.auth(authData, plugin)
+		authResp, err := mc.auth(authData, plugin)
 		if err != nil {
 			return err
 		}
-		if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil {
+		if err = mc.writeAuthSwitchPacket(authResp); err != nil {
 			return err
 		}
 
@@ -352,7 +352,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
 			case cachingSha2PasswordPerformFullAuthentication:
 				if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
 					// write cleartext auth packet
-					err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true)
+					err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0))
 					if err != nil {
 						return err
 					}

+ 37 - 36
auth_test.go

@@ -85,11 +85,11 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) {
 	plugin := "caching_sha2_password"
 
 	// Send Client Authentication Packet
-	authResp, addNUL, err := mc.auth(authData, plugin)
+	authResp, err := mc.auth(authData, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
-	err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
+	err = mc.writeHandshakeResponsePacket(authResp, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -130,11 +130,11 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) {
 	plugin := "caching_sha2_password"
 
 	// Send Client Authentication Packet
-	authResp, addNUL, err := mc.auth(authData, plugin)
+	authResp, err := mc.auth(authData, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
-	err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
+	err = mc.writeHandshakeResponsePacket(authResp, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -172,11 +172,11 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) {
 	plugin := "caching_sha2_password"
 
 	// Send Client Authentication Packet
-	authResp, addNUL, err := mc.auth(authData, plugin)
+	authResp, err := mc.auth(authData, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
-	err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
+	err = mc.writeHandshakeResponsePacket(authResp, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -228,11 +228,11 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) {
 	plugin := "caching_sha2_password"
 
 	// Send Client Authentication Packet
-	authResp, addNUL, err := mc.auth(authData, plugin)
+	authResp, err := mc.auth(authData, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
-	err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
+	err = mc.writeHandshakeResponsePacket(authResp, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -280,11 +280,11 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) {
 	plugin := "caching_sha2_password"
 
 	// Send Client Authentication Packet
-	authResp, addNUL, err := mc.auth(authData, plugin)
+	authResp, err := mc.auth(authData, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
-	err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
+	err = mc.writeHandshakeResponsePacket(authResp, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -336,7 +336,7 @@ func TestAuthFastCleartextPasswordNotAllowed(t *testing.T) {
 	plugin := "mysql_clear_password"
 
 	// Send Client Authentication Packet
-	_, _, err := mc.auth(authData, plugin)
+	_, err := mc.auth(authData, plugin)
 	if err != ErrCleartextPassword {
 		t.Errorf("expected ErrCleartextPassword, got %v", err)
 	}
@@ -353,11 +353,11 @@ func TestAuthFastCleartextPassword(t *testing.T) {
 	plugin := "mysql_clear_password"
 
 	// Send Client Authentication Packet
-	authResp, addNUL, err := mc.auth(authData, plugin)
+	authResp, err := mc.auth(authData, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
-	err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
+	err = mc.writeHandshakeResponsePacket(authResp, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -367,8 +367,8 @@ func TestAuthFastCleartextPassword(t *testing.T) {
 	authRespEnd := authRespStart + 1 + len(authResp)
 	writtenAuthRespLen := conn.written[authRespStart]
 	writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
-	expectedAuthResp := []byte{115, 101, 99, 114, 101, 116}
-	if writtenAuthRespLen != 6 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
+	expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0}
+	if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
 		t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
 	}
 	conn.written = nil
@@ -396,11 +396,11 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) {
 	plugin := "mysql_clear_password"
 
 	// Send Client Authentication Packet
-	authResp, addNUL, err := mc.auth(authData, plugin)
+	authResp, err := mc.auth(authData, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
-	err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
+	err = mc.writeHandshakeResponsePacket(authResp, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -410,9 +410,9 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) {
 	authRespEnd := authRespStart + 1 + len(authResp)
 	writtenAuthRespLen := conn.written[authRespStart]
 	writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
-	if writtenAuthRespLen != 0 {
-		t.Fatalf("unexpected written auth response (%d bytes): %v",
-			writtenAuthRespLen, writtenAuthResp)
+	expectedAuthResp := []byte{0}
+	if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
+		t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
 	}
 	conn.written = nil
 
@@ -439,7 +439,7 @@ func TestAuthFastNativePasswordNotAllowed(t *testing.T) {
 	plugin := "mysql_native_password"
 
 	// Send Client Authentication Packet
-	_, _, err := mc.auth(authData, plugin)
+	_, err := mc.auth(authData, plugin)
 	if err != ErrNativePassword {
 		t.Errorf("expected ErrNativePassword, got %v", err)
 	}
@@ -455,11 +455,11 @@ func TestAuthFastNativePassword(t *testing.T) {
 	plugin := "mysql_native_password"
 
 	// Send Client Authentication Packet
-	authResp, addNUL, err := mc.auth(authData, plugin)
+	authResp, err := mc.auth(authData, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
-	err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
+	err = mc.writeHandshakeResponsePacket(authResp, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -498,11 +498,11 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) {
 	plugin := "mysql_native_password"
 
 	// Send Client Authentication Packet
-	authResp, addNUL, err := mc.auth(authData, plugin)
+	authResp, err := mc.auth(authData, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
-	err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
+	err = mc.writeHandshakeResponsePacket(authResp, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -540,11 +540,11 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) {
 	plugin := "sha256_password"
 
 	// Send Client Authentication Packet
-	authResp, addNUL, err := mc.auth(authData, plugin)
+	authResp, err := mc.auth(authData, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
-	err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
+	err = mc.writeHandshakeResponsePacket(authResp, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -554,7 +554,8 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) {
 	authRespEnd := authRespStart + 1 + len(authResp)
 	writtenAuthRespLen := conn.written[authRespStart]
 	writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
-	if writtenAuthRespLen != 0 {
+	expectedAuthResp := []byte{0}
+	if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
 		t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
 	}
 	conn.written = nil
@@ -587,11 +588,11 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) {
 	plugin := "sha256_password"
 
 	// Send Client Authentication Packet
-	authResp, addNUL, err := mc.auth(authData, plugin)
+	authResp, err := mc.auth(authData, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
-	err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
+	err = mc.writeHandshakeResponsePacket(authResp, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -636,11 +637,11 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) {
 	plugin := "sha256_password"
 
 	// Send Client Authentication Packet
-	authResp, addNUL, err := mc.auth(authData, plugin)
+	authResp, err := mc.auth(authData, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
-	err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
+	err = mc.writeHandshakeResponsePacket(authResp, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -669,7 +670,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) {
 	plugin := "sha256_password"
 
 	// send Client Authentication Packet
-	authResp, addNUL, err := mc.auth(authData, plugin)
+	authResp, err := mc.auth(authData, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -677,18 +678,18 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) {
 	// unset TLS config to prevent the actual establishment of a TLS wrapper
 	mc.cfg.tls = nil
 
-	err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
+	err = mc.writeHandshakeResponsePacket(authResp, plugin)
 	if err != nil {
 		t.Fatal(err)
 	}
 
 	// check written auth response
 	authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
-	authRespEnd := authRespStart + 1 + len(authResp) + 1
+	authRespEnd := authRespStart + 1 + len(authResp)
 	writtenAuthRespLen := conn.written[authRespStart]
 	writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
 	expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0}
-	if writtenAuthRespLen != 6 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
+	if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
 		t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
 	}
 	conn.written = nil

+ 3 - 3
driver.go

@@ -117,18 +117,18 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
 	}
 
 	// Send Client Authentication Packet
-	authResp, addNUL, err := mc.auth(authData, plugin)
+	authResp, err := mc.auth(authData, plugin)
 	if err != nil {
 		// try the default auth plugin, if using the requested plugin failed
 		errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
 		plugin = defaultAuthPlugin
-		authResp, addNUL, err = mc.auth(authData, plugin)
+		authResp, err = mc.auth(authData, plugin)
 		if err != nil {
 			mc.cleanup()
 			return nil, err
 		}
 	}
-	if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil {
+	if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
 		mc.cleanup()
 		return nil, err
 	}

+ 6 - 18
packets.go

@@ -243,7 +243,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
 
 // Client Authentication Packet
 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
-func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error {
+func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error {
 	// Adjust client flags based on server support
 	clientFlags := clientProtocol41 |
 		clientSecureConn |
@@ -269,7 +269,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
 
 	// encode length of the auth plugin data
 	var authRespLEIBuf [9]byte
-	authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp)))
+	authRespLen := len(authResp)
+	authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen))
 	if len(authRespLEI) > 1 {
 		// if the length can not be written in 1 byte, it must be written as a
 		// length encoded integer
@@ -277,9 +278,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
 	}
 
 	pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
-	if addNUL {
-		pktLen++
-	}
 
 	// To specify a db name
 	if n := len(mc.cfg.DBName); n > 0 {
@@ -350,10 +348,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
 	// Auth Data [length encoded integer]
 	pos += copy(data[pos:], authRespLEI)
 	pos += copy(data[pos:], authResp)
-	if addNUL {
-		data[pos] = 0x00
-		pos++
-	}
 
 	// Databasename [null terminated string]
 	if len(mc.cfg.DBName) > 0 {
@@ -364,17 +358,15 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
 
 	pos += copy(data[pos:], plugin)
 	data[pos] = 0x00
+	pos++
 
 	// Send Auth packet
-	return mc.writePacket(data)
+	return mc.writePacket(data[:pos])
 }
 
 // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
-func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error {
+func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
 	pktLen := 4 + len(authData)
-	if addNUL {
-		pktLen++
-	}
 	data := mc.buf.takeSmallBuffer(pktLen)
 	if data == nil {
 		// cannot take the buffer. Something must be wrong with the connection
@@ -384,10 +376,6 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error {
 
 	// Add the auth data [EOF]
 	copy(data[4:], authData)
-	if addNUL {
-		data[pktLen-1] = 0x00
-	}
-
 	return mc.writePacket(data)
 }