Quellcode durchsuchen

Experimental nil-Values-Patch

Julien Schmidt vor 13 Jahren
Ursprung
Commit
e94b728b48
4 geänderte Dateien mit 55 neuen und 32 gelöschten Zeilen
  1. 2 2
      connection.go
  2. 41 22
      packets.go
  3. 6 2
      rows.go
  4. 6 6
      utils.go

+ 2 - 2
connection.go

@@ -258,13 +258,13 @@ func (mc *mysqlConn) getSystemVar(name string) (val string, e error) {
 			return
 		}
 
-		var rows []*[][]byte
+		var rows []*[]*[]byte
 		rows, e = mc.readRows(int(n))
 		if e != nil {
 			return
 		}
 
-		val = string((*rows[0])[0])
+		val = string(*(*rows[0])[0])
 	}
 
 	return

+ 41 - 22
packets.go

@@ -484,7 +484,7 @@ func (mc *mysqlConn) readColumns(n int) (columns []mysqlField, e error) {
 		}
 
 		var pos, n int
-		var name []byte
+		var name *[]byte
 		//var catalog, database, table, orgTable, name, orgName []byte
 		//var defaultVal uint64
 
@@ -563,14 +563,14 @@ func (mc *mysqlConn) readColumns(n int) (columns []mysqlField, e error) {
 		//	defaultVal, _, e = bytesToLengthCodedBinary(data[pos:])
 		//}
 
-		columns = append(columns, mysqlField{name: string(name), fieldType: fieldType, flags: flags})
+		columns = append(columns, mysqlField{name: string(*name), fieldType: fieldType, flags: flags})
 	}
 
 	return
 }
 
 // Read Packets as Field Packets until EOF-Packet or an Error appears
