Ver código fonte

Merge pull request #134 from go-sql-driver/write-buffer

Use the Connection Buffer for Writing
Julien Schmidt 12 anos atrás
pai
commit
c87f84b1ef
7 arquivos alterados com 445 adições e 329 exclusões
  1. 2 1
      CHANGELOG.md
  2. 5 2
      benchmark_test.go
  3. 49 2
      buffer.go
  4. 27 30
      connection.go
  5. 341 272
      packets.go
  6. 3 3
      rows.go
  7. 18 19
      statement.go

+ 2 - 1
CHANGELOG.md

@@ -7,7 +7,8 @@ Changes:
   - Made closing rows and connections error tolerant. This allows for example deferring rows.Close() without checking for errors
   - New Logo
   - Changed the copyright header to include all contributors
-  - Optimized the read buffer
+  - Optimized the buffer for reading
+  - Use the buffer also for writing. This results in zero allocations (by the driver) for most queries
   - Improved the LOAD INFILE documentation
   - The driver struct is now exported to make the driver directly accessible
   - Refactored the driver tests

+ 5 - 2
benchmark_test.go

@@ -68,23 +68,26 @@ func BenchmarkQuery(b *testing.B) {
 
 	stmt := tb.checkStmt(db.Prepare("SELECT val FROM foo WHERE id=?"))
 	defer stmt.Close()
-	b.StartTimer()
 
 	remain := int64(b.N)
 	var wg sync.WaitGroup
 	wg.Add(concurrencyLevel)
 	defer wg.Wait()
+	b.StartTimer()
+
 	for i := 0; i < concurrencyLevel; i++ {
 		go func() {
-			defer wg.Done()
 			for {
 				if atomic.AddInt64(&remain, -1) < 0 {
+					wg.Done()
 					return
 				}
+
 				var got string
 				tb.check(stmt.QueryRow(1).Scan(&got))
 				if got != "one" {
 					b.Errorf("query = %q; want one", got)
+					wg.Done()
 					return
 				}
 			}

+ 49 - 2
buffer.go

@@ -12,7 +12,10 @@ import "io"
 
 const defaultBufSize = 4096
 
-// A read buffer similar to bufio.Reader but zero-copy-ish
+// A buffer which is used for both reading and writing.
+// This is possible since communication on each connection is synchronous.
+// In other words, we can't write and read simultaneously on the same connection.
+// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
 // Also highly optimized for this particular use case.
 type buffer struct {
 	buf    []byte
@@ -37,8 +40,11 @@ func (b *buffer) fill(need int) (err error) {
 	}
 
 	// grow buffer if necessary
+	// TODO: let the buffer shrink again at some point
+	//       Maybe keep the org buf slice and swap back?
 	if need > len(b.buf) {
-		newBuf := make([]byte, need)
+		// Round up to the next multiple of the default size
+		newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
 		copy(newBuf, b.buf)
 		b.buf = newBuf
 	}
@@ -74,3 +80,44 @@ func (b *buffer) readNext(need int) (p []byte, err error) {
 	b.length -= need
 	return
 }
+
+// returns a buffer with the requested size.
+// If possible, a slice from the existing buffer is returned.
+// Otherwise a bigger buffer is made.
+// Only one buffer (total) can be used at a time.
+func (b *buffer) takeBuffer(length int) []byte {
+	if b.length > 0 {
+		return nil
+	}
+
+	// test (cheap) general case first
+	if length <= defaultBufSize || length <= cap(b.buf) {
+		return b.buf[:length]
+	}
+
+	if length < maxPacketSize {
+		b.buf = make([]byte, length)
+		return b.buf
+	}
+	return make([]byte, length)
+}
+
+// shortcut which can be used if the requested buffer is guaranteed to be
+// smaller than defaultBufSize
+// Only one buffer (total) can be used at a time.
+func (b *buffer) takeSmallBuffer(length int) []byte {
+	if b.length == 0 {
+		return b.buf[:length]
+	}
+	return nil
+}
+
+// takeCompleteBuffer returns the complete existing buffer.
+// This can be used if the necessary buffer size is unknown.
+// Only one buffer (total) can be used at a time.
+func (b *buffer) takeCompleteBuffer() []byte {
+	if b.length == 0 {
+		return b.buf
+	}
+	return nil
+}

+ 27 - 30
connection.go

@@ -136,14 +136,14 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 	columnCount, err := stmt.readPrepareResultPacket()
 	if err == nil {
 		if stmt.paramCount > 0 {
-			stmt.params, err = stmt.mc.readColumns(stmt.paramCount)
+			stmt.params, err = mc.readColumns(stmt.paramCount)
 			if err != nil {
 				return nil, err
 			}
 		}
 
 		if columnCount > 0 {
-			err = stmt.mc.readUntilEOF()
+			err = mc.readUntilEOF()
 		}
 	}
 
@@ -171,26 +171,24 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
 }
 
 // Internal function to execute commands
-func (mc *mysqlConn) exec(query string) (err error) {
+func (mc *mysqlConn) exec(query string) error {
 	// Send command
-	err = mc.writeCommandPacketStr(comQuery, query)
+	err := mc.writeCommandPacketStr(comQuery, query)
 	if err != nil {
-		return
+		return err
 	}
 
 	// Read Result
-	var resLen int
-	resLen, err = mc.readResultSetHeaderPacket()
+	resLen, err := mc.readResultSetHeaderPacket()
 	if err == nil && resLen > 0 {
-		err = mc.readUntilEOF()
-		if err != nil {
-			return
+		if err = mc.readUntilEOF(); err != nil {
+			return err
 		}
 
 		err = mc.readUntilEOF()
 	}
 
-	return
+	return err
 }
 
 func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
@@ -211,7 +209,6 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
 				return rows, err
 			}
 		}
-
 		return nil, err
 	}
 
@@ -221,29 +218,29 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
 
 // Gets the value of the given MySQL System Variable
 // The returned byte slice is only valid until the next read
-func (mc *mysqlConn) getSystemVar(name string) (val []byte, err error) {
+func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
 	// Send command
-	err = mc.writeCommandPacketStr(comQuery, "SELECT @@"+name)
+	if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
+		return nil, err
+	}
+
+	// Read Result
+	resLen, err := mc.readResultSetHeaderPacket()
 	if err == nil {
-		// Read Result
-		var resLen int
-		resLen, err = mc.readResultSetHeaderPacket()
-		if err == nil {
-			rows := &mysqlRows{mc, false, nil, false}
+		rows := &mysqlRows{mc, false, nil, false}
 
-			if resLen > 0 {
-				// Columns
-				rows.columns, err = mc.readColumns(resLen)
+		if resLen > 0 {
+			// Columns
+			rows.columns, err = mc.readColumns(resLen)
+			if err != nil {
+				return nil, err
 			}
+		}
 
-			dest := make([]driver.Value, resLen)
-			err = rows.readRow(dest)
-			if err == nil {
-				val = dest[0].([]byte)
-				err = mc.readUntilEOF()
-			}
+		dest := make([]driver.Value, resLen)
+		if err = rows.readRow(dest); err == nil {
+			return dest[0].([]byte), mc.readUntilEOF()
 		}
 	}
-
-	return
+	return nil, err
 }

Diferenças do arquivo suprimidas por serem muito extensas
+ 341 - 272
packets.go


+ 3 - 3
rows.go

@@ -26,12 +26,12 @@ type mysqlRows struct {
 	eof     bool
 }
 
-func (rows *mysqlRows) Columns() (columns []string) {
-	columns = make([]string, len(rows.columns))
+func (rows *mysqlRows) Columns() []string {
+	columns := make([]string, len(rows.columns))
 	for i := range columns {
 		columns[i] = rows.columns[i].name
 	}
-	return
+	return columns
 }
 
 func (rows *mysqlRows) Close() (err error) {

+ 18 - 19
statement.go

@@ -19,14 +19,14 @@ type mysqlStmt struct {
 	params     []mysqlField
 }
 
-func (stmt *mysqlStmt) Close() (err error) {
+func (stmt *mysqlStmt) Close() error {
 	if stmt.mc == nil || stmt.mc.netConn == nil {
 		return errInvalidConn
 	}
 
-	err = stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
+	err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
 	stmt.mc = nil
-	return
+	return err
 }
 
 func (stmt *mysqlStmt) NumInput() int {
@@ -34,33 +34,34 @@ func (stmt *mysqlStmt) NumInput() int {
 }
 
 func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
-	stmt.mc.affectedRows = 0
-	stmt.mc.insertId = 0
-
 	// Send command
 	err := stmt.writeExecutePacket(args)
 	if err != nil {
 		return nil, err
 	}
 
+	mc := stmt.mc
+
+	mc.affectedRows = 0
+	mc.insertId = 0
+
 	// Read Result
-	var resLen int
-	resLen, err = stmt.mc.readResultSetHeaderPacket()
+	resLen, err := mc.readResultSetHeaderPacket()
 	if err == nil {
 		if resLen > 0 {
 			// Columns
-			err = stmt.mc.readUntilEOF()
+			err = mc.readUntilEOF()
 			if err != nil {
 				return nil, err
 			}
 
 			// Rows
-			err = stmt.mc.readUntilEOF()
+			err = mc.readUntilEOF()
 		}
 		if err == nil {
 			return &mysqlResult{
-				affectedRows: int64(stmt.mc.affectedRows),
-				insertId:     int64(stmt.mc.insertId),
+				affectedRows: int64(mc.affectedRows),
+				insertId:     int64(mc.insertId),
 			}, nil
 		}
 	}
@@ -75,21 +76,19 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
 		return nil, err
 	}
 
+	mc := stmt.mc
+
 	// Read Result
-	var resLen int
-	resLen, err = stmt.mc.readResultSetHeaderPacket()
+	resLen, err := mc.readResultSetHeaderPacket()
 	if err != nil {
 		return nil, err
 	}
 
-	rows := &mysqlRows{stmt.mc, true, nil, false}
+	rows := &mysqlRows{mc, true, nil, false}
 
 	if resLen > 0 {
 		// Columns
-		rows.columns, err = stmt.mc.readColumns(resLen)
-		if err != nil {
-			return nil, err
-		}
+		rows.columns, err = mc.readColumns(resLen)
 	}
 
 	return rows, err

Alguns arquivos não foram mostrados porque muitos arquivos mudaram nesse diff