|
|
@@ -9,12 +9,14 @@
|
|
|
package mysql
|
|
|
|
|
|
import (
|
|
|
+ "bytes"
|
|
|
"crypto/tls"
|
|
|
"database/sql"
|
|
|
"database/sql/driver"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"io/ioutil"
|
|
|
+ "log"
|
|
|
"net"
|
|
|
"net/url"
|
|
|
"os"
|
|
|
@@ -1018,7 +1020,7 @@ func TestFoundRows(t *testing.T) {
|
|
|
|
|
|
func TestStrict(t *testing.T) {
|
|
|
// ALLOW_INVALID_DATES to get rid of stricter modes - we want to test for warnings, not errors
|
|
|
- relaxedDsn := dsn + "&sql_mode=ALLOW_INVALID_DATES"
|
|
|
+ relaxedDsn := dsn + "&sql_mode='ALLOW_INVALID_DATES,NO_AUTO_CREATE_USER'"
|
|
|
// make sure the MySQL version is recent enough with a separate connection
|
|
|
// before running the test
|
|
|
conn, err := MySQLDriver{}.Open(relaxedDsn)
|
|
|
@@ -1643,7 +1645,7 @@ func TestSqlInjection(t *testing.T) {
|
|
|
|
|
|
dsns := []string{
|
|
|
dsn,
|
|
|
- dsn + "&sql_mode=NO_BACKSLASH_ESCAPES",
|
|
|
+ dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'",
|
|
|
}
|
|
|
for _, testdsn := range dsns {
|
|
|
runTests(t, testdsn, createTest("1 OR 1=1"))
|
|
|
@@ -1673,9 +1675,56 @@ func TestInsertRetrieveEscapedData(t *testing.T) {
|
|
|
|
|
|
dsns := []string{
|
|
|
dsn,
|
|
|
- dsn + "&sql_mode=NO_BACKSLASH_ESCAPES",
|
|
|
+ dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'",
|
|
|
}
|
|
|
for _, testdsn := range dsns {
|
|
|
runTests(t, testdsn, testData)
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+func TestUnixSocketAuthFail(t *testing.T) {
|
|
|
+ runTests(t, dsn, func(dbt *DBTest) {
|
|
|
+ // Save the current logger so we can restore it.
|
|
|
+ oldLogger := errLog
|
|
|
+
|
|
|
+ // Set a new logger so we can capture its output.
|
|
|
+ buffer := bytes.NewBuffer(make([]byte, 0, 64))
|
|
|
+ newLogger := log.New(buffer, "prefix: ", 0)
|
|
|
+ SetLogger(newLogger)
|
|
|
+
|
|
|
+ // Restore the logger.
|
|
|
+ defer SetLogger(oldLogger)
|
|
|
+
|
|
|
+ // Make a new DSN that uses the MySQL socket file and a bad password, which
|
|
|
+ // we can make by simply appending any character to the real password.
|
|
|
+ badPass := pass + "x"
|
|
|
+ socket := ""
|
|
|
+ if prot == "unix" {
|
|
|
+ socket = addr
|
|
|
+ } else {
|
|
|
+ // Get socket file from MySQL.
|
|
|
+ err := dbt.db.QueryRow("SELECT @@socket").Scan(&socket)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("Error on SELECT @@socket: %s", err.Error())
|
|
|
+ }
|
|
|
+ }
|
|
|
+ t.Logf("socket: %s", socket)
|
|
|
+ badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s&strict=true", user, badPass, socket, dbname)
|
|
|
+ db, err := sql.Open("mysql", badDSN)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("Error connecting: %s", err.Error())
|
|
|
+ }
|
|
|
+ defer db.Close()
|
|
|
+
|
|
|
+ // Connect to MySQL for real. This will cause an auth failure.
|
|
|
+ err = db.Ping()
|
|
|
+ if err == nil {
|
|
|
+ t.Error("expected Ping() to return an error")
|
|
|
+ }
|
|
|
+
|
|
|
+ // The driver should not log anything.
|
|
|
+ if actual := buffer.String(); actual != "" {
|
|
|
+ t.Errorf("expected no output, got %q", actual)
|
|
|
+ }
|
|
|
+ })
|
|
|
+}
|