Przeglądaj źródła

more refactoring

Try to remove unnecessary indirections and initialisations with zero.
Also update links to the MySQL doc
Julien Schmidt 12 lat temu
rodzic
commit
5975ca9212
4 zmienionych plików z 78 dodań i 79 usunięć
  1. 27 30
      connection.go
  2. 30 27
      packets.go
  3. 3 3
      rows.go
  4. 18 19
      statement.go

+ 27 - 30
connection.go

@@ -136,14 +136,14 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 	columnCount, err := stmt.readPrepareResultPacket()
 	if err == nil {
 		if stmt.paramCount > 0 {
-			stmt.params, err = stmt.mc.readColumns(stmt.paramCount)
+			stmt.params, err = mc.readColumns(stmt.paramCount)
 			if err != nil {
 				return nil, err
 			}
 		}
 
 		if columnCount > 0 {
-			err = stmt.mc.readUntilEOF()
+			err = mc.readUntilEOF()
 		}
 	}
 
@@ -171,26 +171,24 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
 }
 
 // Internal function to execute commands
-func (mc *mysqlConn) exec(query string) (err error) {
+func (mc *mysqlConn) exec(query string) error {
 	// Send command
-	err = mc.writeCommandPacketStr(comQuery, query)
+	err := mc.writeCommandPacketStr(comQuery, query)
 	if err != nil {
-		return
+		return err
 	}
 
 	// Read Result
-	var resLen int
-	resLen, err = mc.readResultSetHeaderPacket()
+	resLen, err := mc.readResultSetHeaderPacket()
 	if err == nil && resLen > 0 {
-		err = mc.readUntilEOF()
-		if err != nil {
-			return
+		if err = mc.readUntilEOF(); err != nil {
+			return err
 		}
 
 		err = mc.readUntilEOF()
 	}
 
-	return
+	return err
 }
 
 func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
@@ -211,7 +209,6 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
 				return rows, err
 			}
 		}
-
 		return nil, err
 	}
 
@@ -221,29 +218,29 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
 
 // Gets the value of the given MySQL System Variable
 // The returned byte slice is only valid until the next read
