Browse Source

various refactoring

Julien Schmidt 12 năm trước cách đây
mục cha
commit
33d6df2bf4
1 tập tin đã thay đổi với 65 bổ sung92 xóa
  1. 65 92
      packets.go

+ 65 - 92
packets.go

@@ -23,9 +23,9 @@ import (
 // http://dev.mysql.com/doc/internals/en/client-server-protocol.html
 
 // Read packet to buffer 'data'
-func (mc *mysqlConn) readPacket() (data []byte, err error) {
+func (mc *mysqlConn) readPacket() ([]byte, error) {
 	// Read packet header
-	data, err = mc.buf.readNext(4)
+	data, err := mc.buf.readNext(4)
 	if err != nil {
 		errLog.Print(err.Error())
 		mc.Close()
@@ -97,7 +97,7 @@ func (mc *mysqlConn) writePacket(data []byte) error {
 	return mc.splitPacket(data)
 }
 
-func (mc *mysqlConn) splitPacket(data []byte) (err error) {
+func (mc *mysqlConn) splitPacket(data []byte) error {
 	pktLen := len(data) - 4
 
 	if pktLen > mc.maxPacketAllowed {
@@ -141,10 +141,10 @@ func (mc *mysqlConn) splitPacket(data []byte) (err error) {
 
 // Handshake Initialization Packet
 // http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::Handshake
-func (mc *mysqlConn) readInitPacket() (cipher []byte, err error) {
+func (mc *mysqlConn) readInitPacket() ([]byte, error) {
 	data, err := mc.readPacket()
 	if err != nil {
-		return
+		return nil, err
 	}
 
 	if data[0] == iERR {
@@ -153,11 +153,11 @@ func (mc *mysqlConn) readInitPacket() (cipher []byte, err error) {
 
 	// protocol version [1 byte]
 	if data[0] < minProtocolVersion {
-		err = fmt.Errorf(
+		return nil, fmt.Errorf(
 			"Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required",
 			data[0],
-			minProtocolVersion)
-		return
+			minProtocolVersion,
+		)
 	}
 
 	// server version [null terminated string]
@@ -165,7 +165,7 @@ func (mc *mysqlConn) readInitPacket() (cipher []byte, err error) {
 	pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
 
 	// first part of the password cipher [8 bytes]
-	cipher = data[pos : pos+8]
+	cipher := data[pos : pos+8]
 
 	// (filler) always 0x00 [1 byte]
 	pos += 8 + 1
@@ -205,7 +205,7 @@ func (mc *mysqlConn) readInitPacket() (cipher []byte, err error) {
 		//return errMalformPkt
 	}
 
-	return
+	return cipher, nil
 }
 
 // Client Authentication Packet
@@ -497,7 +497,7 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
 
 	// SQL State [optional: # + 5bytes string]
 	//sqlstate := string(data[pos : pos+6])
-	if data[pos] == 0x23 {
+	if data[3] == 0x23 {
 		pos = 9
 	}
 
@@ -510,7 +510,7 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
 
 // Ok Packet
 // http://dev.mysql.com/doc/internals/en/overview.html#packet-OK_Packet
-func (mc *mysqlConn) handleOkPacket(data []byte) (err error) {
+func (mc *mysqlConn) handleOkPacket(data []byte) error {
 	var n, m int
 
 	// 0x00 [1 byte]
@@ -525,72 +525,66 @@ func (mc *mysqlConn) handleOkPacket(data []byte) (err error) {
 
 	// warning count [2 bytes]
 	if !mc.strict {
-		return
+		return nil
 	} else {
 		pos := 1 + n + m + 2
 		if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
-			err = mc.getWarnings()
+			return mc.getWarnings()
 		}
+		return nil
 	}
-
-	// message [until end of packet]
-	return
 }
 
 // Read Packets as Field Packets until EOF-Packet or an Error appears
 // http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-Protocol::ColumnDefinition41
-func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
-	var data []byte
-	var i, pos, n int
-	var name []byte
-
-	columns = make([]mysqlField, count)
+func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
+	columns := make([]mysqlField, count)
 
-	for {
-		data, err = mc.readPacket()
+	for i := 0; ; i++ {
+		data, err := mc.readPacket()
 		if err != nil {
-			return
+			return nil, err
 		}
 
 		// EOF Packet
 		if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
-			if i != count {
-				err = fmt.Errorf("ColumnsCount mismatch n:%d len:%d", count, len(columns))
+			if i == count {
+				return columns, nil
 			}
-			return
+			return nil, fmt.Errorf("ColumnsCount mismatch n:%d len:%d", count, len(columns))
 		}
 
 		// Catalog
-		pos, err = skipLengthEnodedString(data)
+		pos, err := skipLengthEnodedString(data)
 		if err != nil {
-			return
+			return nil, err
 		}
 
 		// Database [len coded string]
-		n, err = skipLengthEnodedString(data[pos:])
+		n, err := skipLengthEnodedString(data[pos:])
 		if err != nil {
-			return
+			return nil, err
 		}
 		pos += n
 
 		// Table [len coded string]
 		n, err = skipLengthEnodedString(data[pos:])
 		if err != nil {
-			return
+			return nil, err
 		}
 		pos += n
 
 		// Original table [len coded string]
 		n, err = skipLengthEnodedString(data[pos:])
 		if err != nil {
-			return
+			return nil, err
 		}
 		pos += n
 
 		// Name [len coded string]
-		name, _, n, err = readLengthEnodedString(data[pos:])
+		name, _, n, err := readLengthEnodedString(data[pos:])
 		if err != nil {
-			return
+			return nil, err
 		}
 		columns[i].name = string(name)
 		pos += n
@@ -598,7 +592,7 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
 		// Original name [len coded string]
 		n, err = skipLengthEnodedString(data[pos:])
 		if err != nil {
-			return
+			return nil, err
 		}
 
 		// Filler [1 byte]
@@ -621,8 +615,6 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
 		//if pos < len(data) {
 		//	defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
 		//}
-
-		i++
 	}
 }
 
@@ -656,7 +648,10 @@ func (rows *mysqlRows) readRow(dest []driver.Value) error {
 					switch rows.columns[i].fieldType {
 					case fieldTypeTimestamp, fieldTypeDateTime,
 						fieldTypeDate, fieldTypeNewDate:
-						dest[i], err = parseDateTime(string(dest[i].([]byte)), rows.mc.cfg.loc)
+						dest[i], err = parseDateTime(
+							string(dest[i].([]byte)),
+							rows.mc.cfg.loc,
+						)
 						if err == nil {
 							continue
 						}
@@ -695,61 +690,52 @@ func (mc *mysqlConn) readUntilEOF() error {
 
 // Prepare Result Packets
 // http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-prepare-response
-func (stmt *mysqlStmt) readPrepareResultPacket() (columnCount uint16, err error) {
+func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
 	data, err := stmt.mc.readPacket()
 	if err == nil {
-		// Position
-		pos := 0
-
 		// packet indicator [1 byte]
-		if data[pos] != iOK {
-			err = stmt.mc.handleErrorPacket(data)
-			return
+		if data[0] != iOK {
+			return 0, stmt.mc.handleErrorPacket(data)
 		}
-		pos++
 
 		// statement id [4 bytes]
-		stmt.id = binary.LittleEndian.Uint32(data[pos : pos+4])
-		pos += 4
+		stmt.id = binary.LittleEndian.Uint32(data[1 : 1+4])
 
 		// Column count [16 bit uint]
-		columnCount = binary.LittleEndian.Uint16(data[pos : pos+2])
-		pos += 2
+		columnCount := binary.LittleEndian.Uint16(data[1+4 : 1+4+2])
 
 		// Param count [16 bit uint]
-		stmt.paramCount = int(binary.LittleEndian.Uint16(data[pos : pos+2]))
-		pos += 2
+		stmt.paramCount = int(binary.LittleEndian.Uint16(data[1+4+2 : 1+4+2+2]))
 
 		// Reserved [8 bit]
-		pos++
 
 		// Warning count [16 bit uint]
 		if !stmt.mc.strict {
-			return
+			return columnCount, nil
 		} else {
 			// Check for warnings count > 0, only available in MySQL > 4.1
-			if len(data) >= 12 && binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
-				err = stmt.mc.getWarnings()
+			if len(data) >= 12 && binary.LittleEndian.Uint16(data[1+4+2+2+1:1+4+2+2+1+2]) > 0 {
+				return columnCount, stmt.mc.getWarnings()
 			}
+			return columnCount, nil
 		}
 	}
-	return
+	return 0, err
 }
 
 // http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-send-long-data
 func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
 	maxLen := stmt.mc.maxPacketAllowed - 1
 	pktLen := maxLen
-	argLen := len(arg)
 
 	// Can not use the write buffer since
 	// a) the buffer is too small
 	// b) it is in use
-	data := make([]byte, 4+1+4+2+argLen)
+	data := make([]byte, 4+1+4+2+len(arg))
 
 	copy(data[4+1+4+2:], arg)
 
-	for argLen > 0 {
+	for argLen := len(arg); argLen > 0; argLen -= pktLen - (1 + 4 + 2) {
 		if 1+4+2+argLen < maxLen {
 			pktLen = 1 + 4 + 2 + argLen
 		}
@@ -776,7 +762,6 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
 		// Send CMD packet
 		err := stmt.mc.writePacket(data[:4+pktLen])
 		if err == nil {
-			argLen -= pktLen - (1 + 4 + 2)
 			data = data[pktLen-(1+4+2):]
 			continue
 		}
@@ -796,7 +781,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 		return fmt.Errorf(
 			"Arguments count mismatch (Got: %d Has: %d)",
 			len(args),
-			stmt.paramCount)
+			stmt.paramCount,
+		)
 	}
 
 	// Reset packet-sequence
@@ -991,10 +977,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 }
 
 // http://dev.mysql.com/doc/internals/en/prepared-statements.html#packet-ProtocolBinary::ResultsetRow
-func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
+func (rows *mysqlRows) readBinaryRow(dest []driver.Value) error {
 	data, err := rows.mc.readPacket()
 	if err != nil {
-		return
+		return err
 	}
 
 	// packet indicator [1 byte]
@@ -1010,22 +996,16 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 
 	// NULL-bitmap,  [(column-count + 7 + 2) / 8 bytes]
 	pos := 1 + (len(dest)+7+2)>>3
-	nullBitMap := data[1:pos]
-
-	// values [rest]
-	var n int
-	var unsigned bool
+	nullMask := data[1:pos]
 
 	for i := range dest {
 		// Field is NULL
 		// (byte >> bit-pos) % 2 == 1
-		if ((nullBitMap[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 {
+		if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 {
 			dest[i] = nil
 			continue
 		}
 
-		unsigned = rows.columns[i].flags&flagUnsigned != 0
-
 		// Convert to byte-coded string
 		switch rows.columns[i].fieldType {
 		case fieldTypeNULL:
@@ -1034,7 +1014,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 
 		// Numeric Types
 		case fieldTypeTiny:
-			if unsigned {
+			if rows.columns[i].flags&flagUnsigned != 0 {
 				dest[i] = int64(data[pos])
 			} else {
 				dest[i] = int64(int8(data[pos]))
@@ -1043,7 +1023,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 			continue
 
 		case fieldTypeShort, fieldTypeYear:
-			if unsigned {
+			if rows.columns[i].flags&flagUnsigned != 0 {
 				dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
 			} else {
 				dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
@@ -1052,7 +1032,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 			continue
 
 		case fieldTypeInt24, fieldTypeLong:
-			if unsigned {
+			if rows.columns[i].flags&flagUnsigned != 0 {
 				dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
 			} else {
 				dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
@@ -1061,7 +1041,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 			continue
 
 		case fieldTypeLongLong:
-			if unsigned {
+			if rows.columns[i].flags&flagUnsigned != 0 {
 				val := binary.LittleEndian.Uint64(data[pos : pos+8])
 				if val > math.MaxInt64 {
 					dest[i] = uint64ToString(val)
@@ -1090,6 +1070,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 			fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
 			fieldTypeVarString, fieldTypeString, fieldTypeGeometry:
 			var isNull bool
+			var n int
 			dest[i], isNull, n, err = readLengthEnodedString(data[pos:])
 			pos += n
 			if err == nil {
@@ -1100,14 +1081,11 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 					continue
 				}
 			}
-			return // err
+			return err
 
 		// Date YYYY-MM-DD
 		case fieldTypeDate, fieldTypeNewDate:
-			var num uint64
-			var isNull bool
-			num, isNull, n = readLengthEncodedInteger(data[pos:])
-
+			num, isNull, n := readLengthEncodedInteger(data[pos:])
 			pos += n
 
 			if isNull {
@@ -1130,10 +1108,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 
 		// Time [-][H]HH:MM:SS[.fractal]
 		case fieldTypeTime:
-			var num uint64
-			var isNull bool
-			num, isNull, n = readLengthEncodedInteger(data[pos:])
-
+			num, isNull, n := readLengthEncodedInteger(data[pos:])
 			pos += n
 
 			if num == 0 {
@@ -1179,9 +1154,7 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 
 		// Timestamp YYYY-MM-DD HH:MM:SS[.fractal]
 		case fieldTypeTimestamp, fieldTypeDateTime:
-			var num uint64
-			var isNull bool
-			num, isNull, n = readLengthEncodedInteger(data[pos:])
+			num, isNull, n := readLengthEncodedInteger(data[pos:])
 
 			pos += n
 
@@ -1209,5 +1182,5 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 		}
 	}
 
-	return
+	return nil
 }