Просмотр исходного кода

Merge pull request #55 from go-sql-driver/zero-copy

zero-copy buffer
Julien Schmidt 12 лет назад
Родитель
Сommit
1caf647408
3 измененных файлов с 39 добавлено и 37 удалено
  1. 35 32
      buffer.go
  2. 1 0
      connection.go
  3. 3 5
      packets.go

+ 35 - 32
buffer.go

@@ -9,13 +9,9 @@
 
 package mysql
 
-import (
-	"io"
-)
+import "io"
 
-const (
-	defaultBufSize = 4096
-)
+const defaultBufSize = 4096
 
 type buffer struct {
 	buf    []byte
@@ -31,11 +27,19 @@ func newBuffer(rd io.Reader) *buffer {
 	}
 }
 
-// fill reads at least _need_ bytes in the buffer
-// existing data in the buffer gets lost
+// fill reads into the buffer until at least _need_ bytes are in it
 func (b *buffer) fill(need int) (err error) {
+	// move existing data to the beginning
+	if b.length > 0 && b.idx > 0 {
+		copy(b.buf[0:b.length], b.buf[b.idx:])
+	}
+
+	// grow buffer if necessary
+	if need > len(b.buf) {
+		b.grow(need)
+	}
+
 	b.idx = 0
-	b.length = 0
 
 	var n int
 	for b.length < need {
@@ -51,34 +55,33 @@ func (b *buffer) fill(need int) (err error) {
 	return
 }
 
-// read len(p) bytes
-func (b *buffer) read(p []byte) (err error) {
-	need := len(p)
-
-	if b.length < need {
-		if b.length > 0 {
-			copy(p[0:b.length], b.buf[b.idx:])
-			need -= b.length
-			p = p[b.length:]
-
-			b.idx = 0
-			b.length = 0
-		}
+// grow the buffer to at least the given size
+// credit for this code snippet goes to Maxim Khitrov
+// https://groups.google.com/forum/#!topic/golang-nuts/ETbw1ECDgRs
+func (b *buffer) grow(size int) {
+	// If append would be too expensive, alloc a new slice
+	if size > 2*cap(b.buf) {
+		newBuf := make([]byte, size)
+		copy(newBuf, b.buf)
+		b.buf = newBuf
+		return
+	}
 
-		if need >= len(b.buf) {
-			var n int
-			has := 0
-			for err == nil && need > has {
-				n, err = b.rd.Read(p[has:])
-				has += n
-			}
-			return
-		}
+	for cap(b.buf) < size {
+		b.buf = append(b.buf[:cap(b.buf)], 0)
+	}
+	b.buf = b.buf[:cap(b.buf)]
+}
 
+// returns next N bytes from buffer.
+// The returned slice is only guaranteed to be valid until the next read
+func (b *buffer) readNext(need int) (p []byte, err error) {
+	if b.length < need {
+		// refill
 		err = b.fill(need) // err deferred
 	}
 
-	copy(p, b.buf[b.idx:])
+	p = b.buf[b.idx : b.idx+need]
 	b.idx += need
 	b.length -= need
 	return

+ 1 - 0
connection.go

@@ -212,6 +212,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
 }
 
 // Gets the value of the given MySQL System Variable
+// The returned byte slice is only valid until the next read
 func (mc *mysqlConn) getSystemVar(name string) (val []byte, err error) {
 	// Send command
 	err = mc.writeCommandPacketStr(comQuery, "SELECT @@"+name)

+ 3 - 5
packets.go

@@ -26,15 +26,14 @@ import (
 // Read packet to buffer 'data'
 func (mc *mysqlConn) readPacket() (data []byte, err error) {
 	// Read packet header
-	data = make([]byte, 4)
-	err = mc.buf.read(data)
+	data, err = mc.buf.readNext(4)
 	if err != nil {
 		errLog.Print(err.Error())
 		return nil, driver.ErrBadConn
 	}
 
 	// Packet Length [24 bit]
-	pktLen := uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16
+	pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)
 
 	if pktLen < 1 {
 		errLog.Print(errMalformPkt.Error())
@@ -52,8 +51,7 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
 	mc.sequence++
 
 	// Read packet body [pktLen bytes]
-	data = make([]byte, pktLen)
-	err = mc.buf.read(data)
+	data, err = mc.buf.readNext(pktLen)
 	if err == nil {
 		if pktLen < maxPacketSize {
 			return data, nil