Browse Source

new i/o buffer

Julien Schmidt 12 years ago
parent
commit
ccad956d64
4 changed files with 89 additions and 18 deletions
  1. 80 0
      buffer.go
  2. 2 3
      connection.go
  3. 1 2
      driver.go
  4. 6 13
      packets.go

+ 80 - 0
buffer.go

@@ -0,0 +1,80 @@
+// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
+//
+// Copyright 2013 Julien Schmidt. All rights reserved.
+// http://www.julienschmidt.com
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this file,
+// You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mysql
+
+import (
+	"io"
+)
+
+const (
+	defaultBufSize = 4096
+)
+
+type buffer struct {
+	buf    []byte
+	rd     io.Reader
+	idx    int
+	length int
+}
+
+func newBuffer(rd io.Reader) *buffer {
+	return &buffer{
+		buf: make([]byte, defaultBufSize),
+		rd:  rd,
+	}
+}
+
+// fill reads at least _need_ bytes in the buffer
+// existing data in the buffer gets lost
+func (b *buffer) fill(need int) (err error) {
+	b.idx = 0
+	b.length = 0
+
+	n := 0
+	for err == nil && b.length < need {
+		n, err = b.rd.Read(b.buf[b.length:])
+		b.length += n
+	}
+
+	return
+}
+
+// read len(p) bytes
+func (b *buffer) read(p []byte) (err error) {
+	need := len(p)
+
+	if b.length < need {
+		if b.length > 0 {
+			copy(p[0:b.length], b.buf[b.idx:])
+			need -= b.length
+			p = p[b.length:]
+
+			b.idx = 0
+			b.length = 0
+		}
+
+		if need >= len(b.buf) {
+			var n int
+			has := 0
+			for err == nil && need > has {
+				n, err = b.rd.Read(p[has:])
+				has += n
+			}
+			return
+		}
+
+		err = b.fill(need) // err deferred
+	}
+
+	copy(p, b.buf[b.idx:])
+	b.idx += need
+	b.length -= need
+	return
+}

+ 2 - 3
connection.go

@@ -10,7 +10,6 @@
 package mysql
 package mysql
 
 
 import (
 import (
-	"bufio"
 	"database/sql/driver"
 	"database/sql/driver"
 	"errors"
 	"errors"
 	"net"
 	"net"
@@ -21,7 +20,7 @@ type mysqlConn struct {
 	cfg          *config
 	cfg          *config
 	server       *serverSettings
 	server       *serverSettings
 	netConn      net.Conn
 	netConn      net.Conn
-	bufReader    *bufio.Reader
+	buf          *buffer
 	protocol     uint8
 	protocol     uint8
 	sequence     uint8
 	sequence     uint8
 	affectedRows uint64
 	affectedRows uint64
@@ -96,7 +95,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
 func (mc *mysqlConn) Close() (err error) {
 func (mc *mysqlConn) Close() (err error) {
 	mc.writeCommandPacket(COM_QUIT)
 	mc.writeCommandPacket(COM_QUIT)
 	mc.cfg = nil
 	mc.cfg = nil
-	mc.bufReader = nil
+	mc.buf = nil
 	mc.netConn.Close()
 	mc.netConn.Close()
 	mc.netConn = nil
 	mc.netConn = nil
 	return
 	return

+ 1 - 2
driver.go

@@ -9,7 +9,6 @@
 package mysql
 package mysql
 
 
 import (
 import (
-	"bufio"
 	"database/sql"
 	"database/sql"
 	"database/sql/driver"
 	"database/sql/driver"
 	"net"
 	"net"
@@ -32,7 +31,7 @@ func (d *mysqlDriver) Open(dsn string) (driver.Conn, error) {
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	mc.bufReader = bufio.NewReader(mc.netConn)
+	mc.buf = newBuffer(mc.netConn)
 
 
 	// Reading Handshake Initialization Packet
 	// Reading Handshake Initialization Packet
 	err = mc.readInitPacket()
 	err = mc.readInitPacket()

+ 6 - 13
packets.go

@@ -26,10 +26,10 @@ import (
 func (mc *mysqlConn) readPacket() (data []byte, err error) {
 func (mc *mysqlConn) readPacket() (data []byte, err error) {
 	// Read header
 	// Read header
 	data = make([]byte, 4)
 	data = make([]byte, 4)
-	var n, add int
-	for err == nil && n < 4 {
-		add, err = mc.bufReader.Read(data[n:])
-		n += add
+	err = mc.buf.read(data)
+	if err != nil {
+		errLog.Print(err)
+		return nil, driver.ErrBadConn
 	}
 	}
 
 
 	// Packet Length
 	// Packet Length
@@ -55,15 +55,8 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
 
 
 	// Read rest of packet
 	// Read rest of packet
 	data = make([]byte, pktLen)
 	data = make([]byte, pktLen)
-	n = 0
-	for err == nil && n < int(pktLen) {
-		add, err = mc.bufReader.Read(data[n:])
-		n += add
-	}
-	if err != nil || n < int(pktLen) {
-		if err == nil {
-			err = fmt.Errorf("Length of read data (%d) does not match body length (%d)", n, pktLen)
-		}
+	err = mc.buf.read(data)
+	if err != nil {
 		errLog.Print(err)
 		errLog.Print(err)
 		return nil, driver.ErrBadConn
 		return nil, driver.ErrBadConn
 	}
 	}