Przeglądaj źródła

Merge pull request #147 from go-sql-driver/byte_nil

Insert NULL for []byte(nil) values
Julien Schmidt 12 lat temu
rodzic
commit
46150062e4
4 zmienionych plików z 263 dodań i 221 usunięć
  1. 2 0
      CHANGELOG.md
  2. 224 191
      driver_test.go
  3. 20 11
      packets.go
  4. 17 19
      utils.go

+ 2 - 0
CHANGELOG.md

@@ -5,6 +5,7 @@ Changes:
   - Go-MySQL-Driver now requires Go 1.1
   - Connections now use the collation `utf8_general_ci` by default. Adding `&charset=UTF8` to the DSN should not be necessary anymore
   - Made closing rows and connections error tolerant. This allows for example deferring rows.Close() without checking for errors
+  - `byte(nil)` is now treated as a NULL value. Before it was treated like an empty string / `[]byte("")`.
   - New Logo
   - Changed the copyright header to include all contributors
   - Optimized the buffer for reading
@@ -28,6 +29,7 @@ Bugfixes:
   - Fixed MySQL 4.1 support: MySQL 4.1 sends packets with lengths which differ from the specification
   - Convert to DB timezone when inserting time.Time
   - Splitted packets (more than 16MB) are now merged correctly
+  - Fixed empty string producing false nil values
 
 
 ## 1.0 (2013-05-14)

+ 224 - 191
driver_test.go

@@ -108,143 +108,6 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows)
 	return rows
 }
 
