driver.go 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. // Package godrv implements database/sql MySQL driver.
  2. package godrv
  3. import (
  4. "database/sql"
  5. "database/sql/driver"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net"
  10. "strconv"
  11. "strings"
  12. "time"
  13. "github.com/ziutek/mymysql/mysql"
  14. "github.com/ziutek/mymysql/native"
  15. )
  16. type conn struct {
  17. my mysql.Conn
  18. }
  19. type rowsRes struct {
  20. row mysql.Row
  21. my mysql.Result
  22. simpleQuery mysql.Stmt
  23. }
  24. func errFilter(err error) error {
  25. if err == io.ErrUnexpectedEOF {
  26. return driver.ErrBadConn
  27. }
  28. if _, ok := err.(net.Error); ok {
  29. return driver.ErrBadConn
  30. }
  31. return err
  32. }
  33. func join(a []string) string {
  34. n := 0
  35. for _, s := range a {
  36. n += len(s)
  37. }
  38. b := make([]byte, n)
  39. n = 0
  40. for _, s := range a {
  41. n += copy(b[n:], s)
  42. }
  43. return string(b)
  44. }
  45. func (c conn) parseQuery(query string, args []driver.Value) (string, error) {
  46. if len(args) == 0 {
  47. return query, nil
  48. }
  49. if strings.ContainsAny(query, `'"`) {
  50. return "", nil
  51. }
  52. q := make([]string, 2*len(args)+1)
  53. n := 0
  54. for _, a := range args {
  55. i := strings.IndexRune(query, '?')
  56. if i == -1 {
  57. return "", errors.New("number of parameters doesn't match number of placeholders")
  58. }
  59. var s string
  60. switch v := a.(type) {
  61. case nil:
  62. s = "NULL"
  63. case string:
  64. s = "'" + c.my.Escape(v) + "'"
  65. case []byte:
  66. s = "'" + c.my.Escape(string(v)) + "'"
  67. case int64:
  68. s = strconv.FormatInt(v, 10)
  69. case time.Time:
  70. s = "'" + v.Format(mysql.TimeFormat) + "'"
  71. case bool:
  72. if v {
  73. s = "1"
  74. } else {
  75. s = "0"
  76. }
  77. case float64:
  78. s = strconv.FormatFloat(v, 'e', 12, 64)
  79. default:
  80. panic(fmt.Sprintf("%v (%T) can't be handled by godrv", v, v))
  81. }
  82. q[n] = query[:i]
  83. q[n+1] = s
  84. query = query[i+1:]
  85. n += 2
  86. }
  87. q[n] = query
  88. return join(q), nil
  89. }
  90. func (c conn) Exec(query string, args []driver.Value) (driver.Result, error) {
  91. q, err := c.parseQuery(query, args)
  92. if err != nil {
  93. return nil, err
  94. }
  95. if len(q) == 0 {
  96. return nil, driver.ErrSkip
  97. }
  98. res, err := c.my.Start(q)
  99. if err != nil {
  100. return nil, errFilter(err)
  101. }
  102. return &rowsRes{my: res}, nil
  103. }
  104. var textQuery = mysql.Stmt(new(native.Stmt))
  105. func (c conn) Query(query string, args []driver.Value) (driver.Rows, error) {
  106. q, err := c.parseQuery(query, args)
  107. if err != nil {
  108. return nil, err
  109. }
  110. if len(q) == 0 {
  111. return nil, driver.ErrSkip
  112. }
  113. res, err := c.my.Start(q)
  114. if err != nil {
  115. return nil, errFilter(err)
  116. }
  117. return &rowsRes{row: res.MakeRow(), my: res, simpleQuery: textQuery}, nil
  118. }
  119. type stmt struct {
  120. my mysql.Stmt
  121. args []interface{}
  122. }
  123. func (s *stmt) run(args []driver.Value) (*rowsRes, error) {
  124. for i, v := range args {
  125. s.args[i] = interface{}(v)
  126. }
  127. res, err := s.my.Run(s.args...)
  128. if err != nil {
  129. return nil, errFilter(err)
  130. }
  131. return &rowsRes{my: res}, nil
  132. }
  133. func (c conn) Prepare(query string) (driver.Stmt, error) {
  134. st, err := c.my.Prepare(query)
  135. if err != nil {
  136. return nil, errFilter(err)
  137. }
  138. return &stmt{st, make([]interface{}, st.NumParam())}, nil
  139. }
  140. func (c *conn) Close() (err error) {
  141. err = c.my.Close()
  142. c.my = nil
  143. if err != nil {
  144. err = errFilter(err)
  145. }
  146. return
  147. }
  148. type tx struct {
  149. my mysql.Transaction
  150. }
  151. func (c conn) Begin() (driver.Tx, error) {
  152. t, err := c.my.Begin()
  153. if err != nil {
  154. return nil, errFilter(err)
  155. }
  156. return tx{t}, nil
  157. }
  158. func (t tx) Commit() (err error) {
  159. err = t.my.Commit()
  160. if err != nil {
  161. err = errFilter(err)
  162. }
  163. return
  164. }
  165. func (t tx) Rollback() (err error) {
  166. err = t.my.Rollback()
  167. if err != nil {
  168. err = errFilter(err)
  169. }
  170. return
  171. }
  172. func (s *stmt) Close() (err error) {
  173. if s.my == nil {
  174. panic("godrv: stmt closed twice")
  175. }
  176. err = s.my.Delete()
  177. s.my = nil
  178. if err != nil {
  179. err = errFilter(err)
  180. }
  181. return
  182. }
  183. func (s *stmt) NumInput() int {
  184. return s.my.NumParam()
  185. }
  186. func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
  187. return s.run(args)
  188. }
  189. func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
  190. r, err := s.run(args)
  191. if err != nil {
  192. return nil, err
  193. }
  194. r.row = r.my.MakeRow()
  195. return r, nil
  196. }
  197. func (r *rowsRes) LastInsertId() (int64, error) {
  198. return int64(r.my.InsertId()), nil
  199. }
  200. func (r *rowsRes) RowsAffected() (int64, error) {
  201. return int64(r.my.AffectedRows()), nil
  202. }
  203. func (r *rowsRes) Columns() []string {
  204. flds := r.my.Fields()
  205. cls := make([]string, len(flds))
  206. for i, f := range flds {
  207. cls[i] = f.Name
  208. }
  209. return cls
  210. }
  211. func (r *rowsRes) Close() error {
  212. if r.my == nil {
  213. return nil // closed before
  214. }
  215. if err := r.my.End(); err != nil {
  216. return errFilter(err)
  217. }
  218. if r.simpleQuery != nil && r.simpleQuery != textQuery {
  219. if err := r.simpleQuery.Delete(); err != nil {
  220. return errFilter(err)
  221. }
  222. }
  223. r.my = nil
  224. return nil
  225. }
  226. var location = time.Local
  227. // DATE, DATETIME, TIMESTAMP are treated as they are in Local time zone (this
  228. // can be changed globaly using SetLocation function).
  229. func (r *rowsRes) Next(dest []driver.Value) error {
  230. if r.my == nil {
  231. return io.EOF // closed before
  232. }
  233. err := r.my.ScanRow(r.row)
  234. if err == nil {
  235. if r.simpleQuery == textQuery {
  236. // workaround for time.Time from text queries
  237. for i, f := range r.my.Fields() {
  238. if r.row[i] != nil {
  239. switch f.Type {
  240. case native.MYSQL_TYPE_TIMESTAMP, native.MYSQL_TYPE_DATETIME,
  241. native.MYSQL_TYPE_DATE, native.MYSQL_TYPE_NEWDATE:
  242. r.row[i] = r.row.ForceTime(i, location)
  243. }
  244. }
  245. }
  246. }
  247. for i, d := range r.row {
  248. dest[i] = driver.Value(d)
  249. }
  250. return nil
  251. }
  252. if err != io.EOF {
  253. return errFilter(err)
  254. }
  255. if r.simpleQuery != nil && r.simpleQuery != textQuery {
  256. if err = r.simpleQuery.Delete(); err != nil {
  257. return errFilter(err)
  258. }
  259. }
  260. r.my = nil
  261. return io.EOF
  262. }
  263. // Driver implements database/sql/driver interface.
  264. type Driver struct {
  265. // Defaults
  266. proto, laddr, raddr, user, passwd, db string
  267. timeout time.Duration
  268. dialer Dialer
  269. initCmds []string
  270. }
  271. // Open creates a new connection. The uri needs to have the following syntax:
  272. //
  273. // [PROTOCOL_SPECFIIC*]DBNAME/USER/PASSWD
  274. //
  275. // where protocol specific part may be empty (this means connection to
  276. // local server using default protocol). Currently possible forms are:
  277. //
  278. // DBNAME/USER/PASSWD
  279. // unix:SOCKPATH*DBNAME/USER/PASSWD
  280. // unix:SOCKPATH,OPTIONS*DBNAME/USER/PASSWD
  281. // tcp:ADDR*DBNAME/USER/PASSWD
  282. // tcp:ADDR,OPTIONS*DBNAME/USER/PASSWD
  283. // cloudsql:INSTANCE*DBNAME/USER/PASSWD
  284. //
  285. // OPTIONS can contain comma separated list of options in form:
  286. // opt1=VAL1,opt2=VAL2,boolopt3,boolopt4
  287. // Currently implemented options, in addition to default MySQL variables:
  288. // laddr - local address/port (eg. 1.2.3.4:0)
  289. // timeout - connect timeout in format accepted by time.ParseDuration
  290. func (d *Driver) Open(uri string) (driver.Conn, error) {
  291. cfg := *d // copy default configuration
  292. pd := strings.SplitN(uri, "*", 2)
  293. connCommands := []string{}
  294. if len(pd) == 2 {
  295. // Parse protocol part of URI
  296. p := strings.SplitN(pd[0], ":", 2)
  297. if len(p) != 2 {
  298. return nil, errors.New("Wrong protocol part of URI")
  299. }
  300. cfg.proto = p[0]
  301. options := strings.Split(p[1], ",")
  302. cfg.raddr = options[0]
  303. for _, o := range options[1:] {
  304. kv := strings.SplitN(o, "=", 2)
  305. var k, v string
  306. if len(kv) == 2 {
  307. k, v = kv[0], kv[1]
  308. } else {
  309. k, v = o, "true"
  310. }
  311. switch k {
  312. case "laddr":
  313. cfg.laddr = v
  314. case "timeout":
  315. to, err := time.ParseDuration(v)
  316. if err != nil {
  317. return nil, err
  318. }
  319. cfg.timeout = to
  320. default:
  321. connCommands = append(connCommands, "SET "+k+"="+v)
  322. }
  323. }
  324. // Remove protocol part
  325. pd = pd[1:]
  326. }
  327. // Parse database part of URI
  328. dup := strings.SplitN(pd[0], "/", 3)
  329. if len(dup) != 3 {
  330. return nil, errors.New("Wrong database part of URI")
  331. }
  332. cfg.db = dup[0]
  333. cfg.user = dup[1]
  334. cfg.passwd = dup[2]
  335. c := conn{mysql.New(
  336. cfg.proto, cfg.laddr, cfg.raddr, cfg.user, cfg.passwd, cfg.db,
  337. )}
  338. if d.dialer != nil {
  339. dialer := func(proto, laddr, raddr string, timeout time.Duration) (
  340. net.Conn, error) {
  341. return d.dialer(proto, laddr, raddr, cfg.user, cfg.passwd, timeout)
  342. }
  343. c.my.SetDialer(dialer)
  344. }
  345. // Establish the connection
  346. c.my.SetTimeout(cfg.timeout)
  347. for _, q := range cfg.initCmds {
  348. c.my.Register(q) // Register initialisation commands
  349. }
  350. for _, q := range connCommands {
  351. c.my.Register(q)
  352. }
  353. if err := c.my.Connect(); err != nil {
  354. return nil, errFilter(err)
  355. }
  356. c.my.NarrowTypeSet(true)
  357. c.my.FullFieldInfo(false)
  358. return &c, nil
  359. }
  360. // Register registers initialization commands.
  361. // This is workaround, see http://codereview.appspot.com/5706047
  362. func (drv *Driver) Register(query string) {
  363. drv.initCmds = append(drv.initCmds, query)
  364. }
  365. // Dialer can be used to dial connections to MySQL. If Dialer returns (nil, nil)
  366. // the hook is skipped and normal dialing proceeds. user and dbname are there
  367. // only for logging.
  368. type Dialer func(proto, laddr, raddr, user, dbname string, timeout time.Duration) (net.Conn, error)
  369. // SetDialer sets custom Dialer used by Driver to make connections.
  370. func (drv *Driver) SetDialer(dialer Dialer) {
  371. drv.dialer = dialer
  372. }
  373. // Driver automatically registered in database/sql.
  374. var dfltdrv = Driver{proto: "tcp", raddr: "127.0.0.1:3306"}
  375. // Register calls Register method on driver registered in database/sql.
  376. // If Register is called twice with the same name it panics.
  377. func Register(query string) {
  378. dfltdrv.Register(query)
  379. }
  380. // SetDialer calls SetDialer method on driver registered in database/sql.
  381. func SetDialer(dialer Dialer) {
  382. dfltdrv.SetDialer(dialer)
  383. }
  384. func init() {
  385. Register("SET NAMES utf8")
  386. sql.Register("mymysql", &dfltdrv)
  387. }
  388. // Version returns mymysql version string.
  389. func Version() string {
  390. return mysql.Version()
  391. }
  392. // SetLocation changes default location used to convert dates obtained from
  393. // server to time.Time.
  394. func SetLocation(loc *time.Location) {
  395. location = loc
  396. }