Browse Source

fix return types

only return types which are allowed by database/sql/driver :
int64
float64
bool
[]byte
time.Time
Julien Schmidt 12 years ago
parent
commit
447c64a6f2
5 changed files with 45 additions and 20 deletions
  1. 2 4
      const.go
  2. 2 2
      driver_test.go
  3. 18 13
      packets.go
  4. 1 1
      rows.go
  5. 22 0
      utils.go

+ 2 - 4
const.go

@@ -74,10 +74,8 @@ const (
 	COM_STMT_FETCH
 )
 
-type FieldType byte
-
 const (
-	FIELD_TYPE_DECIMAL FieldType = iota
+	FIELD_TYPE_DECIMAL byte = iota
 	FIELD_TYPE_TINY
 	FIELD_TYPE_SHORT
 	FIELD_TYPE_LONG
@@ -96,7 +94,7 @@ const (
 	FIELD_TYPE_BIT
 )
 const (
-	FIELD_TYPE_NEWDECIMAL FieldType = iota + 0xf6
+	FIELD_TYPE_NEWDECIMAL byte = iota + 0xf6
 	FIELD_TYPE_ENUM
 	FIELD_TYPE_SET
 	FIELD_TYPE_TINY_BLOB

+ 2 - 2
driver_test.go

@@ -249,8 +249,8 @@ func TestFloat(t *testing.T) {
 	mustExec(t, db, "DROP TABLE IF EXISTS test")
 
 	types := [2]string{"FLOAT", "DOUBLE"}
-	in := float64(42.23)
-	var out float64
+	in := float32(42.23)
+	var out float32
 	var rows *sql.Rows
 
 	for _, v := range types {

+ 18 - 13
packets.go

@@ -434,7 +434,7 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
 		pos += n + 1 + 2 + 4
 
 		// Field type [byte]
-		columns[i].fieldType = FieldType(data[pos])
+		columns[i].fieldType = data[pos]
 		pos++
 
 		// Flags [16 bit uint]
@@ -561,26 +561,26 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 		// build NULL-bitmap
 		if args[i] == nil {
 			bitMask += 1 << uint(i)
-			paramTypes[i<<1] = byte(FIELD_TYPE_NULL)
+			paramTypes[i<<1] = FIELD_TYPE_NULL
 			continue
 		}
 
 		// cache types and values
 		switch args[i].(type) {
 		case int64:
-			paramTypes[i<<1] = byte(FIELD_TYPE_LONGLONG)
+			paramTypes[i<<1] = FIELD_TYPE_LONGLONG
 			paramValues[i] = uint64ToBytes(uint64(args[i].(int64)))
 			pktLen += 8
 			continue
 
 		case float64:
-			paramTypes[i<<1] = byte(FIELD_TYPE_DOUBLE)
+			paramTypes[i<<1] = FIELD_TYPE_DOUBLE
 			paramValues[i] = uint64ToBytes(math.Float64bits(args[i].(float64)))
 			pktLen += 8
 			continue
 
 		case bool:
-			paramTypes[i<<1] = byte(FIELD_TYPE_TINY)
+			paramTypes[i<<1] = FIELD_TYPE_TINY
 			pktLen++
 			if args[i].(bool) {
 				paramValues[i] = []byte{0x01}
@@ -590,7 +590,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 			continue
 
 		case []byte:
-			paramTypes[i<<1] = byte(FIELD_TYPE_STRING)
+			paramTypes[i<<1] = FIELD_TYPE_STRING
 			val := args[i].([]byte)
 			paramValues[i] = append(
 				lengthEncodedIntegerToBytes(uint64(len(val))),
@@ -600,7 +600,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 			continue
 
 		case string:
-			paramTypes[i<<1] = byte(FIELD_TYPE_STRING)
+			paramTypes[i<<1] = FIELD_TYPE_STRING
 			val := []byte(args[i].(string))
 			paramValues[i] = append(
 				lengthEncodedIntegerToBytes(uint64(len(val))),
@@ -610,7 +610,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 			continue
 
 		case time.Time:
-			paramTypes[i<<1] = byte(FIELD_TYPE_STRING)
+			paramTypes[i<<1] = FIELD_TYPE_STRING
 			val := []byte(args[i].(time.Time).Format(TIME_FORMAT))
 			paramValues[i] = append(
 				lengthEncodedIntegerToBytes(uint64(len(val))),
@@ -718,7 +718,7 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 		// Numeric Typs
 		case FIELD_TYPE_TINY:
 			if unsigned {
-				dest[i] = uint64(data[pos])
+				dest[i] = int64(data[pos])
 			} else {
 				dest[i] = int64(int8(data[pos]))
 			}
@@ -727,7 +727,7 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 
 		case FIELD_TYPE_SHORT, FIELD_TYPE_YEAR:
 			if unsigned {
-				dest[i] = uint64(binary.LittleEndian.Uint16(data[pos : pos+2]))
+				dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
 			} else {
 				dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
 			}
@@ -736,7 +736,7 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 
 		case FIELD_TYPE_INT24, FIELD_TYPE_LONG:
 			if unsigned {
-				dest[i] = uint64(binary.LittleEndian.Uint32(data[pos : pos+4]))
+				dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
 			} else {
 				dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
 			}
@@ -745,7 +745,12 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 
 		case FIELD_TYPE_LONGLONG:
 			if unsigned {
-				dest[i] = binary.LittleEndian.Uint64(data[pos : pos+8])
+				val := binary.LittleEndian.Uint64(data[pos : pos+8])
+				if val > math.MaxInt64 {
+					dest[i] = uint64ToString(val)
+				} else {
+					dest[i] = int64(val)
+				}
 			} else {
 				dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8]))
 			}
@@ -753,7 +758,7 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 			continue
 
 		case FIELD_TYPE_FLOAT:
-			dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))
+			dest[i] = float64(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])))
 			pos += 4
 			continue
 

+ 1 - 1
rows.go

@@ -17,7 +17,7 @@ import (
 
 type mysqlField struct {
 	name      string
-	fieldType FieldType
+	fieldType byte
 	flags     FieldFlag
 }
 

+ 22 - 0
utils.go

@@ -128,6 +128,28 @@ func uint64ToBytes(n uint64) []byte {
 	}
 }
 
+func uint64ToString(n uint64) []byte {
+	var a [20]byte
+	i := 20
+
+	// U+0030 = 0
+	// ...
+	// U+0039 = 9
+
+	var q uint64
+	for n >= 10 {
+		i--
+		q = n / 10
+		a[i] = uint8(n-q*10) + 0x30
+		n = q
+	}
+
+	i--
+	a[i] = uint8(n) + 0x30
+
+	return a[i:]
+}
+
 func readLengthEnodedString(b []byte) ([]byte, int, error) {
 	// Get length
 	num, _, n := readLengthEncodedInteger(b)