Browse Source

Merge remote-tracking branch 'upstream/pr/297'

Conflicts:
	connection.go
	utils_test.go
arvenil 11 years ago
parent
commit
f3b82fdf7f
6 changed files with 230 additions and 62 deletions
  1. 177 49
      connection.go
  2. 22 0
      const.go
  3. 9 0
      driver_test.go
  4. 1 0
      packets.go
  5. 8 0
      utils.go
  6. 13 13
      utils_test.go

+ 177 - 49
connection.go

@@ -13,6 +13,7 @@ import (
 	"database/sql/driver"
 	"errors"
 	"net"
+	"strconv"
 	"strings"
 	"time"
 )
@@ -26,26 +27,28 @@ type mysqlConn struct {
 	maxPacketAllowed int
 	maxWriteSize     int
 	flags            clientFlag
+	status           statusFlag
 	sequence         uint8
 	parseTime        bool
 	strict           bool
 }
 
 type config struct {
-	user              string
-	passwd            string
-	net               string
-	addr              string
-	dbname            string
-	params            map[string]string
-	loc               *time.Location
-	tls               *tls.Config
-	timeout           time.Duration
-	collation         uint8
-	allowAllFiles     bool
-	allowOldPasswords bool
-	clientFoundRows   bool
-	columnsWithAlias  bool
+	user                  string
+	passwd                string
+	net                   string
+	addr                  string
+	dbname                string
+	params                map[string]string
+	loc                   *time.Location
+	tls                   *tls.Config
+	timeout               time.Duration
+	collation             uint8
+	allowAllFiles         bool
+	allowOldPasswords     bool
+	clientFoundRows       bool
+	columnsWithAlias      bool
+	substitutePlaceholder bool
 }
 
 // Handles parameters set in DSN after the connection is established
@@ -162,28 +165,146 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 	return stmt, err
 }
 
+func (mc *mysqlConn) escapeBytes(v []byte) string {
+	buf := make([]byte, len(v)*2+2)
+	buf[0] = '\''
+	pos := 1
+	if mc.status&statusNoBackslashEscapes == 0 {
+		for _, c := range v {
+			switch c {
+			case '\x00':
+				buf[pos] = '\\'
+				buf[pos+1] = '0'
+				pos += 2
+			case '\n':
+				buf[pos] = '\\'
+				buf[pos+1] = 'n'
+				pos += 2
+			case '\r':
+				buf[pos] = '\\'
+				buf[pos+1] = 'r'
+				pos += 2
+			case '\x1a':
+				buf[pos] = '\\'
+				buf[pos+1] = 'Z'
+				pos += 2
+			case '\'':
+				buf[pos] = '\\'
+				buf[pos+1] = '\''
+				pos += 2
+			case '"':
+				buf[pos] = '\\'
+				buf[pos+1] = '"'
+				pos += 2
+			case '\\':
+				buf[pos] = '\\'
+				buf[pos+1] = '\\'
+				pos += 2
+			default:
+				buf[pos] = c
+				pos += 1
+			}
+		}
+	} else {
+		for _, c := range v {
+			if c == '\'' {
+				buf[pos] = '\''
+				buf[pos+1] = '\''
+				pos += 2
+			} else {
+				buf[pos] = c
+				pos++
+			}
+		}
+	}
+	buf[pos] = '\''
+	return string(buf[:pos+1])
+}
+
+func (mc *mysqlConn) buildQuery(query string, args []driver.Value) (string, error) {
+	chunks := strings.Split(query, "?")
+	if len(chunks) != len(args)+1 {
+		return "", driver.ErrSkip
+	}
+
+	parts := make([]string, len(chunks)+len(args))
+	parts[0] = chunks[0]
+
+	for i, arg := range args {
+		pos := i*2 + 1
+		parts[pos+1] = chunks[i+1]
+		if arg == nil {
+			parts[pos] = "NULL"
+			continue
+		}
+		switch v := arg.(type) {
+		case int64:
+			parts[pos] = strconv.FormatInt(v, 10)
+		case float64:
+			parts[pos] = strconv.FormatFloat(v, 'f', -1, 64)
+		case bool:
+			if v {
+				parts[pos] = "1"
+			} else {
+				parts[pos] = "0"
+			}
+		case time.Time:
+			if v.IsZero() {
+				parts[pos] = "'0000-00-00'"
+			} else {
+				fmt := "'2006-01-02 15:04:05.999999'"
+				parts[pos] = v.In(mc.cfg.loc).Format(fmt)
+			}
+		case []byte:
+			if v == nil {
+				parts[pos] = "NULL"
+			} else {
+				parts[pos] = mc.escapeBytes(v)
+			}
+		case string:
+			parts[pos] = mc.escapeBytes([]byte(v))
+		default:
+			return "", driver.ErrSkip
+		}
+	}
+	pktSize := len(query) + 4 // 4 bytes for header.
+	for _, p := range parts {
+		pktSize += len(p)
+	}
+	if pktSize > mc.maxPacketAllowed {
+		return "", driver.ErrSkip
+	}
+	return strings.Join(parts, ""), nil
+}
+
 func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
 	if mc.netConn == nil {
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 	}
-	if len(args) == 0 { // no args, fastpath
-		mc.affectedRows = 0
-		mc.insertId = 0
-
-		err := mc.exec(query)
-		if err == nil {
-			return &mysqlResult{
-				affectedRows: int64(mc.affectedRows),
-				insertId:     int64(mc.insertId),
-			}, err
+	if len(args) != 0 {
+		if !mc.cfg.substitutePlaceholder {
+			return nil, driver.ErrSkip
 		}
-		return nil, err
+		// try client-side prepare to reduce roundtrip
+		prepared, err := mc.buildQuery(query, args)
+		if err != nil {
+			return nil, err
+		}
+		query = prepared
+		args = nil
 	}
+	mc.affectedRows = 0
+	mc.insertId = 0
 
-	// with args, must use prepared stmt
-	return nil, driver.ErrSkip
-
+	err := mc.exec(query)
+	if err == nil {
+		return &mysqlResult{
+			affectedRows: int64(mc.affectedRows),
+			insertId:     int64(mc.insertId),
+		}, err
+	}
+	return nil, err
 }
 
 // Internal function to execute commands
@@ -212,31 +333,38 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 	}
-	if len(args) == 0 { // no args, fastpath
-		// Send command
-		err := mc.writeCommandPacketStr(comQuery, query)
+	if len(args) != 0 {
+		if !mc.cfg.substitutePlaceholder {
+			return nil, driver.ErrSkip
+		}
+		// try client-side prepare to reduce roundtrip
+		prepared, err := mc.buildQuery(query, args)
+		if err != nil {
+			return nil, err
+		}
+		query = prepared
+		args = nil
+	}
+	// Send command
+	err := mc.writeCommandPacketStr(comQuery, query)
+	if err == nil {
+		// Read Result
+		var resLen int
+		resLen, err = mc.readResultSetHeaderPacket()
 		if err == nil {
-			// Read Result
-			var resLen int
-			resLen, err = mc.readResultSetHeaderPacket()
-			if err == nil {
-				rows := new(textRows)
-				rows.mc = mc
-
-				if resLen == 0 {
-					// no columns, no more data
-					return emptyRows{}, nil
-				}
-				// Columns
-				rows.columns, err = mc.readColumns(resLen)
-				return rows, err
+			rows := new(textRows)
+			rows.mc = mc
+
+			if resLen == 0 {
+				// no columns, no more data
+				return emptyRows{}, nil
 			}
+			// Columns
+			rows.columns, err = mc.readColumns(resLen)
+			return rows, err
 		}
-		return nil, err
 	}
-
-	// with args, must use prepared stmt
-	return nil, driver.ErrSkip
+	return nil, err
 }
 
 // Gets the value of the given MySQL System Variable

+ 22 - 0
const.go

@@ -130,3 +130,25 @@ const (
 	flagUnknown3
 	flagUnknown4
 )
+
+// http://dev.mysql.com/doc/internals/en/status-flags.html
+
+type statusFlag uint16
+
+const (
+	statusInTrans statusFlag = 1 << iota
+	statusInAutocommit
+	statusUnknown1
+	statusMoreResultsExists
+	statusNoGoodIndexUsed
+	statusNoIndexUsed
+	statusCursorExists
+	statusLastRowSent
+	statusDbDropped
+	statusNoBackslashEscapes
+	statusMetadataChanged
+	statusQueryWasSlow
+	statusPsOutParams
+	statusInTransReadonly
+	statusSessionStateChanged
+)

+ 9 - 0
driver_test.go

@@ -87,10 +87,19 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
 
 	db.Exec("DROP TABLE IF EXISTS test")
 
+	dbp, err := sql.Open("mysql", dsn+"&substitutePlaceholder=true")
+	if err != nil {
+		t.Fatalf("Error connecting: %s", err.Error())
+	}
+	defer dbp.Close()
+
 	dbt := &DBTest{t, db}
+	dbtp := &DBTest{t, dbp}
 	for _, test := range tests {
 		test(dbt)
 		dbt.db.Exec("DROP TABLE IF EXISTS test")
+		test(dbtp)
+		dbtp.db.Exec("DROP TABLE IF EXISTS test")
 	}
 }
 

