Julien Schmidt il y a 12 ans
Parent
commit
d1deaee846
6 fichiers modifiés avec 153 ajouts et 170 suppressions
  1. 1 1
      buffer.go
  2. 19 35
      connection.go
  3. 94 94
      const.go
  4. 37 38
      packets.go
  5. 1 1
      rows.go
  6. 1 1
      statement.go

+ 1 - 1
buffer.go

@@ -37,7 +37,7 @@ func (b *buffer) fill(need int) (err error) {
 	b.idx = 0
 	b.length = 0
 
-	n := 0
+	var n int
 	for b.length < need {
 		n, err = b.rd.Read(b.buf[b.length:])
 		b.length += n

+ 19 - 35
connection.go

@@ -18,7 +18,7 @@ import (
 
 type mysqlConn struct {
 	cfg          *config
-	flags        ClientFlag
+	flags        clientFlag
 	charset      byte
 	cipher       []byte
 	netConn      net.Conn
@@ -87,7 +87,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
 }
 
 func (mc *mysqlConn) Close() (err error) {
-	mc.writeCommandPacket(COM_QUIT)
+	mc.writeCommandPacket(comQuit)
 	mc.cfg = nil
 	mc.buf = nil
 	mc.netConn.Close()
@@ -97,7 +97,7 @@ func (mc *mysqlConn) Close() (err error) {
 
 func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 	// Send command
-	err := mc.writeCommandPacketStr(COM_STMT_PREPARE, query)
+	err := mc.writeCommandPacketStr(comStmtPrepare, query)
 	if err != nil {
 		return nil, err
 	}
@@ -124,37 +124,30 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
 	return stmt, err
 }
 
-func (mc *mysqlConn) Exec(query string, args []driver.Value) (_ driver.Result, err error) {
-	if len(args) > 0 { // with args, must use prepared stmt
-		var res driver.Result
-		var stmt driver.Stmt
-		stmt, err = mc.Prepare(query)
-		if err == nil {
-			res, err = stmt.Exec(args)
-			if err == nil {
-				return res, stmt.Close()
-			}
-		}
-	} else { // no args, fastpath
+func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
+	if len(args) == 0 { // no args, fastpath
 		mc.affectedRows = 0
 		mc.insertId = 0
 
-		err = mc.exec(query)
+		err := mc.exec(query)
 		if err == nil {
 			return &mysqlResult{
 				affectedRows: int64(mc.affectedRows),
 				insertId:     int64(mc.insertId),
 			}, err
 		}
+		return nil, err
 	}
-	return nil, err
+
+	// with args, must use prepared stmt
+	return nil, driver.ErrSkip
 
 }
 
 // Internal function to execute commands
 func (mc *mysqlConn) exec(query string) (err error) {
 	// Send command
-	err = mc.writeCommandPacketStr(COM_QUERY, query)
+	err = mc.writeCommandPacketStr(comQuery, query)
 	if err != nil {
 		return
 	}
@@ -174,28 +167,16 @@ func (mc *mysqlConn) exec(query string) (err error) {
 	return
 }
 
-func (mc *mysqlConn) Query(query string, args []driver.Value) (_ driver.Rows, err error) {
-	if len(args) > 0 { // with args, must use prepared stmt
-		var rows driver.Rows
-		var stmt driver.Stmt
-		stmt, err = mc.Prepare(query)
-		if err == nil {
-			rows, err = stmt.Query(args)
-			if err == nil {
-				return rows, stmt.Close()
-			}
-		}
-		return
-	} else { // no args, fastpath
-		var rows *mysqlRows
+func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
+	if len(args) == 0 { // no args, fastpath
 		// Send command
-		err = mc.writeCommandPacketStr(COM_QUERY, query)
+		err := mc.writeCommandPacketStr(comQuery, query)
 		if err == nil {
 			// Read Result
 			var resLen int
 			resLen, err = mc.readResultSetHeaderPacket()
 			if err == nil {
-				rows = &mysqlRows{mc, false, nil, false}
+				rows := &mysqlRows{mc, false, nil, false}
 
 				if resLen > 0 {
 					// Columns
@@ -204,7 +185,10 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (_ driver.Rows, er
 				return rows, err
 			}
 		}
+
+		return nil, err
 	}
 
-	return nil, err
+	// with args, must use prepared stmt
+	return nil, driver.ErrSkip
 }

+ 94 - 94
const.go

@@ -10,119 +10,119 @@
 package mysql
 
 const (
-	MIN_PROTOCOL_VERSION byte = 10
-	//MAX_PACKET_SIZE      = 1<<24 - 1
-	TIME_FORMAT = "2006-01-02 15:04:05"
+	minProtocolVersion byte = 10
+	//maxPacketSize      = 1<<24 - 1
+	timeFormat = "2006-01-02 15:04:05"
 )
 
 // MySQL constants documentation:
 // http://dev.mysql.com/doc/internals/en/client-server-protocol.html
 
-type ClientFlag uint32
+type clientFlag uint32
 
 const (
-	CLIENT_LONG_PASSWORD ClientFlag = 1 << iota
-	CLIENT_FOUND_ROWS
-	CLIENT_LONG_FLAG
-	CLIENT_CONNECT_WITH_DB
-	CLIENT_NO_SCHEMA
-	CLIENT_COMPRESS
-	CLIENT_ODBC
-	CLIENT_LOCAL_FILES
-	CLIENT_IGNORE_SPACE
-	CLIENT_PROTOCOL_41
-	CLIENT_INTERACTIVE
-	CLIENT_SSL
-	CLIENT_IGNORE_SIGPIPE
-	CLIENT_TRANSACTIONS
-	CLIENT_RESERVED
-	CLIENT_SECURE_CONN
-	CLIENT_MULTI_STATEMENTS
-	CLIENT_MULTI_RESULTS
+	clientLongPassword clientFlag = 1 << iota
+	clientFoundRows
+	clientLongFlag
+	clientConnectWithDB
+	clientNoSchema
+	clientCompress
+	clientODBC
+	clientLocalFiles
+	clientIgnoreSpace
+	clientProtocol41
+	clientInteractive
+	clientSSL
+	clientIgnoreSIGPIPE
+	clientTransactions
+	clientReserved
+	clientSecureConn
+	clientMultiStatements
+	clientMultiResults
 )
 
 type commandType byte
 
 const (
-	COM_QUIT commandType = iota + 1
-	COM_INIT_DB
-	COM_QUERY
-	COM_FIELD_LIST
-	COM_CREATE_DB
-	COM_DROP_DB
-	COM_REFRESH
-	COM_SHUTDOWN
-	COM_STATISTICS
-	COM_PROCESS_INFO
-	COM_CONNECT
-	COM_PROCESS_KILL
-	COM_DEBUG
-	COM_PING
-	COM_TIME
-	COM_DELAYED_INSERT
-	COM_CHANGE_USER
-	COM_BINLOG_DUMP
-	COM_TABLE_DUMP
-	COM_CONNECT_OUT
-	COM_REGISTER_SLAVE
-	COM_STMT_PREPARE
-	COM_STMT_EXECUTE
-	COM_STMT_SEND_LONG_DATA
-	COM_STMT_CLOSE
-	COM_STMT_RESET
-	COM_SET_OPTION
-	COM_STMT_FETCH
+	comQuit commandType = iota + 1
+	comInitDB
+	comQuery
+	comFieldList
+	comCreateDB
+	comDropDB
+	comRefresh
+	comShutdown
+	comStatistics
+	comProcessInfo
+	comConnect
+	comProcessKill
+	comDebug
+	comPing
+	comTime
+	comDelayedInsert
+	comChangeUser
+	comBinlogDump
+	comTableDump
+	comConnectOut
+	comRegiserSlave
+	comStmtPrepare
+	comStmtExecute
+	comStmtSendLongData
+	comStmtClose
+	comStmtReset
+	comSetOption
+	comStmtFetch
 )
 
 const (
-	FIELD_TYPE_DECIMAL byte = iota
-	FIELD_TYPE_TINY
-	FIELD_TYPE_SHORT
-	FIELD_TYPE_LONG
-	FIELD_TYPE_FLOAT
-	FIELD_TYPE_DOUBLE
-	FIELD_TYPE_NULL
-	FIELD_TYPE_TIMESTAMP
-	FIELD_TYPE_LONGLONG
-	FIELD_TYPE_INT24
-	FIELD_TYPE_DATE
-	FIELD_TYPE_TIME
-	FIELD_TYPE_DATETIME
-	FIELD_TYPE_YEAR
-	FIELD_TYPE_NEWDATE
-	FIELD_TYPE_VARCHAR
-	FIELD_TYPE_BIT
+	fieldTypeDecimal byte = iota
+	fieldTypeTiny
+	fieldTypeShort
+	fieldTypeLong
+	fieldTypeFloat
+	fieldTypeDouble
+	fieldTypeNULL
+	fieldTypeTimestamp
+	fieldTypeLongLong
+	fieldTypeInt24
+	fieldTypeDate
+	fieldTypeTime
+	fieldTypeDateTime
+	fieldTypeYear
+	fieldTypeNewDate
+	fieldTypeVarChar
+	fieldTypeBit
 )
 const (
-	FIELD_TYPE_NEWDECIMAL byte = iota + 0xf6
-	FIELD_TYPE_ENUM
-	FIELD_TYPE_SET
-	FIELD_TYPE_TINY_BLOB
-	FIELD_TYPE_MEDIUM_BLOB
-	FIELD_TYPE_LONG_BLOB
-	FIELD_TYPE_BLOB
-	FIELD_TYPE_VAR_STRING
-	FIELD_TYPE_STRING
-	FIELD_TYPE_GEOMETRY
+	fieldTypeNewDecimal byte = iota + 0xf6
+	fieldTypeEnum
+	fieldTypeSet
+	fieldTypeTinyBLOB
+	fieldTypeMediumBLOB
+	fieldTypeLongBLOB
+	fieldTypeBLOB
+	fieldTypeVarString
+	fieldTypeString
+	fieldTypeGeometry
 )
 
-type FieldFlag uint16
+type fieldFlag uint16
 
 const (
-	FLAG_NOT_NULL FieldFlag = 1 << iota
-	FLAG_PRI_KEY
-	FLAG_UNIQUE_KEY
-	FLAG_MULTIPLE_KEY
-	FLAG_BLOB
-	FLAG_UNSIGNED
-	FLAG_ZEROFILL
-	FLAG_BINARY
-	FLAG_ENUM
-	FLAG_AUTO_INCREMENT
-	FLAG_TIMESTAMP
-	FLAG_SET
-	FLAG_UNKNOWN_1
-	FLAG_UNKNOWN_2
-	FLAG_UNKNOWN_3
-	FLAG_UNKNOWN_4
+	flagNotNULL fieldFlag = 1 << iota
+	flagPriKey
+	flagUniqueKey
+	flagMultipleKey
+	flagBLOB
+	flagUnsigned
+	flagZeroFill
+	flagBinary
+	flagEnum
+	flagAutoIncrement
+	flagTimestamp
+	flagSet
+	flagUnknown1
+	flagUnknown2
+	flagUnknown3
+	flagUnknown4
 )

+ 37 - 38
packets.go

@@ -92,11 +92,11 @@ func (mc *mysqlConn) readInitPacket() (err error) {
 	}
 
 	// protocol version [1 byte]
-	if data[0] < MIN_PROTOCOL_VERSION {
+	if data[0] < minProtocolVersion {
 		err = fmt.Errorf(
 			"Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required",
 			data[0],
-			MIN_PROTOCOL_VERSION)
+			minProtocolVersion)
 	}
 
 	// server version [null terminated string]
@@ -110,8 +110,8 @@ func (mc *mysqlConn) readInitPacket() (err error) {
 	pos += 8 + 1
 
 	// capability flags (lower 2 bytes) [2 bytes]
-	mc.flags = ClientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
-	if mc.flags&CLIENT_PROTOCOL_41 == 0 {
+	mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
+	if mc.flags&clientProtocol41 == 0 {
 		err = errors.New("MySQL-Server does not support required Protocol 41+")
 	}
 	pos += 2
@@ -146,13 +146,13 @@ func (mc *mysqlConn) readInitPacket() (err error) {
 func (mc *mysqlConn) writeAuthPacket() error {
 	// Adjust client flags based on server support
 	clientFlags := uint32(
-		CLIENT_PROTOCOL_41 |
-			CLIENT_SECURE_CONN |
-			CLIENT_LONG_PASSWORD |
-			CLIENT_TRANSACTIONS,
+		clientProtocol41 |
+			clientSecureConn |
+			clientLongPassword |
+			clientTransactions,
 	)
-	if mc.flags&CLIENT_LONG_FLAG > 0 {
-		clientFlags |= uint32(CLIENT_LONG_FLAG)
+	if mc.flags&clientLongFlag > 0 {
+		clientFlags |= uint32(clientLongFlag)
 	}
 
 	// User Password
@@ -163,7 +163,7 @@ func (mc *mysqlConn) writeAuthPacket() error {
 
 	// To specify a db name
 	if len(mc.cfg.dbname) > 0 {
-		clientFlags |= uint32(CLIENT_CONNECT_WITH_DB)
+		clientFlags |= uint32(clientConnectWithDB)
 		pktLen += len(mc.cfg.dbname) + 1
 	}
 
@@ -439,7 +439,7 @@ func (mc *mysqlConn) readColumns(count int) (columns []mysqlField, err error) {
 		pos++
 
 		// Flags [16 bit uint]
-		columns[i].flags = FieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
+		columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
 		//pos += 2
 
 		// Decimals [8 bit uint]
@@ -503,7 +503,7 @@ func (mc *mysqlConn) readUntilEOF() (err error) {
 		if err == nil && (data[0] != 254 || len(data) != 5) {
 			continue
 		}
-		return
+		return // Err or EOF
 	}
 	return
 }
@@ -568,26 +568,26 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 		// build NULL-bitmap
 		if args[i] == nil {
 			bitMask += 1 << uint(i)
-			paramTypes[i<<1] = FIELD_TYPE_NULL
+			paramTypes[i<<1] = fieldTypeNULL
 			continue
 		}
 
 		// cache types and values
 		switch args[i].(type) {
 		case int64:
-			paramTypes[i<<1] = FIELD_TYPE_LONGLONG
+			paramTypes[i<<1] = fieldTypeLongLong
 			paramValues[i] = uint64ToBytes(uint64(args[i].(int64)))
 			pktLen += 8
 			continue
 
 		case float64:
-			paramTypes[i<<1] = FIELD_TYPE_DOUBLE
+			paramTypes[i<<1] = fieldTypeDouble
 			paramValues[i] = uint64ToBytes(math.Float64bits(args[i].(float64)))
 			pktLen += 8
 			continue
 
 		case bool:
-			paramTypes[i<<1] = FIELD_TYPE_TINY
+			paramTypes[i<<1] = fieldTypeTiny
 			pktLen++
 			if args[i].(bool) {
 				paramValues[i] = []byte{0x01}
@@ -597,7 +597,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 			continue
 
 		case []byte:
-			paramTypes[i<<1] = FIELD_TYPE_STRING
+			paramTypes[i<<1] = fieldTypeString
 			val := args[i].([]byte)
 			paramValues[i] = append(
 				lengthEncodedIntegerToBytes(uint64(len(val))),
@@ -607,7 +607,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 			continue
 
 		case string:
-			paramTypes[i<<1] = FIELD_TYPE_STRING
+			paramTypes[i<<1] = fieldTypeString
 			val := []byte(args[i].(string))
 			paramValues[i] = append(
 				lengthEncodedIntegerToBytes(uint64(len(val))),
@@ -617,8 +617,8 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 			continue
 
 		case time.Time:
-			paramTypes[i<<1] = FIELD_TYPE_STRING
-			val := []byte(args[i].(time.Time).Format(TIME_FORMAT))
+			paramTypes[i<<1] = fieldTypeString
+			val := []byte(args[i].(time.Time).Format(timeFormat))
 			paramValues[i] = append(
 				lengthEncodedIntegerToBytes(uint64(len(val))),
 				val...,
@@ -640,7 +640,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
 	data[3] = stmt.mc.sequence
 
 	// command [1 byte]
-	data[4] = byte(COM_STMT_EXECUTE)
+	data[4] = byte(comStmtExecute)
 
 	// statement_id [4 bytes]
 	data[5] = byte(stmt.id)
@@ -715,16 +715,16 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 			continue
 		}
 
-		unsigned = rc.columns[i].flags&FLAG_UNSIGNED != 0
+		unsigned = rc.columns[i].flags&flagUnsigned != 0
 
 		// Convert to byte-coded string
 		switch rc.columns[i].fieldType {
-		case FIELD_TYPE_NULL:
+		case fieldTypeNULL:
 			dest[i] = nil
 			continue
 
 		// Numeric Typs
-		case FIELD_TYPE_TINY:
+		case fieldTypeTiny:
 			if unsigned {
 				dest[i] = int64(data[pos])
 			} else {
@@ -733,7 +733,7 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 			pos++
 			continue
 
-		case FIELD_TYPE_SHORT, FIELD_TYPE_YEAR:
+		case fieldTypeShort, fieldTypeYear:
 			if unsigned {
 				dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
 			} else {
@@ -742,7 +742,7 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 			pos += 2
 			continue
 
-		case FIELD_TYPE_INT24, FIELD_TYPE_LONG:
+		case fieldTypeInt24, fieldTypeLong:
 			if unsigned {
 				dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
 			} else {
@@ -751,7 +751,7 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 			pos += 4
 			continue
 
-		case FIELD_TYPE_LONGLONG:
+		case fieldTypeLongLong:
 			if unsigned {
 				val := binary.LittleEndian.Uint64(data[pos : pos+8])
 				if val > math.MaxInt64 {
@@ -765,22 +765,21 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 			pos += 8
 			continue
 
-		case FIELD_TYPE_FLOAT:
+		case fieldTypeFloat:
 			dest[i] = float64(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])))
 			pos += 4
 			continue
 
-		case FIELD_TYPE_DOUBLE:
+		case fieldTypeDouble:
 			dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8]))
 			pos += 8
 			continue
 
 		// Length coded Binary Strings
-		case FIELD_TYPE_DECIMAL, FIELD_TYPE_NEWDECIMAL, FIELD_TYPE_VARCHAR,
-			FIELD_TYPE_BIT, FIELD_TYPE_ENUM, FIELD_TYPE_SET,
-			FIELD_TYPE_TINY_BLOB, FIELD_TYPE_MEDIUM_BLOB, FIELD_TYPE_LONG_BLOB,
-			FIELD_TYPE_BLOB, FIELD_TYPE_VAR_STRING, FIELD_TYPE_STRING,
-			FIELD_TYPE_GEOMETRY:
+		case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
+			fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
+			fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
+			fieldTypeVarString, fieldTypeString, fieldTypeGeometry:
 			var isNull bool
 			dest[i], isNull, n, err = readLengthEnodedString(data[pos:])
 			pos += n
@@ -795,7 +794,7 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 			return // err
 
 		// Date YYYY-MM-DD
-		case FIELD_TYPE_DATE, FIELD_TYPE_NEWDATE:
+		case fieldTypeDate, fieldTypeNewDate:
 			var num uint64
 			var isNull bool
 			num, isNull, n = readLengthEncodedInteger(data[pos:])
@@ -820,7 +819,7 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 			}
 
 		// Time [-][H]HH:MM:SS[.fractal]
-		case FIELD_TYPE_TIME:
+		case fieldTypeTime:
 			var num uint64
 			var isNull bool
 			num, isNull, n = readLengthEncodedInteger(data[pos:])
@@ -871,7 +870,7 @@ func (rc *mysqlRows) readBinaryRow(dest []driver.Value) (err error) {
 			}
 
 		// Timestamp YYYY-MM-DD HH:MM:SS[.fractal]
-		case FIELD_TYPE_TIMESTAMP, FIELD_TYPE_DATETIME:
+		case fieldTypeTimestamp, fieldTypeDateTime:
 			var num uint64
 			var isNull bool
 			num, isNull, n = readLengthEncodedInteger(data[pos:])

+ 1 - 1
rows.go

@@ -18,7 +18,7 @@ import (
 type mysqlField struct {
 	name      string
 	fieldType byte
-	flags     FieldFlag
+	flags     fieldFlag
 }
 
 type mysqlRows struct {

+ 1 - 1
statement.go

@@ -21,7 +21,7 @@ type mysqlStmt struct {
 }
 
 func (stmt *mysqlStmt) Close() (err error) {
-	err = stmt.mc.writeCommandPacketUint32(COM_STMT_CLOSE, stmt.id)
+	err = stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
 	stmt.mc = nil
 	return
 }