| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315 |
- // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
- //
- // Copyright 2012 Julien Schmidt. All rights reserved.
- // http://www.julienschmidt.com
- //
- // This Source Code Form is subject to the terms of the Mozilla Public
- // License, v. 2.0. If a copy of the MPL was not distributed with this file,
- // You can obtain one at http://mozilla.org/MPL/2.0/.
- package mysql
- import (
- "bytes"
- "crypto/sha1"
- "io"
- "log"
- "math"
- "os"
- "regexp"
- "strconv"
- "strings"
- )
- // Logger
- var (
- errLog *log.Logger
- dbgLog *log.Logger
- )
- func init() {
- errLog = log.New(os.Stderr, "[MySQL] ", log.LstdFlags)
- dbgLog = log.New(os.Stdout, "[MySQL] ", log.LstdFlags)
- dsnPattern = regexp.MustCompile(
- `^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
- `(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
- `\/(?P<dbname>.*?)` + // /dbname
- `(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1¶mN=valueN]
- }
- // Data Source Name Parser
- var dsnPattern *regexp.Regexp
- func parseDSN(dsn string) *config {
- cfg := new(config)
- cfg.params = make(map[string]string)
- matches := dsnPattern.FindStringSubmatch(dsn)
- names := dsnPattern.SubexpNames()
- for i, match := range matches {
- switch names[i] {
- case "user":
- cfg.user = match
- case "passwd":
- cfg.passwd = match
- case "net":
- cfg.net = match
- case "addr":
- cfg.addr = match
- case "dbname":
- cfg.dbname = match
- case "params":
- for _, v := range strings.Split(match, "&") {
- param := strings.SplitN(v, "=", 2)
- if len(param) != 2 {
- continue
- }
- cfg.params[param[0]] = param[1]
- }
- }
- }
- // Set default network if empty
- if cfg.net == "" {
- cfg.net = "tcp"
- }
- // Set default adress if empty
- if cfg.addr == "" {
- cfg.addr = "127.0.0.1:3306"
- }
- return cfg
- }
- // Encrypt password using 4.1+ method
- // http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol#4.1_and_later
- func scramblePassword(scramble, password []byte) (result []byte) {
- if len(password) == 0 {
- return
- }
- // stage1Hash = SHA1(password)
- crypt := sha1.New()
- crypt.Write(password)
- stage1Hash := crypt.Sum(nil)
- // scrambleHash = SHA1(scramble + SHA1(stage1Hash))
- // inner Hash
- crypt.Reset()
- crypt.Write(stage1Hash)
- scrambleHash := crypt.Sum(nil)
- // outer Hash
- crypt.Reset()
- crypt.Write(scramble)
- crypt.Write(scrambleHash)
- scrambleHash = crypt.Sum(nil)
- // token = scrambleHash XOR stage1Hash
- result = make([]byte, 20)
- for i := range result {
- result[i] = scrambleHash[i] ^ stage1Hash[i]
- }
- return
- }
- /******************************************************************************
- * Read data-types from bytes *
- ******************************************************************************/
- // Read a slice from the data slice
- func readSlice(data []byte, delim byte) (slice []byte, err error) {
- pos := bytes.IndexByte(data, delim)
- if pos > -1 {
- slice = data[:pos]
- } else {
- slice = data
- err = io.EOF
- }
- return
- }
- func readLengthCodedBinary(data []byte) (*[]byte, int, bool, error) {
- // Get length
- num, n, err := bytesToLengthCodedBinary(data)
- if err != nil {
- return nil, n, true, err
- }
- // Check data length
- if len(data) < n+int(num) {
- return nil, n, true, io.EOF
- }
- // Check if value is NULL
- var isNull bool
- if data[0] == 251 {
- isNull = true
- } else {
- isNull = false
- }
- // Get bytes
- b := data[n : n+int(num)]
- n += int(num)
- return &b, n, isNull, err
- }
- func readAndDropLengthCodedBinary(data []byte) (n int, err error) {
- // Get length
- num, n, err := bytesToLengthCodedBinary(data)
- if err != nil {
- return
- }
- // Check data length
- if len(data) < n+int(num) {
- err = io.EOF
- return
- }
- n += int(num)
- return
- }
- /******************************************************************************
- * Convert from and to bytes *
- ******************************************************************************/
- func byteToUint8(b byte) (n uint8) {
- n |= uint8(b)
- return
- }
- func bytesToUint16(b []byte) (n uint16) {
- n |= uint16(b[0])
- n |= uint16(b[1]) << 8
- return
- }
- func uint24ToBytes(n uint32) (b []byte) {
- b = make([]byte, 3)
- for i := uint8(0); i < 3; i++ {
- b[i] = byte(n >> (i << 3))
- }
- return
- }
- func bytesToUint32(b []byte) (n uint32) {
- for i := uint8(0); i < 4; i++ {
- n |= uint32(b[i]) << (i << 3)
- }
- return
- }
- func uint32ToBytes(n uint32) (b []byte) {
- b = make([]byte, 4)
- for i := uint8(0); i < 4; i++ {
- b[i] = byte(n >> (i << 3))
- }
- return
- }
- func bytesToUint64(b []byte) (n uint64) {
- for i := uint8(0); i < 8; i++ {
- n |= uint64(b[i]) << (i << 3)
- }
- return
- }
- func uint64ToBytes(n uint64) (b []byte) {
- b = make([]byte, 8)
- for i := uint8(0); i < 8; i++ {
- b[i] = byte(n >> (i << 3))
- }
- return
- }
- func int64ToBytes(n int64) []byte {
- return uint64ToBytes(uint64(n))
- }
- func bytesToFloat32(b []byte) float32 {
- return math.Float32frombits(bytesToUint32(b))
- }
- func bytesToFloat64(b []byte) float64 {
- return math.Float64frombits(bytesToUint64(b))
- }
- func float64ToBytes(f float64) []byte {
- return uint64ToBytes(math.Float64bits(f))
- }
- func bytesToLengthCodedBinary(b []byte) (length uint64, n int, err error) {
- switch {
- // 0-250: value of first byte
- case b[0] <= 250:
- length = uint64(b[0])
- n = 1
- return
- // 251: NULL
- case b[0] == 251:
- length = 0
- n = 1
- return
- // 252: value of following 2
- case b[0] == 252:
- n = 3
- // 253: value of following 3
- case b[0] == 253:
- n = 4
- // 254: value of following 8
- case b[0] == 254:
- n = 9
- }
- if len(b) < n {
- err = io.EOF
- return
- }
- // get Length
- tmp := make([]byte, 8)
- copy(tmp, b[1:n])
- length = bytesToUint64(tmp)
- return
- }
- func lengthCodedBinaryToBytes(n uint64) (b []byte) {
- switch {
- case n <= 250:
- b = []byte{byte(n)}
- case n <= 0xffff:
- b = []byte{0xfc, byte(n), byte(n >> 8)}
- case n <= 0xffffff:
- b = []byte{0xfd, byte(n), byte(n >> 8), byte(n >> 16)}
- }
- return
- }
- func intToByteStr(i int64) (b []byte) {
- return strconv.AppendInt(b, i, 10)
- }
- func uintToByteStr(u uint64) (b []byte) {
- return strconv.AppendUint(b, u, 10)
- }
- func float32ToByteStr(f float32) (b []byte) {
- return strconv.AppendFloat(b, float64(f), 'f', -1, 32)
- }
- func float64ToByteStr(f float64) (b []byte) {
- return strconv.AppendFloat(b, f, 'f', -1, 64)
- }
|