Преглед на файлове

small code optimizations

Julien Schmidt преди 12 години
родител
ревизия
74a64527d2
променени са 6 файла, в които са добавени 73 реда и са изтрити 73 реда
  1. 6 1
      buffer.go
  2. 3 3
      connection.go
  3. 1 1
      const.go
  4. 60 65
      packets.go
  5. 1 1
      rows.go
  6. 2 2
      statement.go

+ 6 - 1
buffer.go

@@ -38,9 +38,14 @@ func (b *buffer) fill(need int) (err error) {
 	b.length = 0
 
 	n := 0
-	for err == nil && b.length < need {
+	for b.length < need {
 		n, err = b.rd.Read(b.buf[b.length:])
 		b.length += n
+
+		if err == nil {
+			continue
+		}
+		return // err
 	}
 
 	return

+ 3 - 3
connection.go

@@ -113,7 +113,7 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 		}
 
 		if columnCount > 0 {
-			_, err = stmt.mc.readUntilEOF()
+			err = stmt.mc.readUntilEOF()
 		}
 	}
 
@@ -159,12 +159,12 @@ func (mc *mysqlConn) exec(query string) (err error) {
 	var resLen int
 	resLen, err = mc.readResultSetHeaderPacket()
 	if err == nil && resLen > 0 {
-		_, err = mc.readUntilEOF()
+		err = mc.readUntilEOF()
 		if err != nil {
 			return
 		}
 
-		_, err = mc.readUntilEOF()
+		err = mc.readUntilEOF()
 	}
 
 	return

+ 1 - 1
const.go

@@ -10,7 +10,7 @@
 package mysql
 
 const (
-	MIN_PROTOCOL_VERSION = 10
+	MIN_PROTOCOL_VERSION byte = 10
 	//MAX_PACKET_SIZE      = 1<<24 - 1
 	TIME_FORMAT = "2006-01-02 15:04:05"
 )

+ 60 - 65
packets.go

@@ -29,7 +29,7 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
 	data = make([]byte, 4)
 	err = mc.buf.read(data)
 	if err != nil {
-		errLog.Print(err)
+		errLog.Print(err.Error())
 		return nil, driver.ErrBadConn
 	}
 
@@ -40,7 +40,8 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
 	pktLen |= uint32(data[2]) << 16
 
 	if pktLen == 0 {
-		return nil, err
+		errLog.Print(errMalformPkt.Error())
+		return nil, driver.ErrBadConn
 	}
 
 	// Check Packet Sync
@@ -59,7 +60,7 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
 	if err == nil {
 		return data, nil
 	}
-	errLog.Print(err)
+	errLog.Print(err.Error())
 	return nil, driver.ErrBadConn
 }
 
@@ -74,9 +75,9 @@ func (mc *mysqlConn) writePacket(data []byte) error {
 	}
 
 	if err == nil { // n != len(data)
-		errLog.Print(errMalformPkt)
+		errLog.Print(errMalformPkt.Error())
 	} else {
-		errLog.Print(err)
+		errLog.Print(err.Error())
 	}
 	return driver.ErrBadConn
 }
