Browse Source

Implement placeholder substitution.

INADA Naoki 11 years ago
parent
commit
e35fa001b5
3 changed files with 143 additions and 14 deletions
  1. 120 14
      connection.go
  2. 22 0
      const.go
  3. 1 0
      packets.go

+ 120 - 14
connection.go

@@ -13,6 +13,7 @@ import (
 	"database/sql/driver"
 	"errors"
 	"net"
+	"strconv"
 	"strings"
 	"time"
 )
@@ -26,6 +27,7 @@ type mysqlConn struct {
 	maxPacketAllowed int
 	maxWriteSize     int
 	flags            clientFlag
+	status           statusFlag
 	sequence         uint8
 	parseTime        bool
 	strict           bool
@@ -161,28 +163,132 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 	return stmt, err
 }
 
+func (mc *mysqlConn) escapeBytes(v []byte) string {
+	buf := make([]byte, len(v)*2)
+	buf[0] = '\''
+	pos := 1
+	if mc.status&statusNoBackslashEscapes == 0 {
+		for _, c := range v {
+			switch c {
+			case '\x00':
+				buf[pos] = '\\'
+				buf[pos+1] = '0'
+				pos += 2
+			case '\n':
+				buf[pos] = '\\'
+				buf[pos+1] = 'n'
+				pos += 2
+			case '\r':
+				buf[pos] = '\\'
+				buf[pos+1] = 'r'
+				pos += 2
+			case '\x1a':
+				buf[pos] = '\\'
+				buf[pos+1] = 'Z'
+				pos += 2
+			case '\'':
+				buf[pos] = '\\'
+				buf[pos+1] = '\''
+				pos += 2
+			case '"':
+				buf[pos] = '\\'
+				buf[pos+1] = '"'
+				pos += 2
+			case '\\':
+				buf[pos] = '\\'
+				buf[pos+1] = '\\'
+				pos += 2
+			default:
+				buf[pos] = c
+				pos += 1
+			}
+		}
+	} else {
+		for _, c := range v {
+			if c == '\'' {
+				buf[pos] = '\''
+				buf[pos+1] = '\''
+				pos += 2
+			} else {
+				buf[pos] = c
+				pos++
+			}
+		}
+	}
+	buf[pos] = '\''
+	return string(buf[:pos+1])
+}
+
+func (mc *mysqlConn) buildQuery(query string, args []driver.Value) (string, error) {
+	chunks := strings.Split(query, "?")
+	if len(chunks) != len(args)+1 {
+		return "", driver.ErrSkip
+	}
+
+	parts := make([]string, len(chunks)+len(args))
+	parts[0] = chunks[0]
+
+	for i, arg := range args {
+		pos := i*2 + 1
+		parts[pos+1] = chunks[i+1]
+		if arg == nil {
+			parts[pos] = "NULL"
+			continue
+		}
+		switch v := arg.(type) {
+		case int64:
+			parts[pos] = strconv.FormatInt(v, 10)
+		case float64:
+			parts[pos] = strconv.FormatFloat(v, 'f', -1, 64)
+		case bool:
+			if v {
+				parts[pos] = "1"
+			} else {
+				parts[pos] = "0"
+			}
+		case time.Time:
+			if v.IsZero() {
+				parts[pos] = "'0000-00-00'"
+			} else {
+				fmt := "'2006-01-02 15:04:05.999999'"
+				parts[pos] = v.In(mc.cfg.loc).Format(fmt)
+			}
+		case []byte:
+			parts[pos] = mc.escapeBytes(v)
+		case string:
+			parts[pos] = mc.escapeBytes([]byte(v))
+		default:
+			return "", driver.ErrSkip
+		}
+	}
+	return strings.Join(parts, ""), nil
+}
+
 func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
 	if mc.netConn == nil {
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 	}
-	if len(args) == 0 { // no args, fastpath
-		mc.affectedRows = 0
-		mc.insertId = 0
-
-		err := mc.exec(query)
-		if err == nil {
-			return &mysqlResult{
-				affectedRows: int64(mc.affectedRows),
-				insertId:     int64(mc.insertId),
-			}, err
+	if len(args) != 0 {
+		// try client-side prepare to reduce roundtrip
+		prepared, err := mc.buildQuery(query, args)
+		if err != nil {
+			return nil, err
 		}
-		return nil, err
+		query = prepared
+		args = nil
 	}
+	mc.affectedRows = 0
+	mc.insertId = 0
 
-	// with args, must use prepared stmt
-	return nil, driver.ErrSkip
-
+	err := mc.exec(query)
+	if err == nil {
+		return &mysqlResult{
+			affectedRows: int64(mc.affectedRows),
+			insertId:     int64(mc.insertId),
+		}, err
+	}
+	return nil, err
 }
 
 // Internal function to execute commands

+ 22 - 0
const.go

@@ -130,3 +130,25 @@ const (
 	flagUnknown3
 	flagUnknown4
 )
+
+// http://dev.mysql.com/doc/internals/en/status-flags.html
+
+type statusFlag uint16
+
+const (
+	statusInTrans statusFlag = 1 << iota
+	statusInAutocommit
+	statusUnknown1
+	statusMoreResultsExists
+	statusNoGoodIndexUsed
+	statusNoIndexUsed
+	statusCursorExists
+	statusLastRowSent
+	statusDbDropped
+	statusNoBackslashEscapes
+	statusMetadataChanged
+	statusQueryWasSlow
+	statusPsOutParams
+	statusInTransReadonly
+	statusSessionStateChanged
+)

+ 1 - 0
packets.go

@@ -484,6 +484,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
 	mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
 
 	// server_status [2 bytes]
+	mc.status = statusFlag(data[1+n+m]) | statusFlag(data[1+n+m+1])<<8
 
 	// warning count [2 bytes]
 	if !mc.strict {