Browse Source

Finish collation implementation

Julien Schmidt 11 years ago
parent
commit
8636b6ca08
7 changed files with 53 additions and 63 deletions
  1. 4 6
      collations.go
  2. 3 17
      connection.go
  3. 0 1
      driver.go
  4. 9 14
      driver_test.go
  5. 1 1
      packets.go
  6. 24 12
      utils.go
  7. 12 12
      utils_test.go

+ 4 - 6
collations.go

@@ -8,12 +8,10 @@
 
 package mysql
 
-const collationUtf8GeneralCi = 33
+const defaultCollation byte = 33 // utf8_general_ci
 
-const defaultCollation byte = collationUtf8GeneralCi
-
-// A list of available collations and associated charsets to update this map
-// is available in MySQL with the query
+// A list of available collations mapped to the internal ID.
+// To update this map use the following MySQL query:
 //     SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS
 var collations = map[string]byte{
 	"big5_chinese_ci":          1,
@@ -47,7 +45,7 @@ var collations = map[string]byte{
 	"latin5_turkish_ci":        30,
 	"latin1_german2_ci":        31,
 	"armscii8_general_ci":      32,
-	"utf8_general_ci":          collationUtf8GeneralCi,
+	"utf8_general_ci":          33,
 	"cp1250_czech_cs":          34,
 	"ucs2_general_ci":          35,
 	"cp866_general_ci":         36,

+ 3 - 17
connection.go

@@ -27,7 +27,6 @@ type mysqlConn struct {
 	maxWriteSize     int
 	flags            clientFlag
 	sequence         uint8
-	collation        byte
 	parseTime        bool
 	strict           bool
 }
@@ -40,31 +39,18 @@ type config struct {
 	dbname            string
 	params            map[string]string
 	loc               *time.Location
-	timeout           time.Duration
 	tls               *tls.Config
+	timeout           time.Duration
+	collation         uint8
 	allowAllFiles     bool
 	allowOldPasswords bool
 	clientFoundRows   bool
 }
 
-// Handles parameters set in DSN
+// Handles parameters set in DSN after the connection is established
 func (mc *mysqlConn) handleParams() (err error) {
 	for param, val := range mc.cfg.params {
 		switch param {
-		// Collation
-		case "collation":
-			collation, ok := collations[val]
-			if !ok {
-				// Note possibility for false negatives:
-				// could be caused although the collation is valid
-				// if the collations map does not contain entries
-				// the server supports.
-				err = errors.New("unknown collation")
-				return
-			}
-			mc.collation = collation
-			break
-
 		// Charset
 		case "charset":
 			charsets := strings.Split(val, ",")

+ 0 - 1
driver.go

@@ -40,7 +40,6 @@ func (d *MySQLDriver) Open(dsn string) (driver.Conn, error) {
 	mc := &mysqlConn{
 		maxPacketAllowed: maxPacketSize,
 		maxWriteSize:     maxPacketSize - 1,
-		collation:        defaultCollation,
 	}
 	mc.cfg, err = parseDSN(dsn)
 	if err != nil {

+ 9 - 14
driver_test.go

@@ -944,33 +944,28 @@ func TestCollation(t *testing.T) {
 	}
 
 	defaultCollation := "utf8_general_ci"
-	tests := []string{
+	testCollations := []string{
 		"",               // do not set
 		defaultCollation, // driver default
 		"latin1_general_ci",
 		"binary",
+		"utf8_unicode_ci",
 		"utf8mb4_general_ci",
 	}
-	cdsn := dsn
-	for _, collation := range tests {
-		var expected string
+
+	for _, collation := range testCollations {
+		var expected, tdsn string
 		if collation != "" {
-			cdsn += "&collation=" + collation
+			tdsn = dsn + "&collation=" + collation
 			expected = collation
 		} else {
+			tdsn = dsn
 			expected = defaultCollation
 		}
-		runTests(t, cdsn, func(dbt *DBTest) {
-			rows := dbt.mustQuery("SELECT @@collation_connection")
-			defer rows.Close()
-
-			if !rows.Next() {
-				dbt.Fatalf("Error getting connection collation: %s", rows.Err())
-			}
 
+		runTests(t, tdsn, func(dbt *DBTest) {
 			var got string
-			err := rows.Scan(&got)
-			if err != nil {
+			if err := dbt.db.QueryRow("SELECT @@collation_connection").Scan(&got); err != nil {
 				dbt.Fatal(err)
 			}
 

+ 1 - 1
packets.go

@@ -257,7 +257,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 	data[11] = 0x00
 
 	// Charset [1 byte]
-	data[12] = mc.collation
+	data[12] = mc.cfg.collation
 
 	// SSL Connection Request Packet
 	// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest

+ 24 - 12
utils.go

@@ -72,7 +72,11 @@ func DeregisterTLSConfig(key string) {
 
 // parseDSN parses the DSN string to a config
 func parseDSN(dsn string) (cfg *config, err error) {
-	cfg = new(config)
+	// New config with some default values
+	cfg = &config{
+		loc:       time.UTC,
+		collation: defaultCollation,
+	}
 
 	// TODO: use strings.IndexByte when we can depend on Go 1.2
 
@@ -160,11 +164,6 @@ func parseDSN(dsn string) (cfg *config, err error) {
 
 	}
 
-	// Set default location if empty
-	if cfg.loc == nil {
-		cfg.loc = time.UTC
-	}
-
 	return
 }
 
@@ -188,22 +187,35 @@ func parseDSNParams(cfg *config, params string) (err error) {
 				return fmt.Errorf("Invalid Bool value: %s", value)
 			}
 
-		// Switch "rowsAffected" mode
-		case "clientFoundRows":
+		// Use old authentication mode (pre MySQL 4.1)
+		case "allowOldPasswords":
 			var isBool bool
-			cfg.clientFoundRows, isBool = readBool(value)
+			cfg.allowOldPasswords, isBool = readBool(value)
 			if !isBool {
 				return fmt.Errorf("Invalid Bool value: %s", value)
 			}
 
-		// Use old authentication mode (pre MySQL 4.1)
-		case "allowOldPasswords":
+		// Switch "rowsAffected" mode
+		case "clientFoundRows":
 			var isBool bool
-			cfg.allowOldPasswords, isBool = readBool(value)
+			cfg.clientFoundRows, isBool = readBool(value)
 			if !isBool {
 				return fmt.Errorf("Invalid Bool value: %s", value)
 			}
 
+		// Collation
+		case "collation":
+			collation, ok := collations[value]
+			if !ok {
+				// Note possibility for false negatives:
+				// could be triggered  although the collation is valid if the
+				// collations map does not contain entries the server supports.
+				err = errors.New("unknown collation")
+				return
+			}
+			cfg.collation = collation
+			break
+
 		// Time Location
 		case "loc":
 			if value, err = url.QueryUnescape(value); err != nil {

+ 12 - 12
utils_test.go

@@ -21,18 +21,18 @@ var testDSNs = []struct {
 	out string
 	loc *time.Location
 }{
-	{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
-	{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
-	{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
-	{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
-	{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:30000000000 tls:<nil> allowAllFiles:true allowOldPasswords:true clientFoundRows:true}", time.UTC},
-	{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.Local},
-	{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
-	{"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
-	{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
-	{"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
-	{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
-	{"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p timeout:0 tls:<nil> allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+	{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+	{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+	{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+	{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+	{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:30000000000 collation:224 allowAllFiles:true allowOldPasswords:true clientFoundRows:true}", time.UTC},
+	{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.Local},
+	{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+	{"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+	{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+	{"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+	{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
+	{"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false}", time.UTC},
 }
 
 func TestDSNParser(t *testing.T) {