Browse Source

return byte slice directly from buffer

Updates #52
Julien Schmidt 12 years ago
parent
commit
96a4f13936
2 changed files with 30 additions and 22 deletions
  1. 27 17
      buffer.go
  2. 3 5
      packets.go

+ 27 - 17
buffer.go

@@ -51,35 +51,45 @@ func (b *buffer) fill(need int) (err error) {
 	return
 }
 
-// read len(p) bytes
-func (b *buffer) read(p []byte) (err error) {
-	need := len(p)
+// returns next N bytes from buffer.
+// The returned slice is only guaranteed to be valid until the next read
+func (b *buffer) readNext(need int) (p []byte, err error) {
+	// return slice from buffer if possible
+	if b.length >= need {
+		p = b.buf[b.idx : b.idx+need]
+		b.idx += need
+		b.length -= need
+		return
 
-	if b.length < need {
+	} else {
+		p = make([]byte, need)
+		has := 0
+
+		// copy data that is already in the buffer
 		if b.length > 0 {
 			copy(p[0:b.length], b.buf[b.idx:])
-			need -= b.length
-			p = p[b.length:]
-
+			has = b.length
+			need -= has
 			b.idx = 0
 			b.length = 0
 		}
 
-		if need >= len(b.buf) {
+		// does the data fit into the buffer?
+		if need < len(b.buf) {
+			err = b.fill(need) // err deferred
+			copy(p[has:has+need], b.buf[b.idx:])
+			b.idx += need
+			b.length -= need
+			return
+
+		} else {
 			var n int
-			has := 0
-			for err == nil && need > has {
+			for err == nil && need > 0 {
 				n, err = b.rd.Read(p[has:])
 				has += n
+				need -= n
 			}
-			return
 		}
-
-		err = b.fill(need) // err deferred
 	}
-
-	copy(p, b.buf[b.idx:])
-	b.idx += need
-	b.length -= need
 	return
 }

+ 3 - 5
packets.go

@@ -26,15 +26,14 @@ import (
 // Read packet to buffer 'data'
 func (mc *mysqlConn) readPacket() (data []byte, err error) {
 	// Read packet header
-	data = make([]byte, 4)
-	err = mc.buf.read(data)
+	data, err = mc.buf.readNext(4)
 	if err != nil {
 		errLog.Print(err.Error())
 		return nil, driver.ErrBadConn
 	}
 
 	// Packet Length [24 bit]
-	pktLen := uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16
+	pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16)
 
 	if pktLen < 1 {
 		errLog.Print(errMalformPkt.Error())
@@ -52,8 +51,7 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
 	mc.sequence++
 
 	// Read packet body [pktLen bytes]
-	data = make([]byte, pktLen)
-	err = mc.buf.read(data)
+	data, err = mc.buf.readNext(pktLen)
 	if err == nil {
 		if pktLen < maxPacketSize {
 			return data, nil