123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649 |
- // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
- //
- // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
- //
- // This Source Code Form is subject to the terms of the Mozilla Public
- // License, v. 2.0. If a copy of the MPL was not distributed with this file,
- // You can obtain one at http://mozilla.org/MPL/2.0/.
- package mysql
- import (
- "context"
- "database/sql"
- "database/sql/driver"
- "io"
- "net"
- "strconv"
- "strings"
- "time"
- )
- type mysqlConn struct {
- buf buffer
- netConn net.Conn
- rawConn net.Conn // underlying connection when netConn is TLS connection.
- affectedRows uint64
- insertId uint64
- cfg *Config
- maxAllowedPacket int
- maxWriteSize int
- writeTimeout time.Duration
- flags clientFlag
- status statusFlag
- sequence uint8
- parseTime bool
- reset bool // set when the Go SQL package calls ResetSession
- // for context support (Go 1.8+)
- watching bool
- watcher chan<- context.Context
- closech chan struct{}
- finished chan<- struct{}
- canceled atomicError // set non-nil if conn is canceled
- closed atomicBool // set when conn is closed, before closech is closed
- }
- // Handles parameters set in DSN after the connection is established
- func (mc *mysqlConn) handleParams() (err error) {
- for param, val := range mc.cfg.Params {
- switch param {
- // Charset
- case "charset":
- charsets := strings.Split(val, ",")
- for i := range charsets {
- // ignore errors here - a charset may not exist
- err = mc.exec("SET NAMES " + charsets[i])
- if err == nil {
- break
- }
- }
- if err != nil {
- return
- }
- // System Vars
- default:
- err = mc.exec("SET " + param + "=" + val + "")
- if err != nil {
- return
- }
- }
- }
- return
- }
- func (mc *mysqlConn) markBadConn(err error) error {
- if mc == nil {
- return err
- }
- if err != errBadConnNoWrite {
- return err
- }
- return driver.ErrBadConn
- }
- func (mc *mysqlConn) Begin() (driver.Tx, error) {
- return mc.begin(false)
- }
- func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
- if mc.closed.IsSet() {
- errLog.Print(ErrInvalidConn)
- return nil, driver.ErrBadConn
- }
- var q string
- if readOnly {
- q = "START TRANSACTION READ ONLY"
- } else {
- q = "START TRANSACTION"
- }
- err := mc.exec(q)
- if err == nil {
- return &mysqlTx{mc}, err
- }
- return nil, mc.markBadConn(err)
- }
- func (mc *mysqlConn) Close() (err error) {
- // Makes Close idempotent
- if !mc.closed.IsSet() {
- err = mc.writeCommandPacket(comQuit)
- }
- mc.cleanup()
- return
- }
- // Closes the network connection and unsets internal variables. Do not call this
- // function after successfully authentication, call Close instead. This function
- // is called before auth or on auth failure because MySQL will have already
- // closed the network connection.
- func (mc *mysqlConn) cleanup() {
- if !mc.closed.TrySet(true) {
- return
- }
- // Makes cleanup idempotent
- close(mc.closech)
- if mc.netConn == nil {
- return
- }
- if err := mc.netConn.Close(); err != nil {
- errLog.Print(err)
- }
- }
- func (mc *mysqlConn) error() error {
- if mc.closed.IsSet() {
- if err := mc.canceled.Value(); err != nil {
- return err
- }
- return ErrInvalidConn
- }
- return nil
- }
- func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
- if mc.closed.IsSet() {
- errLog.Print(ErrInvalidConn)
- return nil, driver.ErrBadConn
- }
- // Send command
- err := mc.writeCommandPacketStr(comStmtPrepare, query)
- if err != nil {
- return nil, mc.markBadConn(err)
- }
- stmt := &mysqlStmt{
- mc: mc,
- }
- // Read Result
- columnCount, err := stmt.readPrepareResultPacket()
- if err == nil {
- if stmt.paramCount > 0 {
- if err = mc.readUntilEOF(); err != nil {
- return nil, err
- }
- }
- if columnCount > 0 {
- err = mc.readUntilEOF()
- }
- }
- return stmt, err
- }
- func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
- // Number of ? should be same to len(args)
- if strings.Count(query, "?") != len(args) {
- return "", driver.ErrSkip
- }
- buf, err := mc.buf.takeCompleteBuffer()
- if err != nil {
- // can not take the buffer. Something must be wrong with the connection
- errLog.Print(err)
- return "", ErrInvalidConn
- }
- buf = buf[:0]
- argPos := 0
- for i := 0; i < len(query); i++ {
- q := strings.IndexByte(query[i:], '?')
- if q == -1 {
- buf = append(buf, query[i:]...)
- break
- }
- buf = append(buf, query[i:i+q]...)
- i += q
- arg := args[argPos]
- argPos++
- if arg == nil {
- buf = append(buf, "NULL"...)
- continue
- }
- switch v := arg.(type) {
- case int64:
- buf = strconv.AppendInt(buf, v, 10)
- case uint64:
- // Handle uint64 explicitly because our custom ConvertValue emits unsigned values
- buf = strconv.AppendUint(buf, v, 10)
- case float64:
- buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
- case bool:
- if v {
- buf = append(buf, '1')
- } else {
- buf = append(buf, '0')
- }
- case time.Time:
- if v.IsZero() {
- buf = append(buf, "'0000-00-00'"...)
- } else {
- v := v.In(mc.cfg.Loc)
- v = v.Add(time.Nanosecond * 500) // To round under microsecond
- year := v.Year()
- year100 := year / 100
- year1 := year % 100
- month := v.Month()
- day := v.Day()
- hour := v.Hour()
- minute := v.Minute()
- second := v.Second()
- micro := v.Nanosecond() / 1000
- buf = append(buf, []byte{
- '\'',
- digits10[year100], digits01[year100],
- digits10[year1], digits01[year1],
- '-',
- digits10[month], digits01[month],
- '-',
- digits10[day], digits01[day],
- ' ',
- digits10[hour], digits01[hour],
- ':',
- digits10[minute], digits01[minute],
- ':',
- digits10[second], digits01[second],
- }...)
- if micro != 0 {
- micro10000 := micro / 10000
- micro100 := micro / 100 % 100
- micro1 := micro % 100
- buf = append(buf, []byte{
- '.',
- digits10[micro10000], digits01[micro10000],
- digits10[micro100], digits01[micro100],
- digits10[micro1], digits01[micro1],
- }...)
- }
- buf = append(buf, '\'')
- }
- case []byte:
- if v == nil {
- buf = append(buf, "NULL"...)
- } else {
- buf = append(buf, "_binary'"...)
- if mc.status&statusNoBackslashEscapes == 0 {
- buf = escapeBytesBackslash(buf, v)
- } else {
- buf = escapeBytesQuotes(buf, v)
- }
- buf = append(buf, '\'')
- }
- case string:
- buf = append(buf, '\'')
- if mc.status&statusNoBackslashEscapes == 0 {
- buf = escapeStringBackslash(buf, v)
- } else {
- buf = escapeStringQuotes(buf, v)
- }
- buf = append(buf, '\'')
- default:
- return "", driver.ErrSkip
- }
- if len(buf)+4 > mc.maxAllowedPacket {
- return "", driver.ErrSkip
- }
- }
- if argPos != len(args) {
- return "", driver.ErrSkip
- }
- return string(buf), nil
- }
- func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
- if mc.closed.IsSet() {
- errLog.Print(ErrInvalidConn)
- return nil, driver.ErrBadConn
- }
- if len(args) != 0 {
- if !mc.cfg.InterpolateParams {
- return nil, driver.ErrSkip
- }
- // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
- prepared, err := mc.interpolateParams(query, args)
- if err != nil {
- return nil, err
- }
- query = prepared
- }
- mc.affectedRows = 0
- mc.insertId = 0
- err := mc.exec(query)
- if err == nil {
- return &mysqlResult{
- affectedRows: int64(mc.affectedRows),
- insertId: int64(mc.insertId),
- }, err
- }
- return nil, mc.markBadConn(err)
- }
- // Internal function to execute commands
- func (mc *mysqlConn) exec(query string) error {
- // Send command
- if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
- return mc.markBadConn(err)
- }
- // Read Result
- resLen, err := mc.readResultSetHeaderPacket()
- if err != nil {
- return err
- }
- if resLen > 0 {
- // columns
- if err := mc.readUntilEOF(); err != nil {
- return err
- }
- // rows
- if err := mc.readUntilEOF(); err != nil {
- return err
- }
- }
- return mc.discardResults()
- }
- func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
- return mc.query(query, args)
- }
- func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
- if mc.closed.IsSet() {
- errLog.Print(ErrInvalidConn)
- return nil, driver.ErrBadConn
- }
- if len(args) != 0 {
- if !mc.cfg.InterpolateParams {
- return nil, driver.ErrSkip
- }
- // try client-side prepare to reduce roundtrip
- prepared, err := mc.interpolateParams(query, args)
- if err != nil {
- return nil, err
- }
- query = prepared
- }
- // Send command
- err := mc.writeCommandPacketStr(comQuery, query)
- if err == nil {
- // Read Result
- var resLen int
- resLen, err = mc.readResultSetHeaderPacket()
- if err == nil {
- rows := new(textRows)
- rows.mc = mc
- if resLen == 0 {
- rows.rs.done = true
- switch err := rows.NextResultSet(); err {
- case nil, io.EOF:
- return rows, nil
- default:
- return nil, err
- }
- }
- // Columns
- rows.rs.columns, err = mc.readColumns(resLen)
- return rows, err
- }
- }
- return nil, mc.markBadConn(err)
- }
- // Gets the value of the given MySQL System Variable
- // The returned byte slice is only valid until the next read
- func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
- // Send command
- if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
- return nil, err
- }
- // Read Result
- resLen, err := mc.readResultSetHeaderPacket()
- if err == nil {
- rows := new(textRows)
- rows.mc = mc
- rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
- if resLen > 0 {
- // Columns
- if err := mc.readUntilEOF(); err != nil {
- return nil, err
- }
- }
- dest := make([]driver.Value, resLen)
- if err = rows.readRow(dest); err == nil {
- return dest[0].([]byte), mc.readUntilEOF()
- }
- }
- return nil, err
- }
- // finish is called when the query has canceled.
- func (mc *mysqlConn) cancel(err error) {
- mc.canceled.Set(err)
- mc.cleanup()
- }
- // finish is called when the query has succeeded.
- func (mc *mysqlConn) finish() {
- if !mc.watching || mc.finished == nil {
- return
- }
- select {
- case mc.finished <- struct{}{}:
- mc.watching = false
- case <-mc.closech:
- }
- }
- // Ping implements driver.Pinger interface
- func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
- if mc.closed.IsSet() {
- errLog.Print(ErrInvalidConn)
- return driver.ErrBadConn
- }
- if err = mc.watchCancel(ctx); err != nil {
- return
- }
- defer mc.finish()
- if err = mc.writeCommandPacket(comPing); err != nil {
- return mc.markBadConn(err)
- }
- return mc.readResultOK()
- }
- // BeginTx implements driver.ConnBeginTx interface
- func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
- if err := mc.watchCancel(ctx); err != nil {
- return nil, err
- }
- defer mc.finish()
- if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
- level, err := mapIsolationLevel(opts.Isolation)
- if err != nil {
- return nil, err
- }
- err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level)
- if err != nil {
- return nil, err
- }
- }
- return mc.begin(opts.ReadOnly)
- }
- func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
- dargs, err := namedValueToValue(args)
- if err != nil {
- return nil, err
- }
- if err := mc.watchCancel(ctx); err != nil {
- return nil, err
- }
- rows, err := mc.query(query, dargs)
- if err != nil {
- mc.finish()
- return nil, err
- }
- rows.finish = mc.finish
- return rows, err
- }
- func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
- dargs, err := namedValueToValue(args)
- if err != nil {
- return nil, err
- }
- if err := mc.watchCancel(ctx); err != nil {
- return nil, err
- }
- defer mc.finish()
- return mc.Exec(query, dargs)
- }
- func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
- if err := mc.watchCancel(ctx); err != nil {
- return nil, err
- }
- stmt, err := mc.Prepare(query)
- mc.finish()
- if err != nil {
- return nil, err
- }
- select {
- default:
- case <-ctx.Done():
- stmt.Close()
- return nil, ctx.Err()
- }
- return stmt, nil
- }
- func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
- dargs, err := namedValueToValue(args)
- if err != nil {
- return nil, err
- }
- if err := stmt.mc.watchCancel(ctx); err != nil {
- return nil, err
- }
- rows, err := stmt.query(dargs)
- if err != nil {
- stmt.mc.finish()
- return nil, err
- }
- rows.finish = stmt.mc.finish
- return rows, err
- }
- func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
- dargs, err := namedValueToValue(args)
- if err != nil {
- return nil, err
- }
- if err := stmt.mc.watchCancel(ctx); err != nil {
- return nil, err
- }
- defer stmt.mc.finish()
- return stmt.Exec(dargs)
- }
- func (mc *mysqlConn) watchCancel(ctx context.Context) error {
- if mc.watching {
- // Reach here if canceled,
- // so the connection is already invalid
- mc.cleanup()
- return nil
- }
- // When ctx is already cancelled, don't watch it.
- if err := ctx.Err(); err != nil {
- return err
- }
- // When ctx is not cancellable, don't watch it.
- if ctx.Done() == nil {
- return nil
- }
- // When watcher is not alive, can't watch it.
- if mc.watcher == nil {
- return nil
- }
- mc.watching = true
- mc.watcher <- ctx
- return nil
- }
- func (mc *mysqlConn) startWatcher() {
- watcher := make(chan context.Context, 1)
- mc.watcher = watcher
- finished := make(chan struct{})
- mc.finished = finished
- go func() {
- for {
- var ctx context.Context
- select {
- case ctx = <-watcher:
- case <-mc.closech:
- return
- }
- select {
- case <-ctx.Done():
- mc.cancel(ctx.Err())
- case <-finished:
- case <-mc.closech:
- return
- }
- }
- }()
- }
- func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
- nv.Value, err = converter{}.ConvertValue(nv.Value)
- return
- }
- // ResetSession implements driver.SessionResetter.
- // (From Go 1.10)
- func (mc *mysqlConn) ResetSession(ctx context.Context) error {
- if mc.closed.IsSet() {
- return driver.ErrBadConn
- }
- mc.reset = true
- return nil
- }
|