Prechádzať zdrojové kódy

Merge pull request #210 from arnehormann/fix-many-cols

support prepared statements with more than 32 parameters
Arne Hormann 12 rokov pred
rodič
commit
593ebcfb40
3 zmenil súbory, kde vykonal 58 pridanie a 17 odobranie
  1. 5 0
      CHANGELOG.md
  2. 24 0
      driver_test.go
  3. 29 17
      packets.go

+ 5 - 0
CHANGELOG.md

@@ -4,6 +4,11 @@ New Features:
 
  - Logging of critical errors is configurable with `SetLogger`
 
+Bugfixes:
+
+ - Allow more than 32 parameters in prepared statements
+
+
 ## Version 1.1 (2013-11-02)
 
 Changes:

+ 24 - 0
driver_test.go

@@ -1210,6 +1210,30 @@ func TestStmtMultiRows(t *testing.T) {
 	})
 }
 
+// Regression test for
+// * more than 32 NULL parameters (issue 209)
+// * more parameters than fit into the buffer (issue 201)
+func TestPreparedManyCols(t *testing.T) {
+	const numParams = defaultBufSize
+	runTests(t, dsn, func(dbt *DBTest) {
+		query := "SELECT ?" + strings.Repeat(",?", numParams-1)
+		stmt, err := dbt.db.Prepare(query)
+		if err != nil {
+			dbt.Fatal(err)
+		}
+		defer stmt.Close()
+		// create more parameters than fit into the buffer
+		// which will take nil-values
+		params := make([]interface{}, numParams)
+		rows, err := stmt.Query(params...)
+		if err != nil {
+			stmt.Close()
+			dbt.Fatal(err)
+		}
+		defer rows.Close()
+	})
+}
+
 func TestConcurrent(t *testing.T) {
 	if enabled, _ := readBool(os.Getenv("MYSQL_TEST_CONCURRENT")); !enabled {
 		t.Skip("MYSQL_TEST_CONCURRENT env var not set")

+ 29 - 17
packets.go

@@ -750,6 +750,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 		)
 	}
 
+	const minPktLen = 4 + 1 + 4 + 1 + 4
 	mc := stmt.mc
 
 	// Reset packet-sequence
@@ -758,7 +759,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 	var data []byte
 
 	if len(args) == 0 {
-		data = mc.buf.takeBuffer(4 + 1 + 4 + 1 + 4)
+		data = mc.buf.takeBuffer(minPktLen)
 	} else {
 		data = mc.buf.takeCompleteBuffer()
 	}
@@ -787,10 +788,26 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 	data[13] = 0x00
 
 	if len(args) > 0 {
-		// NULL-bitmap [(len(args)+7)/8 bytes]
-		nullMask := uint64(0)
-
-		pos := 4 + 1 + 4 + 1 + 4 + ((len(args) + 7) >> 3)
+		pos := minPktLen
+
+		var nullMask []byte
+		if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
+			// buffer has to be extended but we don't know by how much so
+			// we depend on append after all data with known sizes fit.
+			// We stop at that because we deal with a lot of columns here
+			// which makes the required allocation size hard to guess.
+			tmp := make([]byte, pos+maskLen+typesLen)
+			copy(tmp[:pos], data[:pos])
+			data = tmp
+			nullMask = data[pos : pos+maskLen]
+			pos += maskLen
+		} else {
+			nullMask = data[pos : pos+maskLen]
+			for i := 0; i < maskLen; i++ {
+				nullMask[i] = 0
+			}
+			pos += maskLen
+		}
 
 		// newParameterBoundFlag 1 [1 byte]
 		data[pos] = 0x01
@@ -798,23 +815,23 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 
 		// type of each parameter [len(args)*2 bytes]
 		paramTypes := data[pos:]
-		pos += (len(args) << 1)
+		pos += len(args) * 2
 
 		// value of each parameter [n bytes]
 		paramValues := data[pos:pos]
 		valuesCap := cap(paramValues)
 
-		for i := range args {
+		for i, arg := range args {
 			// build NULL-bitmap
-			if args[i] == nil {
-				nullMask |= 1 << uint(i)
+			if arg == nil {
+				nullMask[i/8] |= 1 << (uint(i) & 7)
 				paramTypes[i+i] = fieldTypeNULL
 				paramTypes[i+i+1] = 0x00
 				continue
 			}
 
 			// cache types and values
-			switch v := args[i].(type) {
+			switch v := arg.(type) {
 			case int64:
 				paramTypes[i+i] = fieldTypeLongLong
 				paramTypes[i+i+1] = 0x00
@@ -877,7 +894,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 				}
 
 				// Handle []byte(nil) as a NULL value
-				nullMask |= 1 << uint(i)
+				nullMask[i/8] |= 1 << (uint(i) & 7)
 				paramTypes[i+i] = fieldTypeNULL
 				paramTypes[i+i+1] = 0x00
 
@@ -913,7 +930,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 				paramValues = append(paramValues, val...)
 
 			default:
-				return fmt.Errorf("Can't convert type: %T", args[i])
+				return fmt.Errorf("Can't convert type: %T", arg)
 			}
 		}
 
@@ -926,11 +943,6 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 
 		pos += len(paramValues)
 		data = data[:pos]
-
-		// Convert nullMask to bytes
-		for i, max := 0, (stmt.paramCount+7)>>3; i < max; i++ {
-			data[i+14] = byte(nullMask >> uint(i<<3))
-		}
 	}
 
 	return mc.writePacket(data)