浏览代码

Merge pull request #137 from go-sql-driver/refactoring

Refactoring
Julien Schmidt 12 年之前
父节点
当前提交
6a71f6982b
共有 5 个文件被更改,包括 25 次插入22 次删除
  1. 16 13
      buffer.go
  2. 3 4
      connection.go
  3. 3 2
      packets.go
  4. 2 2
      rows.go
  5. 1 1
      statement.go

+ 16 - 13
buffer.go

@@ -33,7 +33,7 @@ func newBuffer(rd io.Reader) *buffer {
 }
 
 // fill reads into the buffer until at least _need_ bytes are in it
-func (b *buffer) fill(need int) (err error) {
+func (b *buffer) fill(need int) error {
 	// move existing data to the beginning
 	if b.length > 0 && b.idx > 0 {
 		copy(b.buf[0:b.length], b.buf[b.idx:])
@@ -51,34 +51,37 @@ func (b *buffer) fill(need int) (err error) {
 
 	b.idx = 0
 
-	var n int
 	for {
-		n, err = b.rd.Read(b.buf[b.length:])
+		n, err := b.rd.Read(b.buf[b.length:])
 		b.length += n
 
-		if b.length < need && err == nil {
-			continue
+		if err == nil {
+			if b.length < need {
+				continue
+			}
+			return nil
 		}
-		return // err
+		if b.length >= need && err == io.EOF {
+			return nil
+		}
+		return err
 	}
-	return
 }
 
 // returns next N bytes from buffer.
 // The returned slice is only guaranteed to be valid until the next read
-func (b *buffer) readNext(need int) (p []byte, err error) {
+func (b *buffer) readNext(need int) ([]byte, error) {
 	if b.length < need {
 		// refill
-		err = b.fill(need) // err deferred
-		if err == io.EOF && b.length >= need {
-			err = nil
+		if err := b.fill(need); err != nil {
+			return nil, err
 		}
 	}
 
-	p = b.buf[b.idx : b.idx+need]
+	offset := b.idx
 	b.idx += need
 	b.length -= need
-	return
+	return b.buf[offset:b.idx], nil
 }
 
 // returns a buffer with the requested size.

+ 3 - 4
connection.go

@@ -199,7 +199,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
 			var resLen int
 			resLen, err = mc.readResultSetHeaderPacket()
 			if err == nil {
-				rows := &mysqlRows{mc, false, nil, false}
+				rows := &mysqlRows{mc, nil, false, false}
 
 				if resLen > 0 {
 					// Columns
@@ -226,12 +226,11 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
 	// Read Result
 	resLen, err := mc.readResultSetHeaderPacket()
 	if err == nil {
-		rows := &mysqlRows{mc, false, nil, false}
+		rows := &mysqlRows{mc, nil, false, false}
 
 		if resLen > 0 {
 			// Columns
-			rows.columns, err = mc.readColumns(resLen)
-			if err != nil {
+			if err := mc.readUntilEOF(); err != nil {
 				return nil, err
 			}
 		}

+ 3 - 2
packets.go

@@ -57,8 +57,9 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
 			return data, nil
 		}
 
-		var buf []byte
-		buf = append(buf, data...)
+		// Make a copy since data becomes invalid with the next read
+		buf := make([]byte, len(data))
+		copy(buf, data)
 
 		// More data
 		data, err = mc.readPacket()

+ 2 - 2
rows.go

@@ -14,15 +14,15 @@ import (
 )
 
 type mysqlField struct {
-	name      string
 	fieldType byte
 	flags     fieldFlag
+	name      string
 }
 
 type mysqlRows struct {
 	mc      *mysqlConn
-	binary  bool
 	columns []mysqlField
+	binary  bool
 	eof     bool
 }
 

+ 1 - 1
statement.go

@@ -84,7 +84,7 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
 		return nil, err
 	}
 
-	rows := &mysqlRows{mc, true, nil, false}
+	rows := &mysqlRows{mc, nil, true, false}
 
 	if resLen > 0 {
 		// Columns