Browse Source

Reduce allocs in interpolateParams.

benchmark                  old ns/op     new ns/op     delta
BenchmarkInterpolation     4065          2533          -37.69%

benchmark                  old allocs     new allocs     delta
BenchmarkInterpolation     15             6              -60.00%

benchmark                  old bytes     new bytes     delta
BenchmarkInterpolation     1144          560           -51.05%
INADA Naoki 11 năm trước cách đây
mục cha
commit
029731571e
3 tập tin đã thay đổi với 84 bổ sung36 xóa
  1. 62 28
      connection.go
  2. 20 6
      utils.go
  3. 2 2
      utils_test.go

+ 62 - 28
connection.go

@@ -166,73 +166,107 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 }
 }
 
 
 // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/libmysql/libmysql.c#L1150-L1156
 // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/libmysql/libmysql.c#L1150-L1156
-func (mc *mysqlConn) escapeBytes(v []byte) string {
-	var escape func([]byte) []byte
+func (mc *mysqlConn) escapeBytes(buf, v []byte) []byte {
+	var escape func([]byte, []byte) []byte
 	if mc.status&statusNoBackslashEscapes == 0 {
 	if mc.status&statusNoBackslashEscapes == 0 {
-		escape = EscapeString
+		escape = escapeString
 	} else {
 	} else {
-		escape = EscapeQuotes
+		escape = escapeQuotes
 	}
 	}
-	return "'" + string(escape(v)) + "'"
+	buf = append(buf, '\'')
+	buf = escape(buf, v)
+	buf = append(buf, '\'')
+	return buf
+}
+
+func estimateParamLength(args []driver.Value) (int, bool) {
+	l := 0
+	for _, a := range args {
+		switch v := a.(type) {
+		case int64, float64:
+			l += 20
+		case bool:
+			l += 5
+		case time.Time:
+			l += 30
+		case string:
+			l += len(v)*2 + 2
+		case []byte:
+			l += len(v)*2 + 2
+		default:
+			return 0, false
+		}
+	}
+	return l, true
 }
 }
 
 
 func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
 func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
-	chunks := strings.Split(query, "?")
-	if len(chunks) != len(args)+1 {
+	estimated, ok := estimateParamLength(args)
+	if !ok {
 		return "", driver.ErrSkip
 		return "", driver.ErrSkip
 	}
 	}
+	estimated += len(query)
 
 
-	parts := make([]string, len(chunks)+len(args))
-	parts[0] = chunks[0]
+	buf := make([]byte, 0, estimated)
+	argPos := 0
+
+	// Go 1.5 will optimize range([]byte(string)) to skip allocation.
+	for _, c := range []byte(query) {
+		if c != '?' {
+			buf = append(buf, c)
+			continue
+		}
+
+		arg := args[argPos]
+		argPos++
 
 
-	for i, arg := range args {
-		pos := i*2 + 1
-		parts[pos+1] = chunks[i+1]
 		if arg == nil {
 		if arg == nil {
-			parts[pos] = "NULL"
+			buf = append(buf, []byte("NULL")...)
 			continue
 			continue
 		}
 		}
+
 		switch v := arg.(type) {
 		switch v := arg.(type) {
 		case int64:
 		case int64:
-			parts[pos] = strconv.FormatInt(v, 10)
+			buf = strconv.AppendInt(buf, v, 10)
 		case float64:
 		case float64:
-			parts[pos] = strconv.FormatFloat(v, 'g', -1, 64)
+			buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
 		case bool:
 		case bool:
 			if v {
 			if v {
-				parts[pos] = "1"
+				buf = append(buf, '1')
 			} else {
 			} else {
-				parts[pos] = "0"
+				buf = append(buf, '0')
 			}
 			}
 		case time.Time:
 		case time.Time:
 			if v.IsZero() {
 			if v.IsZero() {
-				parts[pos] = "'0000-00-00'"
+				buf = append(buf, []byte("'0000-00-00'")...)
 			} else {
 			} else {
 				fmt := "'2006-01-02 15:04:05.999999'"
 				fmt := "'2006-01-02 15:04:05.999999'"
 				if v.Nanosecond() == 0 {
 				if v.Nanosecond() == 0 {
 					fmt = "'2006-01-02 15:04:05'"
 					fmt = "'2006-01-02 15:04:05'"
 				}
 				}
-				parts[pos] = v.In(mc.cfg.loc).Format(fmt)
+				s := v.In(mc.cfg.loc).Format(fmt)
+				buf = append(buf, []byte(s)...)
 			}
 			}
 		case []byte:
 		case []byte:
 			if v == nil {
 			if v == nil {
-				parts[pos] = "NULL"
+				buf = append(buf, []byte("NULL")...)
 			} else {
 			} else {
-				parts[pos] = mc.escapeBytes(v)
+				buf = mc.escapeBytes(buf, v)
 			}
 			}
 		case string:
 		case string:
-			parts[pos] = mc.escapeBytes([]byte(v))
+			buf = mc.escapeBytes(buf, []byte(v))
 		default:
 		default:
 			return "", driver.ErrSkip
 			return "", driver.ErrSkip
 		}
 		}
+
+		if len(buf)+4 > mc.maxPacketAllowed {
+			return "", driver.ErrSkip
+		}
 	}
 	}