-func (mc *mysqlConn) readRows(columnsCount int) (rows []*[][]byte, e error) {
+func (mc *mysqlConn) readRows(columnsCount int) (rows []*[]*[]byte, e error) {
 	var data []byte
 	var i, pos, n int
 	var isNull bool
@@ -587,7 +587,7 @@ func (mc *mysqlConn) readRows(columnsCount int) (rows []*[][]byte, e error) {
 		}
 
 		// RowSet Packet
-		row := make([][]byte, columnsCount)
+		row := make([]*[]byte, columnsCount)
 		pos = 0
 		for i = 0; i < columnsCount; i++ {
 			// Read bytes and convert to string
@@ -862,7 +862,7 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 		pos++
 
 		// BinaryRowSet Packet
-		row := make([][]byte, columnsCount)
+		row := make([]*[]byte, columnsCount)
 
 		nullBitMap = data[pos : pos+(columnsCount+7+2)/8]
 		pos += (columnsCount + 7 + 2) / 8
@@ -883,47 +883,60 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 
 			// Numeric Typs
 			case FIELD_TYPE_TINY:
+				var val []byte
 				if unsigned {
-					row[i] = uintToByteStr(uint64(byteToUint8(data[pos])))
+					val = uintToByteStr(uint64(byteToUint8(data[pos])))
 				} else {
-					row[i] = intToByteStr(int64(int8(byteToUint8(data[pos]))))
+					val = intToByteStr(int64(int8(byteToUint8(data[pos]))))
 				}
+				row[i] = &val
 				pos++
 
 			case FIELD_TYPE_SHORT, FIELD_TYPE_YEAR:
+				var val []byte
 				if unsigned {
-					row[i] = uintToByteStr(uint64(bytesToUint16(data[pos : pos+2])))
+					val = uintToByteStr(uint64(bytesToUint16(data[pos : pos+2])))
 				} else {
-					row[i] = intToByteStr(int64(int16(bytesToUint16(data[pos : pos+2]))))
+					val = intToByteStr(int64(int16(bytesToUint16(data[pos : pos+2]))))
 				}
+				row[i] = &val
 				pos += 2
 
 			case FIELD_TYPE_INT24, FIELD_TYPE_LONG:
+				var val []byte
 				if unsigned {
-					row[i] = uintToByteStr(uint64(bytesToUint32(data[pos : pos+4])))
+					val = uintToByteStr(uint64(bytesToUint32(data[pos : pos+4])))
 				} else {
-					row[i] = intToByteStr(int64(int32(bytesToUint32(data[pos : pos+4]))))
+					val = intToByteStr(int64(int32(bytesToUint32(data[pos : pos+4]))))
 				}
+				row[i] = &val
 				pos += 4
 
 			case FIELD_TYPE_LONGLONG:
+				var val []byte
 				if unsigned {
-					row[i] = uintToByteStr(bytesToUint64(data[pos : pos+8]))
+					val = uintToByteStr(bytesToUint64(data[pos : pos+8]))
 				} else {
-					row[i] = intToByteStr(int64(bytesToUint64(data[pos : pos+8])))
+					val = intToByteStr(int64(bytesToUint64(data[pos : pos+8])))
 				}
+				row[i] = &val
 				pos += 8
 
 			case FIELD_TYPE_FLOAT:
-				row[i] = float32ToByteStr(bytesToFloat32(data[pos : pos+4]))
+				var val []byte
+				val = float32ToByteStr(bytesToFloat32(data[pos : pos+4]))
+				row[i] = &val
 				pos += 4
 
 			case FIELD_TYPE_DOUBLE:
-				row[i] = float64ToByteStr(bytesToFloat64(data[pos : pos+8]))
+				var val []byte
+				val = float64ToByteStr(bytesToFloat64(data[pos : pos+8]))
+				row[i] = &val
 				pos += 8
 
 			case FIELD_TYPE_DECIMAL, FIELD_TYPE_NEWDECIMAL:
 				row[i], n, isNull, e = readLengthCodedBinary(data[pos:])
+
 				if e != nil {
 					return
 				}
@@ -957,14 +970,16 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 				}
 				pos += n
 
+				var val []byte
 				if num == 0 {
-					row[i] = []byte("0000-00-00")
+					val = []byte("0000-00-00")
 				} else {
-					row[i] = []byte(fmt.Sprintf("%04d-%02d-%02d",
+					val = []byte(fmt.Sprintf("%04d-%02d-%02d",
 						bytesToUint16(data[pos:pos+2]),
 						data[pos+2],
 						data[pos+3]))
 				}
+				row[i] = &val
 				pos += int(num)
 
 			// Time HH:MM:SS
@@ -975,14 +990,16 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 					return
 				}
 
+				var val []byte
 				if num == 0 {
-					row[i] = []byte("00:00:00")
+					val = []byte("00:00:00")
 				} else {
-					row[i] = []byte(fmt.Sprintf("%02d:%02d:%02d",
+					val = []byte(fmt.Sprintf("%02d:%02d:%02d",
 						data[pos+6],
 						data[pos+7],
 						data[pos+8]))
 				}
+				row[i] = &val
 				pos += n + int(num)
 
 			// Timestamp YYYY-MM-DD HH:MM:SS
@@ -994,11 +1011,12 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 				}
 				pos += n
 
+				var val []byte
 				switch num {
 				case 0:
-					row[i] = []byte("0000-00-00 00:00:00")
+					val = []byte("0000-00-00 00:00:00")
 				case 4:
-					row[i] = []byte(fmt.Sprintf("%04d-%02d-%02d 00:00:00",
+					val = []byte(fmt.Sprintf("%04d-%02d-%02d 00:00:00",
 						bytesToUint16(data[pos:pos+2]),
 						data[pos+2],
 						data[pos+3]))
@@ -1006,7 +1024,7 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 					if num < 7 {
 						return fmt.Errorf("Invalid datetime-packet length %d", num)
 					}
-					row[i] = []byte(fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d",
+					val = []byte(fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d",
 						bytesToUint16(data[pos:pos+2]),
 						data[pos+2],
 						data[pos+3],
@@ -1014,6 +1032,7 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 						data[pos+5],
 						data[pos+6]))
 				}
+				row[i] = &val
 				pos += int(num)
 
 			// Please report if this happens!

+ 6 - 2
rows.go

@@ -22,7 +22,7 @@ type mysqlField struct {
 
 type rowsContent struct {
 	columns []mysqlField
-	rows    []*[][]byte
+	rows    []*[]*[]byte
 }
 
 type mysqlRows struct {
@@ -49,7 +49,11 @@ func (rows mysqlRows) Close() error {
 func (rows mysqlRows) Next(dest []driver.Value) error {
 	if len(rows.content.rows) > 0 {
 		for i := 0; i < cap(dest); i++ {
-			dest[i] = (*rows.content.rows[0])[i]
+			if (*rows.content.rows[0])[i] == nil {
+				dest[i] = nil
+			} else {
+				dest[i] = *(*rows.content.rows[0])[i]
+			}
 		}
 		rows.content.rows = rows.content.rows[1:]
 	} else {

+ 6 - 6
utils.go

@@ -132,20 +132,20 @@ func readSlice(data []byte, delim byte) (slice []byte, e error) {
 	return
 }
 
-func readLengthCodedBinary(data []byte) (b []byte, n int, isNull bool, e error) {
+func readLengthCodedBinary(data []byte) (*[]byte, int, bool, error) {
 	// Get length
 	num, n, e := bytesToLengthCodedBinary(data)
 	if e != nil {
-		return
+		return nil, n, true, e
 	}
 
 	// Check data length
 	if len(data) < n+int(num) {
-		e = io.EOF
-		return
+		return nil, n, true, io.EOF
 	}
 
 	// Check if null
+	var isNull bool
 	if data[0] == 251 {
 		isNull = true
 	} else {
@@ -153,9 +153,9 @@ func readLengthCodedBinary(data []byte) (b []byte, n int, isNull bool, e error)
 	}
 
 	// Get bytes
-	b = data[n : n+int(num)]
+	b := data[n : n+int(num)]
 	n += int(num)
-	return
+	return &b, n, isNull, e
 }
 
 func readAndDropLengthCodedBinary(data []byte) (n int, e error) {