connection_go18.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public
  6. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  7. // You can obtain one at http://mozilla.org/MPL/2.0/.
  8. // +build go1.8
  9. package mysql
  10. import (
  11. "context"
  12. "database/sql"
  13. "database/sql/driver"
  14. "errors"
  15. )
  16. // Ping implements driver.Pinger interface
  17. func (mc *mysqlConn) Ping(ctx context.Context) error {
  18. if mc.isBroken() {
  19. errLog.Print(ErrInvalidConn)
  20. return driver.ErrBadConn
  21. }
  22. if err := mc.watchCancel(ctx); err != nil {
  23. return err
  24. }
  25. defer mc.finish()
  26. if err := mc.writeCommandPacket(comPing); err != nil {
  27. return err
  28. }
  29. if _, err := mc.readResultOK(); err != nil {
  30. return err
  31. }
  32. return nil
  33. }
  34. // BeginTx implements driver.ConnBeginTx interface
  35. func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
  36. if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
  37. // TODO: support isolation levels
  38. return nil, errors.New("mysql: isolation levels not supported")
  39. }
  40. if opts.ReadOnly {
  41. // TODO: support read-only transactions
  42. return nil, errors.New("mysql: read-only transactions not supported")
  43. }
  44. if err := mc.watchCancel(ctx); err != nil {
  45. return nil, err
  46. }
  47. tx, err := mc.Begin()
  48. mc.finish()
  49. if err != nil {
  50. return nil, err
  51. }
  52. select {
  53. default:
  54. case <-ctx.Done():
  55. tx.Rollback()
  56. return nil, ctx.Err()
  57. }
  58. return tx, err
  59. }
  60. func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
  61. dargs, err := namedValueToValue(args)
  62. if err != nil {
  63. return nil, err
  64. }
  65. if err := mc.watchCancel(ctx); err != nil {
  66. return nil, err
  67. }
  68. rows, err := mc.query(query, dargs)
  69. if err != nil {
  70. mc.finish()
  71. return nil, err
  72. }
  73. rows.finish = mc.finish
  74. return rows, err
  75. }
  76. func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
  77. dargs, err := namedValueToValue(args)
  78. if err != nil {
  79. return nil, err
  80. }
  81. if err := mc.watchCancel(ctx); err != nil {
  82. return nil, err
  83. }
  84. defer mc.finish()
  85. return mc.Exec(query, dargs)
  86. }
  87. func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
  88. if err := mc.watchCancel(ctx); err != nil {
  89. return nil, err
  90. }
  91. stmt, err := mc.Prepare(query)
  92. mc.finish()
  93. if err != nil {
  94. return nil, err
  95. }
  96. select {
  97. default:
  98. case <-ctx.Done():
  99. stmt.Close()
  100. return nil, ctx.Err()
  101. }
  102. return stmt, nil
  103. }
  104. func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
  105. dargs, err := namedValueToValue(args)
  106. if err != nil {
  107. return nil, err
  108. }
  109. if err := stmt.mc.watchCancel(ctx); err != nil {
  110. return nil, err
  111. }
  112. rows, err := stmt.query(dargs)
  113. if err != nil {
  114. stmt.mc.finish()
  115. return nil, err
  116. }
  117. rows.finish = stmt.mc.finish
  118. return rows, err
  119. }
  120. func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
  121. dargs, err := namedValueToValue(args)
  122. if err != nil {
  123. return nil, err
  124. }
  125. if err := stmt.mc.watchCancel(ctx); err != nil {
  126. return nil, err
  127. }
  128. defer stmt.mc.finish()
  129. return stmt.Exec(dargs)
  130. }
  131. func (mc *mysqlConn) watchCancel(ctx context.Context) error {
  132. if mc.watching {
  133. // Reach here if canceled,
  134. // so the connection is already invalid
  135. mc.cleanup()
  136. return nil
  137. }
  138. if ctx.Done() == nil {
  139. return nil
  140. }
  141. mc.watching = true
  142. select {
  143. default:
  144. case <-ctx.Done():
  145. return ctx.Err()
  146. }
  147. if mc.watcher == nil {
  148. return nil
  149. }
  150. mc.watcher <- ctx
  151. return nil
  152. }
  153. func (mc *mysqlConn) startWatcher() {
  154. watcher := make(chan mysqlContext, 1)
  155. mc.watcher = watcher
  156. finished := make(chan struct{})
  157. mc.finished = finished
  158. go func() {
  159. for {
  160. var ctx mysqlContext
  161. select {
  162. case ctx = <-watcher:
  163. case <-mc.closech:
  164. return
  165. }
  166. select {
  167. case <-ctx.Done():
  168. mc.cancel(ctx.Err())
  169. case <-finished:
  170. case <-mc.closech:
  171. return
  172. }
  173. }
  174. }()
  175. }