فهرست منبع

more refactoring

Try to remove unnecessary indirections and initialisations with zero.
Also update links to the MySQL doc
Julien Schmidt 12 سال پیش
والد
کامیت
5975ca9212
4فایلهای تغییر یافته به همراه78 افزوده شده و 79 حذف شده
  1. 27 30
      connection.go
  2. 30 27
      packets.go
  3. 3 3
      rows.go
  4. 18 19
      statement.go

+ 27 - 30
connection.go

@@ -136,14 +136,14 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 	columnCount, err := stmt.readPrepareResultPacket()
 	if err == nil {
 		if stmt.paramCount > 0 {
-			stmt.params, err = stmt.mc.readColumns(stmt.paramCount)
+			stmt.params, err = mc.readColumns(stmt.paramCount)
 			if err != nil {
 				return nil, err
 			}
 		}
 
 		if columnCount > 0 {
-			err = stmt.mc.readUntilEOF()
+			err = mc.readUntilEOF()
 		}
 	}
 
@@ -171,26 +171,24 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
 }
 
 // Internal function to execute commands
-func (mc *mysqlConn) exec(query string) (err error) {
+func (mc *mysqlConn) exec(query string) error {
 	// Send command
-	err = mc.writeCommandPacketStr(comQuery, query)
+	err := mc.writeCommandPacketStr(comQuery, query)
 	if err != nil {
-		return
+		return err
 	}
 
 	// Read Result
-	var resLen int
-	resLen, err = mc.readResultSetHeaderPacket()
+	resLen, err := mc.readResultSetHeaderPacket()
 	if err == nil && resLen > 0 {
-		err = mc.readUntilEOF()
-		if err != nil {
-			return
+		if err = mc.readUntilEOF(); err != nil {
+			return err
 		}
 
 		err = mc.readUntilEOF()
 	}
 
-	return
+	return err
 }
 
 func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
@@ -211,7 +209,6 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
 				return rows, err
 			}
 		}
-
 		return nil, err
 	}
 
@@ -221,29 +218,29 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
 
 // Gets the value of the given MySQL System Variable
 // The returned byte slice is only valid until the next read
