Browse Source

Differentiate between BINARY and CHAR (#724)

* Differentiate between BINARY and CHAR

When looking up the database type name, we now check the character set
for the following field types:
 * CHAR
 * VARCHAR
 * BLOB
 * TINYBLOB
 * MEDIUMBLOB
 * LONGBLOB

If the character set is 63 (which is the binary pseudo character set),
we return the binary names, which are (respectively):
 * BINARY
 * VARBINARY
 * BLOB
 * TINYBLOB
 * MEDIUMBLOB
 * LONGBLOB

If any other character set is in use, we return the text names, which
are (again, respectively):
 * CHAR
 * VARCHAR
 * TEXT
 * TINYTEXT
 * MEDIUMTEXT
 * LONGTEXT

To facilitate this, mysqlField has been extended to include a uint8
field for character set, which is read from the appropriate packet.

Column type tests have been updated to ensure coverage of binary and
text types.

* Increase test coverage for column types
Kieron Woodhouse 8 years ago
parent
commit
9889442273
6 changed files with 112 additions and 36 deletions
  1. 2 0
      AUTHORS
  2. 1 0
      collations.go
  3. 20 2
      driver_go18_test.go
  4. 83 29
      fields.go
  5. 5 1
      packets.go
  6. 1 4
      rows.go

+ 2 - 0
AUTHORS

@@ -44,6 +44,7 @@ Justin Li <jli at j-li.net>
 Justin Nuß <nuss.justin at gmail.com>
 Kamil Dziedzic <kamil at klecza.pl>
 Kevin Malachowski <kevin at chowski.com>
+Kieron Woodhouse <kieron.woodhouse at infosum.com>
 Lennart Rudolph <lrudolph at hmc.edu>
 Leonardo YongUk Kim <dalinaum at gmail.com>
 Linh Tran Tuan <linhduonggnu at gmail.com>
@@ -75,6 +76,7 @@ Zhenye Xie <xiezhenye at gmail.com>
 Barracuda Networks, Inc.
 Counting Ltd.
 Google Inc.
+InfoSum Ltd.
 Keybase Inc.
 Pivotal Inc.
 Stripe Inc.

+ 1 - 0
collations.go

@@ -9,6 +9,7 @@
 package mysql
 
 const defaultCollation = "utf8_general_ci"
+const binaryCollation = "binary"
 
 // A list of available collations mapped to the internal ID.
 // To update this map use the following MySQL query:

+ 20 - 2
driver_go18_test.go

@@ -588,10 +588,16 @@ func TestRowsColumnTypes(t *testing.T) {
 	nt1 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true}
 	nt2 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true}
 	nt6 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true}
+	nd1 := NullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true}
+	nd2 := NullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true}
+	ndNULL := NullTime{Time: time.Time{}, Valid: false}
 	rbNULL := sql.RawBytes(nil)
 	rb0 := sql.RawBytes("0")
 	rb42 := sql.RawBytes("42")
 	rbTest := sql.RawBytes("Test")
