Browse Source

Merge writePacket and splitPacket into one function

Xiuming Chen 12 năm trước cách đây
mục cha
commit
b203832566
1 tập tin đã thay đổi với 22 bổ sung35 xóa
  1. 22 35
      packets.go

+ 22 - 35
packets.go

@@ -77,46 +77,36 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
 // Write packet buffer 'data'
 // The packet header must be already included
 func (mc *mysqlConn) writePacket(data []byte) error {
-	if len(data)-4 <= mc.maxWriteSize { // Can send data at once
-		// Write packet
-		n, err := mc.netConn.Write(data)
-		if err == nil && n == len(data) {
-			mc.sequence++
-			return nil
-		}
-
-		// Handle error
-		if err == nil { // n != len(data)
-			errLog.Print(errMalformPkt.Error())
-		} else {
-			errLog.Print(err.Error())
-		}
-		return driver.ErrBadConn
-	}
-
-	// Must split packet
-	return mc.splitPacket(data)
-}
-
-func (mc *mysqlConn) splitPacket(data []byte) error {
 	pktLen := len(data) - 4
 
 	if pktLen > mc.maxPacketAllowed {
 		return errPktTooLarge
 	}
 
-	for pktLen >= maxPacketSize {
-		data[0] = 0xff
-		data[1] = 0xff
-		data[2] = 0xff
+	for {
+		var size int
+		if pktLen >= maxPacketSize {
+			data[0] = 0xff
+			data[1] = 0xff
+			data[2] = 0xff
+			size = maxPacketSize
+		} else {
+			data[0] = byte(pktLen)
+			data[1] = byte(pktLen >> 8)
+			data[2] = byte(pktLen >> 16)
+			size = pktLen
+		}
 		data[3] = mc.sequence
 
 		// Write packet
-		n, err := mc.netConn.Write(data[:4+maxPacketSize])
-		if err == nil && n == 4+maxPacketSize {
+		n, err := mc.netConn.Write(data[:4+size])
+		if err == nil && n == 4+size {
 			mc.sequence++
-			data = data[maxPacketSize:]
-			pktLen -= maxPacketSize
+			if size != maxPacketSize {
+				break
+			}
+			pktLen -= size
+			data = data[size:]
 			continue
 		}
 
@@ -129,11 +119,7 @@ func (mc *mysqlConn) splitPacket(data []byte) error {
 		return driver.ErrBadConn
 	}
 
-	data[0] = byte(pktLen)
-	data[1] = byte(pktLen >> 8)
-	data[2] = byte(pktLen >> 16)
-	data[3] = mc.sequence
-	return mc.writePacket(data)
+	return nil
 }
 
 /******************************************************************************
@@ -748,6 +734,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
 			pktLen = dataOffset + argLen
 		}
 
+		stmt.mc.sequence = 0
 		// Add the packet header [24bit length + 1 byte sequence]
 		data[0] = byte(pktLen)
 		data[1] = byte(pktLen >> 8)