| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480 |
- // Copyright (c) 2012 The gocql Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- // The gocql package provides a database/sql driver for CQL, the Cassandra
- // query language.
- //
- // This package requires a recent version of Cassandra (≥ 1.2) that supports
- // CQL 3.0 and the new native protocol. The native protocol is still considered
- // beta and must be enabled manually in Cassandra 1.2 by setting
- // "start_native_transport" to true in conf/cassandra.yaml.
- //
- // Example Usage:
- //
- // db, err := sql.Open("gocql", "localhost:9042 keyspace=system")
- // // ...
- // rows, err := db.Query("SELECT keyspace_name FROM schema_keyspaces")
- // // ...
- // for rows.Next() {
- // var keyspace string
- // err = rows.Scan(&keyspace)
- // // ...
- // fmt.Println(keyspace)
- // }
- // if err := rows.Err(); err != nil {
- // // ...
- // }
- //
- package gocql
- import (
- "bytes"
- "code.google.com/p/snappy-go/snappy"
- "database/sql"
- "database/sql/driver"
- "encoding/binary"
- "fmt"
- "io"
- "math/rand"
- "net"
- "strings"
- )
- const (
- protoRequest byte = 0x01
- protoResponse byte = 0x81
- opError byte = 0x00
- opStartup byte = 0x01
- opReady byte = 0x02
- opAuthenticate byte = 0x03
- opCredentials byte = 0x04
- opOptions byte = 0x05
- opSupported byte = 0x06
- opQuery byte = 0x07
- opResult byte = 0x08
- opPrepare byte = 0x09
- opExecute byte = 0x0A
- opLAST byte = 0x0A // not a real opcode -- used to check for valid opcodes
- flagCompressed byte = 0x01
- keyVersion string = "CQL_VERSION"
- keyCompression string = "COMPRESSION"
- )
- var consistencyLevels = map[string]byte{"any": 0x00, "one": 0x01, "two": 0x02,
- "three": 0x03, "quorum": 0x04, "all": 0x05, "local_quorum": 0x06, "each_quorum": 0x07}
- var rnd = rand.New(rand.NewSource(0))
- type drv struct{}
- func (d drv) Open(name string) (driver.Conn, error) {
- return Open(name)
- }
- type connection struct {
- c net.Conn
- compression string
- consistency byte
- }
- func Open(name string) (*connection, error) {
- parts := strings.Split(name, " ")
- address := ""
- if len(parts) >= 1 {
- addresses := strings.Split(parts[0], ",")
- if len(addresses) > 0 {
- address = addresses[rnd.Intn(len(addresses))]
- }
- }
- c, err := net.Dial("tcp", address)
- if err != nil {
- return nil, err
- }
- version := "3.0.0"
- var (
- keyspace string
- compression string
- consistency byte = 0x01
- ok bool
- )
- for i := 1; i < len(parts); i++ {
- switch {
- case parts[i] == "":
- continue
- case strings.HasPrefix(parts[i], "keyspace="):
- keyspace = strings.TrimSpace(parts[i][9:])
- case strings.HasPrefix(parts[i], "compression="):
- compression = strings.TrimSpace(parts[i][12:])
- if compression != "snappy" {
- return nil, fmt.Errorf("unknown compression algorithm %q",
- compression)
- }
- case strings.HasPrefix(parts[i], "version="):
- version = strings.TrimSpace(parts[i][8:])
- case strings.HasPrefix(parts[i], "consistency="):
- cs := strings.TrimSpace(parts[i][12:])
- if consistency, ok = consistencyLevels[cs]; !ok {
- return nil, fmt.Errorf("unknown consistency level %q", cs)
- }
- default:
- return nil, fmt.Errorf("unsupported option %q", parts[i])
- }
- }
- cn := &connection{c: c, compression: compression, consistency: consistency}
- b := &bytes.Buffer{}
- if compression != "" {
- binary.Write(b, binary.BigEndian, uint16(2))
- } else {
- binary.Write(b, binary.BigEndian, uint16(1))
- }
- binary.Write(b, binary.BigEndian, uint16(len(keyVersion)))
- b.WriteString(keyVersion)
- binary.Write(b, binary.BigEndian, uint16(len(version)))
- b.WriteString(version)
- if compression != "" {
- binary.Write(b, binary.BigEndian, uint16(len(keyCompression)))
- b.WriteString(keyCompression)
- binary.Write(b, binary.BigEndian, uint16(len(compression)))
- b.WriteString(compression)
- }
- if err := cn.sendUncompressed(opStartup, b.Bytes()); err != nil {
- return nil, err
- }
- opcode, _, err := cn.recv()
- if err != nil {
- return nil, err
- }
- if opcode != opReady {
- return nil, fmt.Errorf("connection not ready")
- }
- if keyspace != "" {
- st, err := cn.Prepare(fmt.Sprintf("USE %s", keyspace))
- if err != nil {
- return nil, err
- }
- if _, err = st.Exec([]driver.Value{}); err != nil {
- return nil, err
- }
- }
- return cn, nil
- }
- // close a connection actively, typically used when there's an error and we want to ensure
- // we don't repeatedly try to use the broken connection
- func (cn *connection) close() {
- cn.c.Close()
- cn.c = nil // ensure we generate ErrBadConn when cn gets reused
- }
- // explicitly send a request as uncompressed
- // This is only really needed for the "startup" handshake
- func (cn *connection) sendUncompressed(opcode byte, body []byte) error {
- return cn._send(opcode, body, false)
- }
- func (cn *connection) send(opcode byte, body []byte) error {
- return cn._send(opcode, body, cn.compression == "snappy" && len(body) > 0)
- }
- func (cn *connection) _send(opcode byte, body []byte, compression bool) error {
- if cn.c == nil {
- return driver.ErrBadConn
- }
- var flags byte = 0x00
- if compression {
- var err error
- body, err = snappy.Encode(nil, body)
- if err != nil {
- return err
- }
- flags = flagCompressed
- }
- frame := make([]byte, len(body)+8)
- frame[0] = protoRequest
- frame[1] = flags
- frame[2] = 0
- frame[3] = opcode
- binary.BigEndian.PutUint32(frame[4:8], uint32(len(body)))
- copy(frame[8:], body)
- if _, err := cn.c.Write(frame); err != nil {
- return err
- }
- return nil
- }
- func (cn *connection) recv() (byte, []byte, error) {
- if cn.c == nil {
- return 0, nil, driver.ErrBadConn
- }
- header := make([]byte, 8)
- if _, err := io.ReadFull(cn.c, header); err != nil {
- cn.close() // better assume that the connection is broken (may have read some bytes)
- return 0, nil, err
- }
- // verify that the frame starts with version==1 and req/resp flag==response
- // this may be overly conservative in that future versions may be backwards compatible
- // in that case simply amend the check...
- if header[0] != protoResponse {
- cn.close()
- return 0, nil, fmt.Errorf("unsupported frame version or not a response: 0x%x (header=%v)", header[0], header)
- }
- // verify that the flags field has only a single flag set, again, this may
- // be overly conservative if additional flags are backwards-compatible
- if header[1] > 1 {
- cn.close()
- return 0, nil, fmt.Errorf("unsupported frame flags: 0x%x (header=%v)", header[1], header)
- }
- opcode := header[3]
- if opcode > opLAST {
- cn.close()
- return 0, nil, fmt.Errorf("unknown opcode: 0x%x (header=%v)", opcode, header)
- }
- length := binary.BigEndian.Uint32(header[4:8])
- var body []byte
- if length > 0 {
- if length > 256*1024*1024 { // spec says 256MB is max
- cn.close()
- return 0, nil, fmt.Errorf("frame too large: %d (header=%v)", length, header)
- }
- body = make([]byte, length)
- if _, err := io.ReadFull(cn.c, body); err != nil {
- cn.close() // better assume that the connection is broken
- return 0, nil, err
- }
- }
- if header[1]&flagCompressed != 0 && cn.compression == "snappy" {
- var err error
- body, err = snappy.Decode(nil, body)
- if err != nil {
- cn.close()
- return 0, nil, err
- }
- }
- if opcode == opError {
- code := binary.BigEndian.Uint32(body[0:4])
- msglen := binary.BigEndian.Uint16(body[4:6])
- msg := string(body[6 : 6+msglen])
- return opcode, body, Error{Code: int(code), Msg: msg}
- }
- return opcode, body, nil
- }
- func (cn *connection) Begin() (driver.Tx, error) {
- if cn.c == nil {
- return nil, driver.ErrBadConn
- }
- return cn, nil
- }
- func (cn *connection) Commit() error {
- if cn.c == nil {
- return driver.ErrBadConn
- }
- return nil
- }
- func (cn *connection) Close() error {
- if cn.c == nil {
- return driver.ErrBadConn
- }
- cn.close()
- return nil
- }
- func (cn *connection) Rollback() error {
- if cn.c == nil {
- return driver.ErrBadConn
- }
- return nil
- }
- func (cn *connection) Prepare(query string) (driver.Stmt, error) {
- body := make([]byte, len(query)+4)
- binary.BigEndian.PutUint32(body[0:4], uint32(len(query)))
- copy(body[4:], []byte(query))
- if err := cn.send(opPrepare, body); err != nil {
- return nil, err
- }
- opcode, body, err := cn.recv()
- if err != nil {
- return nil, err
- }
- if opcode != opResult || binary.BigEndian.Uint32(body) != 4 {
- return nil, fmt.Errorf("expected prepared result")
- }
- n := int(binary.BigEndian.Uint16(body[4:]))
- prepared := body[6 : 6+n]
- columns, meta, _ := parseMeta(body[6+n:])
- return &statement{cn: cn, query: query,
- prepared: prepared, columns: columns, meta: meta}, nil
- }
- type statement struct {
- cn *connection
- query string
- prepared []byte
- columns []string
- meta []uint16
- }
- func (s *statement) Close() error {
- return nil
- }
- func (st *statement) ColumnConverter(idx int) driver.ValueConverter {
- return (&columnEncoder{st.meta}).ColumnConverter(idx)
- }
- func (st *statement) NumInput() int {
- return len(st.columns)
- }
- func parseMeta(body []byte) ([]string, []uint16, int) {
- flags := binary.BigEndian.Uint32(body)
- globalTableSpec := flags&1 == 1
- columnCount := int(binary.BigEndian.Uint32(body[4:]))
- i := 8
- if globalTableSpec {
- l := int(binary.BigEndian.Uint16(body[i:]))
- keyspace := string(body[i+2 : i+2+l])
- i += 2 + l
- l = int(binary.BigEndian.Uint16(body[i:]))
- tablename := string(body[i+2 : i+2+l])
- i += 2 + l
- _, _ = keyspace, tablename
- }
- columns := make([]string, columnCount)
- meta := make([]uint16, columnCount)
- for c := 0; c < columnCount; c++ {
- l := int(binary.BigEndian.Uint16(body[i:]))
- columns[c] = string(body[i+2 : i+2+l])
- i += 2 + l
- meta[c] = binary.BigEndian.Uint16(body[i:])
- i += 2
- }
- return columns, meta, i
- }
- func (st *statement) exec(v []driver.Value) error {
- sz := 6 + len(st.prepared)
- for i := range v {
- if b, ok := v[i].([]byte); ok {
- sz += len(b) + 4
- }
- }
- body, p := make([]byte, sz), 4+len(st.prepared)
- binary.BigEndian.PutUint16(body, uint16(len(st.prepared)))
- copy(body[2:], st.prepared)
- binary.BigEndian.PutUint16(body[p-2:], uint16(len(v)))
- for i := range v {
- b, ok := v[i].([]byte)
- if !ok {
- return fmt.Errorf("unsupported type %T at column %d", v[i], i)
- }
- binary.BigEndian.PutUint32(body[p:], uint32(len(b)))
- copy(body[p+4:], b)
- p += 4 + len(b)
- }
- binary.BigEndian.PutUint16(body[p:], uint16(st.cn.consistency))
- if err := st.cn.send(opExecute, body); err != nil {
- return err
- }
- return nil
- }
- func (st *statement) Exec(v []driver.Value) (driver.Result, error) {
- if err := st.exec(v); err != nil {
- return nil, err
- }
- opcode, body, err := st.cn.recv()
- if err != nil {
- return nil, err
- }
- _, _ = opcode, body
- return nil, nil
- }
- func (st *statement) Query(v []driver.Value) (driver.Rows, error) {
- if err := st.exec(v); err != nil {
- return nil, err
- }
- opcode, body, err := st.cn.recv()
- if err != nil {
- return nil, err
- }
- kind := binary.BigEndian.Uint32(body[0:4])
- if opcode != opResult || kind != 2 {
- return nil, fmt.Errorf("expected rows as result")
- }
- columns, meta, n := parseMeta(body[4:])
- i := n + 4
- rows := &rows{
- columns: columns,
- meta: meta,
- numRows: int(binary.BigEndian.Uint32(body[i:])),
- }
- i += 4
- rows.body = body[i:]
- return rows, nil
- }
- type rows struct {
- columns []string
- meta []uint16
- body []byte
- row int
- numRows int
- }
- func (r *rows) Close() error {
- return nil
- }
- func (r *rows) Columns() []string {
- return r.columns
- }
- func (r *rows) Next(values []driver.Value) error {
- if r.row >= r.numRows {
- return io.EOF
- }
- for column := 0; column < len(r.columns); column++ {
- n := int32(binary.BigEndian.Uint32(r.body))
- r.body = r.body[4:]
- if n >= 0 {
- values[column] = decode(r.body[:n], r.meta[column])
- r.body = r.body[n:]
- } else {
- values[column] = nil
- }
- }
- r.row++
- return nil
- }
- type Error struct {
- Code int
- Msg string
- }
- func (e Error) Error() string {
- return e.Msg
- }
- func init() {
- sql.Register("gocql", &drv{})
- }
|