Browse Source

Fix prepared statement (#734)

* Fix prepared statement

When there are many args and maxAllowedPacket is not enough,
writeExecutePacket() attempted to use STMT_LONG_DATA even for
0byte string.
But writeCommandLongData() doesn't support 0byte data. So it
caused to send malfold packet.

This commit loosen threshold for using STMT_LONG_DATA.

* Change minimum size of LONG_DATA to 64byte

* Add test which reproduce issue 730

* TestPreparedManyCols test only numParams = 65535 case

* s/as possible//
INADA Naoki 8 years ago
parent
commit
2cc627ac8d
2 changed files with 22 additions and 5 deletions
  1. 14 3
      driver_test.go
  2. 8 2
      packets.go

+ 14 - 3
driver_test.go

@@ -1669,8 +1669,9 @@ func TestStmtMultiRows(t *testing.T) {
 // Regression test for
 // * more than 32 NULL parameters (issue 209)
 // * more parameters than fit into the buffer (issue 201)
+// * parameters * 64 > max_allowed_packet (issue 734)
 func TestPreparedManyCols(t *testing.T) {
-	const numParams = defaultBufSize
+	numParams := 65535
 	runTests(t, dsn, func(dbt *DBTest) {
 		query := "SELECT ?" + strings.Repeat(",?", numParams-1)
 		stmt, err := dbt.db.Prepare(query)
@@ -1678,15 +1679,25 @@ func TestPreparedManyCols(t *testing.T) {
 			dbt.Fatal(err)
 		}
 		defer stmt.Close()
+
 		// create more parameters than fit into the buffer
 		// which will take nil-values
 		params := make([]interface{}, numParams)
 		rows, err := stmt.Query(params...)
 		if err != nil {
-			stmt.Close()
 			dbt.Fatal(err)
 		}
-		defer rows.Close()
+		rows.Close()
+
+		// Create 0byte string which we can't send via STMT_LONG_DATA.
+		for i := 0; i < numParams; i++ {
+			params[i] = ""
+		}
+		rows, err = stmt.Query(params...)
+		if err != nil {
+			dbt.Fatal(err)
+		}
+		rows.Close()
 	})
 }
 

+ 8 - 2
packets.go

@@ -916,6 +916,12 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 	const minPktLen = 4 + 1 + 4 + 1 + 4
 	mc := stmt.mc
 
+	// Determine threshould dynamically to avoid packet size shortage.
+	longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
+	if longDataSize < 64 {
+		longDataSize = 64
+	}
+
 	// Reset packet-sequence
 	mc.sequence = 0
 
@@ -1043,7 +1049,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 					paramTypes[i+i] = byte(fieldTypeString)
 					paramTypes[i+i+1] = 0x00
 
-					if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 {
+					if len(v) < longDataSize {
 						paramValues = appendLengthEncodedInteger(paramValues,
 							uint64(len(v)),
 						)
@@ -1065,7 +1071,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 				paramTypes[i+i] = byte(fieldTypeString)
 				paramTypes[i+i+1] = 0x00
 
-				if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 {
+				if len(v) < longDataSize {
 					paramValues = appendLengthEncodedInteger(paramValues,
 						uint64(len(v)),
 					)