-	pktSize := len(query) + 4 // 4 bytes for header.
-	for _, p := range parts {
-		pktSize += len(p)
-	}
-	if pktSize > mc.maxPacketAllowed {
+	if argPos != len(args) {
 		return "", driver.ErrSkip
 		return "", driver.ErrSkip
 	}
 	}
-	return strings.Join(parts, ""), nil
+	return string(buf), nil
 }
 }
 
 
 func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
 func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {

+ 20 - 6
utils.go

@@ -812,9 +812,16 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte {
 // characters, and turning others into specific escape sequences, such as
 // characters, and turning others into specific escape sequences, such as
 // turning newlines into \n and null bytes into \0.
 // turning newlines into \n and null bytes into \0.
 // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932
 // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932
-func EscapeString(v []byte) []byte {
-	buf := make([]byte, len(v)*2)
-	pos := 0
+func escapeString(buf, v []byte) []byte {
+	pos := len(buf)
+	end := pos + len(v)*2
+	if cap(buf) < end {
+		n := make([]byte, pos+end)
+		copy(n, buf)
+		buf = n
+	}
+	buf = buf[0:end]
+
 	for _, c := range v {
 	for _, c := range v {
 		switch c {
 		switch c {
 		case '\x00':
 		case '\x00':
@@ -859,9 +866,16 @@ func EscapeString(v []byte) []byte {
 // it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in
 // it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in
 // effect on the server.
 // effect on the server.
 // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038
 // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038
-func EscapeQuotes(v []byte) []byte {
-	buf := make([]byte, len(v)*2)
-	pos := 0
+func escapeQuotes(buf, v []byte) []byte {
+	pos := len(buf)
+	end := pos + len(v)*2
+	if cap(buf) < end {
+		n := make([]byte, pos+end)
+		copy(n, buf)
+		buf = n
+	}
+	buf = buf[0:end]
+
 	for _, c := range v {
 	for _, c := range v {
 		if c == '\'' {
 		if c == '\'' {
 			buf[pos] = '\''
 			buf[pos] = '\''

+ 2 - 2
utils_test.go

@@ -255,7 +255,7 @@ func TestFormatBinaryDateTime(t *testing.T) {
 
 
 func TestEscapeString(t *testing.T) {
 func TestEscapeString(t *testing.T) {
 	expect := func(expected, value string) {
 	expect := func(expected, value string) {
-		actual := string(EscapeString([]byte(value)))
+		actual := string(escapeString([]byte{}, []byte(value)))
 		if actual != expected {
 		if actual != expected {
 			t.Errorf(
 			t.Errorf(
 				"expected %s, got %s",
 				"expected %s, got %s",
@@ -275,7 +275,7 @@ func TestEscapeString(t *testing.T) {
 
 
 func TestEscapeQuotes(t *testing.T) {
 func TestEscapeQuotes(t *testing.T) {
 	expect := func(expected, value string) {
 	expect := func(expected, value string) {
-		actual := string(EscapeQuotes([]byte(value)))
+		actual := string(escapeQuotes([]byte{}, []byte(value)))
 		if actual != expected {
 		if actual != expected {
 			t.Errorf(
 			t.Errorf(
 				"expected %s, got %s",
 				"expected %s, got %s",