Procházet zdrojové kódy

Merge pull request #2 from go-sql-driver/master

Pull recent changes from the main fork
Kevin Malachowski před 10 roky
rodič
revize
de5a0de5a3
6 změnil soubory, kde provedl 136 přidání a 45 odebrání
  1. 3 0
      AUTHORS
  2. 1 1
      README.md
  3. 17 8
      connection.go
  4. 41 29
      driver.go
  5. 52 3
      driver_test.go
  6. 22 4
      infile.go

+ 3 - 0
AUTHORS

@@ -15,15 +15,18 @@ Aaron Hopkins <go-sql-driver at die.net>
 Arne Hormann <arnehormann at gmail.com>
 Carlos Nieto <jose.carlos at menteslibres.net>
 Chris Moos <chris at tech9computers.com>
+Daniel Nichter <nil at codenode.com>
 DisposaBoy <disposaboy at dby.me>
 Frederick Mayle <frederickmayle at gmail.com>
 Gustavo Kristic <gkristic at gmail.com>
 Hanno Braun <mail at hannobraun.com>
 Henri Yandell <flamefew at gmail.com>
+Hirotaka Yamamoto <ymmt2005 at gmail.com>
 INADA Naoki <songofacandy at gmail.com>
 James Harr <james.harr at gmail.com>
 Jian Zhen <zhenjl at gmail.com>
 Joshua Prunier <joshua.prunier at gmail.com>
+Julien Lefevre <julien.lefevr at gmail.com>
 Julien Schmidt <go-sql-driver at julienschmidt.com>
 Kamil Dziedzic <kamil at klecza.pl>
 Kevin Malachowski <kevin at chowski.com>

+ 1 - 1
README.md

@@ -331,7 +331,7 @@ import "github.com/go-sql-driver/mysql"
 
 Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the Whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)).
 
