|
|
@@ -13,6 +13,7 @@ import (
|
|
|
"database/sql/driver"
|
|
|
"errors"
|
|
|
"net"
|
|
|
+ "strconv"
|
|
|
"strings"
|
|
|
"time"
|
|
|
)
|
|
|
@@ -26,6 +27,7 @@ type mysqlConn struct {
|
|
|
maxPacketAllowed int
|
|
|
maxWriteSize int
|
|
|
flags clientFlag
|
|
|
+ status statusFlag
|
|
|
sequence uint8
|
|
|
parseTime bool
|
|
|
strict bool
|
|
|
@@ -161,28 +163,132 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
|
|
|
return stmt, err
|
|
|
}
|
|
|
|
|
|
+func (mc *mysqlConn) escapeBytes(v []byte) string {
|
|
|
+ buf := make([]byte, len(v)*2)
|
|
|
+ buf[0] = '\''
|
|
|
+ pos := 1
|
|
|
+ if mc.status&statusNoBackslashEscapes == 0 {
|
|
|
+ for _, c := range v {
|
|
|
+ switch c {
|
|
|
+ case '\x00':
|
|
|
+ buf[pos] = '\\'
|
|
|
+ buf[pos+1] = '0'
|
|
|
+ pos += 2
|
|
|
+ case '\n':
|
|
|
+ buf[pos] = '\\'
|
|
|
+ buf[pos+1] = 'n'
|
|
|
+ pos += 2
|
|
|
+ case '\r':
|
|
|
+ buf[pos] = '\\'
|
|
|
+ buf[pos+1] = 'r'
|
|
|
+ pos += 2
|
|
|
+ case '\x1a':
|
|
|
+ buf[pos] = '\\'
|
|
|
+ buf[pos+1] = 'Z'
|
|
|
+ pos += 2
|
|
|
+ case '\'':
|
|
|
+ buf[pos] = '\\'
|
|
|
+ buf[pos+1] = '\''
|
|
|
+ pos += 2
|
|
|
+ case '"':
|
|
|
+ buf[pos] = '\\'
|
|
|
+ buf[pos+1] = '"'
|
|
|
+ pos += 2
|
|
|
+ case '\\':
|
|
|
+ buf[pos] = '\\'
|
|
|
+ buf[pos+1] = '\\'
|
|
|
+ pos += 2
|
|
|
+ default:
|
|
|
+ buf[pos] = c
|
|
|
+ pos += 1
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ for _, c := range v {
|
|
|
+ if c == '\'' {
|
|
|
+ buf[pos] = '\''
|
|
|
+ buf[pos+1] = '\''
|
|
|
+ pos += 2
|
|
|
+ } else {
|
|
|
+ buf[pos] = c
|
|
|
+ pos++
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ buf[pos] = '\''
|
|
|
+ return string(buf[:pos+1])
|
|
|
+}
|
|
|
+
|
|
|
+func (mc *mysqlConn) buildQuery(query string, args []driver.Value) (string, error) {
|
|
|
+ chunks := strings.Split(query, "?")
|
|
|
+ if len(chunks) != len(args)+1 {
|
|
|
+ return "", driver.ErrSkip
|
|
|
+ }
|
|
|
+
|
|
|
+ parts := make([]string, len(chunks)+len(args))
|
|
|
+ parts[0] = chunks[0]
|
|
|
+
|
|
|
+ for i, arg := range args {
|
|
|
+ pos := i*2 + 1
|
|
|
+ parts[pos+1] = chunks[i+1]
|
|
|
+ if arg == nil {
|
|
|
+ parts[pos] = "NULL"
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ switch v := arg.(type) {
|
|
|
+ case int64:
|
|
|
+ parts[pos] = strconv.FormatInt(v, 10)
|
|
|
+ case float64:
|
|
|
+ parts[pos] = strconv.FormatFloat(v, 'f', -1, 64)
|
|
|
+ case bool:
|
|
|
+ if v {
|
|
|
+ parts[pos] = "1"
|
|
|
+ } else {
|
|
|
+ parts[pos] = "0"
|
|
|
+ }
|
|
|
+ case time.Time:
|
|
|
+ if v.IsZero() {
|
|
|
+ parts[pos] = "'0000-00-00'"
|
|
|
+ } else {
|
|
|
+ fmt := "'2006-01-02 15:04:05.999999'"
|
|
|
+ parts[pos] = v.In(mc.cfg.loc).Format(fmt)
|
|
|
+ }
|
|
|
+ case []byte:
|
|
|
+ parts[pos] = mc.escapeBytes(v)
|
|
|
+ case string:
|
|
|
+ parts[pos] = mc.escapeBytes([]byte(v))
|
|
|
+ default:
|
|
|
+ return "", driver.ErrSkip
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return strings.Join(parts, ""), nil
|
|
|
+}
|
|
|
+
|
|
|
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
|
|
|
if mc.netConn == nil {
|
|
|
errLog.Print(ErrInvalidConn)
|
|
|
return nil, driver.ErrBadConn
|
|
|
}
|
|
|
- if len(args) == 0 { // no args, fastpath
|
|
|
- mc.affectedRows = 0
|
|
|
- mc.insertId = 0
|
|
|
-
|
|
|
- err := mc.exec(query)
|
|
|
- if err == nil {
|
|
|
- return &mysqlResult{
|
|
|
- affectedRows: int64(mc.affectedRows),
|
|
|
- insertId: int64(mc.insertId),
|
|
|
- }, err
|
|
|
+ if len(args) != 0 {
|
|
|
+ // try client-side prepare to reduce roundtrip
|
|
|
+ prepared, err := mc.buildQuery(query, args)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
}
|
|
|
- return nil, err
|
|
|
+ query = prepared
|
|
|
+ args = nil
|
|
|
}
|
|
|
+ mc.affectedRows = 0
|
|
|
+ mc.insertId = 0
|
|
|
|
|
|
- // with args, must use prepared stmt
|
|
|
- return nil, driver.ErrSkip
|
|
|
-
|
|
|
+ err := mc.exec(query)
|
|
|
+ if err == nil {
|
|
|
+ return &mysqlResult{
|
|
|
+ affectedRows: int64(mc.affectedRows),
|
|
|
+ insertId: int64(mc.insertId),
|
|
|
+ }, err
|
|
|
+ }
|
|
|
+ return nil, err
|
|
|
}
|
|
|
|
|
|
// Internal function to execute commands
|