Browse Source

Merge pull request #309 from arvenil/placeholder

Placeholder interpolation
Julien Schmidt 10 years ago
parent
commit
200c80b773
11 changed files with 631 additions and 58 deletions
  1. 1 1
      .travis.yml
  2. 1 0
      AUTHORS
  3. 13 0
      README.md
  4. 38 1
      benchmark_test.go
  5. 14 0
      collations.go
  6. 191 35
      connection.go
  7. 22 0
      const.go
  8. 76 2
      driver_test.go
  9. 1 0
      packets.go
  10. 169 6
      utils.go
  11. 105 13
      utils_test.go

+ 1 - 1
.travis.yml

@@ -1,9 +1,9 @@
 sudo: false
 language: go
 go:
-  - 1.1
   - 1.2
   - 1.3
+  - 1.4
   - tip
 
 before_script:

+ 1 - 0
AUTHORS

@@ -24,6 +24,7 @@ INADA Naoki <songofacandy at gmail.com>
 James Harr <james.harr at gmail.com>
 Jian Zhen <zhenjl at gmail.com>
 Julien Schmidt <go-sql-driver at julienschmidt.com>
+Kamil Dziedzic <kamil at klecza.pl>
 Leonardo YongUk Kim <dalinaum at gmail.com>
 Lucas Liu <extrafliu at gmail.com>
 Luke Scott <luke at webconnex.com>

+ 13 - 0
README.md

@@ -182,6 +182,19 @@ SELECT u.id FROM users as u
 
 will return `u.id` instead of just `id` if `columnsWithAlias=true`.
 
