gocql.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. // Copyright (c) 2012 The gocql Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // The gocql package provides a database/sql driver for CQL, the Cassandra
  5. // query language.
  6. //
  7. // This package requires a recent version of Cassandra (≥ 1.2) that supports
  8. // CQL 3.0 and the new native protocol. The native protocol is still considered
  9. // beta and must be enabled manually in Cassandra 1.2 by setting
  10. // "start_native_transport" to true in conf/cassandra.yaml.
  11. //
  12. // Example Usage:
  13. //
  14. // db, err := sql.Open("gocql", "localhost:9042 keyspace=system")
  15. // // ...
  16. // rows, err := db.Query("SELECT keyspace_name FROM schema_keyspaces")
  17. // // ...
  18. // for rows.Next() {
  19. // var keyspace string
  20. // err = rows.Scan(&keyspace)
  21. // // ...
  22. // fmt.Println(keyspace)
  23. // }
  24. // if err := rows.Err(); err != nil {
  25. // // ...
  26. // }
  27. //
  28. package gocql
  29. import (
  30. "bytes"
  31. "code.google.com/p/snappy-go/snappy"
  32. "database/sql"
  33. "database/sql/driver"
  34. "encoding/binary"
  35. "fmt"
  36. "io"
  37. "math/rand"
  38. "net"
  39. "strings"
  40. )
  41. const (
  42. protoRequest byte = 0x01
  43. protoResponse byte = 0x81
  44. opError byte = 0x00
  45. opStartup byte = 0x01
  46. opReady byte = 0x02
  47. opAuthenticate byte = 0x03
  48. opCredentials byte = 0x04
  49. opOptions byte = 0x05
  50. opSupported byte = 0x06
  51. opQuery byte = 0x07
  52. opResult byte = 0x08
  53. opPrepare byte = 0x09
  54. opExecute byte = 0x0A
  55. opLAST byte = 0x0A // not a real opcode -- used to check for valid opcodes
  56. flagCompressed byte = 0x01
  57. keyVersion string = "CQL_VERSION"
  58. keyCompression string = "COMPRESSION"
  59. )
  60. var consistencyLevels = map[string]byte{"any": 0x00, "one": 0x01, "two": 0x02,
  61. "three": 0x03, "quorum": 0x04, "all": 0x05, "local_quorum": 0x06, "each_quorum": 0x07}
  62. var rnd = rand.New(rand.NewSource(0))
  63. type drv struct{}
  64. func (d drv) Open(name string) (driver.Conn, error) {
  65. return Open(name)
  66. }
  67. type connection struct {
  68. c net.Conn
  69. compression string
  70. consistency byte
  71. }
  72. func Open(name string) (*connection, error) {
  73. parts := strings.Split(name, " ")
  74. address := ""
  75. if len(parts) >= 1 {
  76. addresses := strings.Split(parts[0], ",")
  77. if len(addresses) > 0 {
  78. address = addresses[rnd.Intn(len(addresses))]
  79. }
  80. }
  81. c, err := net.Dial("tcp", address)
  82. if err != nil {
  83. return nil, err
  84. }
  85. version := "3.0.0"
  86. var (
  87. keyspace string
  88. compression string
  89. consistency byte = 0x01
  90. ok bool
  91. )
  92. for i := 1; i < len(parts); i++ {
  93. switch {
  94. case parts[i] == "":
  95. continue
  96. case strings.HasPrefix(parts[i], "keyspace="):
  97. keyspace = strings.TrimSpace(parts[i][9:])
  98. case strings.HasPrefix(parts[i], "compression="):
  99. compression = strings.TrimSpace(parts[i][12:])
  100. if compression != "snappy" {
  101. return nil, fmt.Errorf("unknown compression algorithm %q",
  102. compression)
  103. }
  104. case strings.HasPrefix(parts[i], "version="):
  105. version = strings.TrimSpace(parts[i][8:])
  106. case strings.HasPrefix(parts[i], "consistency="):
  107. cs := strings.TrimSpace(parts[i][12:])
  108. if consistency, ok = consistencyLevels[cs]; !ok {
  109. return nil, fmt.Errorf("unknown consistency level %q", cs)
  110. }
  111. default:
  112. return nil, fmt.Errorf("unsupported option %q", parts[i])
  113. }
  114. }
  115. cn := &connection{c: c, compression: compression, consistency: consistency}
  116. b := &bytes.Buffer{}
  117. if compression != "" {
  118. binary.Write(b, binary.BigEndian, uint16(2))
  119. } else {
  120. binary.Write(b, binary.BigEndian, uint16(1))
  121. }
  122. binary.Write(b, binary.BigEndian, uint16(len(keyVersion)))
  123. b.WriteString(keyVersion)
  124. binary.Write(b, binary.BigEndian, uint16(len(version)))
  125. b.WriteString(version)
  126. if compression != "" {
  127. binary.Write(b, binary.BigEndian, uint16(len(keyCompression)))
  128. b.WriteString(keyCompression)
  129. binary.Write(b, binary.BigEndian, uint16(len(compression)))
  130. b.WriteString(compression)
  131. }
  132. if err := cn.sendUncompressed(opStartup, b.Bytes()); err != nil {
  133. return nil, err
  134. }
  135. opcode, _, err := cn.recv()
  136. if err != nil {
  137. return nil, err
  138. }
  139. if opcode != opReady {
  140. return nil, fmt.Errorf("connection not ready")
  141. }
  142. if keyspace != "" {
  143. st, err := cn.Prepare(fmt.Sprintf("USE %s", keyspace))
  144. if err != nil {
  145. return nil, err
  146. }
  147. if _, err = st.Exec([]driver.Value{}); err != nil {
  148. return nil, err
  149. }
  150. }
  151. return cn, nil
  152. }
  153. // close a connection actively, typically used when there's an error and we want to ensure
  154. // we don't repeatedly try to use the broken connection
  155. func (cn *connection) close() {
  156. cn.c.Close()
  157. cn.c = nil // ensure we generate ErrBadConn when cn gets reused
  158. }
  159. // explicitly send a request as uncompressed
  160. // This is only really needed for the "startup" handshake
  161. func (cn *connection) sendUncompressed(opcode byte, body []byte) error {
  162. return cn._send(opcode, body, false)
  163. }
  164. func (cn *connection) send(opcode byte, body []byte) error {
  165. return cn._send(opcode, body, cn.compression == "snappy" && len(body) > 0)
  166. }
  167. func (cn *connection) _send(opcode byte, body []byte, compression bool) error {
  168. if cn.c == nil {
  169. return driver.ErrBadConn
  170. }
  171. var flags byte = 0x00
  172. if compression {
  173. var err error
  174. body, err = snappy.Encode(nil, body)
  175. if err != nil {
  176. return err
  177. }
  178. flags = flagCompressed
  179. }
  180. frame := make([]byte, len(body)+8)
  181. frame[0] = protoRequest
  182. frame[1] = flags
  183. frame[2] = 0
  184. frame[3] = opcode
  185. binary.BigEndian.PutUint32(frame[4:8], uint32(len(body)))
  186. copy(frame[8:], body)
  187. if _, err := cn.c.Write(frame); err != nil {
  188. return err
  189. }
  190. return nil
  191. }
  192. func (cn *connection) recv() (byte, []byte, error) {
  193. if cn.c == nil {
  194. return 0, nil, driver.ErrBadConn
  195. }
  196. header := make([]byte, 8)
  197. if _, err := io.ReadFull(cn.c, header); err != nil {
  198. cn.close() // better assume that the connection is broken (may have read some bytes)
  199. return 0, nil, err
  200. }
  201. // verify that the frame starts with version==1 and req/resp flag==response
  202. // this may be overly conservative in that future versions may be backwards compatible
  203. // in that case simply amend the check...
  204. if header[0] != protoResponse {
  205. cn.close()
  206. return 0, nil, fmt.Errorf("unsupported frame version or not a response: 0x%x (header=%v)", header[0], header)
  207. }
  208. // verify that the flags field has only a single flag set, again, this may
  209. // be overly conservative if additional flags are backwards-compatible
  210. if header[1] > 1 {
  211. cn.close()
  212. return 0, nil, fmt.Errorf("unsupported frame flags: 0x%x (header=%v)", header[1], header)
  213. }
  214. opcode := header[3]
  215. if opcode > opLAST {
  216. cn.close()
  217. return 0, nil, fmt.Errorf("unknown opcode: 0x%x (header=%v)", opcode, header)
  218. }
  219. length := binary.BigEndian.Uint32(header[4:8])
  220. var body []byte
  221. if length > 0 {
  222. if length > 256*1024*1024 { // spec says 256MB is max
  223. cn.close()
  224. return 0, nil, fmt.Errorf("frame too large: %d (header=%v)", length, header)
  225. }
  226. body = make([]byte, length)
  227. if _, err := io.ReadFull(cn.c, body); err != nil {
  228. cn.close() // better assume that the connection is broken
  229. return 0, nil, err
  230. }
  231. }
  232. if header[1]&flagCompressed != 0 && cn.compression == "snappy" {
  233. var err error
  234. body, err = snappy.Decode(nil, body)
  235. if err != nil {
  236. cn.close()
  237. return 0, nil, err
  238. }
  239. }
  240. if opcode == opError {
  241. code := binary.BigEndian.Uint32(body[0:4])
  242. msglen := binary.BigEndian.Uint16(body[4:6])
  243. msg := string(body[6 : 6+msglen])
  244. return opcode, body, Error{Code: int(code), Msg: msg}
  245. }
  246. return opcode, body, nil
  247. }
  248. func (cn *connection) Begin() (driver.Tx, error) {
  249. if cn.c == nil {
  250. return nil, driver.ErrBadConn
  251. }
  252. return cn, nil
  253. }
  254. func (cn *connection) Commit() error {
  255. if cn.c == nil {
  256. return driver.ErrBadConn
  257. }
  258. return nil
  259. }
  260. func (cn *connection) Close() error {
  261. if cn.c == nil {
  262. return driver.ErrBadConn
  263. }
  264. cn.close()
  265. return nil
  266. }
  267. func (cn *connection) Rollback() error {
  268. if cn.c == nil {
  269. return driver.ErrBadConn
  270. }
  271. return nil
  272. }
  273. func (cn *connection) Prepare(query string) (driver.Stmt, error) {
  274. body := make([]byte, len(query)+4)
  275. binary.BigEndian.PutUint32(body[0:4], uint32(len(query)))
  276. copy(body[4:], []byte(query))
  277. if err := cn.send(opPrepare, body); err != nil {
  278. return nil, err
  279. }
  280. opcode, body, err := cn.recv()
  281. if err != nil {
  282. return nil, err
  283. }
  284. if opcode != opResult || binary.BigEndian.Uint32(body) != 4 {
  285. return nil, fmt.Errorf("expected prepared result")
  286. }
  287. n := int(binary.BigEndian.Uint16(body[4:]))
  288. prepared := body[6 : 6+n]
  289. columns, meta, _ := parseMeta(body[6+n:])
  290. return &statement{cn: cn, query: query,
  291. prepared: prepared, columns: columns, meta: meta}, nil
  292. }
  293. type statement struct {
  294. cn *connection
  295. query string
  296. prepared []byte
  297. columns []string
  298. meta []uint16
  299. }
  300. func (s *statement) Close() error {
  301. return nil
  302. }
  303. func (st *statement) ColumnConverter(idx int) driver.ValueConverter {
  304. return (&columnEncoder{st.meta}).ColumnConverter(idx)
  305. }
  306. func (st *statement) NumInput() int {
  307. return len(st.columns)
  308. }
  309. func parseMeta(body []byte) ([]string, []uint16, int) {
  310. flags := binary.BigEndian.Uint32(body)
  311. globalTableSpec := flags&1 == 1
  312. columnCount := int(binary.BigEndian.Uint32(body[4:]))
  313. i := 8
  314. if globalTableSpec {
  315. l := int(binary.BigEndian.Uint16(body[i:]))
  316. keyspace := string(body[i+2 : i+2+l])
  317. i += 2 + l
  318. l = int(binary.BigEndian.Uint16(body[i:]))
  319. tablename := string(body[i+2 : i+2+l])
  320. i += 2 + l
  321. _, _ = keyspace, tablename
  322. }
  323. columns := make([]string, columnCount)
  324. meta := make([]uint16, columnCount)
  325. for c := 0; c < columnCount; c++ {
  326. l := int(binary.BigEndian.Uint16(body[i:]))
  327. columns[c] = string(body[i+2 : i+2+l])
  328. i += 2 + l
  329. meta[c] = binary.BigEndian.Uint16(body[i:])
  330. i += 2
  331. }
  332. return columns, meta, i
  333. }
  334. func (st *statement) exec(v []driver.Value) error {
  335. sz := 6 + len(st.prepared)
  336. for i := range v {
  337. if b, ok := v[i].([]byte); ok {
  338. sz += len(b) + 4
  339. }
  340. }
  341. body, p := make([]byte, sz), 4+len(st.prepared)
  342. binary.BigEndian.PutUint16(body, uint16(len(st.prepared)))
  343. copy(body[2:], st.prepared)
  344. binary.BigEndian.PutUint16(body[p-2:], uint16(len(v)))
  345. for i := range v {
  346. b, ok := v[i].([]byte)
  347. if !ok {
  348. return fmt.Errorf("unsupported type %T at column %d", v[i], i)
  349. }
  350. binary.BigEndian.PutUint32(body[p:], uint32(len(b)))
  351. copy(body[p+4:], b)
  352. p += 4 + len(b)
  353. }
  354. binary.BigEndian.PutUint16(body[p:], uint16(st.cn.consistency))
  355. if err := st.cn.send(opExecute, body); err != nil {
  356. return err
  357. }
  358. return nil
  359. }
  360. func (st *statement) Exec(v []driver.Value) (driver.Result, error) {
  361. if err := st.exec(v); err != nil {
  362. return nil, err
  363. }
  364. opcode, body, err := st.cn.recv()
  365. if err != nil {
  366. return nil, err
  367. }
  368. _, _ = opcode, body
  369. return nil, nil
  370. }
  371. func (st *statement) Query(v []driver.Value) (driver.Rows, error) {
  372. if err := st.exec(v); err != nil {
  373. return nil, err
  374. }
  375. opcode, body, err := st.cn.recv()
  376. if err != nil {
  377. return nil, err
  378. }
  379. kind := binary.BigEndian.Uint32(body[0:4])
  380. if opcode != opResult || kind != 2 {
  381. return nil, fmt.Errorf("expected rows as result")
  382. }
  383. columns, meta, n := parseMeta(body[4:])
  384. i := n + 4
  385. rows := &rows{
  386. columns: columns,
  387. meta: meta,
  388. numRows: int(binary.BigEndian.Uint32(body[i:])),
  389. }
  390. i += 4
  391. rows.body = body[i:]
  392. return rows, nil
  393. }
  394. type rows struct {
  395. columns []string
  396. meta []uint16
  397. body []byte
  398. row int
  399. numRows int
  400. }
  401. func (r *rows) Close() error {
  402. return nil
  403. }
  404. func (r *rows) Columns() []string {
  405. return r.columns
  406. }
  407. func (r *rows) Next(values []driver.Value) error {
  408. if r.row >= r.numRows {
  409. return io.EOF
  410. }
  411. for column := 0; column < len(r.columns); column++ {
  412. n := int32(binary.BigEndian.Uint32(r.body))
  413. r.body = r.body[4:]
  414. if n >= 0 {
  415. values[column] = decode(r.body[:n], r.meta[column])
  416. r.body = r.body[n:]
  417. } else {
  418. values[column] = nil
  419. }
  420. }
  421. r.row++
  422. return nil
  423. }
  424. type Error struct {
  425. Code int
  426. Msg string
  427. }
  428. func (e Error) Error() string {
  429. return e.Msg
  430. }
  431. func init() {
  432. sql.Register("gocql", &drv{})
  433. }