| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426 |
- // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
- //
- // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
- //
- // This Source Code Form is subject to the terms of the Mozilla Public
- // License, v. 2.0. If a copy of the MPL was not distributed with this file,
- // You can obtain one at http://mozilla.org/MPL/2.0/.
- package mysql
- import (
- "crypto/tls"
- "database/sql/driver"
- "errors"
- "net"
- "strconv"
- "strings"
- "time"
- )
- type mysqlConn struct {
- buf buffer
- netConn net.Conn
- affectedRows uint64
- insertId uint64
- cfg *config
- maxPacketAllowed int
- maxWriteSize int
- flags clientFlag
- status statusFlag
- sequence uint8
- parseTime bool
- strict bool
- }
- type config struct {
- user string
- passwd string
- net string
- addr string
- dbname string
- params map[string]string
- loc *time.Location
- tls *tls.Config
- timeout time.Duration
- collation uint8
- allowAllFiles bool
- allowOldPasswords bool
- clientFoundRows bool
- columnsWithAlias bool
- interpolateParams bool
- }
- // Handles parameters set in DSN after the connection is established
- func (mc *mysqlConn) handleParams() (err error) {
- for param, val := range mc.cfg.params {
- switch param {
- // Charset
- case "charset":
- charsets := strings.Split(val, ",")
- for i := range charsets {
- // ignore errors here - a charset may not exist
- err = mc.exec("SET NAMES " + charsets[i])
- if err == nil {
- break
- }
- }
- if err != nil {
- return
- }
- // time.Time parsing
- case "parseTime":
- var isBool bool
- mc.parseTime, isBool = readBool(val)
- if !isBool {
- return errors.New("Invalid Bool value: " + val)
- }
- // Strict mode
- case "strict":
- var isBool bool
- mc.strict, isBool = readBool(val)
- if !isBool {
- return errors.New("Invalid Bool value: " + val)
- }
- // Compression
- case "compress":
- err = errors.New("Compression not implemented yet")
- return
- // System Vars
- default:
- err = mc.exec("SET " + param + "=" + val + "")
- if err != nil {
- return
- }
- }
- }
- return
- }
- func (mc *mysqlConn) Begin() (driver.Tx, error) {
- if mc.netConn == nil {
- errLog.Print(ErrInvalidConn)
- return nil, driver.ErrBadConn
- }
- err := mc.exec("START TRANSACTION")
- if err == nil {
- return &mysqlTx{mc}, err
- }
- return nil, err
- }
- func (mc *mysqlConn) Close() (err error) {
- // Makes Close idempotent
- if mc.netConn != nil {
- err = mc.writeCommandPacket(comQuit)
- if err == nil {
- err = mc.netConn.Close()
- } else {
- mc.netConn.Close()
- }
- mc.netConn = nil
- }
- mc.cfg = nil
- mc.buf.rd = nil
- return
- }
- func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
- if mc.netConn == nil {
- errLog.Print(ErrInvalidConn)
- return nil, driver.ErrBadConn
- }
- // Send command
- err := mc.writeCommandPacketStr(comStmtPrepare, query)
- if err != nil {
- return nil, err
- }
- stmt := &mysqlStmt{
- mc: mc,
- }
- // Read Result
- columnCount, err := stmt.readPrepareResultPacket()
- if err == nil {
- if stmt.paramCount > 0 {
- if err = mc.readUntilEOF(); err != nil {
- return nil, err
- }
- }
- if columnCount > 0 {
- err = mc.readUntilEOF()
- }
- }
- return stmt, err
- }
- // https://github.com/mysql/mysql-server/blob/mysql-5.7.5/libmysql/libmysql.c#L1150-L1156
- func (mc *mysqlConn) escapeBytes(buf, v []byte) []byte {
- buf = append(buf, '\'')
- if mc.status&statusNoBackslashEscapes == 0 {
- buf = escapeBackslash(buf, v)
- } else {
- buf = escapeQuotes(buf, v)
- }
- return append(buf, '\'')
- }
- // estimateParamLength calculates upper bound of string length from types.
- func estimateParamLength(args []driver.Value) (int, bool) {
- l := 0
- for _, a := range args {
- switch v := a.(type) {
- case int64, float64:
- // 24 (-1.7976931348623157e+308) may be upper bound. But I'm not sure.
- l += 25
- case bool:
- l += 1 // 0 or 1
- case time.Time:
- l += 30 // '1234-12-23 12:34:56.777777'
- case string:
- l += len(v)*2 + 2
- case []byte:
- l += len(v)*2 + 2
- default:
- return 0, false
- }
- }
- return l, true
- }
- func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
- estimated, ok := estimateParamLength(args)
- if !ok {
- return "", driver.ErrSkip
- }
- estimated += len(query)
- buf := make([]byte, 0, estimated)
- argPos := 0
- for i := 0; i < len(query); i++ {
- c := query[i]
- if c != '?' {
- buf = append(buf, c)
- continue
- }
- arg := args[argPos]
- argPos++
- if arg == nil {
- buf = append(buf, []byte("NULL")...)
- continue
- }
- switch v := arg.(type) {
- case int64:
- buf = strconv.AppendInt(buf, v, 10)
- case float64:
- buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
- case bool:
- if v {
- buf = append(buf, '1')
- } else {
- buf = append(buf, '0')
- }
- case time.Time:
- if v.IsZero() {
- buf = append(buf, []byte("'0000-00-00'")...)
- } else {
- v := v.In(mc.cfg.loc)
- year := v.Year()
- month := v.Month()
- day := v.Day()
- hour := v.Hour()
- minute := v.Minute()
- second := v.Second()
- micro := v.Nanosecond() / 1000
- buf = append(buf, []byte{
- byte('\''),
- byte('0' + year/1000),
- byte('0' + year/100%10),
- byte('0' + year/10%10),
- byte('0' + year%10),
- byte('-'),
- byte('0' + month/10),
- byte('0' + month%10),
- byte('-'),
- byte('0' + day/10),
- byte('0' + day%10),
- byte(' '),
- byte('0' + hour/10),
- byte('0' + hour%10),
- byte(':'),
- byte('0' + minute/10),
- byte('0' + minute%10),
- byte(':'),
- byte('0' + second/10),
- byte('0' + second%10),
- }...)
- if micro != 0 {
- buf = append(buf, []byte{
- byte('.'),
- byte('0' + micro/100000),
- byte('0' + micro/10000%10),
- byte('0' + micro/1000%10),
- byte('0' + micro/100%10),
- byte('0' + micro/10%10),
- byte('0' + micro%10),
- }...)
- }
- buf = append(buf, '\'')
- }
- case []byte:
- if v == nil {
- buf = append(buf, []byte("NULL")...)
- } else {
- buf = mc.escapeBytes(buf, v)
- }
- case string:
- buf = mc.escapeBytes(buf, []byte(v))
- default:
- return "", driver.ErrSkip
- }
- if len(buf)+4 > mc.maxPacketAllowed {
- return "", driver.ErrSkip
- }
- }
- if argPos != len(args) {
- return "", driver.ErrSkip
- }
- return string(buf), nil
- }
- func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
- if mc.netConn == nil {
- errLog.Print(ErrInvalidConn)
- return nil, driver.ErrBadConn
- }
- if len(args) != 0 {
- if !mc.cfg.interpolateParams {
- return nil, driver.ErrSkip
- }
- // try client-side prepare to reduce roundtrip
- prepared, err := mc.interpolateParams(query, args)
- if err != nil {
- return nil, err
- }
- query = prepared
- args = nil
- }
- mc.affectedRows = 0
- mc.insertId = 0
- err := mc.exec(query)
- if err == nil {
- return &mysqlResult{
- affectedRows: int64(mc.affectedRows),
- insertId: int64(mc.insertId),
- }, err
- }
- return nil, err
- }
- // Internal function to execute commands
- func (mc *mysqlConn) exec(query string) error {
- // Send command
- err := mc.writeCommandPacketStr(comQuery, query)
- if err != nil {
- return err
- }
- // Read Result
- resLen, err := mc.readResultSetHeaderPacket()
- if err == nil && resLen > 0 {
- if err = mc.readUntilEOF(); err != nil {
- return err
- }
- err = mc.readUntilEOF()
- }
- return err
- }
- func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
- if mc.netConn == nil {
- errLog.Print(ErrInvalidConn)
- return nil, driver.ErrBadConn
- }
- if len(args) != 0 {
- if !mc.cfg.interpolateParams {
- return nil, driver.ErrSkip
- }
- // try client-side prepare to reduce roundtrip
- prepared, err := mc.interpolateParams(query, args)
- if err != nil {
- return nil, err
- }
- query = prepared
- args = nil
- }
- // Send command
- err := mc.writeCommandPacketStr(comQuery, query)
- if err == nil {
- // Read Result
- var resLen int
- resLen, err = mc.readResultSetHeaderPacket()
- if err == nil {
- rows := new(textRows)
- rows.mc = mc
- if resLen == 0 {
- // no columns, no more data
- return emptyRows{}, nil
- }
- // Columns
- rows.columns, err = mc.readColumns(resLen)
- return rows, err
- }
- }
- return nil, err
- }
- // Gets the value of the given MySQL System Variable
- // The returned byte slice is only valid until the next read
- func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
- // Send command
- if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
- return nil, err
- }
- // Read Result
- resLen, err := mc.readResultSetHeaderPacket()
- if err == nil {
- rows := new(textRows)
- rows.mc = mc
- if resLen > 0 {
- // Columns
- if err := mc.readUntilEOF(); err != nil {
- return nil, err
- }
- }
- dest := make([]driver.Value, resLen)
- if err = rows.readRow(dest); err == nil {
- return dest[0].([]byte), mc.readUntilEOF()
- }
- }
- return nil, err
- }
|