+ 1 - 0
packets.go

@@ -484,6 +484,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
 	mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
 
 	// server_status [2 bytes]
+	mc.status = statusFlag(data[1+n+m]) | statusFlag(data[1+n+m+1])<<8
 
 	// warning count [2 bytes]
 	if !mc.strict {

+ 8 - 0
utils.go

@@ -180,6 +180,14 @@ func parseDSNParams(cfg *config, params string) (err error) {
 		// cfg params
 		switch value := param[1]; param[0] {
 
+		// Enable client side placeholder substitution
+		case "substitutePlaceholder":
+			var isBool bool
+			cfg.substitutePlaceholder, isBool = readBool(value)
+			if !isBool {
+				return fmt.Errorf("Invalid Bool value: %s", value)
+			}
+
 		// Disable INFILE whitelist / enable all files
 		case "allowAllFiles":
 			var isBool bool

+ 13 - 13
utils_test.go

@@ -22,19 +22,19 @@ var testDSNs = []struct {
 	out string
 	loc *time.Location
 }{
-	{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
-	{"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:true}", time.UTC},
-	{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
-	{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
-	{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
-	{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:30000000000 collation:224 allowAllFiles:true allowOldPasswords:true clientFoundRows:true columnsWithAlias:false}", time.UTC},
-	{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.Local},
-	{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
-	{"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
-	{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
-	{"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
-	{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
-	{"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
+	{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC},
+	{"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:true substitutePlaceholder:false}", time.UTC},
+	{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC},
+	{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC},
+	{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC},
+	{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:30000000000 collation:224 allowAllFiles:true allowOldPasswords:true clientFoundRows:true columnsWithAlias:false substitutePlaceholder:false}", time.UTC},
+	{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.Local},
+	{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC},
+	{"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC},
+	{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC},
+	{"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC},
+	{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC},
+	{"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC},
 }
 
 func TestDSNParser(t *testing.T) {