소스 검색

Specialize escape functions for string

benchmark                  old ns/op     new ns/op     delta
BenchmarkInterpolation     2463          2118          -14.01%

benchmark                  old allocs     new allocs     delta
BenchmarkInterpolation     3              2              -33.33%

benchmark                  old bytes     new bytes     delta
BenchmarkInterpolation     496           448           -9.68%
INADA Naoki 11 년 전
부모
커밋
43536c7d6d
3개의 변경된 파일96개의 추가작업 그리고 9개의 파일을 삭제
  1. 13 3
      connection.go
  2. 81 4
      utils.go
  3. 2 2
      utils_test.go

+ 13 - 3
connection.go

@@ -169,9 +169,19 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 func (mc *mysqlConn) escapeBytes(buf, v []byte) []byte {
 	buf = append(buf, '\'')
 	if mc.status&statusNoBackslashEscapes == 0 {
-		buf = escapeBackslash(buf, v)
+		buf = escapeBytesBackslash(buf, v)
 	} else {
-		buf = escapeQuotes(buf, v)
+		buf = escapeBytesQuotes(buf, v)
+	}
+	return append(buf, '\'')
+}
+
+func (mc *mysqlConn) escapeString(buf []byte, v string) []byte {
+	buf = append(buf, '\'')
+	if mc.status&statusNoBackslashEscapes == 0 {
+		buf = escapeStringBackslash(buf, v)
+	} else {
+		buf = escapeStringQuotes(buf, v)
 	}
 	return append(buf, '\'')
 }
@@ -293,7 +303,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
 				buf = mc.escapeBytes(buf, v)
 			}
 		case string:
-			buf = mc.escapeBytes(buf, []byte(v))
+			buf = mc.escapeString(buf, v)
 		default:
 			return "", driver.ErrSkip
 		}

+ 81 - 4
utils.go

@@ -807,12 +807,12 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte {
 		byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56))
 }
 
-// Escape string with backslashes (\)
+// escapeBytesBackslash escapes []byte with backslashes (\)
 // This escapes the contents of a string (provided as []byte) by adding backslashes before special
 // characters, and turning others into specific escape sequences, such as
 // turning newlines into \n and null bytes into \0.
 // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932
-func escapeBackslash(buf, v []byte) []byte {
+func escapeBytesBackslash(buf, v []byte) []byte {
 	pos := len(buf)
 	end := pos + len(v)*2
 	if cap(buf) < end {
@@ -861,12 +861,63 @@ func escapeBackslash(buf, v []byte) []byte {
 	return buf[:pos]
 }
 
-// Escape apostrophes by doubling them up
+// escapeStringBackslash is similar to escapeBytesBackslash but for string.
+func escapeStringBackslash(buf []byte, v string) []byte {
+	pos := len(buf)
+	end := pos + len(v)*2
+	if cap(buf) < end {
+		n := make([]byte, pos+end)
+		copy(n, buf)
+		buf = n
+	}
+	buf = buf[0:end]
+
+	for i := 0; i < len(v); i++ {
+		c := v[i]
+		switch c {
+		case '\x00':
+			buf[pos] = '\\'
+			buf[pos+1] = '0'
+			pos += 2
+		case '\n':
+			buf[pos] = '\\'
+			buf[pos+1] = 'n'
+			pos += 2
+		case '\r':
+			buf[pos] = '\\'
+			buf[pos+1] = 'r'
+			pos += 2
+		case '\x1a':
+			buf[pos] = '\\'
+			buf[pos+1] = 'Z'
+			pos += 2
+		case '\'':
+			buf[pos] = '\\'
+			buf[pos+1] = '\''
+			pos += 2
+		case '"':
+			buf[pos] = '\\'
+			buf[pos+1] = '"'
+			pos += 2
+		case '\\':
+			buf[pos] = '\\'
+			buf[pos+1] = '\\'
+			pos += 2
+		default:
+			buf[pos] = c
+			pos += 1
+		}
+	}
+
+	return buf[:pos]
+}
+
+// escapeBytesQuotes escapes apostrophes in []byte by doubling them up.
 // This escapes the contents of a string by doubling up any apostrophes that
 // it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in
 // effect on the server.
 // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038
-func escapeQuotes(buf, v []byte) []byte {
+func escapeBytesQuotes(buf, v []byte) []byte {
 	pos := len(buf)
 	end := pos + len(v)*2
 	if cap(buf) < end {
@@ -889,3 +940,29 @@ func escapeQuotes(buf, v []byte) []byte {
 
 	return buf[:pos]
 }
+
+// escapeStringQuotes is similar to escapeBytesQuotes but for string.
+func escapeStringQuotes(buf []byte, v string) []byte {
+	pos := len(buf)
+	end := pos + len(v)*2
+	if cap(buf) < end {
+		n := make([]byte, pos+end)
+		copy(n, buf)
+		buf = n
+	}
+	buf = buf[0:end]
+
+	for i := 0; i < len(v); i++ {
+		c := v[i]
+		if c == '\'' {
+			buf[pos] = '\''
+			buf[pos+1] = '\''
+			pos += 2
+		} else {
+			buf[pos] = c
+			pos++
+		}
+	}
+
+	return buf[:pos]
+}

+ 2 - 2
utils_test.go

@@ -255,7 +255,7 @@ func TestFormatBinaryDateTime(t *testing.T) {
 
 func TestEscapeBackslash(t *testing.T) {
 	expect := func(expected, value string) {
-		actual := string(escapeBackslash([]byte{}, []byte(value)))
+		actual := string(escapeBytesBackslash([]byte{}, []byte(value)))
 		if actual != expected {
 			t.Errorf(
 				"expected %s, got %s",
@@ -275,7 +275,7 @@ func TestEscapeBackslash(t *testing.T) {
 
 func TestEscapeQuotes(t *testing.T) {
 	expect := func(expected, value string) {
-		actual := string(escapeQuotes([]byte{}, []byte(value)))
+		actual := string(escapeBytesQuotes([]byte{}, []byte(value)))
 		if actual != expected {
 			t.Errorf(
 				"expected %s, got %s",