+	rb0pad4 := sql.RawBytes("0\x00\x00\x00") // BINARY right-pads values with 0x00
+	rbx0 := sql.RawBytes("\x00")
+	rbx42 := sql.RawBytes("\x42")
 
 	var columns = []struct {
 		name             string
@@ -604,6 +610,7 @@ func TestRowsColumnTypes(t *testing.T) {
 		valuesIn         [3]string
 		valuesOut        [3]interface{}
 	}{
+		{"bit8null", "BIT(8)", "BIT", scanTypeRawBytes, true, 0, 0, [3]string{"0x0", "NULL", "0x42"}, [3]interface{}{rbx0, rbNULL, rbx42}},
 		{"boolnull", "BOOL", "TINYINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "true", "0"}, [3]interface{}{niNULL, ni1, ni0}},
 		{"bool", "BOOL NOT NULL", "TINYINT", scanTypeInt8, false, 0, 0, [3]string{"1", "0", "FALSE"}, [3]interface{}{int8(1), int8(0), int8(0)}},
 		{"intnull", "INTEGER", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}},
@@ -611,6 +618,7 @@ func TestRowsColumnTypes(t *testing.T) {
 		{"smallintnull", "SMALLINT", "SMALLINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}},
 		{"int3null", "INT(3)", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}},
 		{"int7", "INT(7) NOT NULL", "INT", scanTypeInt32, false, 0, 0, [3]string{"0", "-1337", "42"}, [3]interface{}{int32(0), int32(-1337), int32(42)}},
+		{"mediumintnull", "MEDIUMINT", "MEDIUMINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "42", "NULL"}, [3]interface{}{ni0, ni42, niNULL}},
 		{"bigint", "BIGINT NOT NULL", "BIGINT", scanTypeInt64, false, 0, 0, [3]string{"0", "65535", "-42"}, [3]interface{}{int64(0), int64(65535), int64(-42)}},
 		{"bigintnull", "BIGINT", "BIGINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "1", "42"}, [3]interface{}{niNULL, ni1, ni42}},
 		{"tinyuint", "TINYINT UNSIGNED NOT NULL", "TINYINT", scanTypeUint8, false, 0, 0, [3]string{"0", "255", "42"}, [3]interface{}{uint8(0), uint8(255), uint8(42)}},
@@ -630,11 +638,21 @@ func TestRowsColumnTypes(t *testing.T) {
 		{"decimal3null", "DECIMAL(5,0)", "DECIMAL", scanTypeRawBytes, true, 5, 0, [3]string{"0", "NULL", "-12345.123456"}, [3]interface{}{rb0, rbNULL, sql.RawBytes("-12345")}},
 		{"char25null", "CHAR(25)", "CHAR", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
 		{"varchar42", "VARCHAR(42) NOT NULL", "VARCHAR", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
-		{"textnull", "TEXT", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
-		{"longtext", "LONGTEXT NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
+		{"binary4null", "BINARY(4)", "BINARY", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0pad4, rbNULL, rbTest}},
+		{"varbinary42", "VARBINARY(42) NOT NULL", "VARBINARY", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
+		{"tinyblobnull", "TINYBLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
+		{"tinytextnull", "TINYTEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
+		{"blobnull", "BLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
+		{"textnull", "TEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
+		{"mediumblob", "MEDIUMBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
+		{"mediumtext", "MEDIUMTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
+		{"longblob", "LONGBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
+		{"longtext", "LONGTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
 		{"datetime", "DATETIME", "DATETIME", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt0, nt0}},
 		{"datetime2", "DATETIME(2)", "DATETIME", scanTypeNullTime, true, 2, 2, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt2}},
 		{"datetime6", "DATETIME(6)", "DATETIME", scanTypeNullTime, true, 6, 6, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt6}},
+		{"date", "DATE", "DATE", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02'", "NULL", "'2006-03-04'"}, [3]interface{}{nd1, ndNULL, nd2}},
+		{"year", "YEAR NOT NULL", "YEAR", scanTypeUint16, false, 0, 0, [3]string{"2006", "2000", "1994"}, [3]interface{}{uint16(2006), uint16(2000), uint16(1994)}},
 	}
 
 	schema := ""

+ 83 - 29
fields.go

@@ -13,35 +13,88 @@ import (
 	"reflect"
 )
 
-var typeDatabaseName = map[fieldType]string{
-	fieldTypeBit:        "BIT",
-	fieldTypeBLOB:       "BLOB",
-	fieldTypeDate:       "DATE",
-	fieldTypeDateTime:   "DATETIME",
-	fieldTypeDecimal:    "DECIMAL",
-	fieldTypeDouble:     "DOUBLE",
-	fieldTypeEnum:       "ENUM",
-	fieldTypeFloat:      "FLOAT",
-	fieldTypeGeometry:   "GEOMETRY",
-	fieldTypeInt24:      "MEDIUMINT",
-	fieldTypeJSON:       "JSON",
-	fieldTypeLong:       "INT",
-	fieldTypeLongBLOB:   "LONGBLOB",
-	fieldTypeLongLong:   "BIGINT",
-	fieldTypeMediumBLOB: "MEDIUMBLOB",
-	fieldTypeNewDate:    "DATE",
-	fieldTypeNewDecimal: "DECIMAL",
-	fieldTypeNULL:       "NULL",
-	fieldTypeSet:        "SET",
-	fieldTypeShort:      "SMALLINT",
-	fieldTypeString:     "CHAR",
-	fieldTypeTime:       "TIME",
-	fieldTypeTimestamp:  "TIMESTAMP",
-	fieldTypeTiny:       "TINYINT",
-	fieldTypeTinyBLOB:   "TINYBLOB",
-	fieldTypeVarChar:    "VARCHAR",
-	fieldTypeVarString:  "VARCHAR",
-	fieldTypeYear:       "YEAR",
+func (mf *mysqlField) typeDatabaseName() string {
+	switch mf.fieldType {
+	case fieldTypeBit:
+		return "BIT"
+	case fieldTypeBLOB:
+		if mf.charSet != collations[binaryCollation] {
+			return "TEXT"
+		}
+		return "BLOB"
+	case fieldTypeDate:
+		return "DATE"
+	case fieldTypeDateTime:
+		return "DATETIME"
+	case fieldTypeDecimal:
+		return "DECIMAL"
+	case fieldTypeDouble:
+		return "DOUBLE"
+	case fieldTypeEnum:
+		return "ENUM"
+	case fieldTypeFloat:
+		return "FLOAT"
+	case fieldTypeGeometry:
+		return "GEOMETRY"
+	case fieldTypeInt24:
+		return "MEDIUMINT"
+	case fieldTypeJSON:
+		return "JSON"
+	case fieldTypeLong:
+		return "INT"
+	case fieldTypeLongBLOB:
+		if mf.charSet != collations[binaryCollation] {
+			return "LONGTEXT"
+		}
+		return "LONGBLOB"
+	case fieldTypeLongLong:
+		return "BIGINT"
+	case fieldTypeMediumBLOB:
+		if mf.charSet != collations[binaryCollation] {
+			return "MEDIUMTEXT"
+		}
+		return "MEDIUMBLOB"
+	case fieldTypeNewDate:
+		return "DATE"
+	case fieldTypeNewDecimal:
+		return "DECIMAL"
+	case fieldTypeNULL:
+		return "NULL"
+	case fieldTypeSet:
+		return "SET"
+	case fieldTypeShort:
+		return "SMALLINT"
+	case fieldTypeString:
+		if mf.charSet == collations[binaryCollation] {
+			return "BINARY"
+		}
+		return "CHAR"
+	case fieldTypeTime:
+		return "TIME"
+	case fieldTypeTimestamp:
+		return "TIMESTAMP"
+	case fieldTypeTiny:
+		return "TINYINT"
+	case fieldTypeTinyBLOB:
+		if mf.charSet != collations[binaryCollation] {
+			return "TINYTEXT"
+		}
+		return "TINYBLOB"
+	case fieldTypeVarChar:
+		if mf.charSet == collations[binaryCollation] {
+			return "VARBINARY"
+		}
+		return "VARCHAR"
+	case fieldTypeVarString:
+		if mf.charSet == collations[binaryCollation] {
+			return "VARBINARY"
+		}
+		return "VARCHAR"
+	case fieldTypeYear:
+		return "YEAR"
+	default:
+		return ""
+	}
 }
 
 var (
@@ -69,6 +122,7 @@ type mysqlField struct {
 	flags     fieldFlag
 	fieldType fieldType
 	decimals  byte
+	charSet   uint8
 }
 
 func (mf *mysqlField) scanType() reflect.Type {

+ 5 - 1
packets.go

@@ -697,10 +697,14 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
 		if err != nil {
 			return nil, err
 		}
+		pos += n
 
 		// Filler [uint8]
+		pos++
+
 		// Charset [charset, collation uint8]
-		pos += n + 1 + 2
+		columns[i].charSet = data[pos]
+		pos += 2
 
 		// Length [uint32]
 		columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4])

+ 1 - 4
rows.go

@@ -60,10 +60,7 @@ func (rows *mysqlRows) Columns() []string {
 }
 
 func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string {
-	if name, ok := typeDatabaseName[rows.rs.columns[i].fieldType]; ok {
-		return name
-	}
-	return ""
+	return rows.rs.columns[i].typeDatabaseName()
 }
 
 // func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) {