@@ -103,7 +104,7 @@ func (mc *mysqlConn) readInitPacket() (err error) {
 
 	// server version [null terminated string]
 	// connection id [4 bytes]
-	pos := 1 + (bytes.IndexByte(data[1:], 0x00) + 1) + 4
+	pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
 
 	// first part of scramble buffer [8 bytes]
 	mc.scrambleBuff = data[pos : pos+8]
@@ -287,45 +288,43 @@ func (mc *mysqlConn) writeCommandPacketUint32(command commandType, arg uint32) e
 // Returns error if Packet is not an 'Result OK'-Packet
 func (mc *mysqlConn) readResultOK() error {
 	data, err := mc.readPacket()
-	if err != nil {
-		return err
-	}
-
-	switch data[0] {
-	// OK
-	case 0:
-		mc.handleOkPacket(data)
-		return nil
-	// EOF, someone is using old_passwords
-	case 254:
-		return errOldPassword
+	if err == nil {
+		switch data[0] {
+		// OK
+		case 0:
+			mc.handleOkPacket(data)
+			return nil
+		// EOF, someone is using old_passwords
+		case 254:
+			return errOldPassword
+		}
+		// ERROR
+		return mc.handleErrorPacket(data)
 	}
-	// ERROR
-	return mc.handleErrorPacket(data)
+	return err
 }
 
 // Result Set Header Packet
 // http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::Resultset
 func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
 	data, err := mc.readPacket()
-	if err != nil {
-		return 0, err
-	}
+	if err == nil {
+		if data[0] == 0 {
+			mc.handleOkPacket(data)
+			return 0, nil
+		} else if data[0] == 255 {
+			return 0, mc.handleErrorPacket(data)
+		}
 
-	if data[0] == 0 {
-		mc.handleOkPacket(data)
-		return 0, nil
-	} else if data[0] == 255 {
-		return 0, mc.handleErrorPacket(data)
-	}
+		// column count
+		num, _, n := readLengthEncodedInteger(data)
+		if n-len(data) == 0 {
+			return int(num), nil
+		}
 
-	// column count
-	num, _, n := readLengthEncodedInteger(data)
-	if n-len(data) == 0 {
-		return int(num), nil
+		return 0, errMalformPkt
 	}
-
-	return 0, errMalformPkt
+	return 0, err
 }
 
 // Error Packet
@@ -487,18 +486,17 @@ func (rows *mysqlRows) readRow(dest []driver.Value) (err error) {
 }
 
 // Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
-func (mc *mysqlConn) readUntilEOF() (count uint64, err error) {
+func (mc *mysqlConn) readUntilEOF() (err error) {
 	var data []byte
 
 	for {
 		data, err = mc.readPacket()
 
-		// Err or EOF Packet
-		if err != nil || (data[0] == 254 && len(data) == 5) {
-			return
+		// No Err and no EOF Packet
+		if err == nil && (data[0] != 254 || len(data) != 5) {
+			continue
 		}
-
-		count++
+		return
 	}
 	return
 }
@@ -511,35 +509,32 @@ func (mc *mysqlConn) readUntilEOF() (count uint64, err error) {
 // http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-prepare-response
 func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error) {
 	data, err := stmt.mc.readPacket()
-	if err != nil {
-		return
-	}
-
-	// Position
-	pos := 0
-
-	// packet marker [1 byte]
-	if data[pos] != 0 { // not OK (0) ?
-		err = stmt.mc.handleErrorPacket(data)
-		return
-	}
-	pos++
+	if err == nil {
+		// Position
+		pos := 0
 
-	// statement id [4 bytes]
-	stmt.id = binary.LittleEndian.Uint32(data[pos : pos+4])
-	pos += 4
+		// packet marker [1 byte]
+		if data[pos] != 0 { // not OK (0) ?
+			err = stmt.mc.handleErrorPacket(data)
+			return
+		}
+		pos++
 
-	// Column count [16 bit uint]
-	columnCount = binary.LittleEndian.Uint16(data[pos : pos+2])
-	pos += 2
+		// statement id [4 bytes]
+		stmt.id = binary.LittleEndian.Uint32(data[pos : pos+4])
+		pos += 4
 
-	// Param count [16 bit uint]
-	stmt.paramCount = int(binary.LittleEndian.Uint16(data[pos : pos+2]))
-	pos += 2
+		// Column count [16 bit uint]
+		columnCount = binary.LittleEndian.Uint16(data[pos : pos+2])
+		pos += 2
 
-	// Warning count [16 bit uint]
-	// bytesToUint16(data[pos : pos+2])
+		// Param count [16 bit uint]
+		stmt.paramCount = int(binary.LittleEndian.Uint16(data[pos : pos+2]))
+		pos += 2
 
+		// Warning count [16 bit uint]
+		// bytesToUint16(data[pos : pos+2])
+	}
 	return
 }
 

+ 1 - 1
rows.go

@@ -47,7 +47,7 @@ func (rows *mysqlRows) Close() (err error) {
 			return errors.New("Invalid Connection")
 		}
 
-		_, err = rows.mc.readUntilEOF()
+		err = rows.mc.readUntilEOF()
 	}
 
 	return

+ 2 - 2
statement.go

@@ -46,13 +46,13 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
 	if err == nil {
 		if resLen > 0 {
 			// Columns
-			_, err = stmt.mc.readUntilEOF()
+			err = stmt.mc.readUntilEOF()
 			if err != nil {
 				return nil, err
 			}
 
 			// Rows
-			_, err = stmt.mc.readUntilEOF()
+			err = stmt.mc.readUntilEOF()
 		}
 		if err == nil {
 			return &mysqlResult{