// Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 Julien Schmidt. All rights reserved. // http://www.julienschmidt.com // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. package mysql import ( "crypto/sha1" "database/sql/driver" "encoding/binary" "fmt" "io" "log" "os" "regexp" "strings" "time" ) // NullTime represents a time.Time that may be NULL. // NullTime implements the Scanner interface so // it can be used as a scan destination: // // var nt NullTime // err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) // ... // if nt.Valid { // // use nt.Time // } else { // // NULL value // } // // This NullTime implementation is not driver-specific type NullTime struct { Time time.Time Valid bool // Valid is true if Time is not NULL } // Scan implements the Scanner interface. // The value type must be time.Time or string / []byte (formatted time-string), // otherwise Scan fails. func (nt *NullTime) Scan(value interface{}) (err error) { if value == nil { nt.Time, nt.Valid = time.Time{}, false return } switch v := value.(type) { case time.Time: nt.Time, nt.Valid = v, true return case []byte: nt.Time, err = parseDateTime(string(v), time.UTC) nt.Valid = (err == nil) return case string: nt.Time, err = parseDateTime(v, time.UTC) nt.Valid = (err == nil) return } nt.Valid = false return fmt.Errorf("Can't convert %T to time.Time", value) } // Value implements the driver Valuer interface. func (nt NullTime) Value() (driver.Value, error) { if !nt.Valid { return nil, nil } return nt.Time, nil } // Logger var ( errLog *log.Logger ) func init() { errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile) dsnPattern = regexp.MustCompile( `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] `\/(?P.*?)` + // /dbname `(?:\?(?P[^\?]*))?$`) // [?param1=value1¶mN=valueN] } // Data Source Name Parser var dsnPattern *regexp.Regexp func parseDSN(dsn string) (cfg *config, err error) { cfg = new(config) cfg.params = make(map[string]string) matches := dsnPattern.FindStringSubmatch(dsn) names := dsnPattern.SubexpNames() for i, match := range matches { switch names[i] { case "user": cfg.user = match case "passwd": cfg.passwd = match case "net": cfg.net = match case "addr": cfg.addr = match case "dbname": cfg.dbname = match case "params": for _, v := range strings.Split(match, "&") { param := strings.SplitN(v, "=", 2) if len(param) != 2 { continue } cfg.params[param[0]] = param[1] } } } // Set default network if empty if cfg.net == "" { cfg.net = "tcp" } // Set default adress if empty if cfg.addr == "" { cfg.addr = "127.0.0.1:3306" } cfg.loc, err = time.LoadLocation(cfg.params["loc"]) return } // Encrypt password using 4.1+ method // http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol#4.1_and_later func scramblePassword(scramble, password []byte) []byte { if len(password) == 0 { return nil } // stage1Hash = SHA1(password) crypt := sha1.New() crypt.Write(password) stage1 := crypt.Sum(nil) // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) // inner Hash crypt.Reset() crypt.Write(stage1) hash := crypt.Sum(nil) // outer Hash crypt.Reset() crypt.Write(scramble) crypt.Write(hash) scramble = crypt.Sum(nil) // token = scrambleHash XOR stage1Hash for i := range scramble { scramble[i] ^= stage1[i] } return scramble } func parseDateTime(str string, loc *time.Location) (t time.Time, err error) { switch len(str) { case 10: // YYYY-MM-DD if str == "0000-00-00" { return } t, err = time.Parse(timeFormat[:10], str) case 19: // YYYY-MM-DD HH:MM:SS if str == "0000-00-00 00:00:00" { return } t, err = time.Parse(timeFormat, str) default: err = fmt.Errorf("Invalid Time-String: %s", str) return } // Adjust location if err == nil && loc != time.UTC { y, mo, d := t.Date() h, mi, s := t.Clock() t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil } return } func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) { switch num { case 0: return time.Time{}, nil case 4: return time.Date( int(binary.LittleEndian.Uint16(data[:2])), // year time.Month(data[2]), // month int(data[3]), // day 0, 0, 0, 0, loc, ), nil case 7: return time.Date( int(binary.LittleEndian.Uint16(data[:2])), // year time.Month(data[2]), // month int(data[3]), // day int(data[4]), // hour int(data[5]), // minutes int(data[6]), // seconds 0, loc, ), nil case 11: return time.Date( int(binary.LittleEndian.Uint16(data[:2])), // year time.Month(data[2]), // month int(data[3]), // day int(data[4]), // hour int(data[5]), // minutes int(data[6]), // seconds int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds loc, ), nil } return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num) } func formatBinaryDate(num uint64, data []byte) (driver.Value, error) { switch num { case 0: return []byte("0000-00-00"), nil case 4: return []byte(fmt.Sprintf( "%04d-%02d-%02d", binary.LittleEndian.Uint16(data[:2]), data[2], data[3], )), nil } return nil, fmt.Errorf("Invalid DATE-packet length %d", num) } func formatBinaryDateTime(num uint64, data []byte) (driver.Value, error) { switch num { case 0: return []byte("0000-00-00 00:00:00"), nil case 4: return []byte(fmt.Sprintf( "%04d-%02d-%02d 00:00:00", binary.LittleEndian.Uint16(data[:2]), data[2], data[3], )), nil case 7: return []byte(fmt.Sprintf( "%04d-%02d-%02d %02d:%02d:%02d", binary.LittleEndian.Uint16(data[:2]), data[2], data[3], data[4], data[5], data[6], )), nil case 11: return []byte(fmt.Sprintf( "%04d-%02d-%02d %02d:%02d:%02d.%06d", binary.LittleEndian.Uint16(data[:2]), data[2], data[3], data[4], data[5], data[6], binary.LittleEndian.Uint32(data[7:11]), )), nil } return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num) } func readBool(value string) bool { switch strings.ToLower(value) { case "true": return true case "1": return true } return false } /****************************************************************************** * Convert from and to bytes * ******************************************************************************/ func uint64ToBytes(n uint64) []byte { return []byte{ byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24), byte(n >> 32), byte(n >> 40), byte(n >> 48), byte(n >> 56), } } func uint64ToString(n uint64) []byte { var a [20]byte i := 20 // U+0030 = 0 // ... // U+0039 = 9 var q uint64 for n >= 10 { i-- q = n / 10 a[i] = uint8(n-q*10) + 0x30 n = q } i-- a[i] = uint8(n) + 0x30 return a[i:] } // treats string value as unsigned integer representation func stringToInt(b []byte) int { val := 0 for i := range b { val *= 10 val += int(b[i] - 0x30) } return val } func readLengthEnodedString(b []byte) ([]byte, bool, int, error) { // Get length num, isNull, n := readLengthEncodedInteger(b) if num < 1 { return b[n:n], isNull, n, nil } n += int(num) // Check data length if len(b) >= n { return b[n-int(num) : n], false, n, nil } return nil, false, n, io.EOF } func skipLengthEnodedString(b []byte) (int, error) { // Get length num, _, n := readLengthEncodedInteger(b) if num < 1 { return n, nil } n += int(num) // Check data length if len(b) >= n { return n, nil } return n, io.EOF } func readLengthEncodedInteger(b []byte) (num uint64, isNull bool, n int) { switch b[0] { // 251: NULL case 0xfb: n = 1 isNull = true return // 252: value of following 2 case 0xfc: num = uint64(b[1]) | uint64(b[2])<<8 n = 3 return // 253: value of following 3 case 0xfd: num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 n = 4 return // 254: value of following 8 case 0xfe: num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | uint64(b[7])<<48 | uint64(b[8])<<54 n = 9 return } // 0-250: value of first byte num = uint64(b[0]) n = 1 return } func lengthEncodedIntegerToBytes(n uint64) []byte { switch { case n <= 250: return []byte{byte(n)} case n <= 0xffff: return []byte{0xfc, byte(n), byte(n >> 8)} case n <= 0xffffff: return []byte{0xfd, byte(n), byte(n >> 8), byte(n >> 16)} } return nil }