浏览代码

read directly into dest slice

Julien Schmidt 12 年之前
父节点
当前提交
a8a04cc28d
共有 2 个文件被更改,包括 58 次插入96 次删除
  1. 53 76
      packets.go
  2. 5 20
      rows.go

+ 53 - 76
packets.go

@@ -564,38 +564,41 @@ func (mc *mysqlConn) readColumns(n int) (columns []mysqlField, err error) {
 }
 
 // Read Packets as Field Packets until EOF-Packet or an Error appears
-func (mc *mysqlConn) readRow(columnsCount int) (*[]*[]byte, error) {
-	data, err := mc.readPacket()
+func (rows *mysqlRows) readRow(dest *[]driver.Value) (err error) {
+	data, err := rows.mc.readPacket()
 	if err != nil {
-		return nil, err
+		return
 	}
 
 	// EOF Packet
 	if data[0] == 254 && len(data) == 5 {
-		return nil, io.EOF
+		return io.EOF
 	}
 
 	// RowSet Packet
-	row := make([]*[]byte, columnsCount)
 	var n int
 	var isNull bool
+	var val *[]byte
+	columnsCount := len(*dest)
 	pos := 0
 
 	for i := 0; i < columnsCount; i++ {
 		// Read bytes and convert to string
-		row[i], n, isNull, err = readLengthCodedBinary(data[pos:])
+		val, n, isNull, err = readLengthCodedBinary(data[pos:])
 		if err != nil {
-			return nil, err
+			return
 		}
 
 		// nil if field is NULL
 		if isNull {
-			row[i] = nil
+			(*dest)[i] = nil
+		} else {
+			(*dest)[i] = *val
 		}
 		pos += n
 	}
 
-	return &row, nil
+	return
 }
 
 // Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
@@ -830,23 +833,22 @@ func (stmt *mysqlStmt) buildExecutePacket(args *[]driver.Value) error {
 }
 
 // http://dev.mysql.com/doc/internals/en/prepared-statements.html#packet-ProtocolBinary::ResultsetRow
-func (mc *mysqlConn) readBinaryRow(rc *mysqlRows) (*[]*[]byte, error) {
-	data, err := mc.readPacket()
+func (rc *mysqlRows) readBinaryRow(dest *[]driver.Value) (err error) {
+	data, err := rc.mc.readPacket()
 	if err != nil {
-		return nil, err
+		return
 	}
 
 	pos := 0
 
 	// EOF Packet
 	if data[pos] == 254 && len(data) == 5 {
-		return nil, io.EOF
+		return io.EOF
 	}
 	pos++
 
 	// BinaryRowSet Packet
 	columnsCount := len(rc.columns)
-	row := make([]*[]byte, columnsCount)
 
 	nullBitMap := data[pos : pos+(columnsCount+7+2)>>3]
 	pos += (columnsCount + 7 + 2) >> 3
@@ -856,7 +858,7 @@ func (mc *mysqlConn) readBinaryRow(rc *mysqlRows) (*[]*[]byte, error) {
 	for i := 0; i < columnsCount; i++ {
 		// Field is NULL
 		if (nullBitMap[(i+2)>>3] >> uint((i+2)&7) & 1) == 1 {
-			row[i] = nil
+			(*dest)[i] = nil
 			continue
 		}
 
@@ -865,86 +867,67 @@ func (mc *mysqlConn) readBinaryRow(rc *mysqlRows) (*[]*[]byte, error) {
 		// Convert to byte-coded string
 		switch rc.columns[i].fieldType {
 		case FIELD_TYPE_NULL:
-			row[i] = nil
+			(*dest)[i] = nil
 
 		// Numeric Typs
 		case FIELD_TYPE_TINY:
-			var val []byte
 			if unsigned {
-				val = uintToByteStr(uint64(byteToUint8(data[pos])))
+				(*dest)[i] = uintToByteStr(uint64(byteToUint8(data[pos])))
 			} else {
-				val = intToByteStr(int64(int8(byteToUint8(data[pos]))))
+				(*dest)[i] = intToByteStr(int64(int8(byteToUint8(data[pos]))))
 			}
-			row[i] = &val
 			pos++
 
 		case FIELD_TYPE_SHORT, FIELD_TYPE_YEAR:
-			var val []byte
 			if unsigned {
-				val = uintToByteStr(uint64(bytesToUint16(data[pos : pos+2])))
+				(*dest)[i] = uintToByteStr(uint64(bytesToUint16(data[pos : pos+2])))
 			} else {
-				val = intToByteStr(int64(int16(bytesToUint16(data[pos : pos+2]))))
+				(*dest)[i] = intToByteStr(int64(int16(bytesToUint16(data[pos : pos+2]))))
 			}
-			row[i] = &val
 			pos += 2
 
 		case FIELD_TYPE_INT24, FIELD_TYPE_LONG:
-			var val []byte
 			if unsigned {
-				val = uintToByteStr(uint64(bytesToUint32(data[pos : pos+4])))
+				(*dest)[i] = uintToByteStr(uint64(bytesToUint32(data[pos : pos+4])))
 			} else {
-				val = intToByteStr(int64(int32(bytesToUint32(data[pos : pos+4]))))
+				(*dest)[i] = intToByteStr(int64(int32(bytesToUint32(data[pos : pos+4]))))
 			}
-			row[i] = &val
 			pos += 4
 
 		case FIELD_TYPE_LONGLONG:
-			var val []byte
 			if unsigned {
-				val = uintToByteStr(bytesToUint64(data[pos : pos+8]))
+				(*dest)[i] = uintToByteStr(bytesToUint64(data[pos : pos+8]))
 			} else {
-				val = intToByteStr(int64(bytesToUint64(data[pos : pos+8])))
+				(*dest)[i] = intToByteStr(int64(bytesToUint64(data[pos : pos+8])))
 			}
-			row[i] = &val
 			pos += 8
 
 		case FIELD_TYPE_FLOAT:
-			var val []byte
-			val = float32ToByteStr(bytesToFloat32(data[pos : pos+4]))
-			row[i] = &val
+			(*dest)[i] = float32ToByteStr(bytesToFloat32(data[pos : pos+4]))
 			pos += 4
 
 		case FIELD_TYPE_DOUBLE:
-			var val []byte
-			val = float64ToByteStr(bytesToFloat64(data[pos : pos+8]))
-			row[i] = &val
+			(*dest)[i] = float64ToByteStr(bytesToFloat64(data[pos : pos+8]))
 			pos += 8
 
-		case FIELD_TYPE_DECIMAL, FIELD_TYPE_NEWDECIMAL:
-			row[i], n, isNull, err = readLengthCodedBinary(data[pos:])
-
-			if err != nil {
-				return nil, err
-			}
-
-			if isNull && rc.columns[i].flags&FLAG_NOT_NULL == 0 {
-				row[i] = nil
-			}
-			pos += n
-
 		// Length coded Binary Strings
-		case FIELD_TYPE_VARCHAR, FIELD_TYPE_BIT, FIELD_TYPE_ENUM,
-			FIELD_TYPE_SET, FIELD_TYPE_TINY_BLOB, FIELD_TYPE_MEDIUM_BLOB,
-			FIELD_TYPE_LONG_BLOB, FIELD_TYPE_BLOB, FIELD_TYPE_VAR_STRING,
-			FIELD_TYPE_STRING, FIELD_TYPE_GEOMETRY:
-			row[i], n, isNull, err = readLengthCodedBinary(data[pos:])
+		case FIELD_TYPE_DECIMAL, FIELD_TYPE_NEWDECIMAL, FIELD_TYPE_VARCHAR,
+			FIELD_TYPE_BIT, FIELD_TYPE_ENUM, FIELD_TYPE_SET,
+			FIELD_TYPE_TINY_BLOB, FIELD_TYPE_MEDIUM_BLOB, FIELD_TYPE_LONG_BLOB,
+			FIELD_TYPE_BLOB, FIELD_TYPE_VAR_STRING, FIELD_TYPE_STRING,
+			FIELD_TYPE_GEOMETRY:
+			var val *[]byte
+			val, n, isNull, err = readLengthCodedBinary(data[pos:])
 			if err != nil {
-				return nil, err
+				return
 			}
 
 			if isNull && rc.columns[i].flags&FLAG_NOT_NULL == 0 {
-				row[i] = nil
+				(*dest)[i] = nil
+			} else {
+				(*dest)[i] = *val
 			}
+
 			pos += n
 
 		// Date YYYY-MM-DD
@@ -952,20 +935,18 @@ func (mc *mysqlConn) readBinaryRow(rc *mysqlRows) (*[]*[]byte, error) {
 			var num uint64
 			num, n, err = bytesToLengthCodedBinary(data[pos:])
 			if err != nil {
-				return nil, err
+				return
 			}
 			pos += n
 
-			var val []byte
 			if num == 0 {
-				val = []byte("0000-00-00")
+				(*dest)[i] = []byte("0000-00-00")
 			} else {
-				val = []byte(fmt.Sprintf("%04d-%02d-%02d",
+				(*dest)[i] = []byte(fmt.Sprintf("%04d-%02d-%02d",
 					bytesToUint16(data[pos:pos+2]),
 					data[pos+2],
 					data[pos+3]))
 			}
-			row[i] = &val
 			pos += int(num)
 
 		// Time HH:MM:SS
@@ -973,19 +954,17 @@ func (mc *mysqlConn) readBinaryRow(rc *mysqlRows) (*[]*[]byte, error) {
 			var num uint64
 			num, n, err = bytesToLengthCodedBinary(data[pos:])
 			if err != nil {
-				return nil, err
+				return
 			}
 
-			var val []byte
 			if num == 0 {
-				val = []byte("00:00:00")
+				(*dest)[i] = []byte("00:00:00")
 			} else {
-				val = []byte(fmt.Sprintf("%02d:%02d:%02d",
+				(*dest)[i] = []byte(fmt.Sprintf("%02d:%02d:%02d",
 					data[pos+6],
 					data[pos+7],
 					data[pos+8]))
 			}
-			row[i] = &val
 			pos += n + int(num)
 
 		// Timestamp YYYY-MM-DD HH:MM:SS
@@ -993,24 +972,23 @@ func (mc *mysqlConn) readBinaryRow(rc *mysqlRows) (*[]*[]byte, error) {
 			var num uint64
 			num, n, err = bytesToLengthCodedBinary(data[pos:])
 			if err != nil {
-				return nil, err
+				return
 			}
 			pos += n
 
-			var val []byte
 			switch num {
 			case 0:
-				val = []byte("0000-00-00 00:00:00")
+				(*dest)[i] = []byte("0000-00-00 00:00:00")
 			case 4:
-				val = []byte(fmt.Sprintf("%04d-%02d-%02d 00:00:00",
+				(*dest)[i] = []byte(fmt.Sprintf("%04d-%02d-%02d 00:00:00",
 					bytesToUint16(data[pos:pos+2]),
 					data[pos+2],
 					data[pos+3]))
 			default:
 				if num < 7 {
-					return nil, fmt.Errorf("Invalid datetime-packet length %d", num)
+					return fmt.Errorf("Invalid datetime-packet length %d", num)
 				}
-				val = []byte(fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d",
+				(*dest)[i] = []byte(fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d",
 					bytesToUint16(data[pos:pos+2]),
 					data[pos+2],
 					data[pos+3],
@@ -1018,14 +996,13 @@ func (mc *mysqlConn) readBinaryRow(rc *mysqlRows) (*[]*[]byte, error) {
 					data[pos+5],
 					data[pos+6]))
 			}
-			row[i] = &val
 			pos += int(num)
 
 		// Please report if this happens!
 		default:
-			return nil, fmt.Errorf("Unknown FieldType %d", rc.columns[i].fieldType)
+			return fmt.Errorf("Unknown FieldType %d", rc.columns[i].fieldType)
 		}
 	}
 
-	return &row, nil
+	return
 }

+ 5 - 20
rows.go

@@ -69,31 +69,16 @@ func (rows *mysqlRows) Next(dest []driver.Value) error {
 		return errors.New("Invalid Connection")
 	}
 
-	columnsCount := cap(dest)
-
 	// Fetch next row from stream
-	var row *[]*[]byte
 	var err error
 	if rows.binary {
-		row, err = rows.mc.readBinaryRow(rows)
+		err = rows.readBinaryRow(&dest)
 	} else {
-		row, err = rows.mc.readRow(columnsCount)
+		err = rows.readRow(&dest)
 	}
 
-	if err != nil {
-		if err == io.EOF {
-			rows.eof = true
-		}
-		return err
+	if err == io.EOF {
+		rows.eof = true
 	}
-
-	for i := 0; i < columnsCount; i++ {
-		if (*row)[i] == nil {
-			dest[i] = nil
-		} else {
-			dest[i] = *(*row)[i]
-		}
-	}
-
-	return nil
+	return err
 }