-func (mc *mysqlConn) getSystemVar(name string) (val []byte, err error) {
+func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
 	// Send command
-	err = mc.writeCommandPacketStr(comQuery, "SELECT @@"+name)
+	if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
+		return nil, err
+	}
+
+	// Read Result
+	resLen, err := mc.readResultSetHeaderPacket()
 	if err == nil {
-		// Read Result
-		var resLen int
-		resLen, err = mc.readResultSetHeaderPacket()
-		if err == nil {
-			rows := &mysqlRows{mc, false, nil, false}
+		rows := &mysqlRows{mc, false, nil, false}
 
-			if resLen > 0 {
-				// Columns
-				rows.columns, err = mc.readColumns(resLen)
+		if resLen > 0 {
+			// Columns
+			rows.columns, err = mc.readColumns(resLen)
+			if err != nil {
+				return nil, err
 			}
+		}
 
-			dest := make([]driver.Value, resLen)
-			err = rows.readRow(dest)
-			if err == nil {
-				val = dest[0].([]byte)
-				err = mc.readUntilEOF()
-			}
+		dest := make([]driver.Value, resLen)
+		if err = rows.readRow(dest); err == nil {
+			return dest[0].([]byte), mc.readUntilEOF()
 		}
 	}
-
-	return
+	return nil, err
 }

+ 30 - 27
packets.go

@@ -140,7 +140,7 @@ func (mc *mysqlConn) splitPacket(data []byte) error {
 ******************************************************************************/
 
 // Handshake Initialization Packet
-// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::Handshake
+// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
 func (mc *mysqlConn) readInitPacket() ([]byte, error) {
 	data, err := mc.readPacket()
 	if err != nil {
@@ -197,7 +197,6 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
 		// TODO: Verify string termination
 		// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
 		// \NUL otherwise
-		// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::Handshake
 		//
 		//if data[len(data)-1] == 0 {
 		//	return
@@ -209,7 +208,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
 }
 
 // Client Authentication Packet
-// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::HandshakeResponse
+// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
 func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 	// Adjust client flags based on server support
 	clientFlags := clientProtocol41 |
@@ -263,7 +262,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 	data[12] = collation_utf8_general_ci
 
 	// SSL Connection Request Packet
-	// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::SSLRequest
+	// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
 	if mc.cfg.tls != nil {
 		// Packet header  [24bit length + 1 byte sequence]
 		data[0] = byte((4 + 4 + 1 + 23))
@@ -316,7 +315,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
 }
 
 //  Client old authentication packet
-// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::AuthSwitchResponse
+// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
 func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
 	// User password
 	scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.passwd))
@@ -454,7 +453,7 @@ func (mc *mysqlConn) readResultOK() error {
 }
 
 // Result Set Header Packet
-// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::Resultset
+// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
 func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
 	data, err := mc.readPacket()
 	if err == nil {
@@ -482,7 +481,7 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
 }
 
 // Error Packet
-// http://dev.mysql.com/doc/internals/en/overview.html#packet-ERR_Packet
+// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet
 func (mc *mysqlConn) handleErrorPacket(data []byte) error {
 	if data[0] != iERR {
 		return errMalformPkt
@@ -509,7 +508,7 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
 }
 
 // Ok Packet
-// http://dev.mysql.com/doc/internals/en/overview.html#packet-OK_Packet
+// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
 func (mc *mysqlConn) handleOkPacket(data []byte) error {
 	var n, m int
 
@@ -536,7 +535,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
 }
 
 // Read Packets as Field Packets until EOF-Packet or an Error appears
-// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-Protocol::ColumnDefinition41
+// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
 func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
 	columns := make([]mysqlField, count)
 
@@ -619,9 +618,11 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
 }
 
 // Read Packets as Field Packets until EOF-Packet or an Error appears
-// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::ResultsetRow
+// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
 func (rows *mysqlRows) readRow(dest []driver.Value) error {
-	data, err := rows.mc.readPacket()
+	mc := rows.mc
+
+	data, err := mc.readPacket()
 	if err != nil {
 		return err
 	}
@@ -642,7 +643,7 @@ func (rows *mysqlRows) readRow(dest []driver.Value) error {
 		pos += n
 		if err == nil {
 			if !isNull {
-				if !rows.mc.parseTime {
+				if !mc.parseTime {
 					continue
 				} else {
 					switch rows.columns[i].fieldType {
@@ -650,7 +651,7 @@ func (rows *mysqlRows) readRow(dest []driver.Value) error {
 						fieldTypeDate, fieldTypeNewDate:
 						dest[i], err = parseDateTime(
 							string(dest[i].([]byte)),
-							rows.mc.cfg.loc,
+							mc.cfg.loc,
 						)
 						if err == nil {
 							continue
@@ -689,7 +690,7 @@ func (mc *mysqlConn) readUntilEOF() error {
 ******************************************************************************/
 
 // Prepare Result Packets
-// http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-prepare-response
+// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
 func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
 	data, err := stmt.mc.readPacket()
 	if err == nil {
@@ -723,7 +724,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
 	return 0, err
 }
 
-// http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-send-long-data
+// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
 func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
 	maxLen := stmt.mc.maxPacketAllowed - 1
 	pktLen := maxLen
@@ -785,14 +786,16 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 		)
 	}
 
+	mc := stmt.mc
+
 	// Reset packet-sequence
-	stmt.mc.sequence = 0
+	mc.sequence = 0
 
 	var data []byte
 
 	if len(args) == 0 {
 		const pktLen = 1 + 4 + 1 + 4
-		data = stmt.mc.buf.writeBuffer(4 + pktLen)
+		data = mc.buf.writeBuffer(4 + pktLen)
 		if data == nil {
 			// can not take the buffer. Something must be wrong with the connection
 			errLog.Print("Busy buffer")
@@ -805,7 +808,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 		data[2] = byte(pktLen >> 16)
 		data[3] = 0x00 // sequence is always 0
 	} else {
-		data = stmt.mc.buf.takeCompleteBuffer()
+		data = mc.buf.takeCompleteBuffer()
 		if data == nil {
 			// can not take the buffer. Something must be wrong with the connection
 			errLog.Print("Busy buffer")
@@ -902,7 +905,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 				paramTypes[i+i] = fieldTypeString
 				paramTypes[i+i+1] = 0x00
 
-				if len(v) < stmt.mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
+				if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
 					paramValues = append(paramValues,
 						lengthEncodedIntegerToBytes(uint64(len(v)))...,
 					)
@@ -917,7 +920,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 				paramTypes[i+i] = fieldTypeString
 				paramTypes[i+i+1] = 0x00
 
-				if len(v) < stmt.mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
+				if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
 					paramValues = append(paramValues,
 						lengthEncodedIntegerToBytes(uint64(len(v)))...,
 					)
@@ -936,7 +939,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 				if v.IsZero() {
 					val = []byte("0000-00-00")
 				} else {
-					val = []byte(v.In(stmt.mc.cfg.loc).Format(timeFormat))
+					val = []byte(v.In(mc.cfg.loc).Format(timeFormat))
 				}
 
 				paramValues = append(paramValues,
@@ -953,7 +956,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 		// In that case we must build the data packet with the new values buffer
 		if valuesCap != cap(paramValues) {
 			data = append(data[:pos], paramValues...)
-			stmt.mc.buf.buf = data
+			mc.buf.buf = data
 		}
 
 		pos += len(paramValues)
@@ -965,18 +968,18 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 		data[0] = byte(pktLen)
 		data[1] = byte(pktLen >> 8)
 		data[2] = byte(pktLen >> 16)
-		data[3] = stmt.mc.sequence
+		data[3] = mc.sequence
 
 		// Convert nullMask to bytes
-		for i, max := 14, 14+((stmt.paramCount+7)>>3); i < max; i++ {
-			data[i] = byte(nullMask >> uint((i-14)<<3))
+		for i, max := 0, (stmt.paramCount+7)>>3; i < max; i++ {
+			data[i+14] = byte(nullMask >> uint(i<<3))
 		}
 	}
 
-	return stmt.mc.writePacket(data)
+	return mc.writePacket(data)
 }
 
-// http://dev.mysql.com/doc/internals/en/prepared-statements.html#packet-ProtocolBinary::ResultsetRow
+// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
 func (rows *mysqlRows) readBinaryRow(dest []driver.Value) error {
 	data, err := rows.mc.readPacket()
 	if err != nil {

+ 3 - 3
rows.go

@@ -26,12 +26,12 @@ type mysqlRows struct {
 	eof     bool
 }
 
-func (rows *mysqlRows) Columns() (columns []string) {
-	columns = make([]string, len(rows.columns))
+func (rows *mysqlRows) Columns() []string {
+	columns := make([]string, len(rows.columns))
 	for i := range columns {
 		columns[i] = rows.columns[i].name
 	}
-	return
+	return columns
 }
 
 func (rows *mysqlRows) Close() (err error) {

+ 18 - 19
statement.go

@@ -19,14 +19,14 @@ type mysqlStmt struct {
 	params     []mysqlField
 }
 
-func (stmt *mysqlStmt) Close() (err error) {
+func (stmt *mysqlStmt) Close() error {
 	if stmt.mc == nil || stmt.mc.netConn == nil {
 		return errInvalidConn
 	}
 
-	err = stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
+	err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
 	stmt.mc = nil
-	return
+	return err
 }
 
 func (stmt *mysqlStmt) NumInput() int {
@@ -34,33 +34,34 @@ func (stmt *mysqlStmt) NumInput() int {
 }
 
 func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
-	stmt.mc.affectedRows = 0
-	stmt.mc.insertId = 0
-
 	// Send command
 	err := stmt.writeExecutePacket(args)
 	if err != nil {
 		return nil, err
 	}
 
+	mc := stmt.mc
+
+	mc.affectedRows = 0
+	mc.insertId = 0
+
 	// Read Result
-	var resLen int
-	resLen, err = stmt.mc.readResultSetHeaderPacket()
+	resLen, err := mc.readResultSetHeaderPacket()
 	if err == nil {
 		if resLen > 0 {
 			// Columns
-			err = stmt.mc.readUntilEOF()
+			err = mc.readUntilEOF()
 			if err != nil {
 				return nil, err
 			}
 
 			// Rows
-			err = stmt.mc.readUntilEOF()
+			err = mc.readUntilEOF()
 		}
 		if err == nil {
 			return &mysqlResult{
-				affectedRows: int64(stmt.mc.affectedRows),
-				insertId:     int64(stmt.mc.insertId),
+				affectedRows: int64(mc.affectedRows),
+				insertId:     int64(mc.insertId),
 			}, nil
 		}
 	}
@@ -75,21 +76,19 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
 		return nil, err
 	}
 
+	mc := stmt.mc
+
 	// Read Result
-	var resLen int
-	resLen, err = stmt.mc.readResultSetHeaderPacket()
+	resLen, err := mc.readResultSetHeaderPacket()
 	if err != nil {
 		return nil, err
 	}
 
-	rows := &mysqlRows{stmt.mc, true, nil, false}
+	rows := &mysqlRows{mc, true, nil, false}
 
 	if resLen > 0 {
 		// Columns
-		rows.columns, err = stmt.mc.readColumns(resLen)
-		if err != nil {
-			return nil, err
-		}
+		rows.columns, err = mc.readColumns(resLen)
 	}
 
 	return rows, err