Pārlūkot izejas kodu

Update ConvertValue() to match the database/sql/driver implementation except for uint64 (#760)

This simply copies recent changes to ConvertValue from
database/sql/driver to ensure that our behaviour only differs for
uint64.

Fixes #739
Andrew Reid 7 gadi atpakaļ
vecāks
revīzija
e8153fbb24
3 mainītis faili ar 46 papildinājumiem un 4 dzēšanām
  1. 1 0
      AUTHORS
  2. 8 0
      driver_go18_test.go
  3. 37 4
      statement.go

+ 1 - 0
AUTHORS

@@ -14,6 +14,7 @@
 Aaron Hopkins <go-sql-driver at die.net>
 Achille Roussel <achille.roussel at gmail.com>
 Alexey Palazhchenko <alexey.palazhchenko at gmail.com>
+Andrew Reid <andrew.reid at tixtrack.com>
 Arne Hormann <arnehormann at gmail.com>
 Asta Xie <xiemengjun at gmail.com>
 Bulat Gaifullin <gaifullinbf at gmail.com>

+ 8 - 0
driver_go18_test.go

@@ -796,3 +796,11 @@ func TestRowsColumnTypes(t *testing.T) {
 		})
 	}
 }
+
+func TestValuerWithValueReceiverGivenNilValue(t *testing.T) {
+	runTests(t, dsn, func(dbt *DBTest) {
+		dbt.mustExec("CREATE TABLE test (value VARCHAR(255))")
+		dbt.db.Exec("INSERT INTO test VALUES (?)", (*testValuer)(nil))
+		// This test will panic on the INSERT if ConvertValue() does not check for typed nil before calling Value()
+	})
+}

+ 37 - 4
statement.go

@@ -132,15 +132,25 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
 
 type converter struct{}
 
+// ConvertValue mirrors the reference/default converter in database/sql/driver
+// with _one_ exception.  We support uint64 with their high bit and the default
+// implementation does not.  This function should be kept in sync with
+// database/sql/driver defaultConverter.ConvertValue() except for that
+// deliberate difference.
 func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
 	if driver.IsValue(v) {
 		return v, nil
 	}
 
-	if v != nil {
-		if valuer, ok := v.(driver.Valuer); ok {
-			return valuer.Value()
+	if vr, ok := v.(driver.Valuer); ok {
+		sv, err := callValuerValue(vr)
+		if err != nil {
+			return nil, err
+		}
+		if !driver.IsValue(sv) {
+			return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
 		}
+		return sv, nil
 	}
 
 	rv := reflect.ValueOf(v)
@@ -149,8 +159,9 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
 		// indirect pointers
 		if rv.IsNil() {
 			return nil, nil
+		} else {
+			return c.ConvertValue(rv.Elem().Interface())
 		}
-		return c.ConvertValue(rv.Elem().Interface())
 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
 		return rv.Int(), nil
 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
@@ -176,3 +187,25 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
 	}
 	return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
 }
+
+var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
+
+// callValuerValue returns vr.Value(), with one exception:
+// If vr.Value is an auto-generated method on a pointer type and the
+// pointer is nil, it would panic at runtime in the panicwrap
+// method. Treat it like nil instead.
+//
+// This is so people can implement driver.Value on value types and
+// still use nil pointers to those types to mean nil/NULL, just like
+// string/*string.
+//
+// This is an exact copy of the same-named unexported function from the
+// database/sql package.
+func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
+	if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
+		rv.IsNil() &&
+		rv.Type().Elem().Implements(valuerReflectType) {
+		return nil, nil
+	}
+	return vr.Value()
+}