+##### `interpolateParams`
+
+```
+Type:           bool
+Valid Values:   true, false
+Default:        false
+```
+
+If `interpolateParams` is true, placeholders (`?`) in calls to `db.Query()` and `db.Exec()` are interpolated into a single query string with given parameters. This reduces the number of roundtrips, since the driver has to prepare a statement, execute it with given parameters and close the statement again with `interpolateParams=false`.
+
+NOTE: *This may introduce a SQL injection vulnerability when connection encoding is multibyte encoding except for UTF-8 (e.g. CP932)!*
+(See http://stackoverflow.com/a/12118602/3430118)
+
 ##### `loc`
 
 ```

+ 38 - 1
benchmark_test.go

@@ -11,10 +11,13 @@ package mysql
 import (
 	"bytes"
 	"database/sql"
+	"database/sql/driver"
+	"math"
 	"strings"
 	"sync"
 	"sync/atomic"
 	"testing"
+	"time"
 )
 
 type TB testing.B
@@ -45,7 +48,11 @@ func initDB(b *testing.B, queries ...string) *sql.DB {
 	db := tb.checkDB(sql.Open("mysql", dsn))
 	for _, query := range queries {
 		if _, err := db.Exec(query); err != nil {
-			b.Fatalf("Error on %q: %v", query, err)
+			if w, ok := err.(MySQLWarnings); ok {
+				b.Logf("Warning on %q: %v", query, w)
+			} else {
+				b.Fatalf("Error on %q: %v", query, err)
+			}
 		}
 	}
 	return db
@@ -206,3 +213,33 @@ func BenchmarkRoundtripBin(b *testing.B) {
 		rows.Close()
 	}
 }
+
+func BenchmarkInterpolation(b *testing.B) {
+	mc := &mysqlConn{
+		cfg: &config{
+			interpolateParams: true,
+			loc:               time.UTC,
+		},
+		maxPacketAllowed: maxPacketSize,
+		maxWriteSize:     maxPacketSize - 1,
+	}
+
+	args := []driver.Value{
+		int64(42424242),
+		float64(math.Pi),
+		false,
+		time.Unix(1423411542, 807015000),
+		[]byte("bytes containing special chars ' \" \a \x00"),
+		"string containing special chars ' \" \a \x00",
+	}
+	q := "SELECT ?, ?, ?, ?, ?, ?"
+
+	b.ReportAllocs()
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		_, err := mc.interpolateParams(q, args)
+		if err != nil {
+			b.Fatal(err)
+		}
+	}
+}

+ 14 - 0
collations.go

@@ -234,3 +234,17 @@ var collations = map[string]byte{
 	"utf8mb4_unicode_520_ci":   246,
 	"utf8mb4_vietnamese_ci":    247,
 }
+
+// A blacklist of collations which is unsafe to interpolate parameters.
+// These multibyte encodings may contains 0x5c (`\`) in their trailing bytes.
+var unsafeCollations = map[byte]bool{
+	1:  true, // big5_chinese_ci
+	13: true, // sjis_japanese_ci
+	28: true, // gbk_chinese_ci
+	84: true, // big5_bin
+	86: true, // gb2312_bin
+	87: true, // gbk_bin
+	88: true, // sjis_bin
+	95: true, // cp932_japanese_ci
+	96: true, // cp932_bin
+}

+ 191 - 35
connection.go

@@ -13,6 +13,7 @@ import (
 	"database/sql/driver"
 	"errors"
 	"net"
+	"strconv"
 	"strings"
 	"time"
 )
@@ -26,6 +27,7 @@ type mysqlConn struct {
 	maxPacketAllowed int
 	maxWriteSize     int
 	flags            clientFlag
+	status           statusFlag
 	sequence         uint8
 	parseTime        bool
 	strict           bool
@@ -46,6 +48,7 @@ type config struct {
 	allowOldPasswords bool
 	clientFoundRows   bool
 	columnsWithAlias  bool
+	interpolateParams bool
 }
 
 // Handles parameters set in DSN after the connection is established
@@ -162,28 +165,174 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 	return stmt, err
 }
 
+// estimateParamLength calculates upper bound of string length from types.
+func estimateParamLength(args []driver.Value) (int, bool) {
+	l := 0
+	for _, a := range args {
+		switch v := a.(type) {
+		case int64, float64:
+			// 24 (-1.7976931348623157e+308) may be upper bound. But I'm not sure.
+			l += 25
+		case bool:
+			l += 1 // 0 or 1
+		case time.Time:
+			l += 30 // '1234-12-23 12:34:56.777777'
+		case string:
+			l += len(v)*2 + 2
+		case []byte:
+			l += len(v)*2 + 2
+		default:
+			return 0, false
+		}
+	}
+	return l, true
+}
+
+func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
+	estimated, ok := estimateParamLength(args)
+	if !ok {
+		return "", driver.ErrSkip
+	}
+	estimated += len(query)
+
+	buf := make([]byte, 0, estimated)
+	argPos := 0
+
+	for i := 0; i < len(query); i++ {
+		q := strings.IndexByte(query[i:], '?')
+		if q == -1 {
+			buf = append(buf, query[i:]...)
+			break
+		}
+		buf = append(buf, query[i:i+q]...)
+		i += q
+
+		arg := args[argPos]
+		argPos++
+
+		if arg == nil {
+			buf = append(buf, "NULL"...)
+			continue
+		}
+
+		switch v := arg.(type) {
+		case int64:
+			buf = strconv.AppendInt(buf, v, 10)
+		case float64:
+			buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
+		case bool:
+			if v {
+				buf = append(buf, '1')
+			} else {
+				buf = append(buf, '0')
+			}
+		case time.Time:
+			if v.IsZero() {
+				buf = append(buf, "'0000-00-00'"...)
+			} else {
+				v := v.In(mc.cfg.loc)
+				v = v.Add(time.Nanosecond * 500) // To round under microsecond
+				year := v.Year()
+				year100 := year / 100
+				year1 := year % 100
+				month := v.Month()
+				day := v.Day()
+				hour := v.Hour()
+				minute := v.Minute()
+				second := v.Second()
+				micro := v.Nanosecond() / 1000
+
+				buf = append(buf, []byte{
+					'\'',
+					digits10[year100], digits01[year100],
+					digits10[year1], digits01[year1],
+					'-',
+					digits10[month], digits01[month],
+					'-',
+					digits10[day], digits01[day],
+					' ',
+					digits10[hour], digits01[hour],
+					':',
+					digits10[minute], digits01[minute],
+					':',
+					digits10[second], digits01[second],
+				}...)
+
+				if micro != 0 {
+					micro10000 := micro / 10000
+					micro100 := micro / 100 % 100
+					micro1 := micro % 100
+					buf = append(buf, []byte{
+						'.',
+						digits10[micro10000], digits01[micro10000],
+						digits10[micro100], digits01[micro100],
+						digits10[micro1], digits01[micro1],
+					}...)
+				}
+				buf = append(buf, '\'')
+			}
+		case []byte:
+			if v == nil {
+				buf = append(buf, "NULL"...)
+			} else {
+				buf = append(buf, '\'')
+				if mc.status&statusNoBackslashEscapes == 0 {
+					buf = escapeBytesBackslash(buf, v)
+				} else {
+					buf = escapeBytesQuotes(buf, v)
+				}
+				buf = append(buf, '\'')
+			}
+		case string:
+			buf = append(buf, '\'')
+			if mc.status&statusNoBackslashEscapes == 0 {
+				buf = escapeStringBackslash(buf, v)
+			} else {
+				buf = escapeStringQuotes(buf, v)
+			}
+			buf = append(buf, '\'')
+		default:
+			return "", driver.ErrSkip
+		}
+
+		if len(buf)+4 > mc.maxPacketAllowed {
+			return "", driver.ErrSkip
+		}
+	}
+	if argPos != len(args) {
+		return "", driver.ErrSkip
+	}
+	return string(buf), nil
+}
+
 func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
 	if mc.netConn == nil {
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 	}
-	if len(args) == 0 { // no args, fastpath
-		mc.affectedRows = 0
-		mc.insertId = 0
-
-		err := mc.exec(query)
-		if err == nil {
-			return &mysqlResult{
-				affectedRows: int64(mc.affectedRows),
-				insertId:     int64(mc.insertId),
-			}, err
+	if len(args) != 0 {
+		if !mc.cfg.interpolateParams {
+			return nil, driver.ErrSkip
 		}
-		return nil, err
+		// try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
+		prepared, err := mc.interpolateParams(query, args)
+		if err != nil {
+			return nil, err
+		}
+		query = prepared
+		args = nil
 	}
+	mc.affectedRows = 0
+	mc.insertId = 0
 
-	// with args, must use prepared stmt
-	return nil, driver.ErrSkip
-
+	err := mc.exec(query)
+	if err == nil {
+		return &mysqlResult{
+			affectedRows: int64(mc.affectedRows),
+			insertId:     int64(mc.insertId),
+		}, err
+	}
+	return nil, err
 }
 
 // Internal function to execute commands
@@ -212,31 +361,38 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 	}
-	if len(args) == 0 { // no args, fastpath
-		// Send command
-		err := mc.writeCommandPacketStr(comQuery, query)
+	if len(args) != 0 {
+		if !mc.cfg.interpolateParams {
+			return nil, driver.ErrSkip
+		}
+		// try client-side prepare to reduce roundtrip
+		prepared, err := mc.interpolateParams(query, args)
+		if err != nil {
+			return nil, err
+		}
+		query = prepared
+		args = nil
+	}
+	// Send command
+	err := mc.writeCommandPacketStr(comQuery, query)
+	if err == nil {
+		// Read Result
+		var resLen int
+		resLen, err = mc.readResultSetHeaderPacket()
 		if err == nil {
-			// Read Result
-			var resLen int
-			resLen, err = mc.readResultSetHeaderPacket()
-			if err == nil {
-				rows := new(textRows)
-				rows.mc = mc
-
-				if resLen == 0 {
-					// no columns, no more data
-					return emptyRows{}, nil
-				}
-				// Columns
-				rows.columns, err = mc.readColumns(resLen)
-				return rows, err
+			rows := new(textRows)
+			rows.mc = mc
+
+			if resLen == 0 {
+				// no columns, no more data
+				return emptyRows{}, nil
 			}
+			// Columns
+			rows.columns, err = mc.readColumns(resLen)
+			return rows, err
 		}
-		return nil, err
 	}
-
-	// with args, must use prepared stmt
-	return nil, driver.ErrSkip
+	return nil, err
 }
 
 // Gets the value of the given MySQL System Variable

+ 22 - 0
const.go

@@ -130,3 +130,25 @@ const (
 	flagUnknown3
 	flagUnknown4
 )
+
+// http://dev.mysql.com/doc/internals/en/status-flags.html
+
+type statusFlag uint16
+
+const (
+	statusInTrans statusFlag = 1 << iota
+	statusInAutocommit
+	statusReserved // Not in documentation
+	statusMoreResultsExists
+	statusNoGoodIndexUsed
+	statusNoIndexUsed
+	statusCursorExists
+	statusLastRowSent
+	statusDbDropped
+	statusNoBackslashEscapes
+	statusMetadataChanged
+	statusQueryWasSlow
+	statusPsOutParams
+	statusInTransReadonly
+	statusSessionStateChanged
+)

+ 76 - 2
driver_test.go

@@ -87,10 +87,25 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
 
 	db.Exec("DROP TABLE IF EXISTS test")
 
+	dsn2 := dsn + "&interpolateParams=true"
+	var db2 *sql.DB
+	if _, err := parseDSN(dsn2); err != errInvalidDSNUnsafeCollation {
+		db2, err = sql.Open("mysql", dsn2)
+		if err != nil {
+			t.Fatalf("Error connecting: %s", err.Error())
+		}
+		defer db2.Close()
+	}
+
 	dbt := &DBTest{t, db}
+	dbt2 := &DBTest{t, db2}
 	for _, test := range tests {
 		test(dbt)
 		dbt.db.Exec("DROP TABLE IF EXISTS test")
+		if db2 != nil {
+			test(dbt2)
+			dbt2.db.Exec("DROP TABLE IF EXISTS test")
+		}
 	}
 }
 
@@ -855,7 +870,7 @@ func TestLoadData(t *testing.T) {
 					dbt.Fatalf("%d != %d", i, id)
 				}
 				if values[i-1] != value {
-					dbt.Fatalf("%s != %s", values[i-1], value)
+					dbt.Fatalf("%q != %q", values[i-1], value)
 				}
 			}
 			err = rows.Err()
@@ -880,7 +895,7 @@ func TestLoadData(t *testing.T) {
 
 		// Local File
 		RegisterLocalFile(file.Name())
-		dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE '%q' INTO TABLE test", file.Name()))
+		dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name()))
 		verifyLoadDataResult()
 		// negative test
 		_, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'doesnotexist' INTO TABLE test")
@@ -1538,3 +1553,62 @@ func TestCustomDial(t *testing.T) {
 		t.Fatalf("Connection failed: %s", err.Error())
 	}
 }
+
+func TestSqlInjection(t *testing.T) {
+	createTest := func(arg string) func(dbt *DBTest) {
+		return func(dbt *DBTest) {
+			dbt.mustExec("CREATE TABLE test (v INTEGER)")
+			dbt.mustExec("INSERT INTO test VALUES (?)", 1)
+
+			var v int
+			// NULL can't be equal to anything, the idea here is to inject query so it returns row
+			// This test verifies that escapeQuotes and escapeBackslash are working properly
+			err := dbt.db.QueryRow("SELECT v FROM test WHERE NULL = ?", arg).Scan(&v)
+			if err == sql.ErrNoRows {
+				return // success, sql injection failed
+			} else if err == nil {
+				dbt.Errorf("Sql injection successful with arg: %s", arg)
+			} else {
+				dbt.Errorf("Error running query with arg: %s; err: %s", arg, err.Error())
+			}
+		}
+	}
+
+	dsns := []string{
+		dsn,
+		dsn + "&sql_mode=NO_BACKSLASH_ESCAPES",
+	}
+	for _, testdsn := range dsns {
+		runTests(t, testdsn, createTest("1 OR 1=1"))
+		runTests(t, testdsn, createTest("' OR '1'='1"))
+	}
+}
+
+// Test if inserted data is correctly retrieved after being escaped
+func TestInsertRetrieveEscapedData(t *testing.T) {
+	testData := func(dbt *DBTest) {
+		dbt.mustExec("CREATE TABLE test (v VARCHAR(255))")
+
+		// All sequences that are escaped by escapeQuotes and escapeBackslash
+		v := "foo \x00\n\r\x1a\"'\\"
+		dbt.mustExec("INSERT INTO test VALUES (?)", v)
+
+		var out string
+		err := dbt.db.QueryRow("SELECT v FROM test").Scan(&out)
+		if err != nil {
+			dbt.Fatalf("%s", err.Error())
+		}
+
+		if out != v {
+			dbt.Errorf("%q != %q", out, v)
+		}
+	}
+
+	dsns := []string{
+		dsn,
+		dsn + "&sql_mode=NO_BACKSLASH_ESCAPES",
+	}
+	for _, testdsn := range dsns {
+		runTests(t, testdsn, testData)
+	}
+}

+ 1 - 0
packets.go

@@ -484,6 +484,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
 	mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
 
 	// server_status [2 bytes]
+	mc.status = statusFlag(data[1+n+m]) | statusFlag(data[1+n+m+1])<<8
 
 	// warning count [2 bytes]
 	if !mc.strict {

+ 169 - 6
utils.go

@@ -25,9 +25,10 @@ import (
 var (
 	tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs
 
-	errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?")
-	errInvalidDSNAddr      = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)")
-	errInvalidDSNNoSlash   = errors.New("Invalid DSN: Missing the slash separating the database name")
+	errInvalidDSNUnescaped       = errors.New("Invalid DSN: Did you forget to escape a param value?")
+	errInvalidDSNAddr            = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)")
+	errInvalidDSNNoSlash         = errors.New("Invalid DSN: Missing the slash separating the database name")
+	errInvalidDSNUnsafeCollation = errors.New("Invalid DSN: interpolateParams can be used with ascii, latin1, utf8 and utf8mb4 charset")
 )
 
 func init() {
@@ -147,6 +148,10 @@ func parseDSN(dsn string) (cfg *config, err error) {
 		return nil, errInvalidDSNNoSlash
 	}
 
+	if cfg.interpolateParams && unsafeCollations[cfg.collation] {
+		return nil, errInvalidDSNUnsafeCollation
+	}
+
 	// Set default network if empty
 	if cfg.net == "" {
 		cfg.net = "tcp"
@@ -180,6 +185,14 @@ func parseDSNParams(cfg *config, params string) (err error) {
 		// cfg params
 		switch value := param[1]; param[0] {
 
+		// Enable client side placeholder substitution
+		case "interpolateParams":
+			var isBool bool
+			cfg.interpolateParams, isBool = readBool(value)
+			if !isBool {
+				return fmt.Errorf("Invalid Bool value: %s", value)
+			}
+
 		// Disable INFILE whitelist / enable all files
 		case "allowAllFiles":
 			var isBool bool
@@ -216,7 +229,7 @@ func parseDSNParams(cfg *config, params string) (err error) {
 			}
 			cfg.collation = collation
 			break
-		
+
 		case "columnsWithAlias":
 			var isBool bool
 			cfg.columnsWithAlias, isBool = readBool(value)
@@ -532,11 +545,12 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va
 // The current behavior depends on database/sql copying the result.
 var zeroDateTime = []byte("0000-00-00 00:00:00.000000")
 
+const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"
+const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999"
+
 func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value, error) {
 	// length expects the deterministic length of the zero value,
 	// negative time and 100+ hours are automatically added if needed
-	const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"
-	const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999"
 	if len(src) == 0 {
 		if justTime {
 			return zeroDateTime[11 : 11+length], nil
@@ -798,3 +812,152 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte {
 	return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24),
 		byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
 }
+
+// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize.
+// If cap(buf) is not enough, reallocate new buffer.
+func reserveBuffer(buf []byte, appendSize int) []byte {
+	newSize := len(buf) + appendSize
+	if cap(buf) < newSize {
+		// Grow buffer exponentially
+		newBuf := make([]byte, len(buf)*2+appendSize)
+		copy(newBuf, buf)
+		buf = newBuf
+	}
+	return buf[:newSize]
+}
+
+// escapeBytesBackslash escapes []byte with backslashes (\)
+// This escapes the contents of a string (provided as []byte) by adding backslashes before special
+// characters, and turning others into specific escape sequences, such as
+// turning newlines into \n and null bytes into \0.
+// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932
+func escapeBytesBackslash(buf, v []byte) []byte {
+	pos := len(buf)
+	buf = reserveBuffer(buf, len(v)*2)
+
+	for _, c := range v {
+		switch c {
+		case '\x00':
+			buf[pos] = '\\'
+			buf[pos+1] = '0'
+			pos += 2
+		case '\n':
+			buf[pos] = '\\'
+			buf[pos+1] = 'n'
+			pos += 2
+		case '\r':
+			buf[pos] = '\\'
+			buf[pos+1] = 'r'
+			pos += 2
+		case '\x1a':
+			buf[pos] = '\\'
+			buf[pos+1] = 'Z'
+			pos += 2
+		case '\'':
+			buf[pos] = '\\'
+			buf[pos+1] = '\''
+			pos += 2
+		case '"':
+			buf[pos] = '\\'
+			buf[pos+1] = '"'
+			pos += 2
+		case '\\':
+			buf[pos] = '\\'
+			buf[pos+1] = '\\'
+			pos += 2
+		default:
+			buf[pos] = c
+			pos += 1
+		}
+	}
+
+	return buf[:pos]
+}
+
+// escapeStringBackslash is similar to escapeBytesBackslash but for string.
+func escapeStringBackslash(buf []byte, v string) []byte {
+	pos := len(buf)
+	buf = reserveBuffer(buf, len(v)*2)
+
+	for i := 0; i < len(v); i++ {
+		c := v[i]
+		switch c {
+		case '\x00':
+			buf[pos] = '\\'
+			buf[pos+1] = '0'
+			pos += 2
+		case '\n':
+			buf[pos] = '\\'
+			buf[pos+1] = 'n'
+			pos += 2
+		case '\r':
+			buf[pos] = '\\'
+			buf[pos+1] = 'r'
+			pos += 2
+		case '\x1a':
+			buf[pos] = '\\'
+			buf[pos+1] = 'Z'
+			pos += 2
+		case '\'':
+			buf[pos] = '\\'
+			buf[pos+1] = '\''
+			pos += 2
+		case '"':
+			buf[pos] = '\\'
+			buf[pos+1] = '"'
+			pos += 2
+		case '\\':
+			buf[pos] = '\\'
+			buf[pos+1] = '\\'
+			pos += 2
+		default:
+			buf[pos] = c
+			pos += 1
+		}
+	}
+
+	return buf[:pos]
+}
+
+// escapeBytesQuotes escapes apostrophes in []byte by doubling them up.
+// This escapes the contents of a string by doubling up any apostrophes that
+// it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in
+// effect on the server.
+// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038
+func escapeBytesQuotes(buf, v []byte) []byte {
+	pos := len(buf)
+	buf = reserveBuffer(buf, len(v)*2)
+
+	for _, c := range v {
+		if c == '\'' {
+			buf[pos] = '\''
+			buf[pos+1] = '\''
+			pos += 2
+		} else {
+			buf[pos] = c
+			pos++
+		}
+	}
+
+	return buf[:pos]
+}
+
+// escapeStringQuotes is similar to escapeBytesQuotes but for string.
+func escapeStringQuotes(buf []byte, v string) []byte {
+	pos := len(buf)
+	buf = reserveBuffer(buf, len(v)*2)
+
+	for i := 0; i < len(v); i++ {
+		c := v[i]
+		if c == '\'' {
+			buf[pos] = '\''
+			buf[pos+1] = '\''
+			pos += 2
+		} else {
+			buf[pos] = c
+			pos++
+		}
+	}
+
+	return buf[:pos]
+}

+ 105 - 13
utils_test.go

@@ -22,19 +22,19 @@ var testDSNs = []struct {
 	out string
 	loc *time.Location
 }{
-	{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
-	{"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:true}", time.UTC},
-	{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
-	{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
-	{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
-	{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:30000000000 collation:224 allowAllFiles:true allowOldPasswords:true clientFoundRows:true columnsWithAlias:false}", time.UTC},
-	{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.Local},
-	{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
-	{"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
-	{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
-	{"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
-	{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
-	{"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
+	{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
+	{"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:true interpolateParams:false}", time.UTC},
+	{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
+	{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
+	{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
+	{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:30000000000 collation:224 allowAllFiles:true allowOldPasswords:true clientFoundRows:true columnsWithAlias:false interpolateParams:false}", time.UTC},
+	{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.Local},
+	{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
+	{"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
+	{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
+	{"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
+	{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
+	{"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
 }
 
 func TestDSNParser(t *testing.T) {
@@ -116,6 +116,43 @@ func TestDSNWithCustomTLS(t *testing.T) {
 	DeregisterTLSConfig("utils_test")
 }
 
+func TestDSNUnsafeCollation(t *testing.T) {
+	_, err := parseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=true")
+	if err != errInvalidDSNUnsafeCollation {
+		t.Error("Expected %v, Got %v", errInvalidDSNUnsafeCollation, err)
+	}
+
+	_, err = parseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=false")
+	if err != nil {
+		t.Error("Expected %v, Got %v", nil, err)
+	}
+
+	_, err = parseDSN("/dbname?collation=gbk_chinese_ci")
+	if err != nil {
+		t.Error("Expected %v, Got %v", nil, err)
+	}
+
+	_, err = parseDSN("/dbname?collation=ascii_bin&interpolateParams=true")
+	if err != nil {
+		t.Error("Expected %v, Got %v", nil, err)
+	}
+
+	_, err = parseDSN("/dbname?collation=latin1_german1_ci&interpolateParams=true")
+	if err != nil {
+		t.Error("Expected %v, Got %v", nil, err)
+	}
+
+	_, err = parseDSN("/dbname?collation=utf8_general_ci&interpolateParams=true")
+	if err != nil {
+		t.Error("Expected %v, Got %v", nil, err)
+	}
+
+	_, err = parseDSN("/dbname?collation=utf8mb4_general_ci&interpolateParams=true")
+	if err != nil {
+		t.Error("Expected %v, Got %v", nil, err)
+	}
+}
+
 func BenchmarkParseDSN(b *testing.B) {
 	b.ReportAllocs()
 
@@ -252,3 +289,58 @@ func TestFormatBinaryDateTime(t *testing.T) {
 	expect("1978-12-30 15:46:23", 7, 19)
 	expect("1978-12-30 15:46:23.987654", 11, 26)
 }
+
+func TestEscapeBackslash(t *testing.T) {
+	expect := func(expected, value string) {
+		actual := string(escapeBytesBackslash([]byte{}, []byte(value)))
+		if actual != expected {
+			t.Errorf(
+				"expected %s, got %s",
+				expected, actual,
+			)
+		}
+
+		actual = string(escapeStringBackslash([]byte{}, value))
+		if actual != expected {
+			t.Errorf(
+				"expected %s, got %s",
+				expected, actual,
+			)
+		}
+	}
+
+	expect("foo\\0bar", "foo\x00bar")
+	expect("foo\\nbar", "foo\nbar")
+	expect("foo\\rbar", "foo\rbar")
+	expect("foo\\Zbar", "foo\x1abar")
+	expect("foo\\\"bar", "foo\"bar")
+	expect("foo\\\\bar", "foo\\bar")
+	expect("foo\\'bar", "foo'bar")
+}
+
+func TestEscapeQuotes(t *testing.T) {
+	expect := func(expected, value string) {
+		actual := string(escapeBytesQuotes([]byte{}, []byte(value)))
+		if actual != expected {
+			t.Errorf(
+				"expected %s, got %s",
+				expected, actual,
+			)
+		}
+
+		actual = string(escapeStringQuotes([]byte{}, value))
+		if actual != expected {
+			t.Errorf(
+				"expected %s, got %s",
+				expected, actual,
+			)
+		}
+	}
+
+	expect("foo\x00bar", "foo\x00bar") // not affected
+	expect("foo\nbar", "foo\nbar")     // not affected
+	expect("foo\rbar", "foo\rbar")     // not affected
+	expect("foo\x1abar", "foo\x1abar") // not affected
+	expect("foo''bar", "foo'bar")      // affected
+	expect("foo\"bar", "foo\"bar")     // not affected
+}