瀏覽代碼

Merge branch 'master' into simple-benchmarks

Arne Hormann 12 年之前
父節點
當前提交
f459d791fe
共有 4 個文件被更改,包括 105 次插入104 次删除
  1. 35 32
      buffer.go
  2. 1 0
      connection.go
  3. 66 67
      driver_test.go
  4. 3 5
      packets.go

+ 35 - 32
buffer.go

@@ -9,13 +9,9 @@
 
 package mysql
 
-import (
-	"io"
-)
+import "io"
 
-const (
-	defaultBufSize = 4096
-)
+const defaultBufSize = 4096
 
 type buffer struct {
 	buf    []byte
@@ -31,11 +27,19 @@ func newBuffer(rd io.Reader) *buffer {
 	}
 }
 
-// fill reads at least _need_ bytes in the buffer
-// existing data in the buffer gets lost
+// fill reads into the buffer until at least _need_ bytes are in it
 func (b *buffer) fill(need int) (err error) {
+	// move existing data to the beginning
+	if b.length > 0 && b.idx > 0 {
+		copy(b.buf[0:b.length], b.buf[b.idx:])
+	}
+
+	// grow buffer if necessary
+	if need > len(b.buf) {
+		b.grow(need)
+	}
+
 	b.idx = 0
-	b.length = 0
 
 	var n int
 	for b.length < need {
@@ -51,34 +55,33 @@ func (b *buffer) fill(need int) (err error) {
 	return
 }
 
-// read len(p) bytes
-func (b *buffer) read(p []byte) (err error) {
-	need := len(p)
-
-	if b.length < need {
-		if b.length > 0 {
-			copy(p[0:b.length], b.buf[b.idx:])
-			need -= b.length
-			p = p[b.length:]
-
-			b.idx = 0
-			b.length = 0
-		}
+// grow the buffer to at least the given size
+// credit for this code snippet goes to Maxim Khitrov
+// https://groups.google.com/forum/#!topic/golang-nuts/ETbw1ECDgRs
+func (b *buffer) grow(size int) {
+	// If append would be too expensive, alloc a new slice
+	if size > 2*cap(b.buf) {
+		newBuf := make([]byte, size)
+		copy(newBuf, b.buf)
+		b.buf = newBuf
+		return
+	}
 
-		if need >= len(b.buf) {
-			var n int
-			has := 0
-			for err == nil && need > has {
-				n, err = b.rd.Read(p[has:])
-				has += n
-			}
-			return
-		}
+	for cap(b.buf) < size {
+		b.buf = append(b.buf[:cap(b.buf)], 0)
+	}
+	b.buf = b.buf[:cap(b.buf)]
+}
 
+// returns next N bytes from buffer.
+// The returned slice is only guaranteed to be valid until the next read
+func (b *buffer) readNext(need int) (p []byte, err error) {
+	if b.length < need {
+		// refill
 		err = b.fill(need) // err deferred
 	}
 
-	copy(p, b.buf[b.idx:])
+	p = b.buf[b.idx : b.idx+need]
 	b.idx += need
 	b.length -= need
 	return

+ 1 - 0
connection.go

@@ -212,6 +212,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
 }
 
 // Gets the value of the given MySQL System Variable
+// The returned byte slice is only valid until the next read
 func (mc *mysqlConn) getSystemVar(name string) (val []byte, err error) {
 	// Send command
 	err = mc.writeCommandPacketStr(comQuery, "SELECT @@"+name)

+ 66 - 67
driver_test.go

@@ -42,6 +42,55 @@ func init() {
 	}
 }
 
+type DBTest struct {
+	*testing.T
+	db *sql.DB
+}
+
+func runTests(t *testing.T, name, dsn string, tests ...func(dbt *DBTest)) {
+	if !available {
+		t.Logf("MySQL-Server not running on %s. Skipping %s", netAddr, name)
+		return
+	}
+
+	db, err := sql.Open("mysql", dsn)
+	if err != nil {
+		t.Fatalf("Error connecting: %v", err)
+	}
+	defer db.Close()
+
+	db.Exec("DROP TABLE IF EXISTS test")
+
+	dbt := &DBTest{t, db}
+	for _, test := range tests {
+		test(dbt)
+		dbt.db.Exec("DROP TABLE IF EXISTS test")
+	}
+}
+
+func (dbt *DBTest) fail(method, query string, err error) {
+	if len(query) > 300 {
+		query = "[query too large to print]"
+	}
+	dbt.Fatalf("Error on %s %s: %v", method, query, err)
+}
+
+func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) {
+	res, err := dbt.db.Exec(query, args...)
+	if err != nil {
+		dbt.fail("Exec", query, err)
+	}
+	return res
+}
+
+func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) {
+	rows, err := dbt.db.Query(query, args...)
+	if err != nil {
+		dbt.fail("Query", query, err)
+	}
+	return rows
+}
+
 func TestCharset(t *testing.T) {
 	mustSetCharset := func(charsetParam, expected string) {
 		db, err := sql.Open("mysql", strings.Replace(dsn, charset, charsetParam, 1))
@@ -101,57 +150,8 @@ func TestFailingCharset(t *testing.T) {
 	}
 }
 
-type DBTest struct {
-	*testing.T
-	db *sql.DB
-}
-
-func runTests(t *testing.T, name string, tests ...func(dbt *DBTest)) {
-	if !available {
-		t.Logf("MySQL-Server not running on %s. Skipping %s", netAddr, name)
-		return
-	}
-
-	db, err := sql.Open("mysql", dsn)
-	if err != nil {
-		t.Fatalf("Error connecting: %v", err)
-	}
-	defer db.Close()
-
-	db.Exec("DROP TABLE IF EXISTS test")
-
-	dbt := &DBTest{t, db}
-	for _, test := range tests {
-		test(dbt)
-		dbt.db.Exec("DROP TABLE IF EXISTS test")
-	}
-}
-
-func (dbt *DBTest) fail(method, query string, err error) {
-	if len(query) > 300 {
-		query = "[query too large to print]"
-	}
-	dbt.Fatalf("Error on %s %s: %v", method, query, err)
-}
-
-func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) {
-	res, err := dbt.db.Exec(query, args...)
-	if err != nil {
-		dbt.fail("Exec", query, err)
-	}
-	return res
-}
-
-func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) {
-	rows, err := dbt.db.Query(query, args...)
-	if err != nil {
-		dbt.fail("Query", query, err)
-	}
-	return rows
-}
-
 func TestRawBytesResultExceedsBuffer(t *testing.T) {
-	runTests(t, "TestRawBytesResultExceedsBuffer", func(dbt *DBTest) {
+	runTests(t, "TestRawBytesResultExceedsBuffer", dsn, func(dbt *DBTest) {
 		// defaultBufSize from buffer.go
 		expected := strings.Repeat("abc", defaultBufSize)
 		rows := dbt.mustQuery("SELECT '" + expected + "'")
@@ -168,7 +168,7 @@ func TestRawBytesResultExceedsBuffer(t *testing.T) {
 }
 
 func TestCRUD(t *testing.T) {
-	runTests(t, "TestCRUD", func(dbt *DBTest) {
+	runTests(t, "TestCRUD", dsn, func(dbt *DBTest) {
 		// Create Table
 		dbt.mustExec("CREATE TABLE test (value BOOL)")
 
@@ -260,7 +260,7 @@ func TestCRUD(t *testing.T) {
 }
 
 func TestInt(t *testing.T) {
-	runTests(t, "TestInt", func(dbt *DBTest) {
+	runTests(t, "TestInt", dsn, func(dbt *DBTest) {
 		types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"}
 		in := int64(42)
 		var out int64
@@ -307,7 +307,7 @@ func TestInt(t *testing.T) {
 }
 
 func TestFloat(t *testing.T) {
-	runTests(t, "TestFloat", func(dbt *DBTest) {
+	runTests(t, "TestFloat", dsn, func(dbt *DBTest) {
 		types := [2]string{"FLOAT", "DOUBLE"}
 		in := float32(42.23)
 		var out float32
@@ -330,7 +330,7 @@ func TestFloat(t *testing.T) {
 }
 
 func TestString(t *testing.T) {
-	runTests(t, "TestString", func(dbt *DBTest) {
+	runTests(t, "TestString", dsn, func(dbt *DBTest) {
 		types := [6]string{"CHAR(255)", "VARCHAR(255)", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT"}
 		in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах  น่าฟังเอย"
 		var out string
@@ -470,18 +470,15 @@ func TestDateTime(t *testing.T) {
 		}
 	}
 
-	oldDsn := dsn
-	usedDsn := oldDsn + "&sql_mode=ALLOW_INVALID_DATES"
+	timeDsn := dsn + "&sql_mode=ALLOW_INVALID_DATES"
 	for _, v := range setups {
 		s = v
-		dsn = usedDsn + s.dsnSuffix
-		runTests(t, "TestDateTime", testTime)
+		runTests(t, "TestDateTime", timeDsn+s.dsnSuffix, testTime)
 	}
-	dsn = oldDsn
 }
 
 func TestNULL(t *testing.T) {
-	runTests(t, "TestNULL", func(dbt *DBTest) {
+	runTests(t, "TestNULL", dsn, func(dbt *DBTest) {
 		nullStmt, err := dbt.db.Prepare("SELECT NULL")
 		if err != nil {
 			dbt.Fatal(err)
@@ -597,7 +594,7 @@ func TestNULL(t *testing.T) {
 }
 
 func TestLongData(t *testing.T) {
-	runTests(t, "TestLongData", func(dbt *DBTest) {
+	runTests(t, "TestLongData", dsn, func(dbt *DBTest) {
 		var maxAllowedPacketSize int
 		err := dbt.db.QueryRow("select @@max_allowed_packet").Scan(&maxAllowedPacketSize)
 		if err != nil {
@@ -654,7 +651,7 @@ func TestLongData(t *testing.T) {
 }
 
 func TestLoadData(t *testing.T) {
-	runTests(t, "TestLoadData", func(dbt *DBTest) {
+	runTests(t, "TestLoadData", dsn, func(dbt *DBTest) {
 		verifyLoadDataResult := func() {
 			rows, err := dbt.db.Query("SELECT * FROM test")
 			if err != nil {
@@ -741,7 +738,9 @@ func TestLoadData(t *testing.T) {
 }
 
 func TestStrict(t *testing.T) {
-	runTests(t, "TestStrict", func(dbt *DBTest) {
+	// ALLOW_INVALID_DATES to get rid of stricter modes - we want to test for warnings, not errors
+	relaxedDsn := dsn + "&sql_mode=ALLOW_INVALID_DATES"
+	runTests(t, "TestStrict", relaxedDsn, func(dbt *DBTest) {
 		dbt.mustExec("CREATE TABLE test (a TINYINT NOT NULL, b CHAR(4))")
 
 		var queries = [...]struct {
@@ -808,7 +807,7 @@ func TestStrict(t *testing.T) {
 // Special cases
 
 func TestRowsClose(t *testing.T) {
-	runTests(t, "TestRowsClose", func(dbt *DBTest) {
+	runTests(t, "TestRowsClose", dsn, func(dbt *DBTest) {
 		rows, err := dbt.db.Query("SELECT 1")
 		if err != nil {
 			dbt.Fatal(err)
@@ -833,7 +832,7 @@ func TestRowsClose(t *testing.T) {
 // dangling statements
 // http://code.google.com/p/go/issues/detail?id=3865
 func TestCloseStmtBeforeRows(t *testing.T) {
-	runTests(t, "TestCloseStmtBeforeRows", func(dbt *DBTest) {
+	runTests(t, "TestCloseStmtBeforeRows", dsn, func(dbt *DBTest) {
 		stmt, err := dbt.db.Prepare("SELECT 1")
 		if err != nil {
 			dbt.Fatal(err)
@@ -874,7 +873,7 @@ func TestCloseStmtBeforeRows(t *testing.T) {
 // It is valid to have multiple Rows for the same Stmt
 // http://code.google.com/p/go/issues/detail?id=3734
 func TestStmtMultiRows(t *testing.T) {
-	runTests(t, "TestStmtMultiRows", func(dbt *DBTest) {
+	runTests(t, "TestStmtMultiRows", dsn, func(dbt *DBTest) {
 		stmt, err := dbt.db.Prepare("SELECT 1 UNION SELECT 0")
 		if err != nil {
 			dbt.Fatal(err)
@@ -989,7 +988,7 @@ func TestConcurrent(t *testing.T) {
 		t.Log("CONCURRENT env var not set. Skipping TestConcurrent")
 		return
 	}
-	runTests(t, "TestConcurrent", func(dbt *DBTest) {
+	runTests(t, "TestConcurrent", dsn, func(dbt *DBTest) {
 		var max int
 		err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max)
 		if err != nil {

+ 3 - 5
packets.go

@@ -26,15 +26,14 @@ import (
 // Read packet to buffer 'data'
 func (mc *mysqlConn) readPacket() (data []byte, err error) {
 	// Read packet header
-	data = make([]byte, 4)
-	err = mc.buf.read(data)
+	data, err = mc.buf.readNext(4)
 	if err != nil {
 		errLog.Print(err.Error())
 		return nil, driver.ErrBadConn
 	}
 
 	// Packet Length [24 bit]
-	pktLen := uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16
+	pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)
 
 	if pktLen < 1 {
 		errLog.Print(errMalformPkt.Error())
@@ -52,8 +51,7 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
 	mc.sequence++
 
 	// Read packet body [pktLen bytes]
-	data = make([]byte, pktLen)
-	err = mc.buf.read(data)
+	data, err = mc.buf.readNext(pktLen)
 	if err == nil {
 		if pktLen < maxPacketSize {
 			return data, nil