浏览代码

made writeExecutePacket independent of number of arguments

Arne Hormann 12 年之前
父节点
当前提交
a3fcad9b2b
共有 2 个文件被更改,包括 29 次插入17 次删除
  1. 1 1
      driver_test.go
  2. 28 16
      packets.go

+ 1 - 1
driver_test.go

@@ -1211,7 +1211,7 @@ func TestStmtMultiRows(t *testing.T) {
 }
 
 func TestPreparedManyCols(t *testing.T) {
-	const repetitions = 1024
+	const repetitions = 32 // defaultBufSize
 	runTests(t, dsn, func(dbt *DBTest) {
 		query := "SELECT ?" + strings.Repeat(",?", repetitions-1)
 		values := make([]sql.NullString, repetitions)

+ 28 - 16
packets.go

@@ -750,6 +750,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 		)
 	}
 
+	const minPktLen = 4 + 1 + 4 + 1 + 4
 	mc := stmt.mc
 
 	// Reset packet-sequence
@@ -758,7 +759,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 	var data []byte
 
 	if len(args) == 0 {
-		data = mc.buf.takeBuffer(4 + 1 + 4 + 1 + 4)
+		data = mc.buf.takeBuffer(minPktLen)
 	} else {
 		data = mc.buf.takeCompleteBuffer()
 	}
@@ -787,10 +788,26 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 	data[13] = 0x00
 
 	if len(args) > 0 {
+		pos := minPktLen
 		// NULL-bitmap [(len(args)+7)/8 bytes]
-		nullMask := uint64(0)
-
-		pos := 4 + 1 + 4 + 1 + 4 + ((len(args) + 7) >> 3)
+		var nullMask []byte
+		if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
+			// buffer has to be extended but we don't know by how much
+			// so we depend on append after nullMask fits.
+			// The default size didn't suffice and we have to deal with a lot of columns,
+			// so allocation size is hard to guess.
+			tmp := make([]byte, pos+maskLen+typesLen)
+			copy(tmp[:pos], data[:pos])
+			data = tmp
+			nullMask = data[pos : pos+maskLen]
+			pos += maskLen
+		} else {
+			nullMask = data[pos : pos+maskLen]
+			for i := 0; i < maskLen; i++ {
+				nullMask[i] = 0
+			}
+			pos += maskLen
+		}
 
 		// newParameterBoundFlag 1 [1 byte]
 		data[pos] = 0x01
@@ -798,23 +815,23 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 
 		// type of each parameter [len(args)*2 bytes]
 		paramTypes := data[pos:]
-		pos += (len(args) << 1)
+		pos += len(args) * 2
 
 		// value of each parameter [n bytes]
 		paramValues := data[pos:pos]
 		valuesCap := cap(paramValues)
 
-		for i := range args {
+		for i, arg := range args {
 			// build NULL-bitmap
-			if args[i] == nil {
-				nullMask |= 1 << uint(i)
+			if arg == nil {
+				nullMask[i/8] |= 1 << (uint(i) & 7) // |= 1 << uint(i)
 				paramTypes[i+i] = fieldTypeNULL
 				paramTypes[i+i+1] = 0x00
 				continue
 			}
 
 			// cache types and values
-			switch v := args[i].(type) {
+			switch v := arg.(type) {
 			case int64:
 				paramTypes[i+i] = fieldTypeLongLong
 				paramTypes[i+i+1] = 0x00
@@ -877,7 +894,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 				}
 
 				// Handle []byte(nil) as a NULL value
-				nullMask |= 1 << uint(i)
+				nullMask[i/8] |= 1 << (uint(i) & 7) // |= 1 << uint(i)
 				paramTypes[i+i] = fieldTypeNULL
 				paramTypes[i+i+1] = 0x00
 
@@ -913,7 +930,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 				paramValues = append(paramValues, val...)
 
 			default:
-				return fmt.Errorf("Can't convert type: %T", args[i])
+				return fmt.Errorf("Can't convert type: %T", arg)
 			}
 		}
 
@@ -926,11 +943,6 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 
 		pos += len(paramValues)
 		data = data[:pos]
-
-		// Convert nullMask to bytes
-		for i, max := 0, (stmt.paramCount+7)>>3; i < max; i++ {
-			data[i+14] = byte(nullMask >> uint(i<<3))
-		}
 	}
 
 	return mc.writePacket(data)