connection.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public
  6. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  7. // You can obtain one at http://mozilla.org/MPL/2.0/.
  8. package mysql
  9. import (
  10. "database/sql/driver"
  11. "io"
  12. "net"
  13. "strconv"
  14. "strings"
  15. "time"
  16. )
  17. type mysqlConn struct {
  18. buf buffer
  19. netConn net.Conn
  20. affectedRows uint64
  21. insertId uint64
  22. cfg *Config
  23. maxAllowedPacket int
  24. maxWriteSize int
  25. writeTimeout time.Duration
  26. flags clientFlag
  27. status statusFlag
  28. sequence uint8
  29. parseTime bool
  30. strict bool
  31. }
  32. // Handles parameters set in DSN after the connection is established
  33. func (mc *mysqlConn) handleParams() (err error) {
  34. for param, val := range mc.cfg.Params {
  35. switch param {
  36. // Charset
  37. case "charset":
  38. charsets := strings.Split(val, ",")
  39. for i := range charsets {
  40. // ignore errors here - a charset may not exist
  41. err = mc.exec("SET NAMES " + charsets[i])
  42. if err == nil {
  43. break
  44. }
  45. }
  46. if err != nil {
  47. return
  48. }
  49. // System Vars
  50. default:
  51. err = mc.exec("SET " + param + "=" + val + "")
  52. if err != nil {
  53. return
  54. }
  55. }
  56. }
  57. return
  58. }
  59. func (mc *mysqlConn) Begin() (driver.Tx, error) {
  60. if mc.netConn == nil {
  61. errLog.Print(ErrInvalidConn)
  62. return nil, driver.ErrBadConn
  63. }
  64. err := mc.exec("START TRANSACTION")
  65. if err == nil {
  66. return &mysqlTx{mc}, err
  67. }
  68. return nil, err
  69. }
  70. func (mc *mysqlConn) Close() (err error) {
  71. // Makes Close idempotent
  72. if mc.netConn != nil {
  73. err = mc.writeCommandPacket(comQuit)
  74. }
  75. mc.cleanup()
  76. return
  77. }
  78. // Closes the network connection and unsets internal variables. Do not call this
  79. // function after successfully authentication, call Close instead. This function
  80. // is called before auth or on auth failure because MySQL will have already
  81. // closed the network connection.
  82. func (mc *mysqlConn) cleanup() {
  83. // Makes cleanup idempotent
  84. if mc.netConn != nil {
  85. if err := mc.netConn.Close(); err != nil {
  86. errLog.Print(err)
  87. }
  88. mc.netConn = nil
  89. }
  90. mc.cfg = nil
  91. mc.buf.nc = nil
  92. }
  93. func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
  94. if mc.netConn == nil {
  95. errLog.Print(ErrInvalidConn)
  96. return nil, driver.ErrBadConn
  97. }
  98. // Send command
  99. err := mc.writeCommandPacketStr(comStmtPrepare, query)
  100. if err != nil {
  101. return nil, err
  102. }
  103. stmt := &mysqlStmt{
  104. mc: mc,
  105. }
  106. // Read Result
  107. columnCount, err := stmt.readPrepareResultPacket()
  108. if err == nil {
  109. if stmt.paramCount > 0 {
  110. if err = mc.readUntilEOF(); err != nil {
  111. return nil, err
  112. }
  113. }
  114. if columnCount > 0 {
  115. err = mc.readUntilEOF()
  116. }
  117. }
  118. return stmt, err
  119. }
  120. func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
  121. // Number of ? should be same to len(args)
  122. if strings.Count(query, "?") != len(args) {
  123. return "", driver.ErrSkip
  124. }
  125. buf := mc.buf.takeCompleteBuffer()
  126. if buf == nil {
  127. // can not take the buffer. Something must be wrong with the connection
  128. errLog.Print(ErrBusyBuffer)
  129. return "", driver.ErrBadConn
  130. }
  131. buf = buf[:0]
  132. argPos := 0
  133. for i := 0; i < len(query); i++ {
  134. q := strings.IndexByte(query[i:], '?')
  135. if q == -1 {
  136. buf = append(buf, query[i:]...)
  137. break
  138. }
  139. buf = append(buf, query[i:i+q]...)
  140. i += q
  141. arg := args[argPos]
  142. argPos++
  143. if arg == nil {
  144. buf = append(buf, "NULL"...)
  145. continue
  146. }
  147. switch v := arg.(type) {
  148. case int64:
  149. buf = strconv.AppendInt(buf, v, 10)
  150. case float64:
  151. buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
  152. case bool:
  153. if v {
  154. buf = append(buf, '1')
  155. } else {
  156. buf = append(buf, '0')
  157. }
  158. case time.Time:
  159. if v.IsZero() {
  160. buf = append(buf, "'0000-00-00'"...)
  161. } else {
  162. v := v.In(mc.cfg.Loc)
  163. v = v.Add(time.Nanosecond * 500) // To round under microsecond
  164. year := v.Year()
  165. year100 := year / 100
  166. year1 := year % 100
  167. month := v.Month()
  168. day := v.Day()
  169. hour := v.Hour()
  170. minute := v.Minute()
  171. second := v.Second()
  172. micro := v.Nanosecond() / 1000
  173. buf = append(buf, []byte{
  174. '\'',
  175. digits10[year100], digits01[year100],
  176. digits10[year1], digits01[year1],
  177. '-',
  178. digits10[month], digits01[month],
  179. '-',
  180. digits10[day], digits01[day],
  181. ' ',
  182. digits10[hour], digits01[hour],
  183. ':',
  184. digits10[minute], digits01[minute],
  185. ':',
  186. digits10[second], digits01[second],
  187. }...)
  188. if micro != 0 {
  189. micro10000 := micro / 10000
  190. micro100 := micro / 100 % 100
  191. micro1 := micro % 100
  192. buf = append(buf, []byte{
  193. '.',
  194. digits10[micro10000], digits01[micro10000],
  195. digits10[micro100], digits01[micro100],
  196. digits10[micro1], digits01[micro1],
  197. }...)
  198. }
  199. buf = append(buf, '\'')
  200. }
  201. case []byte:
  202. if v == nil {
  203. buf = append(buf, "NULL"...)
  204. } else {
  205. buf = append(buf, "_binary'"...)
  206. if mc.status&statusNoBackslashEscapes == 0 {
  207. buf = escapeBytesBackslash(buf, v)
  208. } else {
  209. buf = escapeBytesQuotes(buf, v)
  210. }
  211. buf = append(buf, '\'')
  212. }
  213. case string:
  214. buf = append(buf, '\'')
  215. if mc.status&statusNoBackslashEscapes == 0 {
  216. buf = escapeStringBackslash(buf, v)
  217. } else {
  218. buf = escapeStringQuotes(buf, v)
  219. }
  220. buf = append(buf, '\'')
  221. default:
  222. return "", driver.ErrSkip
  223. }
  224. if len(buf)+4 > mc.maxAllowedPacket {
  225. return "", driver.ErrSkip
  226. }
  227. }
  228. if argPos != len(args) {
  229. return "", driver.ErrSkip
  230. }
  231. return string(buf), nil
  232. }
  233. func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
  234. if mc.netConn == nil {
  235. errLog.Print(ErrInvalidConn)
  236. return nil, driver.ErrBadConn
  237. }
  238. if len(args) != 0 {
  239. if !mc.cfg.InterpolateParams {
  240. return nil, driver.ErrSkip
  241. }
  242. // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
  243. prepared, err := mc.interpolateParams(query, args)
  244. if err != nil {
  245. return nil, err
  246. }
  247. query = prepared
  248. }
  249. mc.affectedRows = 0
  250. mc.insertId = 0
  251. err := mc.exec(query)
  252. if err == nil {
  253. return &mysqlResult{
  254. affectedRows: int64(mc.affectedRows),
  255. insertId: int64(mc.insertId),
  256. }, err
  257. }
  258. return nil, err
  259. }
  260. // Internal function to execute commands
  261. func (mc *mysqlConn) exec(query string) error {
  262. // Send command
  263. if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
  264. return err
  265. }
  266. // Read Result
  267. resLen, err := mc.readResultSetHeaderPacket()
  268. if err != nil {
  269. return err
  270. }
  271. if resLen > 0 {
  272. // columns
  273. if err := mc.readUntilEOF(); err != nil {
  274. return err
  275. }
  276. // rows
  277. if err := mc.readUntilEOF(); err != nil {
  278. return err
  279. }
  280. }
  281. return mc.discardResults()
  282. }
  283. func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
  284. if mc.netConn == nil {
  285. errLog.Print(ErrInvalidConn)
  286. return nil, driver.ErrBadConn
  287. }
  288. if len(args) != 0 {
  289. if !mc.cfg.InterpolateParams {
  290. return nil, driver.ErrSkip
  291. }
  292. // try client-side prepare to reduce roundtrip
  293. prepared, err := mc.interpolateParams(query, args)
  294. if err != nil {
  295. return nil, err
  296. }
  297. query = prepared
  298. }
  299. // Send command
  300. err := mc.writeCommandPacketStr(comQuery, query)
  301. if err == nil {
  302. // Read Result
  303. var resLen int
  304. resLen, err = mc.readResultSetHeaderPacket()
  305. if err == nil {
  306. rows := new(textRows)
  307. rows.mc = mc
  308. if resLen == 0 {
  309. rows.rs.done = true
  310. switch err := rows.NextResultSet(); err {
  311. case nil, io.EOF:
  312. return rows, nil
  313. default:
  314. return nil, err
  315. }
  316. }
  317. // Columns
  318. rows.rs.columns, err = mc.readColumns(resLen)
  319. return rows, err
  320. }
  321. }
  322. return nil, err
  323. }
  324. // Gets the value of the given MySQL System Variable
  325. // The returned byte slice is only valid until the next read
  326. func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
  327. // Send command
  328. if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
  329. return nil, err
  330. }
  331. // Read Result
  332. resLen, err := mc.readResultSetHeaderPacket()
  333. if err == nil {
  334. rows := new(textRows)
  335. rows.mc = mc
  336. rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
  337. if resLen > 0 {
  338. // Columns
  339. if err := mc.readUntilEOF(); err != nil {
  340. return nil, err
  341. }
  342. }
  343. dest := make([]driver.Value, resLen)
  344. if err = rows.readRow(dest); err == nil {
  345. return dest[0].([]byte), mc.readUntilEOF()
  346. }
  347. }
  348. return nil, err
  349. }