conn.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721
  1. // Copyright (c) 2012 The gocql Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package gocql
  5. import (
  6. "bufio"
  7. "crypto/tls"
  8. "errors"
  9. "fmt"
  10. "log"
  11. "net"
  12. "strconv"
  13. "strings"
  14. "sync"
  15. "sync/atomic"
  16. "time"
  17. )
  18. const (
  19. defaultFrameSize = 4096
  20. flagResponse = 0x80
  21. maskVersion = 0x7F
  22. )
  23. //JoinHostPort is a utility to return a address string that can be used
  24. //gocql.Conn to form a connection with a host.
  25. func JoinHostPort(addr string, port int) string {
  26. addr = strings.TrimSpace(addr)
  27. if _, _, err := net.SplitHostPort(addr); err != nil {
  28. addr = net.JoinHostPort(addr, strconv.Itoa(port))
  29. }
  30. return addr
  31. }
  32. type Authenticator interface {
  33. Challenge(req []byte) (resp []byte, auth Authenticator, err error)
  34. Success(data []byte) error
  35. }
  36. type PasswordAuthenticator struct {
  37. Username string
  38. Password string
  39. }
  40. func (p PasswordAuthenticator) Challenge(req []byte) ([]byte, Authenticator, error) {
  41. if string(req) != "org.apache.cassandra.auth.PasswordAuthenticator" {
  42. return nil, nil, fmt.Errorf("unexpected authenticator %q", req)
  43. }
  44. resp := make([]byte, 2+len(p.Username)+len(p.Password))
  45. resp[0] = 0
  46. copy(resp[1:], p.Username)
  47. resp[len(p.Username)+1] = 0
  48. copy(resp[2+len(p.Username):], p.Password)
  49. return resp, nil, nil
  50. }
  51. func (p PasswordAuthenticator) Success(data []byte) error {
  52. return nil
  53. }
  54. type SslOptions struct {
  55. CertPath string
  56. KeyPath string
  57. CaPath string //optional depending on server config
  58. // If you want to verify the hostname and server cert (like a wildcard for cass cluster) then you should turn this on
  59. // This option is basically the inverse of InSecureSkipVerify
  60. // See InSecureSkipVerify in http://golang.org/pkg/crypto/tls/ for more info
  61. EnableHostVerification bool
  62. }
  63. type ConnConfig struct {
  64. ProtoVersion int
  65. CQLVersion string
  66. Timeout time.Duration
  67. NumStreams int
  68. Compressor Compressor
  69. Authenticator Authenticator
  70. Keepalive time.Duration
  71. tlsConfig *tls.Config
  72. }
  73. // Conn is a single connection to a Cassandra node. It can be used to execute
  74. // queries, but users are usually advised to use a more reliable, higher
  75. // level API.
  76. type Conn struct {
  77. conn net.Conn
  78. r *bufio.Reader
  79. timeout time.Duration
  80. headerBuf []byte
  81. uniq chan int
  82. calls []callReq
  83. nwait int32
  84. pool ConnectionPool
  85. compressor Compressor
  86. auth Authenticator
  87. addr string
  88. version uint8
  89. currentKeyspace string
  90. started bool
  91. closedMu sync.RWMutex
  92. isClosed bool
  93. }
  94. // Connect establishes a connection to a Cassandra node.
  95. // You must also call the Serve method before you can execute any queries.
  96. func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
  97. var (
  98. err error
  99. conn net.Conn
  100. )
  101. if cfg.tlsConfig != nil {
  102. // the TLS config is safe to be reused by connections but it must not
  103. // be modified after being used.
  104. if conn, err = tls.Dial("tcp", addr, cfg.tlsConfig); err != nil {
  105. return nil, err
  106. }
  107. } else if conn, err = net.DialTimeout("tcp", addr, cfg.Timeout); err != nil {
  108. return nil, err
  109. }
  110. // going to default to proto 2
  111. if cfg.ProtoVersion < protoVersion1 || cfg.ProtoVersion > protoVersion3 {
  112. log.Printf("unsupported protocol version: %d using 2\n", cfg.ProtoVersion)
  113. cfg.ProtoVersion = 2
  114. }
  115. headerSize := 8
  116. maxStreams := 128
  117. if cfg.ProtoVersion > protoVersion2 {
  118. maxStreams = 32768
  119. headerSize = 9
  120. }
  121. if cfg.NumStreams <= 0 || cfg.NumStreams > maxStreams {
  122. cfg.NumStreams = maxStreams
  123. }
  124. c := &Conn{
  125. conn: conn,
  126. r: bufio.NewReader(conn),
  127. uniq: make(chan int, cfg.NumStreams),
  128. calls: make([]callReq, cfg.NumStreams),
  129. timeout: cfg.Timeout,
  130. version: uint8(cfg.ProtoVersion),
  131. addr: conn.RemoteAddr().String(),
  132. pool: pool,
  133. compressor: cfg.Compressor,
  134. auth: cfg.Authenticator,
  135. headerBuf: make([]byte, headerSize),
  136. }
  137. if cfg.Keepalive > 0 {
  138. c.setKeepalive(cfg.Keepalive)
  139. }
  140. for i := 0; i < cfg.NumStreams; i++ {
  141. c.uniq <- i
  142. }
  143. go c.serve()
  144. if err := c.startup(&cfg); err != nil {
  145. conn.Close()
  146. return nil, err
  147. }
  148. c.started = true
  149. return c, nil
  150. }
  151. func (c *Conn) Write(p []byte) (int, error) {
  152. c.conn.SetWriteDeadline(time.Now().Add(c.timeout))
  153. return c.conn.Write(p)
  154. }
  155. func (c *Conn) Read(p []byte) (int, error) {
  156. return c.r.Read(p)
  157. }
  158. func (c *Conn) startup(cfg *ConnConfig) error {
  159. m := map[string]string{
  160. "CQL_VERSION": cfg.CQLVersion,
  161. }
  162. if c.compressor != nil {
  163. m["COMPRESSION"] = c.compressor.Name()
  164. }
  165. frame, err := c.exec(&writeStartupFrame{opts: m}, nil)
  166. if err != nil {
  167. return err
  168. }
  169. switch v := frame.(type) {
  170. case error:
  171. return v
  172. case *readyFrame:
  173. return nil
  174. case *authenticateFrame:
  175. return c.authenticateHandshake(v)
  176. default:
  177. return NewErrProtocol("Unknown type of response to startup frame: %s", v)
  178. }
  179. }
  180. func (c *Conn) authenticateHandshake(authFrame *authenticateFrame) error {
  181. if c.auth == nil {
  182. return fmt.Errorf("authentication required (using %q)", authFrame.class)
  183. }
  184. resp, challenger, err := c.auth.Challenge([]byte(authFrame.class))
  185. if err != nil {
  186. return err
  187. }
  188. req := &writeAuthResponseFrame{data: resp}
  189. for {
  190. frame, err := c.exec(req, nil)
  191. if err != nil {
  192. return err
  193. }
  194. switch v := frame.(type) {
  195. case error:
  196. return v
  197. case authSuccessFrame:
  198. if challenger != nil {
  199. return challenger.Success(v.data)
  200. }
  201. return nil
  202. case authChallengeFrame:
  203. resp, challenger, err = challenger.Challenge(v.data)
  204. if err != nil {
  205. return err
  206. }
  207. req = &writeAuthResponseFrame{
  208. data: resp,
  209. }
  210. }
  211. }
  212. }
  213. // Serve starts the stream multiplexer for this connection, which is required
  214. // to execute any queries. This method runs as long as the connection is
  215. // open and is therefore usually called in a separate goroutine.
  216. func (c *Conn) serve() {
  217. var (
  218. err error
  219. header frameHeader
  220. )
  221. for {
  222. header, err = c.recv()
  223. if err != nil {
  224. break
  225. }
  226. c.dispatch(header)
  227. }
  228. c.Close()
  229. for id := 0; id < len(c.calls); id++ {
  230. req := &c.calls[id]
  231. if atomic.CompareAndSwapInt32(&req.active, 1, 0) {
  232. req.resp <- struct{}{}
  233. close(req.resp)
  234. }
  235. }
  236. if c.started {
  237. c.pool.HandleError(c, err, true)
  238. }
  239. }
  240. func (c *Conn) recv() (frameHeader, error) {
  241. // not safe for concurrent reads
  242. // read a full header, ignore timeouts, as this is being ran in a loop
  243. // TODO: TCP level deadlines? or just query level deadlines?
  244. // were just reading headers over and over and copy bodies
  245. head, err := readHeader(c.r, c.headerBuf)
  246. if err != nil {
  247. return frameHeader{}, err
  248. }
  249. call := &c.calls[head.stream]
  250. call.mu.Lock()
  251. log.Printf("readframe stream=%v\n", head.stream)
  252. err = call.framer.readFrame(&head)
  253. call.mu.Unlock()
  254. if err != nil {
  255. return frameHeader{}, err
  256. }
  257. if head.version.version() != c.version {
  258. return frameHeader{}, NewErrProtocol("unexpected protocol version in response: got %d expected %d", head.version.version(), c.version)
  259. }
  260. return head, nil
  261. }
  262. func (c *Conn) dispatch(header frameHeader) {
  263. id := header.stream
  264. if id >= len(c.calls) {
  265. // should this panic?
  266. return
  267. }
  268. // TODO: replace this with a sparse map
  269. call := &c.calls[id]
  270. call.resp <- struct{}{}
  271. atomic.AddInt32(&c.nwait, -1)
  272. c.uniq <- id
  273. }
  274. type callReq struct {
  275. active int32
  276. // could use a waitgroup but this allows us to do timeouts on the read/send
  277. resp chan struct{}
  278. mu sync.Mutex
  279. framer *framer
  280. err error
  281. }
  282. type callResp struct {
  283. framer *framer
  284. err error
  285. }
  286. func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
  287. // TODO: move tracer onto conn
  288. stream := <-c.uniq
  289. call := &c.calls[stream]
  290. atomic.StoreInt32(&call.active, 1)
  291. defer atomic.StoreInt32(&call.active, 0)
  292. // resp is basically a waiting semafore protecting the framer
  293. call.resp = make(chan struct{}, 1)
  294. // log.Printf("%v: OUT stream=%d (%T) req=%v\n", c.conn.LocalAddr(), stream, req, req)
  295. framer := newFramer(c, c, c.compressor, c.version)
  296. defer framerPool.Put(framer)
  297. call.framer = framer
  298. if tracer != nil {
  299. framer.trace()
  300. }
  301. log.Printf("writing frame stream=%v\n", stream)
  302. call.mu.Lock()
  303. err := req.writeFrame(framer, stream)
  304. call.mu.Unlock()
  305. if err != nil {
  306. return nil, err
  307. }
  308. <-call.resp
  309. if call.err != nil {
  310. return nil, call.err
  311. }
  312. frame, err := framer.parseFrame()
  313. if err != nil {
  314. return nil, err
  315. }
  316. if len(framer.traceID) > 0 {
  317. tracer.Trace(framer.traceID)
  318. }
  319. // log.Printf("%v: IN stream=%d (%T) resp=%v\n", c.conn.LocalAddr(), stream, frame, frame)
  320. return frame, nil
  321. }
  322. func (c *Conn) prepareStatement(stmt string, trace Tracer) (*resultPreparedFrame, error) {
  323. stmtsLRU.Lock()
  324. if stmtsLRU.lru == nil {
  325. initStmtsLRU(defaultMaxPreparedStmts)
  326. }
  327. stmtCacheKey := c.addr + c.currentKeyspace + stmt
  328. if val, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
  329. stmtsLRU.Unlock()
  330. flight := val.(*inflightPrepare)
  331. flight.wg.Wait()
  332. return flight.info, flight.err
  333. }
  334. flight := new(inflightPrepare)
  335. flight.wg.Add(1)
  336. stmtsLRU.lru.Add(stmtCacheKey, flight)
  337. stmtsLRU.Unlock()
  338. prep := &writePrepareFrame{
  339. statement: stmt,
  340. }
  341. resp, err := c.exec(prep, trace)
  342. if err != nil {
  343. flight.err = err
  344. flight.wg.Done()
  345. return nil, err
  346. }
  347. switch x := resp.(type) {
  348. case *resultPreparedFrame:
  349. // log.Printf("prepared %q => %x\n", stmt, x.preparedID)
  350. flight.info = x
  351. case error:
  352. flight.err = x
  353. default:
  354. flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x)
  355. }
  356. flight.wg.Done()
  357. if flight.err != nil {
  358. stmtsLRU.Lock()
  359. stmtsLRU.lru.Remove(stmtCacheKey)
  360. stmtsLRU.Unlock()
  361. }
  362. return flight.info, flight.err
  363. }
  364. func (c *Conn) executeQuery(qry *Query) *Iter {
  365. params := queryParams{
  366. consistency: qry.cons,
  367. }
  368. // TODO: Add DefaultTimestamp, SerialConsistency
  369. if len(qry.pageState) > 0 {
  370. params.pagingState = qry.pageState
  371. }
  372. if qry.pageSize > 0 {
  373. params.pageSize = qry.pageSize
  374. }
  375. // log.Printf("%+#v\n", qry)
  376. var frame frameWriter
  377. if qry.shouldPrepare() {
  378. // Prepare all DML queries. Other queries can not be prepared.
  379. info, err := c.prepareStatement(qry.stmt, qry.trace)
  380. if err != nil {
  381. return &Iter{err: err}
  382. }
  383. var values []interface{}
  384. if qry.binding == nil {
  385. values = qry.values
  386. } else {
  387. binding := &QueryInfo{
  388. Id: info.preparedID,
  389. Args: info.reqMeta.columns,
  390. Rval: info.respMeta.columns,
  391. }
  392. values, err = qry.binding(binding)
  393. if err != nil {
  394. return &Iter{err: err}
  395. }
  396. }
  397. if len(values) != len(info.reqMeta.columns) {
  398. return &Iter{err: ErrQueryArgLength}
  399. }
  400. params.values = make([]queryValues, len(values))
  401. for i := 0; i < len(values); i++ {
  402. val, err := Marshal(info.reqMeta.columns[i].TypeInfo, values[i])
  403. if err != nil {
  404. return &Iter{err: err}
  405. }
  406. v := &params.values[i]
  407. v.value = val
  408. // TODO: handle query binding names
  409. }
  410. frame = &writeExecuteFrame{
  411. preparedID: info.preparedID,
  412. params: params,
  413. }
  414. } else {
  415. frame = &writeQueryFrame{
  416. statement: qry.stmt,
  417. params: params,
  418. }
  419. }
  420. resp, err := c.exec(frame, qry.trace)
  421. if err != nil {
  422. return &Iter{err: err}
  423. }
  424. // log.Printf("resp=%T\n", resp)
  425. switch x := resp.(type) {
  426. case *resultVoidFrame:
  427. return &Iter{}
  428. case *resultRowsFrame:
  429. iter := &Iter{
  430. columns: x.meta.columns,
  431. rows: x.rows,
  432. }
  433. // log.Printf("result meta=%v\n", x.meta)
  434. if len(x.meta.pagingState) > 0 {
  435. iter.next = &nextIter{
  436. qry: *qry,
  437. pos: int((1 - qry.prefetch) * float64(len(iter.rows))),
  438. }
  439. iter.next.qry.pageState = x.meta.pagingState
  440. if iter.next.pos < 1 {
  441. iter.next.pos = 1
  442. }
  443. }
  444. return iter
  445. case *resultKeyspaceFrame, *resultSchemaChangeFrame:
  446. return &Iter{}
  447. case RequestErrUnprepared:
  448. stmtsLRU.Lock()
  449. stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt
  450. if _, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
  451. stmtsLRU.lru.Remove(stmtCacheKey)
  452. stmtsLRU.Unlock()
  453. return c.executeQuery(qry)
  454. }
  455. stmtsLRU.Unlock()
  456. panic(x)
  457. return &Iter{err: x}
  458. case error:
  459. return &Iter{err: x}
  460. default:
  461. return &Iter{err: NewErrProtocol("Unknown type in response to execute query: %s", x)}
  462. }
  463. }
  464. func (c *Conn) Pick(qry *Query) *Conn {
  465. if c.Closed() {
  466. return nil
  467. }
  468. return c
  469. }
  470. func (c *Conn) Closed() bool {
  471. c.closedMu.RLock()
  472. closed := c.isClosed
  473. c.closedMu.RUnlock()
  474. return closed
  475. }
  476. func (c *Conn) Close() {
  477. c.closedMu.Lock()
  478. if c.isClosed {
  479. c.closedMu.Unlock()
  480. return
  481. }
  482. c.isClosed = true
  483. c.closedMu.Unlock()
  484. c.conn.Close()
  485. }
  486. func (c *Conn) Address() string {
  487. return c.addr
  488. }
  489. func (c *Conn) AvailableStreams() int {
  490. return len(c.uniq)
  491. }
  492. func (c *Conn) UseKeyspace(keyspace string) error {
  493. q := &writeQueryFrame{statement: `USE "` + keyspace + `"`}
  494. q.params.consistency = Any
  495. resp, err := c.exec(q, nil)
  496. if err != nil {
  497. return err
  498. }
  499. switch x := resp.(type) {
  500. case *resultKeyspaceFrame:
  501. case error:
  502. return x
  503. default:
  504. return NewErrProtocol("Unknown type in response to USE: %s", x)
  505. }
  506. c.currentKeyspace = keyspace
  507. return nil
  508. }
  509. func (c *Conn) executeBatch(batch *Batch) error {
  510. if c.version == protoVersion1 {
  511. return ErrUnsupported
  512. }
  513. n := len(batch.Entries)
  514. req := &writeBatchFrame{
  515. typ: batch.Type,
  516. statements: make([]batchStatment, n),
  517. consistency: batch.Cons,
  518. }
  519. stmts := make(map[string]string)
  520. for i := 0; i < n; i++ {
  521. entry := &batch.Entries[i]
  522. b := &req.statements[i]
  523. if len(entry.Args) > 0 || entry.binding != nil {
  524. info, err := c.prepareStatement(entry.Stmt, nil)
  525. if err != nil {
  526. return err
  527. }
  528. var args []interface{}
  529. if entry.binding == nil {
  530. args = entry.Args
  531. } else {
  532. binding := &QueryInfo{
  533. Id: info.preparedID,
  534. Args: info.reqMeta.columns,
  535. Rval: info.respMeta.columns,
  536. }
  537. args, err = entry.binding(binding)
  538. if err != nil {
  539. return err
  540. }
  541. }
  542. if len(args) != len(info.reqMeta.columns) {
  543. return ErrQueryArgLength
  544. }
  545. b.preparedID = info.preparedID
  546. stmts[string(info.preparedID)] = entry.Stmt
  547. b.values = make([]queryValues, len(info.reqMeta.columns))
  548. for j := 0; j < len(info.reqMeta.columns); j++ {
  549. val, err := Marshal(info.reqMeta.columns[j].TypeInfo, args[j])
  550. if err != nil {
  551. return err
  552. }
  553. b.values[j].value = val
  554. // TODO: add names
  555. }
  556. } else {
  557. b.statement = entry.Stmt
  558. }
  559. }
  560. // TODO: should batch support tracing?
  561. resp, err := c.exec(req, nil)
  562. if err != nil {
  563. return err
  564. }
  565. switch x := resp.(type) {
  566. case *resultVoidFrame:
  567. return nil
  568. case RequestErrUnprepared:
  569. stmt, found := stmts[string(x.StatementId)]
  570. if found {
  571. stmtsLRU.Lock()
  572. stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
  573. stmtsLRU.Unlock()
  574. }
  575. if found {
  576. return c.executeBatch(batch)
  577. } else {
  578. return x
  579. }
  580. case error:
  581. return x
  582. default:
  583. return NewErrProtocol("Unknown type in response to batch statement: %s", x)
  584. }
  585. }
  586. func (c *Conn) setKeepalive(d time.Duration) error {
  587. if tc, ok := c.conn.(*net.TCPConn); ok {
  588. err := tc.SetKeepAlivePeriod(d)
  589. if err != nil {
  590. return err
  591. }
  592. return tc.SetKeepAlive(true)
  593. }
  594. return nil
  595. }
  596. type inflightPrepare struct {
  597. info *resultPreparedFrame
  598. err error
  599. wg sync.WaitGroup
  600. }
  601. var (
  602. ErrQueryArgLength = errors.New("query argument length mismatch")
  603. )