-To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::<name>` then.
+To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::<name>` then. Choose different names for different handlers and `DeregisterReaderHandler` when you don't need it anymore.
 
 See the [godoc of Go-MySQL-Driver](http://godoc.org/github.com/go-sql-driver/mysql "golang mysql driver documentation") for details.
 

+ 17 - 8
connection.go

@@ -120,18 +120,27 @@ func (mc *mysqlConn) Close() (err error) {
 	// Makes Close idempotent
 	if mc.netConn != nil {
 		err = mc.writeCommandPacket(comQuit)
-		if err == nil {
-			err = mc.netConn.Close()
-		} else {
-			mc.netConn.Close()
+	}
+
+	mc.cleanup()
+
+	return
+}
+
+// Closes the network connection and unsets internal variables. Do not call this
+// function after successfully authentication, call Close instead. This function
+// is called before auth or on auth failure because MySQL will have already
+// closed the network connection.
+func (mc *mysqlConn) cleanup() {
+	// Makes cleanup idempotent
+	if mc.netConn != nil {
+		if err := mc.netConn.Close(); err != nil {
+			errLog.Print(err)
 		}
 		mc.netConn = nil
 	}
-
 	mc.cfg = nil
 	mc.buf.rd = nil
-
-	return
 }
 
 func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
@@ -253,7 +262,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
 			if v == nil {
 				buf = append(buf, "NULL"...)
 			} else {
-				buf = append(buf, '\'')
+				buf = append(buf, "_binary'"...)
 				if mc.status&statusNoBackslashEscapes == 0 {
 					buf = escapeBytesBackslash(buf, v)
 				} else {

+ 41 - 29
driver.go

@@ -84,43 +84,23 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
 	// Reading Handshake Initialization Packet
 	cipher, err := mc.readInitPacket()
 	if err != nil {
-		mc.Close()
+		mc.cleanup()
 		return nil, err
 	}
 
 	// Send Client Authentication Packet
 	if err = mc.writeAuthPacket(cipher); err != nil {
-		mc.Close()
+		mc.cleanup()
 		return nil, err
 	}
 
-	// Read Result Packet
-	err = mc.readResultOK()
-	if err != nil {
-		// Retry with old authentication method, if allowed
-		if mc.cfg != nil && mc.cfg.allowOldPasswords && err == ErrOldPassword {
-			if err = mc.writeOldAuthPacket(cipher); err != nil {
-				mc.Close()
-				return nil, err
-			}
-			if err = mc.readResultOK(); err != nil {
-				mc.Close()
-				return nil, err
-			}
-		} else if mc.cfg != nil && mc.cfg.allowCleartextPasswords && err == ErrCleartextPassword {
-			if err = mc.writeClearAuthPacket(); err != nil {
-				mc.Close()
-				return nil, err
-			}
-			if err = mc.readResultOK(); err != nil {
-				mc.Close()
-				return nil, err
-			}
-		} else {
-			mc.Close()
-			return nil, err
-		}
-
+	// Handle response to auth packet, switch methods if possible
+	if err = handleAuthResult(mc, cipher); err != nil {
+		// Authentication failed and MySQL has already closed the connection
+		// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
+		// Do not send COM_QUIT, just cleanup and return the error.
+		mc.cleanup()
+		return nil, err
 	}
 
 	// Get max allowed packet size
@@ -144,6 +124,38 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
 	return mc, nil
 }
 
+func handleAuthResult(mc *mysqlConn, cipher []byte) error {
+	// Read Result Packet
+	err := mc.readResultOK()
+	if err == nil {
+		return nil // auth successful
+	}
+
+	if mc.cfg == nil {
+		return err // auth failed and retry not possible
+	}
+
+	// Retry auth if configured to do so.
+	if mc.cfg.allowOldPasswords && err == ErrOldPassword {
+		// Retry with old authentication method. Note: there are edge cases
+		// where this should work but doesn't; this is currently "wontfix":
+		// https://github.com/go-sql-driver/mysql/issues/184
+		if err = mc.writeOldAuthPacket(cipher); err != nil {
+			return err
+		}
+		err = mc.readResultOK()
+	} else if mc.cfg.allowCleartextPasswords && err == ErrCleartextPassword {
+		// Retry with clear text password for
+		// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
+		// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
+		if err = mc.writeClearAuthPacket(); err != nil {
+			return err
+		}
+		err = mc.readResultOK()
+	}
+	return err
+}
+
 func init() {
 	sql.Register("mysql", &MySQLDriver{})
 }

+ 52 - 3
driver_test.go

@@ -9,12 +9,14 @@
 package mysql
 
 import (
+	"bytes"
 	"crypto/tls"
 	"database/sql"
 	"database/sql/driver"
 	"fmt"
 	"io"
 	"io/ioutil"
+	"log"
 	"net"
 	"net/url"
 	"os"
@@ -1018,7 +1020,7 @@ func TestFoundRows(t *testing.T) {
 
 func TestStrict(t *testing.T) {
 	// ALLOW_INVALID_DATES to get rid of stricter modes - we want to test for warnings, not errors
-	relaxedDsn := dsn + "&sql_mode=ALLOW_INVALID_DATES"
+	relaxedDsn := dsn + "&sql_mode='ALLOW_INVALID_DATES,NO_AUTO_CREATE_USER'"
 	// make sure the MySQL version is recent enough with a separate connection
 	// before running the test
 	conn, err := MySQLDriver{}.Open(relaxedDsn)
@@ -1643,7 +1645,7 @@ func TestSqlInjection(t *testing.T) {
 
 	dsns := []string{
 		dsn,
-		dsn + "&sql_mode=NO_BACKSLASH_ESCAPES",
+		dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'",
 	}
 	for _, testdsn := range dsns {
 		runTests(t, testdsn, createTest("1 OR 1=1"))
@@ -1673,9 +1675,56 @@ func TestInsertRetrieveEscapedData(t *testing.T) {
 
 	dsns := []string{
 		dsn,
-		dsn + "&sql_mode=NO_BACKSLASH_ESCAPES",
+		dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'",
 	}
 	for _, testdsn := range dsns {
 		runTests(t, testdsn, testData)
 	}
 }
+
+func TestUnixSocketAuthFail(t *testing.T) {
+	runTests(t, dsn, func(dbt *DBTest) {
+		// Save the current logger so we can restore it.
+		oldLogger := errLog
+
+		// Set a new logger so we can capture its output.
+		buffer := bytes.NewBuffer(make([]byte, 0, 64))
+		newLogger := log.New(buffer, "prefix: ", 0)
+		SetLogger(newLogger)
+
+		// Restore the logger.
+		defer SetLogger(oldLogger)
+
+		// Make a new DSN that uses the MySQL socket file and a bad password, which
+		// we can make by simply appending any character to the real password.
+		badPass := pass + "x"
+		socket := ""
+		if prot == "unix" {
+			socket = addr
+		} else {
+			// Get socket file from MySQL.
+			err := dbt.db.QueryRow("SELECT @@socket").Scan(&socket)
+			if err != nil {
+				t.Fatalf("Error on SELECT @@socket: %s", err.Error())
+			}
+		}
+		t.Logf("socket: %s", socket)
+		badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s&strict=true", user, badPass, socket, dbname)
+		db, err := sql.Open("mysql", badDSN)
+		if err != nil {
+			t.Fatalf("Error connecting: %s", err.Error())
+		}
+		defer db.Close()
+
+		// Connect to MySQL for real. This will cause an auth failure.
+		err = db.Ping()
+		if err == nil {
+			t.Error("expected Ping() to return an error")
+		}
+
+		// The driver should not log anything.
+		if actual := buffer.String(); actual != "" {
+			t.Errorf("expected no output, got %q", actual)
+		}
+	})
+}

+ 22 - 4
infile.go

@@ -13,11 +13,14 @@ import (
 	"io"
 	"os"
 	"strings"
+	"sync"
 )
 
 var (
-	fileRegister   map[string]bool
-	readerRegister map[string]func() io.Reader
+	fileRegister       map[string]bool
+	fileRegisterLock   sync.RWMutex
+	readerRegister     map[string]func() io.Reader
+	readerRegisterLock sync.RWMutex
 )
 
 // RegisterLocalFile adds the given file to the file whitelist,
@@ -32,17 +35,21 @@ var (
 //  ...
 //
 func RegisterLocalFile(filePath string) {
+	fileRegisterLock.Lock()
 	// lazy map init
 	if fileRegister == nil {
 		fileRegister = make(map[string]bool)
 	}
 
 	fileRegister[strings.Trim(filePath, `"`)] = true
+	fileRegisterLock.Unlock()
 }
 
 // DeregisterLocalFile removes the given filepath from the whitelist.
 func DeregisterLocalFile(filePath string) {
+	fileRegisterLock.Lock()
 	delete(fileRegister, strings.Trim(filePath, `"`))
+	fileRegisterLock.Unlock()
 }
 
 // RegisterReaderHandler registers a handler function which is used
@@ -61,18 +68,22 @@ func DeregisterLocalFile(filePath string) {
 //  ...
 //
 func RegisterReaderHandler(name string, handler func() io.Reader) {
+	readerRegisterLock.Lock()
 	// lazy map init
 	if readerRegister == nil {
 		readerRegister = make(map[string]func() io.Reader)
 	}
 
 	readerRegister[name] = handler
+	readerRegisterLock.Unlock()
 }
 
 // DeregisterReaderHandler removes the ReaderHandler function with
 // the given name from the registry.
 func DeregisterReaderHandler(name string) {
+	readerRegisterLock.Lock()
 	delete(readerRegister, name)
+	readerRegisterLock.Unlock()
 }
 
 func deferredClose(err *error, closer io.Closer) {
@@ -90,7 +101,11 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
 		// The server might return an an absolute path. See issue #355.
 		name = name[idx+8:]
 
-		if handler, inMap := readerRegister[name]; inMap {
+		readerRegisterLock.RLock()
+		handler, inMap := readerRegister[name]
+		readerRegisterLock.RUnlock()
+
+		if inMap {
 			rdr = handler()
 			if rdr != nil {
 				data = make([]byte, 4+mc.maxWriteSize)
@@ -106,7 +121,10 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
 		}
 	} else { // File
 		name = strings.Trim(name, `"`)
-		if mc.cfg.allowAllFiles || fileRegister[name] {
+		fileRegisterLock.RLock()
+		fr := fileRegister[name]
+		fileRegisterLock.RUnlock()
+		if mc.cfg.allowAllFiles || fr {
 			var file *os.File
 			var fi os.FileInfo