connection.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public
  6. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  7. // You can obtain one at http://mozilla.org/MPL/2.0/.
  8. package mysql
  9. import (
  10. "database/sql/driver"
  11. "io"
  12. "net"
  13. "strconv"
  14. "strings"
  15. "time"
  16. )
  17. // a copy of context.Context for Go 1.7 and earlier
  18. type mysqlContext interface {
  19. Done() <-chan struct{}
  20. Err() error
  21. // defined in context.Context, but not used in this driver:
  22. // Deadline() (deadline time.Time, ok bool)
  23. // Value(key interface{}) interface{}
  24. }
  25. type mysqlConn struct {
  26. buf buffer
  27. netConn net.Conn
  28. affectedRows uint64
  29. insertId uint64
  30. cfg *Config
  31. maxAllowedPacket int
  32. maxWriteSize int
  33. writeTimeout time.Duration
  34. flags clientFlag
  35. status statusFlag
  36. sequence uint8
  37. parseTime bool
  38. strict bool
  39. // for context support (Go 1.8+)
  40. watching bool
  41. watcher chan<- mysqlContext
  42. closech chan struct{}
  43. finished chan<- struct{}
  44. canceled atomicError // set non-nil if conn is canceled
  45. closed atomicBool // set when conn is closed, before closech is closed
  46. }
  47. // Handles parameters set in DSN after the connection is established
  48. func (mc *mysqlConn) handleParams() (err error) {
  49. for param, val := range mc.cfg.Params {
  50. switch param {
  51. // Charset
  52. case "charset":
  53. charsets := strings.Split(val, ",")
  54. for i := range charsets {
  55. // ignore errors here - a charset may not exist
  56. err = mc.exec("SET NAMES " + charsets[i])
  57. if err == nil {
  58. break
  59. }
  60. }
  61. if err != nil {
  62. return
  63. }
  64. // System Vars
  65. default:
  66. err = mc.exec("SET " + param + "=" + val + "")
  67. if err != nil {
  68. return
  69. }
  70. }
  71. }
  72. return
  73. }
  74. func (mc *mysqlConn) markBadConn(err error) error {
  75. if mc == nil {
  76. return err
  77. }
  78. if err != errBadConnNoWrite {
  79. return err
  80. }
  81. return driver.ErrBadConn
  82. }
  83. func (mc *mysqlConn) Begin() (driver.Tx, error) {
  84. return mc.begin(false)
  85. }
  86. func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
  87. if mc.closed.IsSet() {
  88. errLog.Print(ErrInvalidConn)
  89. return nil, driver.ErrBadConn
  90. }
  91. var q string
  92. if readOnly {
  93. q = "START TRANSACTION READ ONLY"
  94. } else {
  95. q = "START TRANSACTION"
  96. }
  97. err := mc.exec(q)
  98. if err == nil {
  99. return &mysqlTx{mc}, err
  100. }
  101. return nil, mc.markBadConn(err)
  102. }
  103. func (mc *mysqlConn) Close() (err error) {
  104. // Makes Close idempotent
  105. if !mc.closed.IsSet() {
  106. err = mc.writeCommandPacket(comQuit)
  107. }
  108. mc.cleanup()
  109. return
  110. }
  111. // Closes the network connection and unsets internal variables. Do not call this
  112. // function after successfully authentication, call Close instead. This function
  113. // is called before auth or on auth failure because MySQL will have already
  114. // closed the network connection.
  115. func (mc *mysqlConn) cleanup() {
  116. if !mc.closed.TrySet(true) {
  117. return
  118. }
  119. // Makes cleanup idempotent
  120. close(mc.closech)
  121. if mc.netConn == nil {
  122. return
  123. }
  124. if err := mc.netConn.Close(); err != nil {
  125. errLog.Print(err)
  126. }
  127. }
  128. func (mc *mysqlConn) error() error {
  129. if mc.closed.IsSet() {
  130. if err := mc.canceled.Value(); err != nil {
  131. return err
  132. }
  133. return ErrInvalidConn
  134. }
  135. return nil
  136. }
  137. func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
  138. if mc.closed.IsSet() {
  139. errLog.Print(ErrInvalidConn)
  140. return nil, driver.ErrBadConn
  141. }
  142. // Send command
  143. err := mc.writeCommandPacketStr(comStmtPrepare, query)
  144. if err != nil {
  145. return nil, mc.markBadConn(err)
  146. }
  147. stmt := &mysqlStmt{
  148. mc: mc,
  149. }
  150. // Read Result
  151. columnCount, err := stmt.readPrepareResultPacket()
  152. if err == nil {
  153. if stmt.paramCount > 0 {
  154. if err = mc.readUntilEOF(); err != nil {
  155. return nil, err
  156. }
  157. }
  158. if columnCount > 0 {
  159. err = mc.readUntilEOF()
  160. }
  161. }
  162. return stmt, err
  163. }
  164. func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
  165. // Number of ? should be same to len(args)
  166. if strings.Count(query, "?") != len(args) {
  167. return "", driver.ErrSkip
  168. }
  169. buf := mc.buf.takeCompleteBuffer()
  170. if buf == nil {
  171. // can not take the buffer. Something must be wrong with the connection
  172. errLog.Print(ErrBusyBuffer)
  173. return "", ErrInvalidConn
  174. }
  175. buf = buf[:0]
  176. argPos := 0
  177. for i := 0; i < len(query); i++ {
  178. q := strings.IndexByte(query[i:], '?')
  179. if q == -1 {
  180. buf = append(buf, query[i:]...)
  181. break
  182. }
  183. buf = append(buf, query[i:i+q]...)
  184. i += q
  185. arg := args[argPos]
  186. argPos++
  187. if arg == nil {
  188. buf = append(buf, "NULL"...)
  189. continue
  190. }
  191. switch v := arg.(type) {
  192. case int64:
  193. buf = strconv.AppendInt(buf, v, 10)
  194. case float64:
  195. buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
  196. case bool:
  197. if v {
  198. buf = append(buf, '1')
  199. } else {
  200. buf = append(buf, '0')
  201. }
  202. case time.Time:
  203. if v.IsZero() {
  204. buf = append(buf, "'0000-00-00'"...)
  205. } else {
  206. v := v.In(mc.cfg.Loc)
  207. v = v.Add(time.Nanosecond * 500) // To round under microsecond
  208. year := v.Year()
  209. year100 := year / 100
  210. year1 := year % 100
  211. month := v.Month()
  212. day := v.Day()
  213. hour := v.Hour()
  214. minute := v.Minute()
  215. second := v.Second()
  216. micro := v.Nanosecond() / 1000
  217. buf = append(buf, []byte{
  218. '\'',
  219. digits10[year100], digits01[year100],
  220. digits10[year1], digits01[year1],
  221. '-',
  222. digits10[month], digits01[month],
  223. '-',
  224. digits10[day], digits01[day],
  225. ' ',
  226. digits10[hour], digits01[hour],
  227. ':',
  228. digits10[minute], digits01[minute],
  229. ':',
  230. digits10[second], digits01[second],
  231. }...)
  232. if micro != 0 {
  233. micro10000 := micro / 10000
  234. micro100 := micro / 100 % 100
  235. micro1 := micro % 100
  236. buf = append(buf, []byte{
  237. '.',
  238. digits10[micro10000], digits01[micro10000],
  239. digits10[micro100], digits01[micro100],
  240. digits10[micro1], digits01[micro1],
  241. }...)
  242. }
  243. buf = append(buf, '\'')
  244. }
  245. case []byte:
  246. if v == nil {
  247. buf = append(buf, "NULL"...)
  248. } else {
  249. buf = append(buf, "_binary'"...)
  250. if mc.status&statusNoBackslashEscapes == 0 {
  251. buf = escapeBytesBackslash(buf, v)
  252. } else {
  253. buf = escapeBytesQuotes(buf, v)
  254. }
  255. buf = append(buf, '\'')
  256. }
  257. case string:
  258. buf = append(buf, '\'')
  259. if mc.status&statusNoBackslashEscapes == 0 {
  260. buf = escapeStringBackslash(buf, v)
  261. } else {
  262. buf = escapeStringQuotes(buf, v)
  263. }
  264. buf = append(buf, '\'')
  265. default:
  266. return "", driver.ErrSkip
  267. }
  268. if len(buf)+4 > mc.maxAllowedPacket {
  269. return "", driver.ErrSkip
  270. }
  271. }
  272. if argPos != len(args) {
  273. return "", driver.ErrSkip
  274. }
  275. return string(buf), nil
  276. }
  277. func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
  278. if mc.closed.IsSet() {
  279. errLog.Print(ErrInvalidConn)
  280. return nil, driver.ErrBadConn
  281. }
  282. if len(args) != 0 {
  283. if !mc.cfg.InterpolateParams {
  284. return nil, driver.ErrSkip
  285. }
  286. // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
  287. prepared, err := mc.interpolateParams(query, args)
  288. if err != nil {
  289. return nil, err
  290. }
  291. query = prepared
  292. }
  293. mc.affectedRows = 0
  294. mc.insertId = 0
  295. err := mc.exec(query)
  296. if err == nil {
  297. return &mysqlResult{
  298. affectedRows: int64(mc.affectedRows),
  299. insertId: int64(mc.insertId),
  300. }, err
  301. }
  302. return nil, mc.markBadConn(err)
  303. }
  304. // Internal function to execute commands
  305. func (mc *mysqlConn) exec(query string) error {
  306. // Send command
  307. if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
  308. return mc.markBadConn(err)
  309. }
  310. // Read Result
  311. resLen, err := mc.readResultSetHeaderPacket()
  312. if err != nil {
  313. return err
  314. }
  315. if resLen > 0 {
  316. // columns
  317. if err := mc.readUntilEOF(); err != nil {
  318. return err
  319. }
  320. // rows
  321. if err := mc.readUntilEOF(); err != nil {
  322. return err
  323. }
  324. }
  325. return mc.discardResults()
  326. }
  327. func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
  328. return mc.query(query, args)
  329. }
  330. func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
  331. if mc.closed.IsSet() {
  332. errLog.Print(ErrInvalidConn)
  333. return nil, driver.ErrBadConn
  334. }
  335. if len(args) != 0 {
  336. if !mc.cfg.InterpolateParams {
  337. return nil, driver.ErrSkip
  338. }
  339. // try client-side prepare to reduce roundtrip
  340. prepared, err := mc.interpolateParams(query, args)
  341. if err != nil {
  342. return nil, err
  343. }
  344. query = prepared
  345. }
  346. // Send command
  347. err := mc.writeCommandPacketStr(comQuery, query)
  348. if err == nil {
  349. // Read Result
  350. var resLen int
  351. resLen, err = mc.readResultSetHeaderPacket()
  352. if err == nil {
  353. rows := new(textRows)
  354. rows.mc = mc
  355. if resLen == 0 {
  356. rows.rs.done = true
  357. switch err := rows.NextResultSet(); err {
  358. case nil, io.EOF:
  359. return rows, nil
  360. default:
  361. return nil, err
  362. }
  363. }
  364. // Columns
  365. rows.rs.columns, err = mc.readColumns(resLen)
  366. return rows, err
  367. }
  368. }
  369. return nil, mc.markBadConn(err)
  370. }
  371. // Gets the value of the given MySQL System Variable
  372. // The returned byte slice is only valid until the next read
  373. func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
  374. // Send command
  375. if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
  376. return nil, err
  377. }
  378. // Read Result
  379. resLen, err := mc.readResultSetHeaderPacket()
  380. if err == nil {
  381. rows := new(textRows)
  382. rows.mc = mc
  383. rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
  384. if resLen > 0 {
  385. // Columns
  386. if err := mc.readUntilEOF(); err != nil {
  387. return nil, err
  388. }
  389. }
  390. dest := make([]driver.Value, resLen)
  391. if err = rows.readRow(dest); err == nil {
  392. return dest[0].([]byte), mc.readUntilEOF()
  393. }
  394. }
  395. return nil, err
  396. }
  397. // finish is called when the query has canceled.
  398. func (mc *mysqlConn) cancel(err error) {
  399. mc.canceled.Set(err)
  400. mc.cleanup()
  401. }
  402. // finish is called when the query has succeeded.
  403. func (mc *mysqlConn) finish() {
  404. if !mc.watching || mc.finished == nil {
  405. return
  406. }
  407. select {
  408. case mc.finished <- struct{}{}:
  409. mc.watching = false
  410. case <-mc.closech:
  411. }
  412. }