Explorar el Código

Merge pull request #332 from arnehormann/uint64params

support uint64 parameters with high bit set
Arne Hormann hace 10 años
padre
commit
0cc29e9fe8
Se han modificado 2 ficheros con 80 adiciones y 0 borrados
  1. 43 0
      driver_test.go
  2. 37 0
      statement.go

+ 43 - 0
driver_test.go

@@ -780,6 +780,49 @@ func TestNULL(t *testing.T) {
 	})
 }
 
+func TestUint64(t *testing.T) {
+	const (
+		u0    = uint64(0)
+		uall  = ^u0
+		uhigh = uall >> 1
+		utop  = ^uhigh
+		s0    = int64(0)
+		sall  = ^s0
+		shigh = int64(uhigh)
+		stop  = ^shigh
+	)
+	runTests(t, dsn, func(dbt *DBTest) {
+		stmt, err := dbt.db.Prepare(`SELECT ?, ?, ? ,?, ?, ?, ?, ?`)
+		if err != nil {
+			dbt.Fatal(err)
+		}
+		defer stmt.Close()
+		row := stmt.QueryRow(
+			u0, uhigh, utop, uall,
+			s0, shigh, stop, sall,
+		)
+
+		var ua, ub, uc, ud uint64
+		var sa, sb, sc, sd int64
+
+		err = row.Scan(&ua, &ub, &uc, &ud, &sa, &sb, &sc, &sd)
+		if err != nil {
+			dbt.Fatal(err)
+		}
+		switch {
+		case ua != u0,
+			ub != uhigh,
+			uc != utop,
+			ud != uall,
+			sa != s0,
+			sb != shigh,
+			sc != stop,
+			sd != sall:
+			dbt.Fatal("Unexpected result value")
+		}
+	})
+}
+
 func TestLongData(t *testing.T) {
 	runTests(t, dsn, func(dbt *DBTest) {
 		var maxAllowedPacketSize int

+ 37 - 0
statement.go

@@ -10,6 +10,8 @@ package mysql
 
 import (
 	"database/sql/driver"
+	"fmt"
+	"reflect"
 )
 
 type mysqlStmt struct {
@@ -34,6 +36,10 @@ func (stmt *mysqlStmt) NumInput() int {
 	return stmt.paramCount
 }
 
+func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
+	return converter{}
+}
+
 func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
 	if stmt.mc.netConn == nil {
 		errLog.Print(ErrInvalidConn)
@@ -110,3 +116,34 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
 
 	return rows, err
 }
+
+type converter struct{}
+
+func (converter) ConvertValue(v interface{}) (driver.Value, error) {
+	if driver.IsValue(v) {
+		return v, nil
+	}
+
+	rv := reflect.ValueOf(v)
+	switch rv.Kind() {
+	case reflect.Ptr:
+		// indirect pointers
+		if rv.IsNil() {
+			return nil, nil
+		}
+		return driver.DefaultParameterConverter.ConvertValue(rv.Elem().Interface())
+	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+		return rv.Int(), nil
+	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
+		return int64(rv.Uint()), nil
+	case reflect.Uint64:
+		u64 := rv.Uint()
+		if u64 >= 1<<63 {
+			return fmt.Sprintf("%d", u64), nil
+		}
+		return int64(u64), nil
+	case reflect.Float32, reflect.Float64:
+		return rv.Float(), nil
+	}
+	return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
+}