|
@@ -166,73 +166,107 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/libmysql/libmysql.c#L1150-L1156
|
|
// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/libmysql/libmysql.c#L1150-L1156
|
|
|
-func (mc *mysqlConn) escapeBytes(v []byte) string {
|
|
|
|
|
- var escape func([]byte) []byte
|
|
|
|
|
|
|
+func (mc *mysqlConn) escapeBytes(buf, v []byte) []byte {
|
|
|
|
|
+ var escape func([]byte, []byte) []byte
|
|
|
if mc.status&statusNoBackslashEscapes == 0 {
|
|
if mc.status&statusNoBackslashEscapes == 0 {
|
|
|
- escape = EscapeString
|
|
|
|
|
|
|
+ escape = escapeString
|
|
|
} else {
|
|
} else {
|
|
|
- escape = EscapeQuotes
|
|
|
|
|
|
|
+ escape = escapeQuotes
|
|
|
}
|
|
}
|
|
|
- return "'" + string(escape(v)) + "'"
|
|
|
|
|
|
|
+ buf = append(buf, '\'')
|
|
|
|
|
+ buf = escape(buf, v)
|
|
|
|
|
+ buf = append(buf, '\'')
|
|
|
|
|
+ return buf
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func estimateParamLength(args []driver.Value) (int, bool) {
|
|
|
|
|
+ l := 0
|
|
|
|
|
+ for _, a := range args {
|
|
|
|
|
+ switch v := a.(type) {
|
|
|
|
|
+ case int64, float64:
|
|
|
|
|
+ l += 20
|
|
|
|
|
+ case bool:
|
|
|
|
|
+ l += 5
|
|
|
|
|
+ case time.Time:
|
|
|
|
|
+ l += 30
|
|
|
|
|
+ case string:
|
|
|
|
|
+ l += len(v)*2 + 2
|
|
|
|
|
+ case []byte:
|
|
|
|
|
+ l += len(v)*2 + 2
|
|
|
|
|
+ default:
|
|
|
|
|
+ return 0, false
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ return l, true
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
|
|
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
|
|
|
- chunks := strings.Split(query, "?")
|
|
|
|
|
- if len(chunks) != len(args)+1 {
|
|
|
|
|
|
|
+ estimated, ok := estimateParamLength(args)
|
|
|
|
|
+ if !ok {
|
|
|
return "", driver.ErrSkip
|
|
return "", driver.ErrSkip
|
|
|
}
|
|
}
|
|
|
|
|
+ estimated += len(query)
|
|
|
|
|
|
|
|
- parts := make([]string, len(chunks)+len(args))
|
|
|
|
|
- parts[0] = chunks[0]
|
|
|
|
|
|
|
+ buf := make([]byte, 0, estimated)
|
|
|
|
|
+ argPos := 0
|
|
|
|
|
+
|
|
|
|
|
+ // Go 1.5 will optimize range([]byte(string)) to skip allocation.
|
|
|
|
|
+ for _, c := range []byte(query) {
|
|
|
|
|
+ if c != '?' {
|
|
|
|
|
+ buf = append(buf, c)
|
|
|
|
|
+ continue
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ arg := args[argPos]
|
|
|
|
|
+ argPos++
|
|
|
|
|
|
|
|
- for i, arg := range args {
|
|
|
|
|
- pos := i*2 + 1
|
|
|
|
|
- parts[pos+1] = chunks[i+1]
|
|
|
|
|
if arg == nil {
|
|
if arg == nil {
|
|
|
- parts[pos] = "NULL"
|
|
|
|
|
|
|
+ buf = append(buf, []byte("NULL")...)
|
|
|
continue
|
|
continue
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
switch v := arg.(type) {
|
|
switch v := arg.(type) {
|
|
|
case int64:
|
|
case int64:
|
|
|
- parts[pos] = strconv.FormatInt(v, 10)
|
|
|
|
|
|
|
+ buf = strconv.AppendInt(buf, v, 10)
|
|
|
case float64:
|
|
case float64:
|
|
|
- parts[pos] = strconv.FormatFloat(v, 'g', -1, 64)
|
|
|
|
|
|
|
+ buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
|
|
|
case bool:
|
|
case bool:
|
|
|
if v {
|
|
if v {
|
|
|
- parts[pos] = "1"
|
|
|
|
|
|
|
+ buf = append(buf, '1')
|
|
|
} else {
|
|
} else {
|
|
|
- parts[pos] = "0"
|
|
|
|
|
|
|
+ buf = append(buf, '0')
|
|
|
}
|
|
}
|
|
|
case time.Time:
|
|
case time.Time:
|
|
|
if v.IsZero() {
|
|
if v.IsZero() {
|
|
|
- parts[pos] = "'0000-00-00'"
|
|
|
|
|
|
|
+ buf = append(buf, []byte("'0000-00-00'")...)
|
|
|
} else {
|
|
} else {
|
|
|
fmt := "'2006-01-02 15:04:05.999999'"
|
|
fmt := "'2006-01-02 15:04:05.999999'"
|
|
|
if v.Nanosecond() == 0 {
|
|
if v.Nanosecond() == 0 {
|
|
|
fmt = "'2006-01-02 15:04:05'"
|
|
fmt = "'2006-01-02 15:04:05'"
|
|
|
}
|
|
}
|
|
|
- parts[pos] = v.In(mc.cfg.loc).Format(fmt)
|
|
|
|
|
|
|
+ s := v.In(mc.cfg.loc).Format(fmt)
|
|
|
|
|
+ buf = append(buf, []byte(s)...)
|
|
|
}
|
|
}
|
|
|
case []byte:
|
|
case []byte:
|
|
|
if v == nil {
|
|
if v == nil {
|
|
|
- parts[pos] = "NULL"
|
|
|
|
|
|
|
+ buf = append(buf, []byte("NULL")...)
|
|
|
} else {
|
|
} else {
|
|
|
- parts[pos] = mc.escapeBytes(v)
|
|
|
|
|
|
|
+ buf = mc.escapeBytes(buf, v)
|
|
|
}
|
|
}
|
|
|
case string:
|
|
case string:
|
|
|
- parts[pos] = mc.escapeBytes([]byte(v))
|
|
|
|
|
|
|
+ buf = mc.escapeBytes(buf, []byte(v))
|
|
|
default:
|
|
default:
|
|
|
return "", driver.ErrSkip
|
|
return "", driver.ErrSkip
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ if len(buf)+4 > mc.maxPacketAllowed {
|
|
|
|
|
+ return "", driver.ErrSkip
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
- pktSize := len(query) + 4 // 4 bytes for header.
|
|
|
|
|
- for _, p := range parts {
|
|
|
|
|
- pktSize += len(p)
|
|
|
|
|
- }
|
|
|
|
|
- if pktSize > mc.maxPacketAllowed {
|
|
|
|
|
|
|
+ if argPos != len(args) {
|
|
|
return "", driver.ErrSkip
|
|
return "", driver.ErrSkip
|
|
|
}
|
|
}
|
|
|
- return strings.Join(parts, ""), nil
|
|
|
|
|
|
|
+ return string(buf), nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
|
|
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
|