-func TestReuseClosedConnection(t *testing.T) {
-	// this test does not use sql.database, it uses the driver directly
-	if !available {
-		t.Skipf("MySQL-Server not running on %s", netAddr)
-	}
-	driver := &MySQLDriver{}
-	conn, err := driver.Open(dsn)
-	if err != nil {
-		t.Fatalf("Error connecting: %s", err.Error())
-	}
-	stmt, err := conn.Prepare("DO 1")
-	if err != nil {
-		t.Fatalf("Error preparing statement: %s", err.Error())
-	}
-	_, err = stmt.Exec(nil)
-	if err != nil {
-		t.Fatalf("Error executing statement: %s", err.Error())
-	}
-	err = conn.Close()
-	if err != nil {
-		t.Fatalf("Error closing connection: %s", err.Error())
-	}
-	defer func() {
-		if err := recover(); err != nil {
-			t.Errorf("Panic after reusing a closed connection: %v", err)
-		}
-	}()
-	_, err = stmt.Exec(nil)
-	if err != nil && err != errInvalidConn {
-		t.Errorf("Unexpected error '%s', expected '%s'",
-			err.Error(), errInvalidConn.Error())
-	}
-}
-
-func TestCharset(t *testing.T) {
-	if !available {
-		t.Skipf("MySQL-Server not running on %s", netAddr)
-	}
-
-	mustSetCharset := func(charsetParam, expected string) {
-		runTests(t, dsn+"&"+charsetParam, func(dbt *DBTest) {
-			rows := dbt.mustQuery("SELECT @@character_set_connection")
-			defer rows.Close()
-
-			if !rows.Next() {
-				dbt.Fatalf("Error getting connection charset: %s", rows.Err())
-			}
-
-			var got string
-			rows.Scan(&got)
-
-			if got != expected {
-				dbt.Fatalf("Expected connection charset %s but got %s", expected, got)
-			}
-		})
-	}
-
-	// non utf8 test
-	mustSetCharset("charset=ascii", "ascii")
-
-	// when the first charset is invalid, use the second
-	mustSetCharset("charset=none,utf8", "utf8")
-
-	// when the first charset is valid, use it
-	mustSetCharset("charset=ascii,utf8", "ascii")
-	mustSetCharset("charset=utf8,ascii", "utf8")
-}
-
-func TestFailingCharset(t *testing.T) {
-	runTests(t, dsn+"&charset=none", func(dbt *DBTest) {
-		// run query to really establish connection...
-		_, err := dbt.db.Exec("SELECT 1")
-		if err == nil {
-			dbt.db.Close()
-			t.Fatalf("Connection must not succeed without a valid charset")
-		}
-	})
-}
-
-func TestRawBytesResultExceedsBuffer(t *testing.T) {
-	runTests(t, dsn, func(dbt *DBTest) {
-		// defaultBufSize from buffer.go
-		expected := strings.Repeat("abc", defaultBufSize)
-		rows := dbt.mustQuery("SELECT '" + expected + "'")
-		defer rows.Close()
-		if !rows.Next() {
-			dbt.Error("expected result, got none")
-		}
-		var result sql.RawBytes
-		rows.Scan(&result)
-		if expected != string(result) {
-			dbt.Error("result did not match expected value")
-		}
-	})
-}
-
-func TestTimezoneConversion(t *testing.T) {
-
-	zones := []string{"UTC", "US/Central", "US/Pacific", "Local"}
-
-	// Regression test for timezone handling
-	tzTest := func(dbt *DBTest) {
-
-		// Create table
-		dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)")
-
-		// Insert local time into database (should be converted)
-		usCentral, _ := time.LoadLocation("US/Central")
-		now := time.Now().In(usCentral)
-		dbt.mustExec("INSERT INTO test VALUE (?)", now)
-
-		// Retrieve time from DB
-		rows := dbt.mustQuery("SELECT ts FROM test")
-		if !rows.Next() {
-			dbt.Fatal("Didn't get any rows out")
-		}
-
-		var nowDB time.Time
-		err := rows.Scan(&nowDB)
-		if err != nil {
-			dbt.Fatal("Err", err)
-		}
-
-		// Check that dates match
-		if now.Unix() != nowDB.Unix() {
-			dbt.Errorf("Times don't match.\n")
-			dbt.Errorf(" Now(%v)=%v\n", usCentral, now)
-			dbt.Errorf(" Now(UTC)=%v\n", nowDB)
-		}
-
-	}
-
-	for _, tz := range zones {
-		runTests(t, dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest)
-	}
-}
-
 func TestCRUD(t *testing.T) {
 	runTests(t, dsn, func(dbt *DBTest) {
 		// Create Table
@@ -548,44 +411,6 @@ func TestDateTime(t *testing.T) {
 	}
 }
 
-// This tests for https://github.com/go-sql-driver/mysql/pull/139
-//
-// An extra (invisible) nil byte was being added to the beginning of positive
-// time strings.
-func TestTimeSign(t *testing.T) {
-	runTests(t, dsn, func(dbt *DBTest) {
-		var sTimes = []struct {
-			value     string
-			fieldType string
-		}{
-			{"12:34:56", "TIME"},
-			{"-12:34:56", "TIME"},
-			// As described in http://dev.mysql.com/doc/refman/5.6/en/fractional-seconds.html
-			// they *should* work, but only in 5.6+.
-			// { "12:34:56.789", "TIME(3)" },
-			// { "-12:34:56.789", "TIME(3)" },
-		}
-
-		for _, sTime := range sTimes {
-			dbt.db.Exec("DROP TABLE IF EXISTS test")
-			dbt.mustExec("CREATE TABLE test (id INT, time_field " + sTime.fieldType + ")")
-			dbt.mustExec("INSERT INTO test (id, time_field) VALUES(1, '" + sTime.value + "')")
-			rows := dbt.mustQuery("SELECT time_field FROM test WHERE id = ?", 1)
-			if rows.Next() {
-				var oTime string
-				rows.Scan(&oTime)
-				if oTime != sTime.value {
-					dbt.Errorf(`time values differ: got %q, expected %q.`, oTime, sTime.value)
-				}
-			} else {
-				dbt.Error("expecting at least one row.")
-			}
-		}
-
-	})
-
-}
-
 func TestNULL(t *testing.T) {
 	runTests(t, dsn, func(dbt *DBTest) {
 		nullStmt, err := dbt.db.Prepare("SELECT NULL")
@@ -603,16 +428,14 @@ func TestNULL(t *testing.T) {
 		// NullBool
 		var nb sql.NullBool
 		// Invalid
-		err = nullStmt.QueryRow().Scan(&nb)
-		if err != nil {
+		if err = nullStmt.QueryRow().Scan(&nb); err != nil {
 			dbt.Fatal(err)
 		}
 		if nb.Valid {
 			dbt.Error("Valid NullBool which should be invalid")
 		}
 		// Valid
-		err = nonNullStmt.QueryRow().Scan(&nb)
-		if err != nil {
+		if err = nonNullStmt.QueryRow().Scan(&nb); err != nil {
 			dbt.Fatal(err)
 		}
 		if !nb.Valid {
@@ -624,16 +447,14 @@ func TestNULL(t *testing.T) {
 		// NullFloat64
 		var nf sql.NullFloat64
 		// Invalid
-		err = nullStmt.QueryRow().Scan(&nf)
-		if err != nil {
+		if err = nullStmt.QueryRow().Scan(&nf); err != nil {
 			dbt.Fatal(err)
 		}
 		if nf.Valid {
 			dbt.Error("Valid NullFloat64 which should be invalid")
 		}
 		// Valid
-		err = nonNullStmt.QueryRow().Scan(&nf)
-		if err != nil {
+		if err = nonNullStmt.QueryRow().Scan(&nf); err != nil {
 			dbt.Fatal(err)
 		}
 		if !nf.Valid {
@@ -645,16 +466,14 @@ func TestNULL(t *testing.T) {
 		// NullInt64
 		var ni sql.NullInt64
 		// Invalid
-		err = nullStmt.QueryRow().Scan(&ni)
-		if err != nil {
+		if err = nullStmt.QueryRow().Scan(&ni); err != nil {
 			dbt.Fatal(err)
 		}
 		if ni.Valid {
 			dbt.Error("Valid NullInt64 which should be invalid")
 		}
 		// Valid
-		err = nonNullStmt.QueryRow().Scan(&ni)
-		if err != nil {
+		if err = nonNullStmt.QueryRow().Scan(&ni); err != nil {
 			dbt.Fatal(err)
 		}
 		if !ni.Valid {
@@ -666,16 +485,14 @@ func TestNULL(t *testing.T) {
 		// NullString
 		var ns sql.NullString
 		// Invalid
-		err = nullStmt.QueryRow().Scan(&ns)
-		if err != nil {
+		if err = nullStmt.QueryRow().Scan(&ns); err != nil {
 			dbt.Fatal(err)
 		}
 		if ns.Valid {
 			dbt.Error("Valid NullString which should be invalid")
 		}
 		// Valid
-		err = nonNullStmt.QueryRow().Scan(&ns)
-		if err != nil {
+		if err = nonNullStmt.QueryRow().Scan(&ns); err != nil {
 			dbt.Fatal(err)
 		}
 		if !ns.Valid {
@@ -684,6 +501,48 @@ func TestNULL(t *testing.T) {
 			dbt.Error("Unexpected NullString value:" + ns.String + " (should be `1`)")
 		}
 
+		// nil-bytes
+		var b []byte
+		// Read nil
+		if err = nullStmt.QueryRow().Scan(&b); err != nil {
+			dbt.Fatal(err)
+		}
+		if b != nil {
+			dbt.Error("Non-nil []byte wich should be nil")
+		}
+		// Read non-nil
+		if err = nonNullStmt.QueryRow().Scan(&b); err != nil {
+			dbt.Fatal(err)
+		}
+		if b == nil {
+			dbt.Error("Nil []byte wich should be non-nil")
+		}
+		// Insert nil
+		b = nil
+		success := false
+		if err = dbt.db.QueryRow("SELECT ? IS NULL", b).Scan(&success); err != nil {
+			dbt.Fatal(err)
+		}
+		if !success {
+			dbt.Error("Inserting []byte(nil) as NULL failed")
+		}
+		// Check input==output with input==nil
+		b = nil
+		if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil {
+			dbt.Fatal(err)
+		}
+		if b != nil {
+			dbt.Error("Non-nil echo from nil input")
+		}
+		// Check input==output with input!=nil
+		b = []byte("")
+		if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil {
+			dbt.Fatal(err)
+		}
+		if b == nil {
+			dbt.Error("nil echo from non-nil input")
+		}
+
 		// Insert NULL
 		dbt.mustExec("CREATE TABLE test (dummmy1 int, value int, dummy2 int)")
 
@@ -995,6 +854,180 @@ func TestTLS(t *testing.T) {
 	runTests(t, dsn+"&tls=custom-skip-verify", tlsTest)
 }
 
+func TestReuseClosedConnection(t *testing.T) {
+	// this test does not use sql.database, it uses the driver directly
+	if !available {
+		t.Skipf("MySQL-Server not running on %s", netAddr)
+	}
+
+	driver := &MySQLDriver{}
+	conn, err := driver.Open(dsn)
+	if err != nil {
+		t.Fatalf("Error connecting: %s", err.Error())
+	}
+	stmt, err := conn.Prepare("DO 1")
+	if err != nil {
+		t.Fatalf("Error preparing statement: %s", err.Error())
+	}
+	_, err = stmt.Exec(nil)
+	if err != nil {
+		t.Fatalf("Error executing statement: %s", err.Error())
+	}
+	err = conn.Close()
+	if err != nil {
+		t.Fatalf("Error closing connection: %s", err.Error())
+	}
+
+	defer func() {
+		if err := recover(); err != nil {
+			t.Errorf("Panic after reusing a closed connection: %v", err)
+		}
+	}()
+	_, err = stmt.Exec(nil)
+	if err != nil && err != errInvalidConn {
+		t.Errorf("Unexpected error '%s', expected '%s'",
+			err.Error(), errInvalidConn.Error())
+	}
+}
+
+func TestCharset(t *testing.T) {
+	if !available {
+		t.Skipf("MySQL-Server not running on %s", netAddr)
+	}
+
+	mustSetCharset := func(charsetParam, expected string) {
+		runTests(t, dsn+"&"+charsetParam, func(dbt *DBTest) {
+			rows := dbt.mustQuery("SELECT @@character_set_connection")
+			defer rows.Close()
+
+			if !rows.Next() {
+				dbt.Fatalf("Error getting connection charset: %s", rows.Err())
+			}
+
+			var got string
+			rows.Scan(&got)
+
+			if got != expected {
+				dbt.Fatalf("Expected connection charset %s but got %s", expected, got)
+			}
+		})
+	}
+
+	// non utf8 test
+	mustSetCharset("charset=ascii", "ascii")
+
+	// when the first charset is invalid, use the second
+	mustSetCharset("charset=none,utf8", "utf8")
+
+	// when the first charset is valid, use it
+	mustSetCharset("charset=ascii,utf8", "ascii")
+	mustSetCharset("charset=utf8,ascii", "utf8")
+}
+
+func TestFailingCharset(t *testing.T) {
+	runTests(t, dsn+"&charset=none", func(dbt *DBTest) {
+		// run query to really establish connection...
+		_, err := dbt.db.Exec("SELECT 1")
+		if err == nil {
+			dbt.db.Close()
+			t.Fatalf("Connection must not succeed without a valid charset")
+		}
+	})
+}
+
+func TestRawBytesResultExceedsBuffer(t *testing.T) {
+	runTests(t, dsn, func(dbt *DBTest) {
+		// defaultBufSize from buffer.go
+		expected := strings.Repeat("abc", defaultBufSize)
+
+		rows := dbt.mustQuery("SELECT '" + expected + "'")
+		defer rows.Close()
+		if !rows.Next() {
+			dbt.Error("expected result, got none")
+		}
+		var result sql.RawBytes
+		rows.Scan(&result)
+		if expected != string(result) {
+			dbt.Error("result did not match expected value")
+		}
+	})
+}
+
+func TestTimezoneConversion(t *testing.T) {
+	zones := []string{"UTC", "US/Central", "US/Pacific", "Local"}
+
+	// Regression test for timezone handling
+	tzTest := func(dbt *DBTest) {
+
+		// Create table
+		dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)")
+
+		// Insert local time into database (should be converted)
+		usCentral, _ := time.LoadLocation("US/Central")
+		now := time.Now().In(usCentral)
+		dbt.mustExec("INSERT INTO test VALUE (?)", now)
+
+		// Retrieve time from DB
+		rows := dbt.mustQuery("SELECT ts FROM test")
+		if !rows.Next() {
+			dbt.Fatal("Didn't get any rows out")
+		}
+
+		var nowDB time.Time
+		err := rows.Scan(&nowDB)
+		if err != nil {
+			dbt.Fatal("Err", err)
+		}
+
+		// Check that dates match
+		if now.Unix() != nowDB.Unix() {
+			dbt.Errorf("Times don't match.\n")
+			dbt.Errorf(" Now(%v)=%v\n", usCentral, now)
+			dbt.Errorf(" Now(UTC)=%v\n", nowDB)
+		}
+	}
+
+	for _, tz := range zones {
+		runTests(t, dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest)
+	}
+}
+
+// This tests for https://github.com/go-sql-driver/mysql/pull/139
+//
+// An extra (invisible) nil byte was being added to the beginning of positive
+// time strings.
+func TestTimeSign(t *testing.T) {
+	runTests(t, dsn, func(dbt *DBTest) {
+		var sTimes = []struct {
+			value     string
+			fieldType string
+		}{
+			{"12:34:56", "TIME"},
+			{"-12:34:56", "TIME"},
+			// As described in http://dev.mysql.com/doc/refman/5.6/en/fractional-seconds.html
+			// they *should* work, but only in 5.6+.
+			// { "12:34:56.789", "TIME(3)" },
+			// { "-12:34:56.789", "TIME(3)" },
+		}
+
+		for _, sTime := range sTimes {
+			dbt.db.Exec("DROP TABLE IF EXISTS test")
+			dbt.mustExec("CREATE TABLE test (id INT, time_field " + sTime.fieldType + ")")
+			dbt.mustExec("INSERT INTO test (id, time_field) VALUES(1, '" + sTime.value + "')")
+			rows := dbt.mustQuery("SELECT time_field FROM test WHERE id = ?", 1)
+			if rows.Next() {
+				var oTime string
+				rows.Scan(&oTime)
+				if oTime != sTime.value {
+					dbt.Errorf(`time values differ: got %q, expected %q.`, oTime, sTime.value)
+				}
+			} else {
+				dbt.Error("expecting at least one row.")
+			}
+		}
+	})
+}
+
 // Special cases
 
 func TestRowsClose(t *testing.T) {

+ 20 - 11
packets.go

@@ -915,20 +915,29 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 				}
 
 			case []byte:
-				paramTypes[i+i] = fieldTypeString
-				paramTypes[i+i+1] = 0x00
-
-				if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
-					paramValues = appendLengthEncodedInteger(paramValues,
-						uint64(len(v)),
-					)
-					paramValues = append(paramValues, v...)
-				} else {
-					if err := stmt.writeCommandLongData(i, v); err != nil {
-						return err
+				// Common case (non-nil value) first
+				if v != nil {
+					paramTypes[i+i] = fieldTypeString
+					paramTypes[i+i+1] = 0x00
+
+					if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
+						paramValues = appendLengthEncodedInteger(paramValues,
+							uint64(len(v)),
+						)
+						paramValues = append(paramValues, v...)
+					} else {
+						if err := stmt.writeCommandLongData(i, v); err != nil {
+							return err
+						}
 					}
+					continue
 				}
 
+				// Handle []byte(nil) as a NULL value
+				nullMask |= 1 << uint(i)
+				paramTypes[i+i] = fieldTypeNULL
+				paramTypes[i+i+1] = 0x00
+
 			case string:
 				paramTypes[i+i] = fieldTypeString
 				paramTypes[i+i+1] = 0x00

+ 17 - 19
utils.go

@@ -600,11 +600,14 @@ func stringToInt(b []byte) int {
 	return val
 }
 
+// returns the string read as a bytes slice, wheter the value is NULL,
+// the number of bytes read and an error, in case the string is longer than
+// the input slice
 func readLengthEnodedString(b []byte) ([]byte, bool, int, error) {
 	// Get length
 	num, isNull, n := readLengthEncodedInteger(b)
 	if num < 1 {
-		return nil, isNull, n, nil
+		return b[n:n], isNull, n, nil
 	}
 
 	n += int(num)
@@ -616,6 +619,8 @@ func readLengthEnodedString(b []byte) ([]byte, bool, int, error) {
 	return nil, false, n, io.EOF
 }
 
+// returns the number of bytes skipped and an error, in case the string is
+// longer than the input slice
 func skipLengthEnodedString(b []byte) (int, error) {
 	// Get length
 	num, _, n := readLengthEncodedInteger(b)
@@ -632,42 +637,35 @@ func skipLengthEnodedString(b []byte) (int, error) {
 	return n, io.EOF
 }
 
-func readLengthEncodedInteger(b []byte) (num uint64, isNull bool, n int) {
+// returns the number read, whether the value is NULL and the number of bytes read
+func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
 	switch b[0] {
 
 	// 251: NULL
 	case 0xfb:
-		n = 1
-		isNull = true
-		return
+		return 0, true, 1
 
 	// 252: value of following 2
 	case 0xfc:
-		num = uint64(b[1]) | uint64(b[2])<<8
-		n = 3
-		return
+		return uint64(b[1]) | uint64(b[2])<<8, false, 3
 
 	// 253: value of following 3
 	case 0xfd:
-		num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16
-		n = 4
-		return
+		return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4
 
 	// 254: value of following 8
 	case 0xfe:
-		num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
-			uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
-			uint64(b[7])<<48 | uint64(b[8])<<54
-		n = 9
-		return
+		return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
+				uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
+				uint64(b[7])<<48 | uint64(b[8])<<54,
+			false, 9
 	}
 
 	// 0-250: value of first byte
-	num = uint64(b[0])
-	n = 1
-	return
+	return uint64(b[0]), false, 1
 }
 
+// encodes a uint64 value and appends it to the given bytes slice
 func appendLengthEncodedInteger(b []byte, n uint64) []byte {
 	switch {
 	case n <= 250: