gocql.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. // Copyright (c) 2012 by Christoph Hack <christoph@tux21b.org>
  2. // All rights reserved. Distributed under the Simplified BSD License.
  3. // The gocql package provides a database/sql driver for CQL, the Cassandra
  4. // query language.
  5. //
  6. // This package requires a recent version of Cassandra (≥ 1.2) that supports
  7. // CQL 3.0 and the new native protocol. The native protocol is still considered
  8. // beta and must be enabled manually in Cassandra 1.2 by setting
  9. // "start_native_transport" to true in conf/cassandra.yaml.
  10. //
  11. // Example Usage:
  12. //
  13. // db, err := sql.Open("gocql", "localhost:8000 keyspace=system")
  14. // // ...
  15. // rows, err := db.Query("SELECT keyspace_name FROM schema_keyspaces")
  16. // // ...
  17. // for rows.Next() {
  18. // var keyspace string
  19. // err = rows.Scan(&keyspace)
  20. // // ...
  21. // fmt.Println(keyspace)
  22. // }
  23. // if err := rows.Err(); err != nil {
  24. // // ...
  25. // }
  26. //
  27. package gocql
  28. import (
  29. "bytes"
  30. "code.google.com/p/snappy-go/snappy"
  31. "database/sql"
  32. "database/sql/driver"
  33. "encoding/binary"
  34. "fmt"
  35. "io"
  36. "math/rand"
  37. "net"
  38. "strings"
  39. )
  40. const (
  41. protoRequest byte = 0x01
  42. protoResponse byte = 0x81
  43. opError byte = 0x00
  44. opStartup byte = 0x01
  45. opReady byte = 0x02
  46. opAuthenticate byte = 0x03
  47. opCredentials byte = 0x04
  48. opOptions byte = 0x05
  49. opSupported byte = 0x06
  50. opQuery byte = 0x07
  51. opResult byte = 0x08
  52. opPrepare byte = 0x09
  53. opExecute byte = 0x0A
  54. flagCompressed byte = 0x01
  55. )
  56. var rnd = rand.New(rand.NewSource(0))
  57. type drv struct{}
  58. func (d drv) Open(name string) (driver.Conn, error) {
  59. return Open(name)
  60. }
  61. type connection struct {
  62. c net.Conn
  63. compression string
  64. }
  65. func Open(name string) (*connection, error) {
  66. parts := strings.Split(name, " ")
  67. address := ""
  68. if len(parts) >= 1 {
  69. addresses := strings.Split(parts[0], ",")
  70. if len(addresses) > 0 {
  71. address = addresses[rnd.Intn(len(addresses))]
  72. }
  73. }
  74. c, err := net.Dial("tcp", address)
  75. if err != nil {
  76. return nil, err
  77. }
  78. version := "3.0.0"
  79. keyspace := ""
  80. compression := ""
  81. for i := 1; i < len(parts); i++ {
  82. switch {
  83. case parts[i] == "":
  84. continue
  85. case strings.HasPrefix(parts[i], "keyspace="):
  86. keyspace = strings.TrimSpace(parts[i][9:])
  87. case strings.HasPrefix(parts[i], "compression="):
  88. compression = strings.TrimSpace(parts[i][12:])
  89. if compression != "snappy" {
  90. return nil, fmt.Errorf("unknown compression algorithm %q",
  91. compression)
  92. }
  93. case strings.HasPrefix(parts[i], "version="):
  94. compression = strings.TrimSpace(parts[i][8:])
  95. default:
  96. return nil, fmt.Errorf("unsupported option %q", parts[i])
  97. }
  98. }
  99. cn := &connection{c: c, compression: compression}
  100. b := &bytes.Buffer{}
  101. binary.Write(b, binary.BigEndian, uint16(len(version)))
  102. b.WriteString(version)
  103. if compression != "" {
  104. binary.Write(b, binary.BigEndian, uint16(1))
  105. binary.Write(b, binary.BigEndian, uint16(1))
  106. binary.Write(b, binary.BigEndian, uint16(len(compression)))
  107. b.WriteString(compression)
  108. } else {
  109. binary.Write(b, binary.BigEndian, uint16(0))
  110. }
  111. if err := cn.send(opStartup, b.Bytes()); err != nil {
  112. return nil, err
  113. }
  114. opcode, _, err := cn.recv()
  115. if err != nil {
  116. return nil, err
  117. }
  118. if opcode != opReady {
  119. return nil, fmt.Errorf("connection not ready")
  120. }
  121. if keyspace != "" {
  122. st, err := cn.Prepare(fmt.Sprintf("USE %s", keyspace))
  123. if err != nil {
  124. return nil, err
  125. }
  126. if _, err = st.Exec([]driver.Value{}); err != nil {
  127. return nil, err
  128. }
  129. }
  130. return cn, nil
  131. }
  132. func (cn *connection) send(opcode byte, body []byte) error {
  133. frame := make([]byte, len(body)+8)
  134. frame[0] = protoRequest
  135. frame[1] = 0
  136. frame[2] = 0
  137. frame[3] = opcode
  138. binary.BigEndian.PutUint32(frame[4:8], uint32(len(body)))
  139. copy(frame[8:], body)
  140. if _, err := cn.c.Write(frame); err != nil {
  141. return err
  142. }
  143. return nil
  144. }
  145. func (cn *connection) recv() (byte, []byte, error) {
  146. header := make([]byte, 8)
  147. if _, err := cn.c.Read(header); err != nil {
  148. return 0, nil, err
  149. }
  150. opcode := header[3]
  151. length := binary.BigEndian.Uint32(header[4:8])
  152. var body []byte
  153. if length > 0 {
  154. body = make([]byte, length)
  155. if _, err := cn.c.Read(body); err != nil {
  156. return 0, nil, err
  157. }
  158. }
  159. if header[1]&flagCompressed != 0 && cn.compression == "snappy" {
  160. var err error
  161. body, err = snappy.Decode(nil, body)
  162. if err != nil {
  163. return 0, nil, err
  164. }
  165. }
  166. if opcode == opError {
  167. code := binary.BigEndian.Uint32(body[0:4])
  168. msglen := binary.BigEndian.Uint16(body[4:6])
  169. msg := string(body[6 : 6+msglen])
  170. return opcode, body, Error{Code: int(code), Msg: msg}
  171. }
  172. return opcode, body, nil
  173. }
  174. func (cn *connection) Begin() (driver.Tx, error) {
  175. return cn, nil
  176. }
  177. func (cn *connection) Commit() error {
  178. return nil
  179. }
  180. func (cn *connection) Close() error {
  181. return cn.c.Close()
  182. }
  183. func (cn *connection) Rollback() error {
  184. return nil
  185. }
  186. func (cn *connection) Prepare(query string) (driver.Stmt, error) {
  187. body := make([]byte, len(query)+4)
  188. binary.BigEndian.PutUint32(body[0:4], uint32(len(query)))
  189. copy(body[4:], []byte(query))
  190. if err := cn.send(opPrepare, body); err != nil {
  191. return nil, err
  192. }
  193. opcode, body, err := cn.recv()
  194. if err != nil {
  195. return nil, err
  196. }
  197. if opcode != opResult || binary.BigEndian.Uint32(body) != 4 {
  198. return nil, fmt.Errorf("expected prepared result")
  199. }
  200. prepared := int(binary.BigEndian.Uint32(body[4:]))
  201. columns, meta, _ := parseMeta(body[8:])
  202. return &statement{cn: cn, query: query,
  203. prepared: prepared, columns: columns, meta: meta}, nil
  204. }
  205. type statement struct {
  206. cn *connection
  207. query string
  208. prepared int
  209. columns []string
  210. meta []uint16
  211. }
  212. func (s *statement) Close() error {
  213. return nil
  214. }
  215. func (st *statement) ColumnConverter(idx int) driver.ValueConverter {
  216. return (&columnEncoder{st.meta}).ColumnConverter(idx)
  217. }
  218. func (st *statement) NumInput() int {
  219. return len(st.columns)
  220. }
  221. func parseMeta(body []byte) ([]string, []uint16, int) {
  222. flags := binary.BigEndian.Uint32(body)
  223. globalTableSpec := flags&1 == 1
  224. columnCount := int(binary.BigEndian.Uint32(body[4:]))
  225. i := 8
  226. if globalTableSpec {
  227. l := int(binary.BigEndian.Uint16(body[i:]))
  228. keyspace := string(body[i+2 : i+2+l])
  229. i += 2 + l
  230. l = int(binary.BigEndian.Uint16(body[i:]))
  231. tablename := string(body[i+2 : i+2+l])
  232. i += 2 + l
  233. _, _ = keyspace, tablename
  234. }
  235. columns := make([]string, columnCount)
  236. meta := make([]uint16, columnCount)
  237. for c := 0; c < columnCount; c++ {
  238. l := int(binary.BigEndian.Uint16(body[i:]))
  239. columns[c] = string(body[i+2 : i+2+l])
  240. i += 2 + l
  241. meta[c] = binary.BigEndian.Uint16(body[i:])
  242. i += 2
  243. }
  244. return columns, meta, i
  245. }
  246. func (st *statement) exec(v []driver.Value) error {
  247. sz := 8
  248. for i := range v {
  249. if b, ok := v[i].([]byte); ok {
  250. sz += len(b) + 4
  251. }
  252. }
  253. body, p := make([]byte, sz), 6
  254. binary.BigEndian.PutUint32(body, uint32(st.prepared))
  255. binary.BigEndian.PutUint16(body[4:], uint16(len(v)))
  256. for i := range v {
  257. b, ok := v[i].([]byte)
  258. if !ok {
  259. return fmt.Errorf("unsupported type %T at column %d", v[i], i)
  260. }
  261. binary.BigEndian.PutUint32(body[p:], uint32(len(b)))
  262. copy(body[p+4:], b)
  263. p += 4 + len(b)
  264. }
  265. if err := st.cn.send(opExecute, body); err != nil {
  266. return err
  267. }
  268. return nil
  269. }
  270. func (st *statement) Exec(v []driver.Value) (driver.Result, error) {
  271. if err := st.exec(v); err != nil {
  272. return nil, err
  273. }
  274. opcode, body, err := st.cn.recv()
  275. if err != nil {
  276. return nil, err
  277. }
  278. _, _ = opcode, body
  279. return nil, nil
  280. }
  281. func (st *statement) Query(v []driver.Value) (driver.Rows, error) {
  282. if err := st.exec(v); err != nil {
  283. return nil, err
  284. }
  285. opcode, body, err := st.cn.recv()
  286. if err != nil {
  287. return nil, err
  288. }
  289. kind := binary.BigEndian.Uint32(body[0:4])
  290. if opcode != opResult || kind != 2 {
  291. return nil, fmt.Errorf("expected rows as result")
  292. }
  293. columns, meta, n := parseMeta(body[4:])
  294. i := n + 4
  295. rows := &rows{
  296. columns: columns,
  297. meta: meta,
  298. numRows: int(binary.BigEndian.Uint32(body[i:])),
  299. }
  300. i += 4
  301. rows.body = body[i:]
  302. return rows, nil
  303. }
  304. type rows struct {
  305. columns []string
  306. meta []uint16
  307. body []byte
  308. row int
  309. numRows int
  310. }
  311. func (r *rows) Close() error {
  312. return nil
  313. }
  314. func (r *rows) Columns() []string {
  315. return r.columns
  316. }
  317. func (r *rows) Next(values []driver.Value) error {
  318. if r.row >= r.numRows {
  319. return io.EOF
  320. }
  321. for column := 0; column < len(r.columns); column++ {
  322. n := int(binary.BigEndian.Uint32(r.body))
  323. r.body = r.body[4:]
  324. if n >= 0 {
  325. values[column] = decode(r.body[:n], r.meta[column])
  326. r.body = r.body[n:]
  327. } else {
  328. fmt.Println(column, n)
  329. values[column] = nil
  330. }
  331. }
  332. r.row++
  333. return nil
  334. }
  335. type Error struct {
  336. Code int
  337. Msg string
  338. }
  339. func (e Error) Error() string {
  340. return e.Msg
  341. }
  342. func init() {
  343. sql.Register("gocql", &drv{})
  344. }