-func (mc *mysqlConn) getSystemVar(name string) (val []byte, err error) {
+func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
 	// Send command
-	err = mc.writeCommandPacketStr(comQuery, "SELECT @@"+name)
+	if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
+		return nil, err
+	}
+
+	// Read Result
+	resLen, err := mc.readResultSetHeaderPacket()
 	if err == nil {
-		// Read Result
-		var resLen int
-		resLen, err = mc.readResultSetHeaderPacket()
-		if err == nil {
-			rows := &mysqlRows{mc, false, nil, false}
+		rows := &mysqlRows{mc, false, nil, false}
 
-			if resLen > 0 {
-				// Columns
-				rows.columns, err = mc.readColumns(resLen)
+		if resLen > 0 {
+			// Columns
+			rows.columns, err = mc.readColumns(resLen)
+			if err != nil {
+				return nil, err
 			}
+		}
 
-			dest := make([]driver.Value, resLen)
-			err = rows.readRow(dest)
-			if err == nil {
-				val = dest[0].([]byte)
-				err = mc.readUntilEOF()
-			}
+		dest := make([]driver.Value, resLen)
+		if err = rows.readRow(dest); err == nil {
+			return dest[0].([]byte), mc.readUntilEOF()
 		}
 	}
-
-	return
+	return nil, err
 }

+ 30 - 27
packets.go

@@ -140,7 +140,7 @@ func (mc *mysqlConn) splitPacket(data []byte) error {
 ******************************************************************************/
 
 // Handshake Initialization Packet
-// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::Handshake
+// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
 func (mc *mysqlConn) readInitPacket() ([]byte, error) {
 	data, err := mc.readPacket()
 	if err != nil {
@@ -197,7 +197,6 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
 		// TODO: Verify string termination
 		// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
 		// \NUL otherwise
-		// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::Handshake
 		//
 		//if data[len(data)-1] == 0 {
 		//	return
@@ -209,7 +208,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
 }
 
 // Client Authentication Packet
-// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::HandshakeResponse
+// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
 func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 	// Adjust client flags based on server support
 	clientFlags := clientProtocol41 |
@@ -263,7 +262,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 	data[12] = collation_utf8_general_ci
 
 	// SSL Connection Request Packet
-	// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::SSLRequest
+	// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
 	if mc.cfg.tls != nil {
 		// Packet header  [24bit length + 1 byte sequence]
 		data[0] = byte((4 + 4 + 1 + 23))
@@ -316,7 +315,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 }
 
 //  Client old authentication packet
-// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::AuthSwitchResponse
+// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
 func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
 	// User password
 	scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.passwd))
@@ -454,7 +453,7 @@ func (mc *mysqlConn) readResultOK() error {
 }
 
 // Result Set Header Packet
-// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::Resultset
+// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
 func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
 	data, err := mc.readPacket()
 	if err == nil {
@@ -482,7 +481,7 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
 }
 
 // Error Packet
-// http://dev.mysql.com/doc/internals/en/overview.html#packet-ERR_Packet
+// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet
 func (mc *mysqlConn) handleErrorPacket(data []byte) error {
 	if data[0] != iERR {
 		return errMalformPkt
@@ -509,7 +508,7 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
 }
 
 // Ok Packet
-// http://dev.mysql.com/doc/internals/en/overview.html#packet-OK_Packet
+// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
 func (mc *mysqlConn) handleOkPacket(data []byte) error {
 	var n, m int
 
@@ -536,7 +535,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
 }
 
 // Read Packets as Field Packets until EOF-Packet or an Error appears
-// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-Protocol::ColumnDefinition41
+// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
 func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
 	columns := make([]mysqlField, count)
 
@@ -619,9 +618,11 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
 }
 
 // Read Packets as Field Packets until EOF-Packet or an Error appears
-// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::ResultsetRow
+// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
 func (rows *mysqlRows) readRow(dest []driver.Value) error {
-	data, err := rows.mc.readPacket()
+	mc := rows.mc
+
+	data, err := mc.readPacket()
 	if err != nil {
 		return err
 	}
@@ -642,7 +643,7 @@ func (rows *mysqlRows) readRow(dest []driver.Value) error {
 		pos += n
 		if err == nil {
 			if !isNull {
-				if !rows.mc.parseTime {
+				if !mc.parseTime {
 					continue
 				} else {
 					switch rows.columns[i].fieldType {
@@ -650,7 +651,7 @@ func (rows *mysqlRows) readRow(dest []driver.Value) error {
 						fieldTypeDate, fieldTypeNewDate:
 						dest[i], err = parseDateTime(
 							string(dest[i].([]byte)),
-							rows.mc.cfg.loc,
+							mc.cfg.loc,
 						)
 						if err == nil {
 							continue
@@ -689,7 +690,7 @@ func (mc *mysqlConn) readUntilEOF() error {
 ******************************************************************************/
 
 // Prepare Result Packets
-// http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-prepare-response
+// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
 func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
 	data, err := stmt.mc.readPacket()
 	if err == nil {
@@ -723,7 +724,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
 	return 0, err
 }
 
-// http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-send-long-data
+// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
 func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
 	maxLen := stmt.mc.maxPacketAllowed - 1
 	pktLen := maxLen
@@ -785,14 +786,16 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 		)
 	}
 
+	mc := stmt.mc
+
 	// Reset packet-sequence
-	stmt.mc.sequence = 0
+	mc.sequence = 0
 
 	var data []byte
 
 	if len(args) == 0 {
 		const pktLen = 1 + 4 + 1 + 4
-		data = stmt.mc.buf.writeBuffer(4 + pktLen)
+		data = mc.buf.writeBuffer(4 + pktLen)
 		if data == nil {
 			// can not take the buffer. Something must be wrong with the connection
 			errLog.Print("Busy buffer")
@@ -805,7 +808,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 		data[2] = byte(pktLen >> 16)
 		data[3] = 0x00 // sequence is always 0
 	} else {
-		data = stmt.mc.buf.takeCompleteBuffer()
+		data = mc.buf.takeCompleteBuffer()
 		if data == nil {
 			// can not take the buffer. Something must be wrong with the connection
 			errLog.Print("Busy buffer")
@@ -902,7 +905,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 				paramTypes[i+i] = fieldTypeString
 				paramTypes[i+i+1] = 0x00
 
-				if len(v) < stmt.mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
+				if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
 					paramValues = append(paramValues,
 						lengthEncodedIntegerToBytes(uint64(len(v)))...,
 					)
@@ -917,7 +920,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 				paramTypes[i+i] = fieldTypeString
 				paramTypes[i+i+1] = 0x00
 
-				if len(v) < stmt.mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
+				if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
 					paramValues = append(paramValues,
 						lengthEncodedIntegerToBytes(uint64(len(v)))...,
 					)
@@ -936,7 +939,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 				if v.IsZero() {
 					val = []byte("0000-00-00")
 				} else {
-					val = []byte(v.In(stmt.mc.cfg.loc).Format(timeFormat))
+					val = []byte(v.In(mc.cfg.loc).Format(timeFormat))
 				}
 
 				paramValues = append(paramValues,
@@ -953,7 +956,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 		// In that case we must build the data packet with the new values buffer
 		if valuesCap != cap(paramValues) {
 			data = append(data[:pos], paramValues...)
-			stmt.mc.buf.buf = data
+			mc.buf.buf = data
 		}
 
 		pos += len(paramValues)
@@ -965,18 +968,18 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 		data[0] = byte(pktLen)
 		data[1] = byte(pktLen >> 8)
 		data[2] = byte(pktLen >> 16)
-		data[3] = stmt.mc.sequence
+		data[3] = mc.sequence
 
 		// Convert nullMask to bytes
-		for i, max := 14, 14+((stmt.paramCount+7)>>3); i < max; i++ {
-			data[i] = byte(nullMask >> uint((i-14)<<3))
+		for i, max := 0, (stmt.paramCount+7)>>3; i < max; i++ {
+			data[i+14] = byte(nullMask >> uint(i<<3))
 		}
 	}
 
-	return stmt.mc.writePacket(data)
+	return mc.writePacket(data)
 }
 
-// http://dev.mysql.com/doc/internals/en/prepared-statements.html#packet-ProtocolBinary::ResultsetRow
+// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
 func (rows *mysqlRows) readBinaryRow(dest []driver.Value) error {
 	data, err := rows.mc.readPacket()
 	if err != nil {

+ 3 - 3
rows.go

@@ -26,12 +26,12 @@ type mysqlRows struct {
 	eof     bool
 }
 
-func (rows *mysqlRows) Columns() (columns []string) {
-	columns = make([]string, len(rows.columns))
+func (rows *mysqlRows) Columns() []string {
+	columns := make([]string, len(rows.columns))
 	for i := range columns {
 		columns[i] = rows.columns[i].name
 	}
-	return
+	return columns
 }
 
 func (rows *mysqlRows) Close() (err error) {

+ 18 - 19
statement.go

@@ -19,14 +19,14 @@ type mysqlStmt struct {
 	params     []mysqlField
 }
 
-func (stmt *mysqlStmt) Close() (err error) {
+func (stmt *mysqlStmt) Close() error {
 	if stmt.mc == nil || stmt.mc.netConn == nil {
 		return errInvalidConn
 	}
 
-	err = stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
+	err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
 	stmt.mc = nil
-	return
+	return err
 }
 
 func (stmt *mysqlStmt) NumInput() int {
@@ -34,33 +34,34 @@ func (stmt *mysqlStmt) NumInput() int {
 }
 
 func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
-	stmt.mc.affectedRows = 0
-	stmt.mc.insertId = 0
-
 	// Send command
 	err := stmt.writeExecutePacket(args)
 	if err != nil {
 		return nil, err
 	}
 
+	mc := stmt.mc
+
+	mc.affectedRows = 0
+	mc.insertId = 0
+
 	// Read Result
-	var resLen int
-	resLen, err = stmt.mc.readResultSetHeaderPacket()
+	resLen, err := mc.readResultSetHeaderPacket()
 	if err == nil {
 		if resLen > 0 {
 			// Columns
-			err = stmt.mc.readUntilEOF()
+			err = mc.readUntilEOF()
 			if err != nil {
 				return nil, err
 			}
 
 			// Rows
-			err = stmt.mc.readUntilEOF()
+			err = mc.readUntilEOF()
 		}
 		if err == nil {
 			return &mysqlResult{
-				affectedRows: int64(stmt.mc.affectedRows),
-				insertId:     int64(stmt.mc.insertId),
+				affectedRows: int64(mc.affectedRows),
+				insertId:     int64(mc.insertId),
 			}, nil
 		}
 	}
@@ -75,21 +76,19 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
 		return nil, err
 	}
 
+	mc := stmt.mc
+
 	// Read Result
-	var resLen int
-	resLen, err = stmt.mc.readResultSetHeaderPacket()
+	resLen, err := mc.readResultSetHeaderPacket()
 	if err != nil {
 		return nil, err
 	}
 
-	rows := &mysqlRows{stmt.mc, true, nil, false}
+	rows := &mysqlRows{mc, true, nil, false}
 
 	if resLen > 0 {
 		// Columns
-		rows.columns, err = stmt.mc.readColumns(resLen)
-		if err != nil {
-			return nil, err
-		}
+		rows.columns, err = mc.readColumns(resLen)
 	}
 
 	return rows, err