123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439 |
- // Package godrv implements database/sql MySQL driver.
- package godrv
- import (
- "database/sql"
- "database/sql/driver"
- "errors"
- "fmt"
- "io"
- "net"
- "strconv"
- "strings"
- "time"
- "github.com/ziutek/mymysql/mysql"
- "github.com/ziutek/mymysql/native"
- )
- type conn struct {
- my mysql.Conn
- }
- type rowsRes struct {
- row mysql.Row
- my mysql.Result
- simpleQuery mysql.Stmt
- }
- func errFilter(err error) error {
- if err == io.ErrUnexpectedEOF {
- return driver.ErrBadConn
- }
- if _, ok := err.(net.Error); ok {
- return driver.ErrBadConn
- }
- return err
- }
- func join(a []string) string {
- n := 0
- for _, s := range a {
- n += len(s)
- }
- b := make([]byte, n)
- n = 0
- for _, s := range a {
- n += copy(b[n:], s)
- }
- return string(b)
- }
- func (c conn) parseQuery(query string, args []driver.Value) (string, error) {
- if len(args) == 0 {
- return query, nil
- }
- if strings.ContainsAny(query, `'"`) {
- return "", nil
- }
- q := make([]string, 2*len(args)+1)
- n := 0
- for _, a := range args {
- i := strings.IndexRune(query, '?')
- if i == -1 {
- return "", errors.New("number of parameters doesn't match number of placeholders")
- }
- var s string
- switch v := a.(type) {
- case nil:
- s = "NULL"
- case string:
- s = "'" + c.my.Escape(v) + "'"
- case []byte:
- s = "'" + c.my.Escape(string(v)) + "'"
- case int64:
- s = strconv.FormatInt(v, 10)
- case time.Time:
- s = "'" + v.Format(mysql.TimeFormat) + "'"
- case bool:
- if v {
- s = "1"
- } else {
- s = "0"
- }
- case float64:
- s = strconv.FormatFloat(v, 'e', 12, 64)
- default:
- panic(fmt.Sprintf("%v (%T) can't be handled by godrv", v, v))
- }
- q[n] = query[:i]
- q[n+1] = s
- query = query[i+1:]
- n += 2
- }
- q[n] = query
- return join(q), nil
- }
- func (c conn) Exec(query string, args []driver.Value) (driver.Result, error) {
- q, err := c.parseQuery(query, args)
- if err != nil {
- return nil, err
- }
- if len(q) == 0 {
- return nil, driver.ErrSkip
- }
- res, err := c.my.Start(q)
- if err != nil {
- return nil, errFilter(err)
- }
- return &rowsRes{my: res}, nil
- }
- var textQuery = mysql.Stmt(new(native.Stmt))
- func (c conn) Query(query string, args []driver.Value) (driver.Rows, error) {
- q, err := c.parseQuery(query, args)
- if err != nil {
- return nil, err
- }
- if len(q) == 0 {
- return nil, driver.ErrSkip
- }
- res, err := c.my.Start(q)
- if err != nil {
- return nil, errFilter(err)
- }
- return &rowsRes{row: res.MakeRow(), my: res, simpleQuery: textQuery}, nil
- }
- type stmt struct {
- my mysql.Stmt
- args []interface{}
- }
- func (s *stmt) run(args []driver.Value) (*rowsRes, error) {
- for i, v := range args {
- s.args[i] = interface{}(v)
- }
- res, err := s.my.Run(s.args...)
- if err != nil {
- return nil, errFilter(err)
- }
- return &rowsRes{my: res}, nil
- }
- func (c conn) Prepare(query string) (driver.Stmt, error) {
- st, err := c.my.Prepare(query)
- if err != nil {
- return nil, errFilter(err)
- }
- return &stmt{st, make([]interface{}, st.NumParam())}, nil
- }
- func (c *conn) Close() (err error) {
- err = c.my.Close()
- c.my = nil
- if err != nil {
- err = errFilter(err)
- }
- return
- }
- type tx struct {
- my mysql.Transaction
- }
- func (c conn) Begin() (driver.Tx, error) {
- t, err := c.my.Begin()
- if err != nil {
- return nil, errFilter(err)
- }
- return tx{t}, nil
- }
- func (t tx) Commit() (err error) {
- err = t.my.Commit()
- if err != nil {
- err = errFilter(err)
- }
- return
- }
- func (t tx) Rollback() (err error) {
- err = t.my.Rollback()
- if err != nil {
- err = errFilter(err)
- }
- return
- }
- func (s *stmt) Close() (err error) {
- if s.my == nil {
- panic("godrv: stmt closed twice")
- }
- err = s.my.Delete()
- s.my = nil
- if err != nil {
- err = errFilter(err)
- }
- return
- }
- func (s *stmt) NumInput() int {
- return s.my.NumParam()
- }
- func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
- return s.run(args)
- }
- func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
- r, err := s.run(args)
- if err != nil {
- return nil, err
- }
- r.row = r.my.MakeRow()
- return r, nil
- }
- func (r *rowsRes) LastInsertId() (int64, error) {
- return int64(r.my.InsertId()), nil
- }
- func (r *rowsRes) RowsAffected() (int64, error) {
- return int64(r.my.AffectedRows()), nil
- }
- func (r *rowsRes) Columns() []string {
- flds := r.my.Fields()
- cls := make([]string, len(flds))
- for i, f := range flds {
- cls[i] = f.Name
- }
- return cls
- }
- func (r *rowsRes) Close() error {
- if r.my == nil {
- return nil // closed before
- }
- if err := r.my.End(); err != nil {
- return errFilter(err)
- }
- if r.simpleQuery != nil && r.simpleQuery != textQuery {
- if err := r.simpleQuery.Delete(); err != nil {
- return errFilter(err)
- }
- }
- r.my = nil
- return nil
- }
- var location = time.Local
- // DATE, DATETIME, TIMESTAMP are treated as they are in Local time zone (this
- // can be changed globaly using SetLocation function).
- func (r *rowsRes) Next(dest []driver.Value) error {
- if r.my == nil {
- return io.EOF // closed before
- }
- err := r.my.ScanRow(r.row)
- if err == nil {
- if r.simpleQuery == textQuery {
- // workaround for time.Time from text queries
- for i, f := range r.my.Fields() {
- if r.row[i] != nil {
- switch f.Type {
- case native.MYSQL_TYPE_TIMESTAMP, native.MYSQL_TYPE_DATETIME,
- native.MYSQL_TYPE_DATE, native.MYSQL_TYPE_NEWDATE:
- r.row[i] = r.row.ForceTime(i, location)
- }
- }
- }
- }
- for i, d := range r.row {
- dest[i] = driver.Value(d)
- }
- return nil
- }
- if err != io.EOF {
- return errFilter(err)
- }
- if r.simpleQuery != nil && r.simpleQuery != textQuery {
- if err = r.simpleQuery.Delete(); err != nil {
- return errFilter(err)
- }
- }
- r.my = nil
- return io.EOF
- }
- // Driver implements database/sql/driver interface.
- type Driver struct {
- // Defaults
- proto, laddr, raddr, user, passwd, db string
- timeout time.Duration
- dialer Dialer
- initCmds []string
- }
- // Open creates a new connection. The uri needs to have the following syntax:
- //
- // [PROTOCOL_SPECFIIC*]DBNAME/USER/PASSWD
- //
- // where protocol specific part may be empty (this means connection to
- // local server using default protocol). Currently possible forms are:
- //
- // DBNAME/USER/PASSWD
- // unix:SOCKPATH*DBNAME/USER/PASSWD
- // unix:SOCKPATH,OPTIONS*DBNAME/USER/PASSWD
- // tcp:ADDR*DBNAME/USER/PASSWD
- // tcp:ADDR,OPTIONS*DBNAME/USER/PASSWD
- // cloudsql:INSTANCE*DBNAME/USER/PASSWD
- //
- // OPTIONS can contain comma separated list of options in form:
- // opt1=VAL1,opt2=VAL2,boolopt3,boolopt4
- // Currently implemented options, in addition to default MySQL variables:
- // laddr - local address/port (eg. 1.2.3.4:0)
- // timeout - connect timeout in format accepted by time.ParseDuration
- func (d *Driver) Open(uri string) (driver.Conn, error) {
- cfg := *d // copy default configuration
- pd := strings.SplitN(uri, "*", 2)
- connCommands := []string{}
- if len(pd) == 2 {
- // Parse protocol part of URI
- p := strings.SplitN(pd[0], ":", 2)
- if len(p) != 2 {
- return nil, errors.New("Wrong protocol part of URI")
- }
- cfg.proto = p[0]
- options := strings.Split(p[1], ",")
- cfg.raddr = options[0]
- for _, o := range options[1:] {
- kv := strings.SplitN(o, "=", 2)
- var k, v string
- if len(kv) == 2 {
- k, v = kv[0], kv[1]
- } else {
- k, v = o, "true"
- }
- switch k {
- case "laddr":
- cfg.laddr = v
- case "timeout":
- to, err := time.ParseDuration(v)
- if err != nil {
- return nil, err
- }
- cfg.timeout = to
- default:
- connCommands = append(connCommands, "SET "+k+"="+v)
- }
- }
- // Remove protocol part
- pd = pd[1:]
- }
- // Parse database part of URI
- dup := strings.SplitN(pd[0], "/", 3)
- if len(dup) != 3 {
- return nil, errors.New("Wrong database part of URI")
- }
- cfg.db = dup[0]
- cfg.user = dup[1]
- cfg.passwd = dup[2]
- c := conn{mysql.New(
- cfg.proto, cfg.laddr, cfg.raddr, cfg.user, cfg.passwd, cfg.db,
- )}
- if d.dialer != nil {
- dialer := func(proto, laddr, raddr string, timeout time.Duration) (
- net.Conn, error) {
- return d.dialer(proto, laddr, raddr, cfg.user, cfg.passwd, timeout)
- }
- c.my.SetDialer(dialer)
- }
- // Establish the connection
- c.my.SetTimeout(cfg.timeout)
- for _, q := range cfg.initCmds {
- c.my.Register(q) // Register initialisation commands
- }
- for _, q := range connCommands {
- c.my.Register(q)
- }
- if err := c.my.Connect(); err != nil {
- return nil, errFilter(err)
- }
- c.my.NarrowTypeSet(true)
- c.my.FullFieldInfo(false)
- return &c, nil
- }
- // Register registers initialization commands.
- // This is workaround, see http://codereview.appspot.com/5706047
- func (drv *Driver) Register(query string) {
- drv.initCmds = append(drv.initCmds, query)
- }
- // Dialer can be used to dial connections to MySQL. If Dialer returns (nil, nil)
- // the hook is skipped and normal dialing proceeds. user and dbname are there
- // only for logging.
- type Dialer func(proto, laddr, raddr, user, dbname string, timeout time.Duration) (net.Conn, error)
- // SetDialer sets custom Dialer used by Driver to make connections.
- func (drv *Driver) SetDialer(dialer Dialer) {
- drv.dialer = dialer
- }
- // Driver automatically registered in database/sql.
- var dfltdrv = Driver{proto: "tcp", raddr: "127.0.0.1:3306"}
- // Register calls Register method on driver registered in database/sql.
- // If Register is called twice with the same name it panics.
- func Register(query string) {
- dfltdrv.Register(query)
- }
- // SetDialer calls SetDialer method on driver registered in database/sql.
- func SetDialer(dialer Dialer) {
- dfltdrv.SetDialer(dialer)
- }
- func init() {
- Register("SET NAMES utf8")
- sql.Register("mymysql", &dfltdrv)
- }
- // Version returns mymysql version string.
- func Version() string {
- return mysql.Version()
- }
- // SetLocation changes default location used to convert dates obtained from
- // server to time.Time.
- func SetLocation(loc *time.Location) {
- location = loc
- }
|