浏览代码

packets: YAR (yet another refactoring)

Julien Schmidt 12 年之前
父节点
当前提交
228ba3461b
共有 1 个文件被更改,包括 31 次插入19 次删除
  1. 31 19
      packets.go

+ 31 - 19
packets.go

@@ -360,7 +360,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
 	data[0] = 0x01 // 1 byte long
 	data[1] = 0x00
 	data[2] = 0x00
-	data[3] = 0x00 // sequence is always 0
+	data[3] = 0x00 // new command, sequence id is always 0
 
 	// Add command byte
 	data[4] = command
@@ -385,7 +385,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
 	data[0] = byte(pktLen)
 	data[1] = byte(pktLen >> 8)
 	data[2] = byte(pktLen >> 16)
-	data[3] = 0x00 // sequence is always 0
+	data[3] = 0x00 // new command, sequence id is always 0
 
 	// Add command byte
 	data[4] = command
@@ -412,7 +412,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
 	data[0] = 0x05 // 5 bytes long
 	data[1] = 0x00
 	data[2] = 0x00
-	data[3] = 0x00 // sequence is always 0
+	data[3] = 0x00 // new command, sequence id is always 0
 
 	// Add command byte
 	data[4] = command
@@ -495,8 +495,8 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
 	pos := 3
 
 	// SQL State [optional: # + 5bytes string]
-	//sqlstate := string(data[pos : pos+6])
 	if data[3] == 0x23 {
+		//sqlstate := string(data[4 : 4+5])
 		pos = 9
 	}
 
@@ -700,13 +700,13 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
 		}
 
 		// statement id [4 bytes]
-		stmt.id = binary.LittleEndian.Uint32(data[1 : 1+4])
+		stmt.id = binary.LittleEndian.Uint32(data[1:5])
 
 		// Column count [16 bit uint]
-		columnCount := binary.LittleEndian.Uint16(data[1+4 : 1+4+2])
+		columnCount := binary.LittleEndian.Uint16(data[5:7])
 
 		// Param count [16 bit uint]
-		stmt.paramCount = int(binary.LittleEndian.Uint16(data[1+4+2 : 1+4+2+2]))
+		stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9]))
 
 		// Reserved [8 bit]
 
@@ -715,7 +715,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
 			return columnCount, nil
 		} else {
 			// Check for warnings count > 0, only available in MySQL > 4.1
-			if len(data) >= 12 && binary.LittleEndian.Uint16(data[1+4+2+2+1:1+4+2+2+1+2]) > 0 {
+			if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 {
 				return columnCount, stmt.mc.getWarnings()
 			}
 			return columnCount, nil
@@ -729,16 +729,22 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
 	maxLen := stmt.mc.maxPacketAllowed - 1
 	pktLen := maxLen
 
+	// After the header (bytes 0-3) follows before the data:
+	// 1 byte command
+	// 4 bytes stmtID
+	// 2 bytes paramID
+	const dataOffset = 1 + 4 + 2
+
 	// Can not use the write buffer since
 	// a) the buffer is too small
 	// b) it is in use
 	data := make([]byte, 4+1+4+2+len(arg))
 
-	copy(data[4+1+4+2:], arg)
+	copy(data[4+dataOffset:], arg)
 
-	for argLen := len(arg); argLen > 0; argLen -= pktLen - (1 + 4 + 2) {
-		if 1+4+2+argLen < maxLen {
-			pktLen = 1 + 4 + 2 + argLen
+	for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset {
+		if dataOffset+argLen < maxLen {
+			pktLen = dataOffset + argLen
 		}
 
 		// Add the packet header [24bit length + 1 byte sequence]
@@ -763,7 +769,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
 		// Send CMD packet
 		err := stmt.mc.writePacket(data[:4+pktLen])
 		if err == nil {
-			data = data[pktLen-(1+4+2):]
+			data = data[pktLen-dataOffset:]
 			continue
 		}
 		return err
@@ -806,7 +812,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 		data[0] = byte(pktLen)
 		data[1] = byte(pktLen >> 8)
 		data[2] = byte(pktLen >> 16)
-		data[3] = 0x00 // sequence is always 0
+		data[3] = 0x00 // new command, sequence id is always 0
 	} else {
 		data = mc.buf.takeCompleteBuffer()
 		if data == nil {
@@ -871,7 +877,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 
 				if cap(paramValues)-len(paramValues)-8 >= 0 {
 					paramValues = paramValues[:len(paramValues)+8]
-					binary.LittleEndian.PutUint64(paramValues[len(paramValues)-8:], uint64(v))
+					binary.LittleEndian.PutUint64(
+						paramValues[len(paramValues)-8:],
+						uint64(v),
+					)
 				} else {
 					paramValues = append(paramValues,
 						uint64ToBytes(uint64(v))...,
@@ -884,7 +893,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 
 				if cap(paramValues)-len(paramValues)-8 >= 0 {
 					paramValues = paramValues[:len(paramValues)+8]
-					binary.LittleEndian.PutUint64(paramValues[len(paramValues)-8:], math.Float64bits(v))
+					binary.LittleEndian.PutUint64(
+						paramValues[len(paramValues)-8:],
+						math.Float64bits(v),
+					)
 				} else {
 					paramValues = append(paramValues,
 						uint64ToBytes(math.Float64bits(v))...,
@@ -991,10 +1003,10 @@ func (rows *mysqlRows) readBinaryRow(dest []driver.Value) error {
 		// EOF Packet
 		if data[0] == iEOF && len(data) == 5 {
 			return io.EOF
-		} else {
-			// Error otherwise
-			return rows.mc.handleErrorPacket(data)
 		}
+
+		// Error otherwise
+		return rows.mc.handleErrorPacket(data)
 	}
 
 	// NULL-bitmap,  [(column-count + 7 + 2) / 8 bytes]