Browse Source

Added type-conversion for time.Time, []byte and float64 in stmt-params
+ even more clean up

Julien Schmidt 13 years ago
parent
commit
e4b10482cb
6 changed files with 81 additions and 55 deletions
  1. 4 5
      connection.go
  2. 1 1
      const.go
  3. 50 43
      packets.go
  4. 1 1
      rows.go
  5. 4 5
      statement.go
  6. 21 0
      utils.go

+ 4 - 5
connection.go

@@ -157,19 +157,18 @@ func (mc *mysqlConn) Prepare(query string) (ds driver.Stmt, e error) {
 	stmt.mc = mc
 
 	// Read Result
-	var columnCount, paramCount uint16
-	stmt.id, columnCount, paramCount, e = mc.readPrepareResultPacket()
+	var columnCount uint16
+	columnCount, e = stmt.readPrepareResultPacket()
 	if e != nil {
 		return
 	}
 
-	if paramCount > 0 {
-		stmt.params, e = stmt.mc.readColumns(int(paramCount))
+	if stmt.paramCount > 0 {
+		stmt.params, e = stmt.mc.readColumns(stmt.paramCount)
 		if e != nil {
 			return
 		}
 	}
-	stmt.paramCount = int(paramCount)
 
 	if columnCount > 0 {
 		_, e = stmt.mc.readColumns(int(columnCount))

+ 1 - 1
const.go

@@ -14,7 +14,7 @@ package mysql
 const (
 	MIN_PROTOCOL_VERSION = 10
 	MAX_PACKET_SIZE      = 1<<24 - 1
-	TIME_FORMAT          = "2006-01-02 15:04:05.000000000"
+	TIME_FORMAT          = "2006-01-02 15:04:05"
 )
 
 type ClientFlag uint32

+ 50 - 43
packets.go

@@ -17,6 +17,9 @@ import (
 	"time"
 )
 
+// Packets documentation:
+// http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol
+
 // Read packet to buffer 'data'
 func (mc *mysqlConn) readPacket() (data []byte, e error) {
 	// Packet Length
@@ -366,7 +369,7 @@ n   (until end of packet)   message
 */
 func (mc *mysqlConn) handleOkPacket(data []byte) (e error) {
 	if data[0] != 0 {
-		e = errors.New("Wrong Packet-Type: Not a OK-Packet")
+		e = errors.New("Wrong Packet-Type: Not an OK-Packet")
 		return
 	}
 
@@ -451,32 +454,37 @@ func (mc *mysqlConn) readColumns(n int) (columns []*mysqlField, e error) {
 		}
 
 		var pos, n int
-		var catalog, database, table, orgTable, name, orgName []byte
-		var defaultVal uint64
+		var name []byte
+		//var catalog, database, table, orgTable, name, orgName []byte
+		//var defaultVal uint64
 
 		// Catalog
-		catalog, n, _, e = readLengthCodedBinary(data)
+		//catalog, n, _, e = readLengthCodedBinary(data)
+		n, e = readAndDropLengthCodedBinary(data)
 		if e != nil {
 			return
 		}
 		pos += n
 
 		// Database [len coded string]
-		database, n, _, e = readLengthCodedBinary(data[pos:])
+		//database, n, _, e = readLengthCodedBinary(data[pos:])
+		n, e = readAndDropLengthCodedBinary(data[pos:])
 		if e != nil {
 			return
 		}
 		pos += n
 
 		// Table [len coded string]
-		table, n, _, e = readLengthCodedBinary(data[pos:])
+		//table, n, _, e = readLengthCodedBinary(data[pos:])
+		n, e = readAndDropLengthCodedBinary(data[pos:])
 		if e != nil {
 			return
 		}
 		pos += n
 
 		// Original table [len coded string]
-		orgTable, n, _, e = readLengthCodedBinary(data[pos:])
+		//orgTable, n, _, e = readLengthCodedBinary(data[pos:])
+		n, e = readAndDropLengthCodedBinary(data[pos:])
 		if e != nil {
 			return
 		}
@@ -490,7 +498,8 @@ func (mc *mysqlConn) readColumns(n int) (columns []*mysqlField, e error) {
 		pos += n
 
 		// Original name [len coded string]
-		orgName, n, _, e = readLengthCodedBinary(data[pos:])
+		//orgName, n, _, e = readLengthCodedBinary(data[pos:])
+		n, e = readAndDropLengthCodedBinary(data[pos:])
 		if e != nil {
 			return
 		}
@@ -500,11 +509,11 @@ func (mc *mysqlConn) readColumns(n int) (columns []*mysqlField, e error) {
 		pos++
 
 		// Charset [16 bit uint]
-		charsetNumber := bytesToUint16(data[pos : pos+2])
+		//charsetNumber := bytesToUint16(data[pos : pos+2])
 		pos += 2
 
 		// Length [32 bit uint]
-		length := bytesToUint32(data[pos : pos+4])
+		//length := bytesToUint32(data[pos : pos+4])
 		pos += 4
 
 		// Field type [byte]
@@ -513,18 +522,16 @@ func (mc *mysqlConn) readColumns(n int) (columns []*mysqlField, e error) {
 
 		// Flags [16 bit uint]
 		flags := FieldFlag(bytesToUint16(data[pos : pos+2]))
-		pos += 2
+		//pos += 2
 
 		// Decimals [8 bit uint]
-		decimals := data[pos]
-		pos++
+		//decimals := data[pos]
+		//pos++
 
 		// Default value [len coded binary]
-		if pos < len(data) {
-			defaultVal, _, e = bytesToLengthCodedBinary(data[pos:])
-		}
-
-		fmt.Printf("catalog=%s database=%s table=%s orgTable=%s name=%s orgName=%s charsetNumber=%d length=%d fieldType=%d flags=%d decimals=%d defaultVal=%d \n", catalog, database, table, orgTable, name, orgName, charsetNumber, length, fieldType, flags, decimals, defaultVal)
+		//if pos < len(data) {
+		//	defaultVal, _, e = bytesToLengthCodedBinary(data[pos:])
+		//}
 
 		columns = append(columns, &mysqlField{name: string(name), fieldType: fieldType, flags: flags})
 	}
@@ -628,8 +635,8 @@ Prepare OK Packet
         (EOF packet) 
 
 */
-func (mc *mysqlConn) readPrepareResultPacket() (stmtID uint32, columnCount uint16, paramCount uint16, e error) {
-	data, e := mc.readPacket()
+func (stmt mysqlStmt) readPrepareResultPacket() (columnCount uint16, e error) {
+	data, e := stmt.mc.readPacket()
 	if e != nil {
 		return
 	}
@@ -638,12 +645,12 @@ func (mc *mysqlConn) readPrepareResultPacket() (stmtID uint32, columnCount uint1
 	pos := 0
 
 	if data[pos] != 0 {
-		e = mc.handleErrorPacket(data)
+		e = stmt.mc.handleErrorPacket(data)
 		return
 	}
 	pos++
 
-	stmtID = bytesToUint32(data[pos : pos+4])
+	stmt.id = bytesToUint32(data[pos : pos+4])
 	pos += 4
 
 	// Column count [16 bit uint]
@@ -651,7 +658,7 @@ func (mc *mysqlConn) readPrepareResultPacket() (stmtID uint32, columnCount uint1
 	pos += 2
 
 	// Param count [16 bit uint]
-	paramCount = bytesToUint16(data[pos : pos+2])
+	stmt.paramCount = int(bytesToUint16(data[pos : pos+2]))
 	pos += 2
 
 	// Warning count [16 bit uint]
@@ -751,10 +758,20 @@ func (stmt mysqlStmt) buildExecutePacket(args *[]driver.Value) (e error) {
 						byte(FIELD_TYPE_NULL),
 						0x0}...)
 					continue
+				
 				case []byte:
-					fmt.Println("[]byte", (*args)[i])
+					data = append(data, []byte{
+						byte(FIELD_TYPE_STRING),
+						0x0}...)
+					val := (*args)[i].([]byte)
+					paramValues = append(paramValues, lengthCodedBinaryToBytes(uint64(len(val)))...)
+					paramValues = append(paramValues, val...)
+					continue
+				
 				case time.Time:
-					fmt.Println("time.Time", (*args)[i])
+					// Format to string for time+date Fields
+					// Data is packed in case reflect.String below
+					(*args)[i] = (*args)[i].(time.Time).Format(TIME_FORMAT)
 				}
 
 				pv = reflect.ValueOf((*args)[i])
@@ -764,10 +781,14 @@ func (stmt mysqlStmt) buildExecutePacket(args *[]driver.Value) (e error) {
 						byte(FIELD_TYPE_LONGLONG),
 						0x0}...)
 					paramValues = append(paramValues, int64ToBytes(pv.Int())...)
-					fmt.Println("int64", (*args)[i])
+					continue
 
 				case reflect.Float64:
-					fmt.Println("float64", (*args)[i])
+					data = append(data, []byte{
+						byte(FIELD_TYPE_DOUBLE),
+						0x0}...)
+					paramValues = append(paramValues, float64ToBytes(pv.Float())...)
+					continue
 
 				case reflect.Bool:
 					data = append(data, []byte{
@@ -779,7 +800,7 @@ func (stmt mysqlStmt) buildExecutePacket(args *[]driver.Value) (e error) {
 					} else {
 						paramValues = append(paramValues, byte(0))
 					}
-					fmt.Println("bool", (*args)[i])
+					continue
 
 				case reflect.String:
 					data = append(data, []byte{
@@ -788,7 +809,7 @@ func (stmt mysqlStmt) buildExecutePacket(args *[]driver.Value) (e error) {
 					val := pv.String()
 					paramValues = append(paramValues, lengthCodedBinaryToBytes(uint64(len(val)))...)
 					paramValues = append(paramValues, []byte(val)...)
-					fmt.Println("string", string([]byte(val)))
+					continue
 
 				default:
 					return fmt.Errorf("Invalid Value: %s", pv.Kind().String())
@@ -797,7 +818,6 @@ func (stmt mysqlStmt) buildExecutePacket(args *[]driver.Value) (e error) {
 
 			// append cached values
 			data = append(data, paramValues...)
-			fmt.Println("data", string(data))
 		}
 
 		// Save args
@@ -855,7 +875,6 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 					row[i] = intToByteStr(int64(int8(byteToUint8(data[pos]))))
 				}
 				pos++
-				fmt.Println("TINY", string(*row[i]))
 
 			case FIELD_TYPE_SHORT, FIELD_TYPE_YEAR:
 				if unsigned {
@@ -864,7 +883,6 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 					row[i] = intToByteStr(int64(int16(bytesToUint16(data[pos : pos+2]))))
 				}
 				pos += 2
-				fmt.Println("SHORT", string(*row[i]))
 
 			case FIELD_TYPE_INT24, FIELD_TYPE_LONG:
 				if unsigned {
@@ -873,7 +891,6 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 					row[i] = intToByteStr(int64(int32(bytesToUint32(data[pos : pos+4]))))
 				}
 				pos += 4
-				fmt.Println("LONG", string(*row[i]))
 
 			case FIELD_TYPE_LONGLONG:
 				if unsigned {
@@ -882,17 +899,14 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 					row[i] = intToByteStr(int64(bytesToUint64(data[pos : pos+8])))
 				}
 				pos += 8
-				fmt.Println("LONGLONG", string(*row[i]))
 
 			case FIELD_TYPE_FLOAT:
 				row[i] = float32ToByteStr(bytesToFloat32(data[pos : pos+4]))
 				pos += 4
-				fmt.Println("FLOAT", string(*row[i]))
 
 			case FIELD_TYPE_DOUBLE:
 				row[i] = float64ToByteStr(bytesToFloat64(data[pos : pos+8]))
 				pos += 8
-				fmt.Println("DOUBLE", string(*row[i]))
 
 			case FIELD_TYPE_DECIMAL, FIELD_TYPE_NEWDECIMAL:
 				var tmp []byte
@@ -903,10 +917,8 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 
 				if isNull && rc.columns[i].flags&FLAG_NOT_NULL == 0 {
 					row[i] = nil
-					fmt.Println("DECIMAL", nil)
 				} else {
 					row[i] = &tmp
-					fmt.Println("DECIMAL", string(tmp))
 				}
 				pos += n
 
@@ -923,10 +935,8 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 
 				if isNull && rc.columns[i].flags&FLAG_NOT_NULL == 0 {
 					row[i] = nil
-					fmt.Println("STRING", nil)
 				} else {
 					row[i] = &tmp
-					fmt.Println("STRING", string(tmp))
 				}
 				pos += n
 
@@ -950,7 +960,6 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 				}
 				row[i] = &tmp
 				pos += int(num)
-				fmt.Println("DATE", string(*row[i]))
 
 			// Time HH:MM:SS
 			case FIELD_TYPE_TIME:
@@ -971,7 +980,6 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 				}
 				row[i] = &tmp
 				pos += n + int(num)
-				fmt.Println("TIME", string(*row[i]))
 
 			// Timestamp YYYY-MM-DD HH:MM:SS
 			case FIELD_TYPE_TIMESTAMP, FIELD_TYPE_DATETIME:
@@ -996,7 +1004,6 @@ func (mc *mysqlConn) readBinaryRows(rc *rowsContent) (e error) {
 				}
 				row[i] = &tmp
 				pos += int(num)
-				fmt.Println("DATE", string(*row[i]))
 
 			// Please report if this happens!
 			default:

+ 1 - 1
rows.go

@@ -42,7 +42,7 @@ func (rows mysqlRows) Close() error {
 }
 
 // Next returns []driver.Value filled with either nil values for NULL entries
-// or []byte for every other entries. Type conversion is done on rows.scan(),
+// or []byte's for all other entries. Type conversion is done on rows.scan(),
 // when the dest. type is know, which makes type conversion easier and avoids 
 // unnecessary conversions.
 func (rows mysqlRows) Next(dest []driver.Value) error {

+ 4 - 5
statement.go

@@ -10,7 +10,6 @@ package mysql
 
 import (
 	"database/sql/driver"
-	"fmt"
 )
 
 type stmtContent struct {
@@ -118,7 +117,7 @@ func (stmt mysqlStmt) Query(args []driver.Value) (dr driver.Rows, e error) {
 // column index.  If the type of a specific column isn't known
 // or shouldn't be handled specially, DefaultValueConverter
 // can be returned.
-func (stmt mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
-	debug(fmt.Sprintf("ColumnConverter(%d)", idx))
-	return driver.DefaultParameterConverter
-}
+//func (stmt mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
+//	debug(fmt.Sprintf("ColumnConverter(%d)", idx))
+//	return driver.DefaultParameterConverter
+//}

+ 21 - 0
utils.go

@@ -145,6 +145,23 @@ func readLengthCodedBinary(data []byte) (b []byte, n int, isNull bool, e error)
 	return
 }
 
+func readAndDropLengthCodedBinary(data []byte) (n int, e error) {
+	// Get length
+	num, n, e := bytesToLengthCodedBinary(data)
+	if e != nil {
+		return
+	}
+
+	// Check data length
+	if len(data) < n+int(num) {
+		e = io.EOF
+		return
+	}
+
+	n += int(num)
+	return
+}
+
 /******************************************************************************
 *                       Convert from and to bytes                             *
 ******************************************************************************/
@@ -210,6 +227,10 @@ func bytesToFloat64(b []byte) float64 {
 	return math.Float64frombits(bytesToUint64(b))
 }
 
+func float64ToBytes(f float64) []byte {
+	return uint64ToBytes(math.Float64bits(f))
+}
+
 func bytesToLengthCodedBinary(b []byte) (length uint64, n int, e error) {
 	switch {