Переглянути джерело

Query() uses client-side placeholder substitution.

INADA Naoki 11 роки тому
батько
коміт
c8c9bb1ec8
1 змінених файлів з 31 додано та 23 видалено
  1. 31 23
      connection.go

+ 31 - 23
connection.go

@@ -164,7 +164,7 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 }
 }
 
 
 func (mc *mysqlConn) escapeBytes(v []byte) string {
 func (mc *mysqlConn) escapeBytes(v []byte) string {
-	buf := make([]byte, len(v)*2)
+	buf := make([]byte, len(v)*2+2)
 	buf[0] = '\''
 	buf[0] = '\''
 	pos := 1
 	pos := 1
 	if mc.status&statusNoBackslashEscapes == 0 {
 	if mc.status&statusNoBackslashEscapes == 0 {
@@ -254,7 +254,11 @@ func (mc *mysqlConn) buildQuery(query string, args []driver.Value) (string, erro
 				parts[pos] = v.In(mc.cfg.loc).Format(fmt)
 				parts[pos] = v.In(mc.cfg.loc).Format(fmt)
 			}
 			}
 		case []byte:
 		case []byte:
-			parts[pos] = mc.escapeBytes(v)
+			if v == nil {
+				parts[pos] = "NULL"
+			} else {
+				parts[pos] = mc.escapeBytes(v)
+			}
 		case string:
 		case string:
 			parts[pos] = mc.escapeBytes([]byte(v))
 			parts[pos] = mc.escapeBytes([]byte(v))
 		default:
 		default:
@@ -317,31 +321,35 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
 		errLog.Print(ErrInvalidConn)
 		errLog.Print(ErrInvalidConn)
 		return nil, driver.ErrBadConn
 		return nil, driver.ErrBadConn
 	}
 	}
-	if len(args) == 0 { // no args, fastpath
-		// Send command
-		err := mc.writeCommandPacketStr(comQuery, query)
+	if len(args) != 0 {
+		// try client-side prepare to reduce roundtrip
+		prepared, err := mc.buildQuery(query, args)
+		if err != nil {
+			return nil, err
+		}
+		query = prepared
+		args = nil
+	}
+	// Send command
+	err := mc.writeCommandPacketStr(comQuery, query)
+	if err == nil {
+		// Read Result
+		var resLen int
+		resLen, err = mc.readResultSetHeaderPacket()
 		if err == nil {
 		if err == nil {
-			// Read Result
-			var resLen int
-			resLen, err = mc.readResultSetHeaderPacket()
-			if err == nil {
-				rows := new(textRows)
-				rows.mc = mc
-
-				if resLen == 0 {
-					// no columns, no more data
-					return emptyRows{}, nil
-				}
-				// Columns
-				rows.columns, err = mc.readColumns(resLen)
-				return rows, err
+			rows := new(textRows)
+			rows.mc = mc
+
+			if resLen == 0 {
+				// no columns, no more data
+				return emptyRows{}, nil
 			}
 			}
+			// Columns
+			rows.columns, err = mc.readColumns(resLen)
+			return rows, err
 		}
 		}
-		return nil, err
 	}
 	}
-
-	// with args, must use prepared stmt
-	return nil, driver.ErrSkip
+	return nil, err
 }
 }
 
 
 // Gets the value of the given MySQL System Variable
 // Gets the value of the given MySQL System Variable