Browse Source

Implement NamedValueChecker for mysqlConn (#690)

* Also add conversions for additional types in ConvertValue
  ref https://github.com/golang/go/commit/d7c0de98a96893e5608358f7578c85be7ba12b25
Justin Li 8 years ago
parent
commit
78d399c0b7
5 changed files with 156 additions and 7 deletions
  1. 1 0
      AUTHORS
  2. 5 0
      connection_go18.go
  3. 30 0
      connection_go18_test.go
  4. 8 0
      statement.go
  5. 112 7
      statement_test.go

+ 1 - 0
AUTHORS

@@ -40,6 +40,7 @@ Jian Zhen <zhenjl at gmail.com>
 Joshua Prunier <joshua.prunier at gmail.com>
 Julien Lefevre <julien.lefevr at gmail.com>
 Julien Schmidt <go-sql-driver at julienschmidt.com>
+Justin Li <jli at j-li.net>
 Justin Nuß <nuss.justin at gmail.com>
 Kamil Dziedzic <kamil at klecza.pl>
 Kevin Malachowski <kevin at chowski.com>

+ 5 - 0
connection_go18.go

@@ -195,3 +195,8 @@ func (mc *mysqlConn) startWatcher() {
 		}
 	}()
 }
+
+func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
+	nv.Value, err = converter{}.ConvertValue(nv.Value)
+	return
+}

+ 30 - 0
connection_go18_test.go

@@ -0,0 +1,30 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+// +build go1.8
+
+package mysql
+
+import (
+	"database/sql/driver"
+	"testing"
+)
+
+func TestCheckNamedValue(t *testing.T) {
+	value := driver.NamedValue{Value: ^uint64(0)}
+	x := &mysqlConn{}
+	err := x.CheckNamedValue(&value)
+
+	if err != nil {
+		t.Fatal("uint64 high-bit not convertible", err)
+	}
+
+	if value.Value != "18446744073709551615" {
+		t.Fatalf("uint64 high-bit not converted, got %#v %T", value.Value, value.Value)
+	}
+}

+ 8 - 0
statement.go

@@ -157,6 +157,14 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
 		return int64(u64), nil
 	case reflect.Float32, reflect.Float64:
 		return rv.Float(), nil
+	case reflect.Bool:
+		return rv.Bool(), nil
+	case reflect.Slice:
+		ek := rv.Type().Elem().Kind()
+		if ek == reflect.Uint8 {
+			return rv.Bytes(), nil
+		}
+		return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
 	case reflect.String:
 		return rv.String(), nil
 	}

+ 112 - 7
statement_test.go

@@ -8,14 +8,119 @@
 
 package mysql
 
-import "testing"
+import (
+	"bytes"
+	"testing"
+)
 
-type customString string
+func TestConvertDerivedString(t *testing.T) {
+	type derived string
 
-func TestConvertValueCustomTypes(t *testing.T) {
-	var cstr customString = "string"
-	c := converter{}
-	if _, err := c.ConvertValue(cstr); err != nil {
-		t.Errorf("custom string type should be valid")
+	output, err := converter{}.ConvertValue(derived("value"))
+	if err != nil {
+		t.Fatal("Derived string type not convertible", err)
+	}
+
+	if output != "value" {
+		t.Fatalf("Derived string type not converted, got %#v %T", output, output)
+	}
+}
+
+func TestConvertDerivedByteSlice(t *testing.T) {
+	type derived []uint8
+
+	output, err := converter{}.ConvertValue(derived("value"))
+	if err != nil {
+		t.Fatal("Byte slice not convertible", err)
+	}
+
+	if bytes.Compare(output.([]byte), []byte("value")) != 0 {
+		t.Fatalf("Byte slice not converted, got %#v %T", output, output)
+	}
+}
+
+func TestConvertDerivedUnsupportedSlice(t *testing.T) {
+	type derived []int
+
+	_, err := converter{}.ConvertValue(derived{1})
+	if err == nil || err.Error() != "unsupported type mysql.derived, a slice of int" {
+		t.Fatal("Unexpected error", err)
+	}
+}
+
+func TestConvertDerivedBool(t *testing.T) {
+	type derived bool
+
+	output, err := converter{}.ConvertValue(derived(true))
+	if err != nil {
+		t.Fatal("Derived bool type not convertible", err)
+	}
+
+	if output != true {
+		t.Fatalf("Derived bool type not converted, got %#v %T", output, output)
+	}
+}
+
+func TestConvertPointer(t *testing.T) {
+	str := "value"
+
+	output, err := converter{}.ConvertValue(&str)
+	if err != nil {
+		t.Fatal("Pointer type not convertible", err)
+	}
+
+	if output != "value" {
+		t.Fatalf("Pointer type not converted, got %#v %T", output, output)
+	}
+}
+
+func TestConvertSignedIntegers(t *testing.T) {
+	values := []interface{}{
+		int8(-42),
+		int16(-42),
+		int32(-42),
+		int64(-42),
+		int(-42),
+	}
+
+	for _, value := range values {
+		output, err := converter{}.ConvertValue(value)
+		if err != nil {
+			t.Fatalf("%T type not convertible %s", value, err)
+		}
+
+		if output != int64(-42) {
+			t.Fatalf("%T type not converted, got %#v %T", value, output, output)
+		}
+	}
+}
+
+func TestConvertUnsignedIntegers(t *testing.T) {
+	values := []interface{}{
+		uint8(42),
+		uint16(42),
+		uint32(42),
+		uint64(42),
+		uint(42),
+	}
+
+	for _, value := range values {
+		output, err := converter{}.ConvertValue(value)
+		if err != nil {
+			t.Fatalf("%T type not convertible %s", value, err)
+		}
+
+		if output != int64(42) {
+			t.Fatalf("%T type not converted, got %#v %T", value, output, output)
+		}
+	}
+
+	output, err := converter{}.ConvertValue(^uint64(0))
+	if err != nil {
+		t.Fatal("uint64 high-bit not convertible", err)
+	}
+
+	if output != "18446744073709551615" {
+		t.Fatalf("uint64 high-bit not converted, got %#v %T", output, output)
 	}
 }