gocql.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620
  1. // Copyright (c) 2012 The gocql Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // The gocql package provides a database/sql driver for CQL, the Cassandra
  5. // query language.
  6. //
  7. // This package requires a recent version of Cassandra (≥ 1.2) that supports
  8. // CQL 3.0 and the new native protocol. The native protocol is still considered
  9. // beta and must be enabled manually in Cassandra 1.2 by setting
  10. // "start_native_transport" to true in conf/cassandra.yaml.
  11. //
  12. // Example Usage:
  13. //
  14. // db, err := sql.Open("gocql", "localhost:9042 keyspace=system")
  15. // // ...
  16. // rows, err := db.Query("SELECT keyspace_name FROM schema_keyspaces")
  17. // // ...
  18. // for rows.Next() {
  19. // var keyspace string
  20. // err = rows.Scan(&keyspace)
  21. // // ...
  22. // fmt.Println(keyspace)
  23. // }
  24. // if err := rows.Err(); err != nil {
  25. // // ...
  26. // }
  27. //
  28. package gocqldriver
  29. import (
  30. "bytes"
  31. "code.google.com/p/snappy-go/snappy"
  32. "database/sql"
  33. "database/sql/driver"
  34. "encoding/binary"
  35. "fmt"
  36. "io"
  37. "net"
  38. "strings"
  39. "time"
  40. )
  41. const (
  42. protoRequest byte = 0x01
  43. protoResponse byte = 0x81
  44. opError byte = 0x00
  45. opStartup byte = 0x01
  46. opReady byte = 0x02
  47. opAuthenticate byte = 0x03
  48. opCredentials byte = 0x04
  49. opOptions byte = 0x05
  50. opSupported byte = 0x06
  51. opQuery byte = 0x07
  52. opResult byte = 0x08
  53. opPrepare byte = 0x09
  54. opExecute byte = 0x0A
  55. opLAST byte = 0x0A // not a real opcode -- used to check for valid opcodes
  56. flagCompressed byte = 0x01
  57. keyVersion string = "CQL_VERSION"
  58. keyCompression string = "COMPRESSION"
  59. keyspaceQuery string = "USE "
  60. )
  61. var consistencyLevels = map[string]byte{"any": 0x00, "one": 0x01, "two": 0x02,
  62. "three": 0x03, "quorum": 0x04, "all": 0x05, "local_quorum": 0x06, "each_quorum": 0x07}
  63. type drv struct{}
  64. func (d drv) Open(name string) (driver.Conn, error) {
  65. return Open(name)
  66. }
  67. type connection struct {
  68. c net.Conn
  69. address string
  70. alive bool
  71. pool *pool
  72. }
  73. type pool struct {
  74. connections []*connection
  75. i int
  76. keyspace string
  77. version string
  78. compression string
  79. consistency byte
  80. dead bool
  81. stop chan struct{}
  82. }
  83. func Open(name string) (*pool, error) {
  84. parts := strings.Split(name, " ")
  85. var addresses []string
  86. if len(parts) >= 1 {
  87. addresses = strings.Split(parts[0], ",")
  88. }
  89. version := "3.0.0"
  90. var (
  91. keyspace string
  92. compression string
  93. consistency byte = 0x01
  94. ok bool
  95. )
  96. for i := 1; i < len(parts); i++ {
  97. switch {
  98. case parts[i] == "":
  99. continue
  100. case strings.HasPrefix(parts[i], "keyspace="):
  101. keyspace = strings.TrimSpace(parts[i][9:])
  102. case strings.HasPrefix(parts[i], "compression="):
  103. compression = strings.TrimSpace(parts[i][12:])
  104. if compression != "snappy" {
  105. return nil, fmt.Errorf("unknown compression algorithm %q",
  106. compression)
  107. }
  108. case strings.HasPrefix(parts[i], "version="):
  109. version = strings.TrimSpace(parts[i][8:])
  110. case strings.HasPrefix(parts[i], "consistency="):
  111. cs := strings.TrimSpace(parts[i][12:])
  112. if consistency, ok = consistencyLevels[cs]; !ok {
  113. return nil, fmt.Errorf("unknown consistency level %q", cs)
  114. }
  115. default:
  116. return nil, fmt.Errorf("unsupported option %q", parts[i])
  117. }
  118. }
  119. pool := &pool{
  120. keyspace: keyspace,
  121. version: version,
  122. compression: compression,
  123. consistency: consistency,
  124. stop: make(chan struct{}),
  125. }
  126. for _, address := range addresses {
  127. pool.connections = append(pool.connections, &connection{address: address, pool: pool})
  128. }
  129. pool.join()
  130. return pool, nil
  131. }
  132. func (cn *connection) open() {
  133. cn.alive = false
  134. var err error
  135. cn.c, err = net.Dial("tcp", cn.address)
  136. if err != nil {
  137. return
  138. }
  139. var (
  140. version = cn.pool.version
  141. compression = cn.pool.compression
  142. keyspace = cn.pool.keyspace
  143. )
  144. b := &bytes.Buffer{}
  145. if compression != "" {
  146. binary.Write(b, binary.BigEndian, uint16(2))
  147. } else {
  148. binary.Write(b, binary.BigEndian, uint16(1))
  149. }
  150. binary.Write(b, binary.BigEndian, uint16(len(keyVersion)))
  151. b.WriteString(keyVersion)
  152. binary.Write(b, binary.BigEndian, uint16(len(version)))
  153. b.WriteString(version)
  154. if compression != "" {
  155. binary.Write(b, binary.BigEndian, uint16(len(keyCompression)))
  156. b.WriteString(keyCompression)
  157. binary.Write(b, binary.BigEndian, uint16(len(compression)))
  158. b.WriteString(compression)
  159. }
  160. if err := cn.sendUncompressed(opStartup, b.Bytes()); err != nil {
  161. return
  162. }
  163. opcode, _, err := cn.recv()
  164. if err != nil {
  165. return
  166. }
  167. if opcode != opReady {
  168. return
  169. }
  170. if keyspace != "" {
  171. cn.UseKeyspace(keyspace)
  172. }
  173. cn.alive = true
  174. }
  175. // close a connection actively, typically used when there's an error and we want to ensure
  176. // we don't repeatedly try to use the broken connection
  177. func (cn *connection) close() {
  178. if cn.c == nil {
  179. return
  180. }
  181. cn.c.Close()
  182. cn.c = nil // ensure we generate ErrBadConn when cn gets reused
  183. cn.alive = false
  184. // Check if the entire pool is dead
  185. for _, cn := range cn.pool.connections {
  186. if cn.alive {
  187. return
  188. }
  189. }
  190. cn.pool.dead = false
  191. }
  192. // explicitly send a request as uncompressed
  193. // This is only really needed for the "startup" handshake
  194. func (cn *connection) sendUncompressed(opcode byte, body []byte) error {
  195. return cn._send(opcode, body, false)
  196. }
  197. func (cn *connection) send(opcode byte, body []byte) error {
  198. return cn._send(opcode, body, cn.pool.compression == "snappy" && len(body) > 0)
  199. }
  200. func (cn *connection) _send(opcode byte, body []byte, compression bool) error {
  201. if cn.c == nil {
  202. return driver.ErrBadConn
  203. }
  204. var flags byte = 0x00
  205. if compression && len(body) > 1 {
  206. var err error
  207. body, err = snappy.Encode(nil, body)
  208. if err != nil {
  209. return err
  210. }
  211. flags = flagCompressed
  212. }
  213. frame := make([]byte, len(body)+8)
  214. frame[0] = protoRequest
  215. frame[1] = flags
  216. frame[2] = 0
  217. frame[3] = opcode
  218. binary.BigEndian.PutUint32(frame[4:8], uint32(len(body)))
  219. copy(frame[8:], body)
  220. if _, err := cn.c.Write(frame); err != nil {
  221. return err
  222. }
  223. return nil
  224. }
  225. func (cn *connection) recv() (byte, []byte, error) {
  226. if cn.c == nil {
  227. return 0, nil, driver.ErrBadConn
  228. }
  229. header := make([]byte, 8)
  230. if _, err := io.ReadFull(cn.c, header); err != nil {
  231. cn.close() // better assume that the connection is broken (may have read some bytes)
  232. return 0, nil, err
  233. }
  234. // verify that the frame starts with version==1 and req/resp flag==response
  235. // this may be overly conservative in that future versions may be backwards compatible
  236. // in that case simply amend the check...
  237. if header[0] != protoResponse {
  238. cn.close()
  239. return 0, nil, fmt.Errorf("unsupported frame version or not a response: 0x%x (header=%v)", header[0], header)
  240. }
  241. // verify that the flags field has only a single flag set, again, this may
  242. // be overly conservative if additional flags are backwards-compatible
  243. if header[1] > 1 {
  244. cn.close()
  245. return 0, nil, fmt.Errorf("unsupported frame flags: 0x%x (header=%v)", header[1], header)
  246. }
  247. opcode := header[3]
  248. if opcode > opLAST {
  249. cn.close()
  250. return 0, nil, fmt.Errorf("unknown opcode: 0x%x (header=%v)", opcode, header)
  251. }
  252. length := binary.BigEndian.Uint32(header[4:8])
  253. var body []byte
  254. if length > 0 {
  255. if length > 256*1024*1024 { // spec says 256MB is max
  256. cn.close()
  257. return 0, nil, fmt.Errorf("frame too large: %d (header=%v)", length, header)
  258. }
  259. body = make([]byte, length)
  260. if _, err := io.ReadFull(cn.c, body); err != nil {
  261. cn.close() // better assume that the connection is broken
  262. return 0, nil, err
  263. }
  264. }
  265. if header[1]&flagCompressed != 0 && cn.pool.compression == "snappy" {
  266. var err error
  267. body, err = snappy.Decode(nil, body)
  268. if err != nil {
  269. cn.close()
  270. return 0, nil, err
  271. }
  272. }
  273. if opcode == opError {
  274. code := binary.BigEndian.Uint32(body[0:4])
  275. msglen := binary.BigEndian.Uint16(body[4:6])
  276. msg := string(body[6 : 6+msglen])
  277. return opcode, body, Error{Code: int(code), Msg: msg}
  278. }
  279. return opcode, body, nil
  280. }
  281. func (p *pool) conn() (*connection, error) {
  282. if p.dead {
  283. return nil, driver.ErrBadConn
  284. }
  285. totalConnections := len(p.connections)
  286. start := p.i + 1 // make sure that we start from the next position in the ring
  287. for i := 0; i < totalConnections; i++ {
  288. idx := (i + start) % totalConnections
  289. cn := p.connections[idx]
  290. if cn.alive {
  291. p.i = idx // set the new 'i' so the ring will start again in the right place
  292. return cn, nil
  293. }
  294. }
  295. // we've exhausted the pool, gonna have a bad time
  296. p.dead = true
  297. return nil, driver.ErrBadConn
  298. }
  299. func (p *pool) join() {
  300. p.reconnect()
  301. // Every 1 second, we want to try reconnecting to disconnected nodes
  302. go func() {
  303. for {
  304. select {
  305. case <-p.stop:
  306. return
  307. default:
  308. p.reconnect()
  309. time.Sleep(time.Second)
  310. }
  311. }
  312. }()
  313. }
  314. func (p *pool) reconnect() {
  315. for _, cn := range p.connections {
  316. if !cn.alive {
  317. cn.open()
  318. }
  319. }
  320. }
  321. func (p *pool) Begin() (driver.Tx, error) {
  322. if p.dead {
  323. return nil, driver.ErrBadConn
  324. }
  325. return p, nil
  326. }
  327. func (p *pool) Commit() error {
  328. if p.dead {
  329. return driver.ErrBadConn
  330. }
  331. return nil
  332. }
  333. func (p *pool) Close() error {
  334. if p.dead {
  335. return driver.ErrBadConn
  336. }
  337. for _, cn := range p.connections {
  338. cn.close()
  339. }
  340. p.stop <- struct{}{}
  341. p.dead = true
  342. return nil
  343. }
  344. func (p *pool) Rollback() error {
  345. if p.dead {
  346. return driver.ErrBadConn
  347. }
  348. return nil
  349. }
  350. func (p *pool) Prepare(query string) (driver.Stmt, error) {
  351. // Explicitly check if the query is a "USE <keyspace>"
  352. // Since it needs to be special cased and run on each server
  353. if strings.HasPrefix(query, keyspaceQuery) {
  354. keyspace := query[len(keyspaceQuery):]
  355. p.UseKeyspace(keyspace)
  356. return &statement{}, nil
  357. }
  358. for {
  359. cn, err := p.conn()
  360. if err != nil {
  361. return nil, err
  362. }
  363. st, err := cn.Prepare(query)
  364. if err == io.EOF {
  365. // the cn has gotten marked as dead already
  366. if p.dead {
  367. // The entire pool is dead, so we bubble up the ErrBadConn
  368. return nil, driver.ErrBadConn
  369. } else {
  370. continue // Retry request on another cn
  371. }
  372. }
  373. return st, err
  374. }
  375. }
  376. func (p *pool) UseKeyspace(keyspace string) {
  377. p.keyspace = keyspace
  378. for _, cn := range p.connections {
  379. cn.UseKeyspace(keyspace)
  380. }
  381. }
  382. func (cn *connection) UseKeyspace(keyspace string) error {
  383. st, err := cn.Prepare(keyspaceQuery + keyspace)
  384. if err != nil {
  385. return err
  386. }
  387. if _, err = st.Exec([]driver.Value{}); err != nil {
  388. return err
  389. }
  390. return nil
  391. }
  392. func (cn *connection) Prepare(query string) (driver.Stmt, error) {
  393. body := make([]byte, len(query)+4)
  394. binary.BigEndian.PutUint32(body[0:4], uint32(len(query)))
  395. copy(body[4:], []byte(query))
  396. if err := cn.send(opPrepare, body); err != nil {
  397. return nil, err
  398. }
  399. opcode, body, err := cn.recv()
  400. if err != nil {
  401. return nil, err
  402. }
  403. if opcode != opResult || binary.BigEndian.Uint32(body) != 4 {
  404. return nil, fmt.Errorf("expected prepared result")
  405. }
  406. n := int(binary.BigEndian.Uint16(body[4:]))
  407. prepared := body[6 : 6+n]
  408. columns, meta, _ := parseMeta(body[6+n:])
  409. return &statement{cn: cn, query: query,
  410. prepared: prepared, columns: columns, meta: meta}, nil
  411. }
  412. type statement struct {
  413. cn *connection
  414. query string
  415. prepared []byte
  416. columns []string
  417. meta []uint16
  418. }
  419. func (s *statement) Close() error {
  420. return nil
  421. }
  422. func (st *statement) ColumnConverter(idx int) driver.ValueConverter {
  423. return (&columnEncoder{st.meta}).ColumnConverter(idx)
  424. }
  425. func (st *statement) NumInput() int {
  426. return len(st.columns)
  427. }
  428. func parseMeta(body []byte) ([]string, []uint16, int) {
  429. flags := binary.BigEndian.Uint32(body)
  430. globalTableSpec := flags&1 == 1
  431. columnCount := int(binary.BigEndian.Uint32(body[4:]))
  432. i := 8
  433. if globalTableSpec {
  434. l := int(binary.BigEndian.Uint16(body[i:]))
  435. keyspace := string(body[i+2 : i+2+l])
  436. i += 2 + l
  437. l = int(binary.BigEndian.Uint16(body[i:]))
  438. tablename := string(body[i+2 : i+2+l])
  439. i += 2 + l
  440. _, _ = keyspace, tablename
  441. }
  442. columns := make([]string, columnCount)
  443. meta := make([]uint16, columnCount)
  444. for c := 0; c < columnCount; c++ {
  445. l := int(binary.BigEndian.Uint16(body[i:]))
  446. columns[c] = string(body[i+2 : i+2+l])
  447. i += 2 + l
  448. meta[c] = binary.BigEndian.Uint16(body[i:])
  449. i += 2
  450. }
  451. return columns, meta, i
  452. }
  453. func (st *statement) exec(v []driver.Value) error {
  454. sz := 6 + len(st.prepared)
  455. for i := range v {
  456. if b, ok := v[i].([]byte); ok {
  457. sz += len(b) + 4
  458. }
  459. }
  460. body, p := make([]byte, sz), 4+len(st.prepared)
  461. binary.BigEndian.PutUint16(body, uint16(len(st.prepared)))
  462. copy(body[2:], st.prepared)
  463. binary.BigEndian.PutUint16(body[p-2:], uint16(len(v)))
  464. for i := range v {
  465. b, ok := v[i].([]byte)
  466. if !ok {
  467. return fmt.Errorf("unsupported type %T at column %d", v[i], i)
  468. }
  469. binary.BigEndian.PutUint32(body[p:], uint32(len(b)))
  470. copy(body[p+4:], b)
  471. p += 4 + len(b)
  472. }
  473. binary.BigEndian.PutUint16(body[p:], uint16(st.cn.pool.consistency))
  474. if err := st.cn.send(opExecute, body); err != nil {
  475. return err
  476. }
  477. return nil
  478. }
  479. func (st *statement) Exec(v []driver.Value) (driver.Result, error) {
  480. if st.cn == nil {
  481. return nil, nil
  482. }
  483. if err := st.exec(v); err != nil {
  484. return nil, err
  485. }
  486. opcode, body, err := st.cn.recv()
  487. if err != nil {
  488. return nil, err
  489. }
  490. _, _ = opcode, body
  491. return nil, nil
  492. }
  493. func (st *statement) Query(v []driver.Value) (driver.Rows, error) {
  494. if err := st.exec(v); err != nil {
  495. return nil, err
  496. }
  497. opcode, body, err := st.cn.recv()
  498. if err != nil {
  499. return nil, err
  500. }
  501. kind := binary.BigEndian.Uint32(body[0:4])
  502. if opcode != opResult || kind != 2 {
  503. return nil, fmt.Errorf("expected rows as result")
  504. }
  505. columns, meta, n := parseMeta(body[4:])
  506. i := n + 4
  507. rows := &rows{
  508. columns: columns,
  509. meta: meta,
  510. numRows: int(binary.BigEndian.Uint32(body[i:])),
  511. }
  512. i += 4
  513. rows.body = body[i:]
  514. return rows, nil
  515. }
  516. type rows struct {
  517. columns []string
  518. meta []uint16
  519. body []byte
  520. row int
  521. numRows int
  522. }
  523. func (r *rows) Close() error {
  524. return nil
  525. }
  526. func (r *rows) Columns() []string {
  527. return r.columns
  528. }
  529. func (r *rows) Next(values []driver.Value) error {
  530. if r.row >= r.numRows {
  531. return io.EOF
  532. }
  533. for column := 0; column < len(r.columns); column++ {
  534. n := int32(binary.BigEndian.Uint32(r.body))
  535. r.body = r.body[4:]
  536. if n >= 0 {
  537. values[column] = decode(r.body[:n], r.meta[column])
  538. r.body = r.body[n:]
  539. } else {
  540. values[column] = nil
  541. }
  542. }
  543. r.row++
  544. return nil
  545. }
  546. type Error struct {
  547. Code int
  548. Msg string
  549. }
  550. func (e Error) Error() string {
  551. return e.Msg
  552. }
  553. func init() {
  554. sql.Register("gocql", &drv{})
  555. }