Browse Source

Merge pull request #164 from cxmcc/writePacket

Refactorization of writePacket
Julien Schmidt 12 years ago
parent
commit
2d629c0564
3 changed files with 28 additions and 111 deletions
  1. 1 0
      AUTHORS
  2. 0 4
      infile.go
  3. 27 107
      packets.go

+ 1 - 0
AUTHORS

@@ -23,6 +23,7 @@ Luke Scott <luke at webconnex.com>
 Michael Woolnough <michael.woolnough at gmail.com>
 Nicola Peduzzi <thenikso at gmail.com>
 Xiaobing Jiang <s7v7nislands at gmail.com>
+Xiuming Chen <cc at cxm.cc>
 
 # Organizations
 

+ 0 - 4
infile.go

@@ -114,10 +114,6 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
 		for err == nil && ioErr == nil {
 			n, err = rdr.Read(data[4:])
 			if n > 0 {
-				data[0] = byte(n)
-				data[1] = byte(n >> 8)
-				data[2] = byte(n >> 16)
-				data[3] = mc.sequence
 				ioErr = mc.writePacket(data[:4+n])
 			}
 		}

+ 27 - 107
packets.go

@@ -75,48 +75,37 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
 }
 
 // Write packet buffer 'data'
-// The packet header must be already included
 func (mc *mysqlConn) writePacket(data []byte) error {
-	if len(data)-4 <= mc.maxWriteSize { // Can send data at once
-		// Write packet
-		n, err := mc.netConn.Write(data)
-		if err == nil && n == len(data) {
-			mc.sequence++
-			return nil
-		}
-
-		// Handle error
-		if err == nil { // n != len(data)
-			errLog.Print(errMalformPkt.Error())
-		} else {
-			errLog.Print(err.Error())
-		}
-		return driver.ErrBadConn
-	}
-
-	// Must split packet
-	return mc.splitPacket(data)
-}
-
-func (mc *mysqlConn) splitPacket(data []byte) error {
 	pktLen := len(data) - 4
 
 	if pktLen > mc.maxPacketAllowed {
 		return errPktTooLarge
 	}
 
-	for pktLen >= maxPacketSize {
-		data[0] = 0xff
-		data[1] = 0xff
-		data[2] = 0xff
+	for {
+		var size int
+		if pktLen >= maxPacketSize {
+			data[0] = 0xff
+			data[1] = 0xff
+			data[2] = 0xff
+			size = maxPacketSize
+		} else {
+			data[0] = byte(pktLen)
+			data[1] = byte(pktLen >> 8)
+			data[2] = byte(pktLen >> 16)
+			size = pktLen
+		}
 		data[3] = mc.sequence
 
 		// Write packet
-		n, err := mc.netConn.Write(data[:4+maxPacketSize])
-		if err == nil && n == 4+maxPacketSize {
+		n, err := mc.netConn.Write(data[:4+size])
+		if err == nil && n == 4+size {
 			mc.sequence++
-			data = data[maxPacketSize:]
-			pktLen -= maxPacketSize
+			if size != maxPacketSize {
+				return nil
+			}
+			pktLen -= size
+			data = data[size:]
 			continue
 		}
 
@@ -128,12 +117,6 @@ func (mc *mysqlConn) splitPacket(data []byte) error {
 		}
 		return driver.ErrBadConn
 	}
-
-	data[0] = byte(pktLen)
-	data[1] = byte(pktLen >> 8)
-	data[2] = byte(pktLen >> 16)
-	data[3] = mc.sequence
-	return mc.writePacket(data)
 }
 
 /******************************************************************************
@@ -265,12 +248,6 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 	// SSL Connection Request Packet
 	// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
 	if mc.cfg.tls != nil {
-		// Packet header  [24bit length + 1 byte sequence]
-		data[0] = byte((4 + 4 + 1 + 23))
-		data[1] = byte((4 + 4 + 1 + 23) >> 8)
-		data[2] = byte((4 + 4 + 1 + 23) >> 16)
-		data[3] = mc.sequence
-
 		// Send TLS / SSL request packet
 		if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
 			return err
@@ -285,12 +262,6 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 		mc.buf.rd = tlsConn
 	}
 
-	// Add the packet header  [24bit length + 1 byte sequence]
-	data[0] = byte(pktLen)
-	data[1] = byte(pktLen >> 8)
-	data[2] = byte(pktLen >> 16)
-	data[3] = mc.sequence
-
 	// Filler [23 bytes] (all 0x00)
 	pos := 13 + 23
 
@@ -330,12 +301,6 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
 		return driver.ErrBadConn
 	}
 
-	// Add the packet header  [24bit length + 1 byte sequence]
-	data[0] = byte(pktLen)
-	data[1] = byte(pktLen >> 8)
-	data[2] = byte(pktLen >> 16)
-	data[3] = mc.sequence
-
 	// Add the scrambled password [null terminated string]
 	copy(data[4:], scrambleBuff)
 
@@ -357,12 +322,6 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
 		return driver.ErrBadConn
 	}
 
-	// Add the packet header [24bit length + 1 byte sequence]
-	data[0] = 0x01 // 1 byte long
-	data[1] = 0x00
-	data[2] = 0x00
-	data[3] = 0x00 // new command, sequence id is always 0
-
 	// Add command byte
 	data[4] = command
 
@@ -382,12 +341,6 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
 		return driver.ErrBadConn
 	}
 
-	// Add the packet header [24bit length + 1 byte sequence]
-	data[0] = byte(pktLen)
-	data[1] = byte(pktLen >> 8)
-	data[2] = byte(pktLen >> 16)
-	data[3] = 0x00 // new command, sequence id is always 0
-
 	// Add command byte
 	data[4] = command
 
@@ -409,12 +362,6 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
 		return driver.ErrBadConn
 	}
 
-	// Add the packet header [24bit length + 1 byte sequence]
-	data[0] = 0x05 // 5 bytes long
-	data[1] = 0x00
-	data[2] = 0x00
-	data[3] = 0x00 // new command, sequence id is always 0
-
 	// Add command byte
 	data[4] = command
 
@@ -748,12 +695,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
 			pktLen = dataOffset + argLen
 		}
 
-		// Add the packet header [24bit length + 1 byte sequence]
-		data[0] = byte(pktLen)
-		data[1] = byte(pktLen >> 8)
-		data[2] = byte(pktLen >> 16)
-		data[3] = 0x00 // mc.sequence
-
+		stmt.mc.sequence = 0
 		// Add command byte [1 byte]
 		data[4] = comStmtSendLongData
 
@@ -801,28 +743,14 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 	var data []byte
 
 	if len(args) == 0 {
-		const pktLen = 1 + 4 + 1 + 4
-		data = mc.buf.takeBuffer(4 + pktLen)
-		if data == nil {
-			// can not take the buffer. Something must be wrong with the connection
-			errLog.Print("Busy buffer")
-			return driver.ErrBadConn
-		}
-
-		// packet header [4 bytes]
-		data[0] = byte(pktLen)
-		data[1] = byte(pktLen >> 8)
-		data[2] = byte(pktLen >> 16)
-		data[3] = 0x00 // new command, sequence id is always 0
+		data = mc.buf.takeBuffer(4 + 1 + 4 + 1 + 4)
 	} else {
 		data = mc.buf.takeCompleteBuffer()
-		if data == nil {
-			// can not take the buffer. Something must be wrong with the connection
-			errLog.Print("Busy buffer")
-			return driver.ErrBadConn
-		}
-
-		// header (bytes 0-3) is added after we know the packet size
+	}
+	if data == nil {
+		// can not take the buffer. Something must be wrong with the connection
+		errLog.Print("Busy buffer")
+		return driver.ErrBadConn
 	}
 
 	// command [1 byte]
@@ -984,14 +912,6 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 		pos += len(paramValues)
 		data = data[:pos]
 
-		pktLen := pos - 4
-
-		// packet header [4 bytes]
-		data[0] = byte(pktLen)
-		data[1] = byte(pktLen >> 8)
-		data[2] = byte(pktLen >> 16)
-		data[3] = mc.sequence
-
 		// Convert nullMask to bytes
 		for i, max := 0, (stmt.paramCount+7)>>3; i < max; i++ {
 			data[i+14] = byte(nullMask >> uint(i<<3))