connection.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public
  6. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  7. // You can obtain one at http://mozilla.org/MPL/2.0/.
  8. package mysql
  9. import (
  10. "database/sql/driver"
  11. "io"
  12. "net"
  13. "strconv"
  14. "strings"
  15. "sync"
  16. "sync/atomic"
  17. "time"
  18. )
  19. // a copy of context.Context for Go 1.7 and later.
  20. type mysqlContext interface {
  21. Done() <-chan struct{}
  22. Err() error
  23. // They are defined in context.Context, but go-mysql-driver does not use them.
  24. // Deadline() (deadline time.Time, ok bool)
  25. // Value(key interface{}) interface{}
  26. }
  27. type mysqlConn struct {
  28. buf buffer
  29. netConn net.Conn
  30. affectedRows uint64
  31. insertId uint64
  32. cfg *Config
  33. maxAllowedPacket int
  34. maxWriteSize int
  35. writeTimeout time.Duration
  36. flags clientFlag
  37. status statusFlag
  38. sequence uint8
  39. parseTime bool
  40. strict bool
  41. // for context support (From Go 1.8)
  42. watching bool
  43. watcher chan<- mysqlContext
  44. closech chan struct{}
  45. finished chan<- struct{}
  46. // set non-zero when conn is closed, before closech is closed.
  47. // accessed atomically.
  48. closed int32
  49. mu sync.Mutex // guards following fields
  50. canceledErr error // set non-nil if conn is canceled
  51. }
  52. // Handles parameters set in DSN after the connection is established
  53. func (mc *mysqlConn) handleParams() (err error) {
  54. for param, val := range mc.cfg.Params {
  55. switch param {
  56. // Charset
  57. case "charset":
  58. charsets := strings.Split(val, ",")
  59. for i := range charsets {
  60. // ignore errors here - a charset may not exist
  61. err = mc.exec("SET NAMES " + charsets[i])
  62. if err == nil {
  63. break
  64. }
  65. }
  66. if err != nil {
  67. return
  68. }
  69. // System Vars
  70. default:
  71. err = mc.exec("SET " + param + "=" + val + "")
  72. if err != nil {
  73. return
  74. }
  75. }
  76. }
  77. return
  78. }
  79. func (mc *mysqlConn) Begin() (driver.Tx, error) {
  80. if mc.isBroken() {
  81. errLog.Print(ErrInvalidConn)
  82. return nil, driver.ErrBadConn
  83. }
  84. err := mc.exec("START TRANSACTION")
  85. if err == nil {
  86. return &mysqlTx{mc}, err
  87. }
  88. return nil, err
  89. }
  90. func (mc *mysqlConn) Close() (err error) {
  91. // Makes Close idempotent
  92. if !mc.isBroken() {
  93. err = mc.writeCommandPacket(comQuit)
  94. }
  95. mc.cleanup()
  96. return
  97. }
  98. // Closes the network connection and unsets internal variables. Do not call this
  99. // function after successfully authentication, call Close instead. This function
  100. // is called before auth or on auth failure because MySQL will have already
  101. // closed the network connection.
  102. func (mc *mysqlConn) cleanup() {
  103. if atomic.SwapInt32(&mc.closed, 1) != 0 {
  104. return
  105. }
  106. // Makes cleanup idempotent
  107. close(mc.closech)
  108. if mc.netConn == nil {
  109. return
  110. }
  111. if err := mc.netConn.Close(); err != nil {
  112. errLog.Print(err)
  113. }
  114. }
  115. func (mc *mysqlConn) isBroken() bool {
  116. return atomic.LoadInt32(&mc.closed) != 0
  117. }
  118. func (mc *mysqlConn) error() error {
  119. if mc.isBroken() {
  120. if err := mc.canceled(); err != nil {
  121. return err
  122. }
  123. return ErrInvalidConn
  124. }
  125. return nil
  126. }
  127. func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
  128. if mc.isBroken() {
  129. errLog.Print(ErrInvalidConn)
  130. return nil, driver.ErrBadConn
  131. }
  132. // Send command
  133. err := mc.writeCommandPacketStr(comStmtPrepare, query)
  134. if err != nil {
  135. return nil, err
  136. }
  137. stmt := &mysqlStmt{
  138. mc: mc,
  139. }
  140. // Read Result
  141. columnCount, err := stmt.readPrepareResultPacket()
  142. if err == nil {
  143. if stmt.paramCount > 0 {
  144. if err = mc.readUntilEOF(); err != nil {
  145. return nil, err
  146. }
  147. }
  148. if columnCount > 0 {
  149. err = mc.readUntilEOF()
  150. }
  151. }
  152. return stmt, err
  153. }
  154. func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
  155. // Number of ? should be same to len(args)
  156. if strings.Count(query, "?") != len(args) {
  157. return "", driver.ErrSkip
  158. }
  159. buf := mc.buf.takeCompleteBuffer()
  160. if buf == nil {
  161. // can not take the buffer. Something must be wrong with the connection
  162. errLog.Print(ErrBusyBuffer)
  163. return "", driver.ErrBadConn
  164. }
  165. buf = buf[:0]
  166. argPos := 0
  167. for i := 0; i < len(query); i++ {
  168. q := strings.IndexByte(query[i:], '?')
  169. if q == -1 {
  170. buf = append(buf, query[i:]...)
  171. break
  172. }
  173. buf = append(buf, query[i:i+q]...)
  174. i += q
  175. arg := args[argPos]
  176. argPos++
  177. if arg == nil {
  178. buf = append(buf, "NULL"...)
  179. continue
  180. }
  181. switch v := arg.(type) {
  182. case int64:
  183. buf = strconv.AppendInt(buf, v, 10)
  184. case float64:
  185. buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
  186. case bool:
  187. if v {
  188. buf = append(buf, '1')
  189. } else {
  190. buf = append(buf, '0')
  191. }
  192. case time.Time:
  193. if v.IsZero() {
  194. buf = append(buf, "'0000-00-00'"...)
  195. } else {
  196. v := v.In(mc.cfg.Loc)
  197. v = v.Add(time.Nanosecond * 500) // To round under microsecond
  198. year := v.Year()
  199. year100 := year / 100
  200. year1 := year % 100
  201. month := v.Month()
  202. day := v.Day()
  203. hour := v.Hour()
  204. minute := v.Minute()
  205. second := v.Second()
  206. micro := v.Nanosecond() / 1000
  207. buf = append(buf, []byte{
  208. '\'',
  209. digits10[year100], digits01[year100],
  210. digits10[year1], digits01[year1],
  211. '-',
  212. digits10[month], digits01[month],
  213. '-',
  214. digits10[day], digits01[day],
  215. ' ',
  216. digits10[hour], digits01[hour],
  217. ':',
  218. digits10[minute], digits01[minute],
  219. ':',
  220. digits10[second], digits01[second],
  221. }...)
  222. if micro != 0 {
  223. micro10000 := micro / 10000
  224. micro100 := micro / 100 % 100
  225. micro1 := micro % 100
  226. buf = append(buf, []byte{
  227. '.',
  228. digits10[micro10000], digits01[micro10000],
  229. digits10[micro100], digits01[micro100],
  230. digits10[micro1], digits01[micro1],
  231. }...)
  232. }
  233. buf = append(buf, '\'')
  234. }
  235. case []byte:
  236. if v == nil {
  237. buf = append(buf, "NULL"...)
  238. } else {
  239. buf = append(buf, "_binary'"...)
  240. if mc.status&statusNoBackslashEscapes == 0 {
  241. buf = escapeBytesBackslash(buf, v)
  242. } else {
  243. buf = escapeBytesQuotes(buf, v)
  244. }
  245. buf = append(buf, '\'')
  246. }
  247. case string:
  248. buf = append(buf, '\'')
  249. if mc.status&statusNoBackslashEscapes == 0 {
  250. buf = escapeStringBackslash(buf, v)
  251. } else {
  252. buf = escapeStringQuotes(buf, v)
  253. }
  254. buf = append(buf, '\'')
  255. default:
  256. return "", driver.ErrSkip
  257. }
  258. if len(buf)+4 > mc.maxAllowedPacket {
  259. return "", driver.ErrSkip
  260. }
  261. }
  262. if argPos != len(args) {
  263. return "", driver.ErrSkip
  264. }
  265. return string(buf), nil
  266. }
  267. func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
  268. if mc.isBroken() {
  269. errLog.Print(ErrInvalidConn)
  270. return nil, driver.ErrBadConn
  271. }
  272. if len(args) != 0 {
  273. if !mc.cfg.InterpolateParams {
  274. return nil, driver.ErrSkip
  275. }
  276. // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
  277. prepared, err := mc.interpolateParams(query, args)
  278. if err != nil {
  279. return nil, err
  280. }
  281. query = prepared
  282. }
  283. mc.affectedRows = 0
  284. mc.insertId = 0
  285. err := mc.exec(query)
  286. if err == nil {
  287. return &mysqlResult{
  288. affectedRows: int64(mc.affectedRows),
  289. insertId: int64(mc.insertId),
  290. }, err
  291. }
  292. return nil, err
  293. }
  294. // Internal function to execute commands
  295. func (mc *mysqlConn) exec(query string) error {
  296. // Send command
  297. if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
  298. return err
  299. }
  300. // Read Result
  301. resLen, err := mc.readResultSetHeaderPacket()
  302. if err != nil {
  303. return err
  304. }
  305. if resLen > 0 {
  306. // columns
  307. if err := mc.readUntilEOF(); err != nil {
  308. return err
  309. }
  310. // rows
  311. if err := mc.readUntilEOF(); err != nil {
  312. return err
  313. }
  314. }
  315. return mc.discardResults()
  316. }
  317. func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
  318. return mc.query(query, args)
  319. }
  320. func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
  321. if mc.isBroken() {
  322. errLog.Print(ErrInvalidConn)
  323. return nil, driver.ErrBadConn
  324. }
  325. if len(args) != 0 {
  326. if !mc.cfg.InterpolateParams {
  327. return nil, driver.ErrSkip
  328. }
  329. // try client-side prepare to reduce roundtrip
  330. prepared, err := mc.interpolateParams(query, args)
  331. if err != nil {
  332. return nil, err
  333. }
  334. query = prepared
  335. }
  336. // Send command
  337. err := mc.writeCommandPacketStr(comQuery, query)
  338. if err == nil {
  339. // Read Result
  340. var resLen int
  341. resLen, err = mc.readResultSetHeaderPacket()
  342. if err == nil {
  343. rows := new(textRows)
  344. rows.mc = mc
  345. if resLen == 0 {
  346. rows.rs.done = true
  347. switch err := rows.NextResultSet(); err {
  348. case nil, io.EOF:
  349. return rows, nil
  350. default:
  351. return nil, err
  352. }
  353. }
  354. // Columns
  355. rows.rs.columns, err = mc.readColumns(resLen)
  356. return rows, err
  357. }
  358. }
  359. return nil, err
  360. }
  361. // Gets the value of the given MySQL System Variable
  362. // The returned byte slice is only valid until the next read
  363. func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
  364. // Send command
  365. if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
  366. return nil, err
  367. }
  368. // Read Result
  369. resLen, err := mc.readResultSetHeaderPacket()
  370. if err == nil {
  371. rows := new(textRows)
  372. rows.mc = mc
  373. rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
  374. if resLen > 0 {
  375. // Columns
  376. if err := mc.readUntilEOF(); err != nil {
  377. return nil, err
  378. }
  379. }
  380. dest := make([]driver.Value, resLen)
  381. if err = rows.readRow(dest); err == nil {
  382. return dest[0].([]byte), mc.readUntilEOF()
  383. }
  384. }
  385. return nil, err
  386. }
  387. // finish is called when the query has canceled.
  388. func (mc *mysqlConn) cancel(err error) {
  389. mc.mu.Lock()
  390. mc.canceledErr = err
  391. mc.mu.Unlock()
  392. mc.cleanup()
  393. }
  394. // canceled returns non-nil if the connection was closed due to context cancelation.
  395. func (mc *mysqlConn) canceled() error {
  396. mc.mu.Lock()
  397. defer mc.mu.Unlock()
  398. return mc.canceledErr
  399. }
  400. // finish is called when the query has succeeded.
  401. func (mc *mysqlConn) finish() {
  402. if !mc.watching || mc.finished == nil {
  403. return
  404. }
  405. select {
  406. case mc.finished <- struct{}{}:
  407. mc.watching = false
  408. case <-mc.closech:
  409. }
  410. }