gocql.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. // Copyright (c) 2012 The gocql Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // The gocql package provides a database/sql driver for CQL, the Cassandra
  5. // query language.
  6. //
  7. // This package requires a recent version of Cassandra (≥ 1.2) that supports
  8. // CQL 3.0 and the new native protocol. The native protocol is still considered
  9. // beta and must be enabled manually in Cassandra 1.2 by setting
  10. // "start_native_transport" to true in conf/cassandra.yaml.
  11. //
  12. // Example Usage:
  13. //
  14. // db, err := sql.Open("gocql", "localhost:9042 keyspace=system")
  15. // // ...
  16. // rows, err := db.Query("SELECT keyspace_name FROM schema_keyspaces")
  17. // // ...
  18. // for rows.Next() {
  19. // var keyspace string
  20. // err = rows.Scan(&keyspace)
  21. // // ...
  22. // fmt.Println(keyspace)
  23. // }
  24. // if err := rows.Err(); err != nil {
  25. // // ...
  26. // }
  27. //
  28. package gocql
  29. import (
  30. "bytes"
  31. "code.google.com/p/snappy-go/snappy"
  32. "database/sql"
  33. "database/sql/driver"
  34. "encoding/binary"
  35. "fmt"
  36. "io"
  37. "math/rand"
  38. "net"
  39. "strings"
  40. )
  41. const (
  42. protoRequest byte = 0x01
  43. protoResponse byte = 0x81
  44. opError byte = 0x00
  45. opStartup byte = 0x01
  46. opReady byte = 0x02
  47. opAuthenticate byte = 0x03
  48. opCredentials byte = 0x04
  49. opOptions byte = 0x05
  50. opSupported byte = 0x06
  51. opQuery byte = 0x07
  52. opResult byte = 0x08
  53. opPrepare byte = 0x09
  54. opExecute byte = 0x0A
  55. flagCompressed byte = 0x01
  56. keyVersion string = "CQL_VERSION"
  57. keyCompression string = "COMPRESSION"
  58. )
  59. var rnd = rand.New(rand.NewSource(0))
  60. type drv struct{}
  61. func (d drv) Open(name string) (driver.Conn, error) {
  62. return Open(name)
  63. }
  64. type connection struct {
  65. c net.Conn
  66. compression string
  67. }
  68. func Open(name string) (*connection, error) {
  69. parts := strings.Split(name, " ")
  70. address := ""
  71. if len(parts) >= 1 {
  72. addresses := strings.Split(parts[0], ",")
  73. if len(addresses) > 0 {
  74. address = addresses[rnd.Intn(len(addresses))]
  75. }
  76. }
  77. c, err := net.Dial("tcp", address)
  78. if err != nil {
  79. return nil, err
  80. }
  81. version := "3.0.0"
  82. keyspace := ""
  83. compression := ""
  84. for i := 1; i < len(parts); i++ {
  85. switch {
  86. case parts[i] == "":
  87. continue
  88. case strings.HasPrefix(parts[i], "keyspace="):
  89. keyspace = strings.TrimSpace(parts[i][9:])
  90. case strings.HasPrefix(parts[i], "compression="):
  91. compression = strings.TrimSpace(parts[i][12:])
  92. if compression != "snappy" {
  93. return nil, fmt.Errorf("unknown compression algorithm %q",
  94. compression)
  95. }
  96. case strings.HasPrefix(parts[i], "version="):
  97. compression = strings.TrimSpace(parts[i][8:])
  98. default:
  99. return nil, fmt.Errorf("unsupported option %q", parts[i])
  100. }
  101. }
  102. cn := &connection{c: c, compression: compression}
  103. b := &bytes.Buffer{}
  104. if compression != "" {
  105. binary.Write(b, binary.BigEndian, uint16(2))
  106. } else {
  107. binary.Write(b, binary.BigEndian, uint16(1))
  108. }
  109. binary.Write(b, binary.BigEndian, uint16(len(keyVersion)))
  110. b.WriteString(keyVersion)
  111. binary.Write(b, binary.BigEndian, uint16(len(version)))
  112. b.WriteString(version)
  113. if compression != "" {
  114. binary.Write(b, binary.BigEndian, uint16(len(keyCompression)))
  115. b.WriteString(keyCompression)
  116. binary.Write(b, binary.BigEndian, uint16(len(compression)))
  117. b.WriteString(compression)
  118. }
  119. if err := cn.send(opStartup, b.Bytes()); err != nil {
  120. return nil, err
  121. }
  122. opcode, _, err := cn.recv()
  123. if err != nil {
  124. return nil, err
  125. }
  126. if opcode != opReady {
  127. return nil, fmt.Errorf("connection not ready")
  128. }
  129. if keyspace != "" {
  130. st, err := cn.Prepare(fmt.Sprintf("USE %s", keyspace))
  131. if err != nil {
  132. return nil, err
  133. }
  134. if _, err = st.Exec([]driver.Value{}); err != nil {
  135. return nil, err
  136. }
  137. }
  138. return cn, nil
  139. }
  140. func (cn *connection) send(opcode byte, body []byte) error {
  141. frame := make([]byte, len(body)+8)
  142. frame[0] = protoRequest
  143. frame[1] = 0
  144. frame[2] = 0
  145. frame[3] = opcode
  146. binary.BigEndian.PutUint32(frame[4:8], uint32(len(body)))
  147. copy(frame[8:], body)
  148. if _, err := cn.c.Write(frame); err != nil {
  149. return err
  150. }
  151. return nil
  152. }
  153. func (cn *connection) recv() (byte, []byte, error) {
  154. header := make([]byte, 8)
  155. if _, err := cn.c.Read(header); err != nil {
  156. return 0, nil, err
  157. }
  158. opcode := header[3]
  159. length := binary.BigEndian.Uint32(header[4:8])
  160. var body []byte
  161. if length > 0 {
  162. body = make([]byte, length)
  163. if _, err := cn.c.Read(body); err != nil {
  164. return 0, nil, err
  165. }
  166. }
  167. if header[1]&flagCompressed != 0 && cn.compression == "snappy" {
  168. var err error
  169. body, err = snappy.Decode(nil, body)
  170. if err != nil {
  171. return 0, nil, err
  172. }
  173. }
  174. if opcode == opError {
  175. code := binary.BigEndian.Uint32(body[0:4])
  176. msglen := binary.BigEndian.Uint16(body[4:6])
  177. msg := string(body[6 : 6+msglen])
  178. return opcode, body, Error{Code: int(code), Msg: msg}
  179. }
  180. return opcode, body, nil
  181. }
  182. func (cn *connection) Begin() (driver.Tx, error) {
  183. return cn, nil
  184. }
  185. func (cn *connection) Commit() error {
  186. return nil
  187. }
  188. func (cn *connection) Close() error {
  189. return cn.c.Close()
  190. }
  191. func (cn *connection) Rollback() error {
  192. return nil
  193. }
  194. func (cn *connection) Prepare(query string) (driver.Stmt, error) {
  195. body := make([]byte, len(query)+4)
  196. binary.BigEndian.PutUint32(body[0:4], uint32(len(query)))
  197. copy(body[4:], []byte(query))
  198. if err := cn.send(opPrepare, body); err != nil {
  199. return nil, err
  200. }
  201. opcode, body, err := cn.recv()
  202. if err != nil {
  203. return nil, err
  204. }
  205. if opcode != opResult || binary.BigEndian.Uint32(body) != 4 {
  206. return nil, fmt.Errorf("expected prepared result")
  207. }
  208. n := int(binary.BigEndian.Uint16(body[4:]))
  209. prepared := body[6 : 6+n]
  210. columns, meta, _ := parseMeta(body[6+n:])
  211. return &statement{cn: cn, query: query,
  212. prepared: prepared, columns: columns, meta: meta}, nil
  213. }
  214. type statement struct {
  215. cn *connection
  216. query string
  217. prepared []byte
  218. columns []string
  219. meta []uint16
  220. }
  221. func (s *statement) Close() error {
  222. return nil
  223. }
  224. func (st *statement) ColumnConverter(idx int) driver.ValueConverter {
  225. return (&columnEncoder{st.meta}).ColumnConverter(idx)
  226. }
  227. func (st *statement) NumInput() int {
  228. return len(st.columns)
  229. }
  230. func parseMeta(body []byte) ([]string, []uint16, int) {
  231. flags := binary.BigEndian.Uint32(body)
  232. globalTableSpec := flags&1 == 1
  233. columnCount := int(binary.BigEndian.Uint32(body[4:]))
  234. i := 8
  235. if globalTableSpec {
  236. l := int(binary.BigEndian.Uint16(body[i:]))
  237. keyspace := string(body[i+2 : i+2+l])
  238. i += 2 + l
  239. l = int(binary.BigEndian.Uint16(body[i:]))
  240. tablename := string(body[i+2 : i+2+l])
  241. i += 2 + l
  242. _, _ = keyspace, tablename
  243. }
  244. columns := make([]string, columnCount)
  245. meta := make([]uint16, columnCount)
  246. for c := 0; c < columnCount; c++ {
  247. l := int(binary.BigEndian.Uint16(body[i:]))
  248. columns[c] = string(body[i+2 : i+2+l])
  249. i += 2 + l
  250. meta[c] = binary.BigEndian.Uint16(body[i:])
  251. i += 2
  252. }
  253. return columns, meta, i
  254. }
  255. func (st *statement) exec(v []driver.Value) error {
  256. sz := 6 + len(st.prepared)
  257. for i := range v {
  258. if b, ok := v[i].([]byte); ok {
  259. sz += len(b) + 4
  260. }
  261. }
  262. body, p := make([]byte, sz), 4+len(st.prepared)
  263. binary.BigEndian.PutUint16(body, uint16(len(st.prepared)))
  264. copy(body[2:], st.prepared)
  265. binary.BigEndian.PutUint16(body[p-2:], uint16(len(v)))
  266. for i := range v {
  267. b, ok := v[i].([]byte)
  268. if !ok {
  269. return fmt.Errorf("unsupported type %T at column %d", v[i], i)
  270. }
  271. binary.BigEndian.PutUint32(body[p:], uint32(len(b)))
  272. copy(body[p+4:], b)
  273. p += 4 + len(b)
  274. }
  275. if err := st.cn.send(opExecute, body); err != nil {
  276. return err
  277. }
  278. return nil
  279. }
  280. func (st *statement) Exec(v []driver.Value) (driver.Result, error) {
  281. if err := st.exec(v); err != nil {
  282. return nil, err
  283. }
  284. opcode, body, err := st.cn.recv()
  285. if err != nil {
  286. return nil, err
  287. }
  288. _, _ = opcode, body
  289. return nil, nil
  290. }
  291. func (st *statement) Query(v []driver.Value) (driver.Rows, error) {
  292. if err := st.exec(v); err != nil {
  293. return nil, err
  294. }
  295. opcode, body, err := st.cn.recv()
  296. if err != nil {
  297. return nil, err
  298. }
  299. kind := binary.BigEndian.Uint32(body[0:4])
  300. if opcode != opResult || kind != 2 {
  301. return nil, fmt.Errorf("expected rows as result")
  302. }
  303. columns, meta, n := parseMeta(body[4:])
  304. i := n + 4
  305. rows := &rows{
  306. columns: columns,
  307. meta: meta,
  308. numRows: int(binary.BigEndian.Uint32(body[i:])),
  309. }
  310. i += 4
  311. rows.body = body[i:]
  312. return rows, nil
  313. }
  314. type rows struct {
  315. columns []string
  316. meta []uint16
  317. body []byte
  318. row int
  319. numRows int
  320. }
  321. func (r *rows) Close() error {
  322. return nil
  323. }
  324. func (r *rows) Columns() []string {
  325. return r.columns
  326. }
  327. func (r *rows) Next(values []driver.Value) error {
  328. if r.row >= r.numRows {
  329. return io.EOF
  330. }
  331. for column := 0; column < len(r.columns); column++ {
  332. n := int(binary.BigEndian.Uint32(r.body))
  333. r.body = r.body[4:]
  334. if n >= 0 {
  335. values[column] = decode(r.body[:n], r.meta[column])
  336. r.body = r.body[n:]
  337. } else {
  338. values[column] = nil
  339. }
  340. }
  341. r.row++
  342. return nil
  343. }
  344. type Error struct {
  345. Code int
  346. Msg string
  347. }
  348. func (e Error) Error() string {
  349. return e.Msg
  350. }
  351. func init() {
  352. sql.Register("gocql", &drv{})
  353. }