connection.go 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public
  6. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  7. // You can obtain one at http://mozilla.org/MPL/2.0/.
  8. package mysql
  9. import (
  10. "database/sql/driver"
  11. "net"
  12. "strconv"
  13. "strings"
  14. "time"
  15. )
  16. type mysqlConn struct {
  17. buf buffer
  18. netConn net.Conn
  19. affectedRows uint64
  20. insertId uint64
  21. cfg *Config
  22. maxPacketAllowed int
  23. maxWriteSize int
  24. writeTimeout time.Duration
  25. flags clientFlag
  26. status statusFlag
  27. sequence uint8
  28. parseTime bool
  29. strict bool
  30. }
  31. // Handles parameters set in DSN after the connection is established
  32. func (mc *mysqlConn) handleParams() (err error) {
  33. for param, val := range mc.cfg.Params {
  34. switch param {
  35. // Charset
  36. case "charset":
  37. charsets := strings.Split(val, ",")
  38. for i := range charsets {
  39. // ignore errors here - a charset may not exist
  40. err = mc.exec("SET NAMES " + charsets[i])
  41. if err == nil {
  42. break
  43. }
  44. }
  45. if err != nil {
  46. return
  47. }
  48. // System Vars
  49. default:
  50. err = mc.exec("SET " + param + "=" + val + "")
  51. if err != nil {
  52. return
  53. }
  54. }
  55. }
  56. return
  57. }
  58. func (mc *mysqlConn) Begin() (driver.Tx, error) {
  59. if mc.netConn == nil {
  60. errLog.Print(ErrInvalidConn)
  61. return nil, driver.ErrBadConn
  62. }
  63. err := mc.exec("START TRANSACTION")
  64. if err == nil {
  65. return &mysqlTx{mc}, err
  66. }
  67. return nil, err
  68. }
  69. func (mc *mysqlConn) Close() (err error) {
  70. // Makes Close idempotent
  71. if mc.netConn != nil {
  72. err = mc.writeCommandPacket(comQuit)
  73. }
  74. mc.cleanup()
  75. return
  76. }
  77. // Closes the network connection and unsets internal variables. Do not call this
  78. // function after successfully authentication, call Close instead. This function
  79. // is called before auth or on auth failure because MySQL will have already
  80. // closed the network connection.
  81. func (mc *mysqlConn) cleanup() {
  82. // Makes cleanup idempotent
  83. if mc.netConn != nil {
  84. if err := mc.netConn.Close(); err != nil {
  85. errLog.Print(err)
  86. }
  87. mc.netConn = nil
  88. }
  89. mc.cfg = nil
  90. mc.buf.nc = nil
  91. }
  92. func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
  93. if mc.netConn == nil {
  94. errLog.Print(ErrInvalidConn)
  95. return nil, driver.ErrBadConn
  96. }
  97. // Send command
  98. err := mc.writeCommandPacketStr(comStmtPrepare, query)
  99. if err != nil {
  100. return nil, err
  101. }
  102. stmt := &mysqlStmt{
  103. mc: mc,
  104. }
  105. // Read Result
  106. columnCount, err := stmt.readPrepareResultPacket()
  107. if err == nil {
  108. if stmt.paramCount > 0 {
  109. if err = mc.readUntilEOF(); err != nil {
  110. return nil, err
  111. }
  112. }
  113. if columnCount > 0 {
  114. err = mc.readUntilEOF()
  115. }
  116. }
  117. return stmt, err
  118. }
  119. func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
  120. buf := mc.buf.takeCompleteBuffer()
  121. if buf == nil {
  122. // can not take the buffer. Something must be wrong with the connection
  123. errLog.Print(ErrBusyBuffer)
  124. return "", driver.ErrBadConn
  125. }
  126. buf = buf[:0]
  127. argPos := 0
  128. for i := 0; i < len(query); i++ {
  129. q := strings.IndexByte(query[i:], '?')
  130. if q == -1 {
  131. buf = append(buf, query[i:]...)
  132. break
  133. }
  134. buf = append(buf, query[i:i+q]...)
  135. i += q
  136. if argPos >= len(args) {
  137. return "", driver.ErrSkip
  138. }
  139. arg := args[argPos]
  140. argPos++
  141. if arg == nil {
  142. buf = append(buf, "NULL"...)
  143. continue
  144. }
  145. switch v := arg.(type) {
  146. case int64:
  147. buf = strconv.AppendInt(buf, v, 10)
  148. case float64:
  149. buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
  150. case bool:
  151. if v {
  152. buf = append(buf, '1')
  153. } else {
  154. buf = append(buf, '0')
  155. }
  156. case time.Time:
  157. if v.IsZero() {
  158. buf = append(buf, "'0000-00-00'"...)
  159. } else {
  160. v := v.In(mc.cfg.Loc)
  161. v = v.Add(time.Nanosecond * 500) // To round under microsecond
  162. year := v.Year()
  163. year100 := year / 100
  164. year1 := year % 100
  165. month := v.Month()
  166. day := v.Day()
  167. hour := v.Hour()
  168. minute := v.Minute()
  169. second := v.Second()
  170. micro := v.Nanosecond() / 1000
  171. buf = append(buf, []byte{
  172. '\'',
  173. digits10[year100], digits01[year100],
  174. digits10[year1], digits01[year1],
  175. '-',
  176. digits10[month], digits01[month],
  177. '-',
  178. digits10[day], digits01[day],
  179. ' ',
  180. digits10[hour], digits01[hour],
  181. ':',
  182. digits10[minute], digits01[minute],
  183. ':',
  184. digits10[second], digits01[second],
  185. }...)
  186. if micro != 0 {
  187. micro10000 := micro / 10000
  188. micro100 := micro / 100 % 100
  189. micro1 := micro % 100
  190. buf = append(buf, []byte{
  191. '.',
  192. digits10[micro10000], digits01[micro10000],
  193. digits10[micro100], digits01[micro100],
  194. digits10[micro1], digits01[micro1],
  195. }...)
  196. }
  197. buf = append(buf, '\'')
  198. }
  199. case []byte:
  200. if v == nil {
  201. buf = append(buf, "NULL"...)
  202. } else {
  203. buf = append(buf, "_binary'"...)
  204. if mc.status&statusNoBackslashEscapes == 0 {
  205. buf = escapeBytesBackslash(buf, v)
  206. } else {
  207. buf = escapeBytesQuotes(buf, v)
  208. }
  209. buf = append(buf, '\'')
  210. }
  211. case string:
  212. buf = append(buf, '\'')
  213. if mc.status&statusNoBackslashEscapes == 0 {
  214. buf = escapeStringBackslash(buf, v)
  215. } else {
  216. buf = escapeStringQuotes(buf, v)
  217. }
  218. buf = append(buf, '\'')
  219. default:
  220. return "", driver.ErrSkip
  221. }
  222. if len(buf)+4 > mc.maxPacketAllowed {
  223. return "", driver.ErrSkip
  224. }
  225. }
  226. if argPos != len(args) {
  227. return "", driver.ErrSkip
  228. }
  229. return string(buf), nil
  230. }
  231. func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
  232. if mc.netConn == nil {
  233. errLog.Print(ErrInvalidConn)
  234. return nil, driver.ErrBadConn
  235. }
  236. if len(args) != 0 {
  237. if !mc.cfg.InterpolateParams {
  238. return nil, driver.ErrSkip
  239. }
  240. // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
  241. prepared, err := mc.interpolateParams(query, args)
  242. if err != nil {
  243. return nil, err
  244. }
  245. query = prepared
  246. args = nil
  247. }
  248. mc.affectedRows = 0
  249. mc.insertId = 0
  250. err := mc.exec(query)
  251. if err == nil {
  252. return &mysqlResult{
  253. affectedRows: int64(mc.affectedRows),
  254. insertId: int64(mc.insertId),
  255. }, err
  256. }
  257. return nil, err
  258. }
  259. // Internal function to execute commands
  260. func (mc *mysqlConn) exec(query string) error {
  261. // Send command
  262. err := mc.writeCommandPacketStr(comQuery, query)
  263. if err != nil {
  264. return err
  265. }
  266. // Read Result
  267. resLen, err := mc.readResultSetHeaderPacket()
  268. if err == nil && resLen > 0 {
  269. if err = mc.readUntilEOF(); err != nil {
  270. return err
  271. }
  272. err = mc.readUntilEOF()
  273. }
  274. return err
  275. }
  276. func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
  277. if mc.netConn == nil {
  278. errLog.Print(ErrInvalidConn)
  279. return nil, driver.ErrBadConn
  280. }
  281. if len(args) != 0 {
  282. if !mc.cfg.InterpolateParams {
  283. return nil, driver.ErrSkip
  284. }
  285. // try client-side prepare to reduce roundtrip
  286. prepared, err := mc.interpolateParams(query, args)
  287. if err != nil {
  288. return nil, err
  289. }
  290. query = prepared
  291. args = nil
  292. }
  293. // Send command
  294. err := mc.writeCommandPacketStr(comQuery, query)
  295. if err == nil {
  296. // Read Result
  297. var resLen int
  298. resLen, err = mc.readResultSetHeaderPacket()
  299. if err == nil {
  300. rows := new(textRows)
  301. rows.mc = mc
  302. if resLen == 0 {
  303. // no columns, no more data
  304. return emptyRows{}, nil
  305. }
  306. // Columns
  307. rows.columns, err = mc.readColumns(resLen)
  308. return rows, err
  309. }
  310. }
  311. return nil, err
  312. }
  313. // Gets the value of the given MySQL System Variable
  314. // The returned byte slice is only valid until the next read
  315. func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
  316. // Send command
  317. if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
  318. return nil, err
  319. }
  320. // Read Result
  321. resLen, err := mc.readResultSetHeaderPacket()
  322. if err == nil {
  323. rows := new(textRows)
  324. rows.mc = mc
  325. rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
  326. if resLen > 0 {
  327. // Columns
  328. if err := mc.readUntilEOF(); err != nil {
  329. return nil, err
  330. }
  331. }
  332. dest := make([]driver.Value, resLen)
  333. if err = rows.readRow(dest); err == nil {
  334. return dest[0].([]byte), mc.readUntilEOF()
  335. }
  336. }
  337. return nil, err
  338. }