// Copyright 2012 Gary Burd // // Licensed under the Apache License, Version 2.0 (the "License"): you may // not use this file except in compliance with the License. You may obtain // a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the // License for the specific language governing permissions and limitations // under the License. package redis import ( "bufio" "bytes" "errors" "fmt" "io" "net" "strconv" "sync" "time" ) // conn is the low-level implementation of Conn type conn struct { // Shared mu sync.Mutex pending int err error conn net.Conn // Read readTimeout time.Duration br *bufio.Reader // Write writeTimeout time.Duration bw *bufio.Writer // Scratch space for formatting argument length. // '*' or '$', length, "\r\n" lenScratch [1 + 19 + 2]byte // Scratch space for formatting integers and floats. numScratch [40]byte } // Dial connects to the Redis server at the given network and address. func Dial(network, address string) (Conn, error) { c, err := net.Dial(network, address) if err != nil { return nil, errors.New("Could not connect to Redis server: " + err.Error()) } return NewConn(c, 0, 0), nil } // DialTimeout acts like Dial but takes timeouts for establishing the // connection to the server, writing a command and reading a reply. func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) { var c net.Conn var err error if connectTimeout > 0 { c, err = net.DialTimeout(network, address, connectTimeout) } else { c, err = net.Dial(network, address) } if err != nil { return nil, errors.New("Could not connect to Redis server: " + err.Error()) } return NewConn(c, readTimeout, writeTimeout), nil } // NewConn returns a new Redigo connection for the given net connection. func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn { return &conn{ conn: netConn, bw: bufio.NewWriter(netConn), br: bufio.NewReader(netConn), readTimeout: readTimeout, writeTimeout: writeTimeout, } } func (c *conn) Close() error { err := c.conn.Close() if err != nil { c.fatal(err) } else { c.fatal(errors.New("redigo: closed")) } return err } func (c *conn) fatal(err error) error { c.mu.Lock() if c.err == nil { c.err = err } c.mu.Unlock() return err } func (c *conn) Err() error { c.mu.Lock() err := c.err c.mu.Unlock() return err } func (c *conn) writeLen(prefix byte, n int) error { c.lenScratch[0] = prefix s := strconv.AppendInt(c.lenScratch[:1], int64(n), 10) s = append(s, "\r\n"...) _, err := c.bw.Write(s) return err } func (c *conn) writeString(s string) error { c.writeLen('$', len(s)) c.bw.WriteString(s) _, err := c.bw.WriteString("\r\n") return err } func (c *conn) writeBytes(p []byte) error { c.writeLen('$', len(p)) c.bw.Write(p) _, err := c.bw.WriteString("\r\n") return err } func (c *conn) writeInt64(n int64) error { return c.writeBytes(strconv.AppendInt(c.numScratch[:0], n, 10)) } func (c *conn) writeFloat64(n float64) error { return c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64)) } func (c *conn) writeCommand(cmd string, args []interface{}) (err error) { c.writeLen('*', 1+len(args)) err = c.writeString(cmd) for _, arg := range args { if err != nil { break } switch arg := arg.(type) { case string: err = c.writeString(arg) case []byte: err = c.writeBytes(arg) case int: err = c.writeInt64(int64(arg)) case int64: err = c.writeInt64(arg) case float64: err = c.writeFloat64(arg) case bool: if arg { err = c.writeString("1") } else { err = c.writeString("0") } case nil: err = c.writeString("") default: var buf bytes.Buffer fmt.Fprint(&buf, arg) err = c.writeBytes(buf.Bytes()) } } return err } func (c *conn) readLine() ([]byte, error) { p, err := c.br.ReadSlice('\n') if err == bufio.ErrBufferFull { return nil, errors.New("redigo: long response line") } if err != nil { return nil, err } i := len(p) - 2 if i < 0 || p[i] != '\r' { return nil, errors.New("redigo: bad response line terminator") } return p[:i], nil } func (c *conn) readReply() (interface{}, error) { line, err := c.readLine() if err != nil { return nil, err } if len(line) == 0 { return nil, errors.New("redigo: short response line") } switch line[0] { case '+': return string(line[1:]), nil case '-': return Error(string(line[1:])), nil case ':': n, err := strconv.ParseInt(string(line[1:]), 10, 64) if err != nil { return nil, err } return n, nil case '$': n, err := strconv.Atoi(string(line[1:])) if err != nil || n < 0 { return nil, err } p := make([]byte, n) _, err = io.ReadFull(c.br, p) if err != nil { return nil, err } line, err := c.readLine() if err != nil { return nil, err } if len(line) != 0 { return nil, errors.New("redigo: bad bulk format") } return p, nil case '*': n, err := strconv.Atoi(string(line[1:])) if err != nil || n < 0 { return nil, err } r := make([]interface{}, n) for i := range r { r[i], err = c.readReply() if err != nil { return nil, err } } return r, nil } return nil, errors.New("redigo: unexpected response line") } func (c *conn) Send(cmd string, args ...interface{}) error { c.mu.Lock() c.pending += 1 c.mu.Unlock() if c.writeTimeout != 0 { c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) } if err := c.writeCommand(cmd, args); err != nil { return c.fatal(err) } return nil } func (c *conn) Flush() error { if c.writeTimeout != 0 { c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) } if err := c.bw.Flush(); err != nil { return c.fatal(err) } return nil } func (c *conn) Receive() (reply interface{}, err error) { c.mu.Lock() // There can be more receives than sends when using pub/sub. To allow // normal use of the connection after unsubscribe from all channels, do not // decrement pending to a negative value. if c.pending > 0 { c.pending -= 1 } c.mu.Unlock() if c.readTimeout != 0 { c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) } if reply, err = c.readReply(); err != nil { return nil, c.fatal(err) } if err, ok := reply.(Error); ok { return nil, err } return } func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) { if c.writeTimeout != 0 { c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) } if cmd != "" { c.writeCommand(cmd, args) } if err := c.bw.Flush(); err != nil { return nil, c.fatal(err) } c.mu.Lock() pending := c.pending c.pending = 0 c.mu.Unlock() if c.readTimeout != 0 { c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) } if cmd == "" { reply := make([]interface{}, pending) for i := range reply { if r, e := c.readReply(); e != nil { return nil, c.fatal(e) } else { reply[i] = r } } return reply, nil } var err error var reply interface{} for i := 0; i <= pending; i++ { var e error if reply, e = c.readReply(); e != nil { return nil, c.fatal(e) } if e, ok := reply.(Error); ok && err == nil { err = e } } return reply, err }