conn.go 18 KB


  1. // Copyright (c) 2012 The gocql Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package gocql
  5. import (
  6. "bufio"
  7. "crypto/tls"
  8. "crypto/x509"
  9. "errors"
  10. "fmt"
  11. "io"
  12. "io/ioutil"
  13. "log"
  14. "net"
  15. "strconv"
  16. "strings"
  17. "sync"
  18. "sync/atomic"
  19. "time"
  20. )
  21. const (
  22. defaultFrameSize = 4096
  23. flagResponse = 0x80
  24. maskVersion = 0x7F
  25. )
  26. //JoinHostPort is a utility to return a address string that can be used
  27. //gocql.Conn to form a connection with a host.
  28. func JoinHostPort(addr string, port int) string {
  29. addr = strings.TrimSpace(addr)
  30. if _, _, err := net.SplitHostPort(addr); err != nil {
  31. addr = net.JoinHostPort(addr, strconv.Itoa(port))
  32. }
  33. return addr
  34. }
  35. type Authenticator interface {
  36. Challenge(req []byte) (resp []byte, auth Authenticator, err error)
  37. Success(data []byte) error
  38. }
  39. type PasswordAuthenticator struct {
  40. Username string
  41. Password string
  42. }
  43. func (p PasswordAuthenticator) Challenge(req []byte) ([]byte, Authenticator, error) {
  44. if string(req) != "org.apache.cassandra.auth.PasswordAuthenticator" {
  45. return nil, nil, fmt.Errorf("unexpected authenticator %q", req)
  46. }
  47. resp := make([]byte, 2+len(p.Username)+len(p.Password))
  48. resp[0] = 0
  49. copy(resp[1:], p.Username)
  50. resp[len(p.Username)+1] = 0
  51. copy(resp[2+len(p.Username):], p.Password)
  52. return resp, nil, nil
  53. }
  54. func (p PasswordAuthenticator) Success(data []byte) error {
  55. return nil
  56. }
  57. type SslOptions struct {
  58. CertPath string
  59. KeyPath string
  60. CaPath string //optional depending on server config
  61. // If you want to verify the hostname and server cert (like a wildcard for cass cluster) then you should turn this on
  62. // This option is basically the inverse of InSecureSkipVerify
  63. // See InSecureSkipVerify in http://golang.org/pkg/crypto/tls/ for more info
  64. EnableHostVerification bool
  65. }
  66. type ConnConfig struct {
  67. ProtoVersion int
  68. CQLVersion string
  69. Timeout time.Duration
  70. NumStreams int
  71. Compressor Compressor
  72. Authenticator Authenticator
  73. Keepalive time.Duration
  74. SslOpts *SslOptions
  75. }
  76. // Conn is a single connection to a Cassandra node. It can be used to execute
  77. // queries, but users are usually advised to use a more reliable, higher
  78. // level API.
  79. type Conn struct {
  80. conn net.Conn
  81. r *bufio.Reader
  82. timeout time.Duration
  83. uniq chan int
  84. calls []callReq
  85. nwait int32
  86. pool ConnectionPool
  87. compressor Compressor
  88. auth Authenticator
  89. addr string
  90. version uint8
  91. currentKeyspace string
  92. closedMu sync.RWMutex
  93. isClosed bool
  94. }
  95. // Connect establishes a connection to a Cassandra node.
  96. // You must also call the Serve method before you can execute any queries.
  97. func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
  98. var (
  99. err error
  100. conn net.Conn
  101. )
  102. if cfg.SslOpts != nil {
  103. certPool := x509.NewCertPool()
  104. //ca cert is optional
  105. if cfg.SslOpts.CaPath != "" {
  106. pem, err := ioutil.ReadFile(cfg.SslOpts.CaPath)
  107. if err != nil {
  108. return nil, err
  109. }
  110. if !certPool.AppendCertsFromPEM(pem) {
  111. return nil, errors.New("Failed parsing or appending certs")
  112. }
  113. }
  114. mycert, err := tls.LoadX509KeyPair(cfg.SslOpts.CertPath, cfg.SslOpts.KeyPath)
  115. if err != nil {
  116. return nil, err
  117. }
  118. config := tls.Config{
  119. Certificates: []tls.Certificate{mycert},
  120. RootCAs: certPool,
  121. }
  122. config.InsecureSkipVerify = !cfg.SslOpts.EnableHostVerification
  123. if conn, err = tls.Dial("tcp", addr, &config); err != nil {
  124. return nil, err
  125. }
  126. } else if conn, err = net.DialTimeout("tcp", addr, cfg.Timeout); err != nil {
  127. return nil, err
  128. }
  129. // going to default to proto 2
  130. if cfg.ProtoVersion < protoVersion1 || cfg.ProtoVersion > protoVersion3 {
  131. log.Printf("unsupported protocol version: %d using 2\n", cfg.ProtoVersion)
  132. cfg.ProtoVersion = 2
  133. }
  134. maxStreams := 128
  135. if cfg.ProtoVersion > protoVersion2 {
  136. maxStreams = 32768
  137. }
  138. if cfg.NumStreams <= 0 || cfg.NumStreams > maxStreams {
  139. cfg.NumStreams = maxStreams
  140. }
  141. c := &Conn{
  142. conn: conn,
  143. r: bufio.NewReader(conn),
  144. uniq: make(chan int, cfg.NumStreams),
  145. calls: make([]callReq, cfg.NumStreams),
  146. timeout: cfg.Timeout,
  147. version: uint8(cfg.ProtoVersion),
  148. addr: conn.RemoteAddr().String(),
  149. pool: pool,
  150. compressor: cfg.Compressor,
  151. auth: cfg.Authenticator,
  152. }
  153. if cfg.Keepalive > 0 {
  154. c.setKeepalive(cfg.Keepalive)
  155. }
  156. for i := 0; i < cfg.NumStreams; i++ {
  157. c.uniq <- i
  158. }
  159. if err := c.startup(&cfg); err != nil {
  160. conn.Close()
  161. return nil, err
  162. }
  163. go c.serve()
  164. return c, nil
  165. }
  166. func (c *Conn) startup(cfg *ConnConfig) error {
  167. compression := ""
  168. if c.compressor != nil {
  169. compression = c.compressor.Name()
  170. }
  171. var req operation = &startupFrame{
  172. CQLVersion: cfg.CQLVersion,
  173. Compression: compression,
  174. }
  175. var challenger Authenticator
  176. for {
  177. resp, err := c.execSimple(req)
  178. if err != nil {
  179. return err
  180. }
  181. switch x := resp.(type) {
  182. case readyFrame:
  183. return nil
  184. case error:
  185. return x
  186. case authenticateFrame:
  187. if c.auth == nil {
  188. return fmt.Errorf("authentication required (using %q)", x.Authenticator)
  189. }
  190. var resp []byte
  191. resp, challenger, err = c.auth.Challenge([]byte(x.Authenticator))
  192. if err != nil {
  193. return err
  194. }
  195. req = &authResponseFrame{resp}
  196. case authChallengeFrame:
  197. if challenger == nil {
  198. return fmt.Errorf("authentication error (invalid challenge)")
  199. }
  200. var resp []byte
  201. resp, challenger, err = challenger.Challenge(x.Data)
  202. if err != nil {
  203. return err
  204. }
  205. req = &authResponseFrame{resp}
  206. case authSuccessFrame:
  207. if challenger != nil {
  208. return challenger.Success(x.Data)
  209. }
  210. return nil
  211. default:
  212. return NewErrProtocol("Unknown type of response to startup frame: %s", x)
  213. }
  214. }
  215. }
  216. // Serve starts the stream multiplexer for this connection, which is required
  217. // to execute any queries. This method runs as long as the connection is
  218. // open and is therefore usually called in a separate goroutine.
  219. func (c *Conn) serve() {
  220. var (
  221. err error
  222. resp frame
  223. )
  224. for {
  225. resp, err = c.recv()
  226. if err != nil {
  227. break
  228. }
  229. c.dispatch(resp)
  230. }
  231. c.Close()
  232. for id := 0; id < len(c.calls); id++ {
  233. req := &c.calls[id]
  234. if atomic.LoadInt32(&req.active) == 1 {
  235. req.resp <- callResp{nil, err}
  236. }
  237. }
  238. c.pool.HandleError(c, err, true)
  239. }
  240. func (c *Conn) Write(p []byte) (int, error) {
  241. c.conn.SetWriteDeadline(time.Now().Add(c.timeout))
  242. return c.conn.Write(p)
  243. }
  244. func (c *Conn) Read(p []byte) (int, error) {
  245. return c.r.Read(p)
  246. }
  247. func (c *Conn) recv() (frame, error) {
  248. size := headerProtoSize[c.version]
  249. resp := make(frame, size, size+512)
  250. // read a full header, ignore timeouts, as this is being ran in a loop
  251. c.conn.SetReadDeadline(time.Time{})
  252. _, err := io.ReadFull(c.r, resp[:size])
  253. if err != nil {
  254. return nil, err
  255. }
  256. if v := c.version | flagResponse; resp[0] != v {
  257. return nil, NewErrProtocol("recv: response protocol version does not match connection protocol version (%d != %d)", resp[0], v)
  258. }
  259. bodySize := resp.Length(c.version)
  260. if bodySize == 0 {
  261. return resp, nil
  262. }
  263. resp.grow(bodySize)
  264. const maxAttempts = 5
  265. n := size
  266. for i := 0; i < maxAttempts; i++ {
  267. var nn int
  268. c.conn.SetReadDeadline(time.Now().Add(c.timeout))
  269. nn, err = io.ReadFull(c.r, resp[n:size+bodySize])
  270. if err == nil {
  271. break
  272. }
  273. n += nn
  274. if verr, ok := err.(net.Error); !ok || !verr.Temporary() {
  275. break
  276. }
  277. }
  278. if err != nil {
  279. return nil, err
  280. }
  281. return resp, nil
  282. }
  283. func (c *Conn) execSimple(op operation) (interface{}, error) {
  284. f, err := op.encodeFrame(c.version, nil)
  285. if err != nil {
  286. // this should be a noop err
  287. return nil, err
  288. }
  289. bodyLen := len(f) - headerProtoSize[c.version]
  290. f.setLength(bodyLen, c.version)
  291. if _, err := c.Write([]byte(f)); err != nil {
  292. c.Close()
  293. return nil, err
  294. }
  295. // here recv wont timeout waiting for a header, should it?
  296. if f, err = c.recv(); err != nil {
  297. return nil, err
  298. }
  299. return c.decodeFrame(f, nil)
  300. }
  301. func (c *Conn) exec(op operation, trace Tracer) (interface{}, error) {
  302. req, err := op.encodeFrame(c.version, nil)
  303. if err != nil {
  304. return nil, err
  305. }
  306. if trace != nil {
  307. req[1] |= flagTrace
  308. }
  309. headerSize := headerProtoSize[c.version]
  310. if len(req) > headerSize && c.compressor != nil {
  311. body, err := c.compressor.Encode([]byte(req[headerSize:]))
  312. if err != nil {
  313. return nil, err
  314. }
  315. req = append(req[:headerSize], frame(body)...)
  316. req[1] |= flagCompress
  317. }
  318. bodyLen := len(req) - headerSize
  319. req.setLength(bodyLen, c.version)
  320. id := <-c.uniq
  321. req.setStream(id, c.version)
  322. call := &c.calls[id]
  323. call.resp = make(chan callResp, 1)
  324. atomic.AddInt32(&c.nwait, 1)
  325. atomic.StoreInt32(&call.active, 1)
  326. if _, err := c.Write(req); err != nil {
  327. c.uniq <- id
  328. c.Close()
  329. return nil, err
  330. }
  331. reply := <-call.resp
  332. call.resp = nil
  333. c.uniq <- id
  334. if reply.err != nil {
  335. return nil, reply.err
  336. }
  337. return c.decodeFrame(reply.buf, trace)
  338. }
  339. func (c *Conn) dispatch(resp frame) {
  340. id := resp.Stream(c.version)
  341. if id >= len(c.calls) {
  342. return
  343. }
  344. call := &c.calls[id]
  345. if !atomic.CompareAndSwapInt32(&call.active, 1, 0) {
  346. return
  347. }
  348. atomic.AddInt32(&c.nwait, -1)
  349. call.resp <- callResp{resp, nil}
  350. }
  351. func (c *Conn) ping() error {
  352. _, err := c.exec(&optionsFrame{}, nil)
  353. return err
  354. }
  355. func (c *Conn) prepareStatement(stmt string, trace Tracer) (*QueryInfo, error) {
  356. stmtsLRU.Lock()
  357. if stmtsLRU.lru == nil {
  358. initStmtsLRU(defaultMaxPreparedStmts)
  359. }
  360. stmtCacheKey := c.addr + c.currentKeyspace + stmt
  361. if val, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
  362. flight := val.(*inflightPrepare)
  363. stmtsLRU.Unlock()
  364. flight.wg.Wait()
  365. return flight.info, flight.err
  366. }
  367. flight := new(inflightPrepare)
  368. flight.wg.Add(1)
  369. stmtsLRU.lru.Add(stmtCacheKey, flight)
  370. stmtsLRU.Unlock()
  371. resp, err := c.exec(&prepareFrame{Stmt: stmt}, trace)
  372. if err != nil {
  373. flight.err = err
  374. } else {
  375. switch x := resp.(type) {
  376. case resultPreparedFrame:
  377. flight.info = &QueryInfo{
  378. Id: x.PreparedId,
  379. Args: x.Arguments,
  380. Rval: x.ReturnValues,
  381. }
  382. case error:
  383. flight.err = x
  384. default:
  385. flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x)
  386. }
  387. err = flight.err
  388. }
  389. flight.wg.Done()
  390. if err != nil {
  391. stmtsLRU.Lock()
  392. stmtsLRU.lru.Remove(stmtCacheKey)
  393. stmtsLRU.Unlock()
  394. }
  395. return flight.info, flight.err
  396. }
  397. func (c *Conn) executeQuery(qry *Query) *Iter {
  398. op := &queryFrame{
  399. Stmt: qry.stmt,
  400. Cons: qry.cons,
  401. PageSize: qry.pageSize,
  402. PageState: qry.pageState,
  403. }
  404. if qry.shouldPrepare() {
  405. // Prepare all DML queries. Other queries can not be prepared.
  406. info, err := c.prepareStatement(qry.stmt, qry.trace)
  407. if err != nil {
  408. return &Iter{err: err}
  409. }
  410. var values []interface{}
  411. if qry.binding == nil {
  412. values = qry.values
  413. } else {
  414. values, err = qry.binding(info)
  415. if err != nil {
  416. return &Iter{err: err}
  417. }
  418. }
  419. if len(values) != len(info.Args) {
  420. return &Iter{err: ErrQueryArgLength}
  421. }
  422. op.Prepared = info.Id
  423. op.Values = make([][]byte, len(values))
  424. for i := 0; i < len(values); i++ {
  425. val, err := Marshal(info.Args[i].TypeInfo, values[i])
  426. if err != nil {
  427. return &Iter{err: err}
  428. }
  429. op.Values[i] = val
  430. }
  431. }
  432. resp, err := c.exec(op, qry.trace)
  433. if err != nil {
  434. return &Iter{err: err}
  435. }
  436. switch x := resp.(type) {
  437. case resultVoidFrame:
  438. return &Iter{}
  439. case resultRowsFrame:
  440. iter := &Iter{columns: x.Columns, rows: x.Rows}
  441. if len(x.PagingState) > 0 {
  442. iter.next = &nextIter{
  443. qry: *qry,
  444. pos: int((1 - qry.prefetch) * float64(len(iter.rows))),
  445. }
  446. iter.next.qry.pageState = x.PagingState
  447. if iter.next.pos < 1 {
  448. iter.next.pos = 1
  449. }
  450. }
  451. return iter
  452. case resultKeyspaceFrame:
  453. return &Iter{}
  454. case RequestErrUnprepared:
  455. stmtsLRU.Lock()
  456. stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt
  457. if _, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
  458. stmtsLRU.lru.Remove(stmtCacheKey)
  459. stmtsLRU.Unlock()
  460. return c.executeQuery(qry)
  461. }
  462. stmtsLRU.Unlock()
  463. return &Iter{err: x}
  464. case error:
  465. return &Iter{err: x}
  466. default:
  467. return &Iter{err: NewErrProtocol("Unknown type in response to execute query: %s", x)}
  468. }
  469. }
  470. func (c *Conn) Pick(qry *Query) *Conn {
  471. if c.Closed() {
  472. return nil
  473. }
  474. return c
  475. }
  476. func (c *Conn) Closed() bool {
  477. c.closedMu.RLock()
  478. closed := c.isClosed
  479. c.closedMu.RUnlock()
  480. return closed
  481. }
  482. func (c *Conn) Close() {
  483. c.closedMu.Lock()
  484. if c.isClosed {
  485. c.closedMu.Unlock()
  486. return
  487. }
  488. c.isClosed = true
  489. c.closedMu.Unlock()
  490. c.conn.Close()
  491. }
  492. func (c *Conn) Address() string {
  493. return c.addr
  494. }
  495. func (c *Conn) AvailableStreams() int {
  496. return len(c.uniq)
  497. }
  498. func (c *Conn) UseKeyspace(keyspace string) error {
  499. resp, err := c.exec(&queryFrame{Stmt: `USE "` + keyspace + `"`, Cons: Any}, nil)
  500. if err != nil {
  501. return err
  502. }
  503. switch x := resp.(type) {
  504. case resultKeyspaceFrame:
  505. case error:
  506. return x
  507. default:
  508. return NewErrProtocol("Unknown type in response to USE: %s", x)
  509. }
  510. c.currentKeyspace = keyspace
  511. return nil
  512. }
  513. func (c *Conn) executeBatch(batch *Batch) error {
  514. if c.version == protoVersion1 {
  515. return ErrUnsupported
  516. }
  517. f := newFrame(c.version)
  518. f.setHeader(c.version, 0, 0, opBatch)
  519. f.writeByte(byte(batch.Type))
  520. f.writeShort(uint16(len(batch.Entries)))
  521. stmts := make(map[string]string)
  522. for i := 0; i < len(batch.Entries); i++ {
  523. entry := &batch.Entries[i]
  524. var info *QueryInfo
  525. var args []interface{}
  526. if len(entry.Args) > 0 || entry.binding != nil {
  527. var err error
  528. info, err = c.prepareStatement(entry.Stmt, nil)
  529. if err != nil {
  530. return err
  531. }
  532. if entry.binding == nil {
  533. args = entry.Args
  534. } else {
  535. args, err = entry.binding(info)
  536. if err != nil {
  537. return err
  538. }
  539. }
  540. if len(args) != len(info.Args) {
  541. return ErrQueryArgLength
  542. }
  543. stmts[string(info.Id)] = entry.Stmt
  544. f.writeByte(1)
  545. f.writeShortBytes(info.Id)
  546. } else {
  547. f.writeByte(0)
  548. f.writeLongString(entry.Stmt)
  549. }
  550. f.writeShort(uint16(len(args)))
  551. for j := 0; j < len(args); j++ {
  552. val, err := Marshal(info.Args[j].TypeInfo, args[j])
  553. if err != nil {
  554. return err
  555. }
  556. f.writeBytes(val)
  557. }
  558. }
  559. f.writeConsistency(batch.Cons)
  560. if c.version >= protoVersion3 {
  561. // TODO: add support for flags here
  562. f.writeByte(0)
  563. }
  564. resp, err := c.exec(f, nil)
  565. if err != nil {
  566. return err
  567. }
  568. switch x := resp.(type) {
  569. case resultVoidFrame:
  570. return nil
  571. case RequestErrUnprepared:
  572. stmt, found := stmts[string(x.StatementId)]
  573. if found {
  574. stmtsLRU.Lock()
  575. stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
  576. stmtsLRU.Unlock()
  577. }
  578. if found {
  579. return c.executeBatch(batch)
  580. } else {
  581. return x
  582. }
  583. case error:
  584. return x
  585. default:
  586. return NewErrProtocol("Unknown type in response to batch statement: %s", x)
  587. }
  588. }
  589. func (c *Conn) decodeFrame(f frame, trace Tracer) (rval interface{}, err error) {
  590. defer func() {
  591. if r := recover(); r != nil {
  592. if e, ok := r.(ErrProtocol); ok {
  593. err = e
  594. return
  595. }
  596. panic(r)
  597. }
  598. }()
  599. headerSize := headerProtoSize[c.version]
  600. if len(f) < headerSize {
  601. return nil, NewErrProtocol("Decoding frame: less data received than required for header: %d < %d", len(f), headerSize)
  602. } else if f[0] != c.version|flagResponse {
  603. return nil, NewErrProtocol("Decoding frame: response protocol version does not match connection protocol version (%d != %d)", f[0], c.version|flagResponse)
  604. }
  605. flags, op, f := f[1], f.Op(c.version), f[headerSize:]
  606. if flags&flagCompress != 0 && len(f) > 0 && c.compressor != nil {
  607. if buf, err := c.compressor.Decode([]byte(f)); err != nil {
  608. return nil, err
  609. } else {
  610. f = frame(buf)
  611. }
  612. }
  613. if flags&flagTrace != 0 {
  614. if len(f) < 16 {
  615. return nil, NewErrProtocol("Decoding frame: length of frame less than 16 while tracing is enabled")
  616. }
  617. traceId := []byte(f[:16])
  618. f = f[16:]
  619. trace.Trace(traceId)
  620. }
  621. switch op {
  622. case opReady:
  623. return readyFrame{}, nil
  624. case opResult:
  625. switch kind := f.readInt(); kind {
  626. case resultKindVoid:
  627. return resultVoidFrame{}, nil
  628. case resultKindRows:
  629. columns, pageState := f.readMetaData(c.version)
  630. numRows := f.readInt()
  631. values := make([][]byte, numRows*len(columns))
  632. for i := 0; i < len(values); i++ {
  633. values[i] = f.readBytes()
  634. }
  635. rows := make([][][]byte, numRows)
  636. for i := 0; i < numRows; i++ {
  637. rows[i], values = values[:len(columns)], values[len(columns):]
  638. }
  639. return resultRowsFrame{columns, rows, pageState}, nil
  640. case resultKindKeyspace:
  641. keyspace := f.readString()
  642. return resultKeyspaceFrame{keyspace}, nil
  643. case resultKindPrepared:
  644. id := f.readShortBytes()
  645. args, _ := f.readMetaData(c.version)
  646. if c.version < 2 {
  647. return resultPreparedFrame{PreparedId: id, Arguments: args}, nil
  648. }
  649. rvals, _ := f.readMetaData(c.version)
  650. return resultPreparedFrame{PreparedId: id, Arguments: args, ReturnValues: rvals}, nil
  651. case resultKindSchemaChanged:
  652. return resultVoidFrame{}, nil
  653. default:
  654. return nil, NewErrProtocol("Decoding frame: unknown result kind %s", kind)
  655. }
  656. case opAuthenticate:
  657. return authenticateFrame{f.readString()}, nil
  658. case opAuthChallenge:
  659. return authChallengeFrame{f.readBytes()}, nil
  660. case opAuthSuccess:
  661. return authSuccessFrame{f.readBytes()}, nil
  662. case opSupported:
  663. return supportedFrame{}, nil
  664. case opError:
  665. return f.readError(), nil
  666. default:
  667. return nil, NewErrProtocol("Decoding frame: unknown op", op)
  668. }
  669. }
  670. func (c *Conn) setKeepalive(d time.Duration) error {
  671. if tc, ok := c.conn.(*net.TCPConn); ok {
  672. err := tc.SetKeepAlivePeriod(d)
  673. if err != nil {
  674. return err
  675. }
  676. return tc.SetKeepAlive(true)
  677. }
  678. return nil
  679. }
  680. // QueryInfo represents the meta data associated with a prepared CQL statement.
  681. type QueryInfo struct {
  682. Id []byte
  683. Args []ColumnInfo
  684. Rval []ColumnInfo
  685. }
  686. type callReq struct {
  687. active int32
  688. resp chan callResp
  689. }
  690. type callResp struct {
  691. buf frame
  692. err error
  693. }
  694. type inflightPrepare struct {
  695. info *QueryInfo
  696. err error
  697. wg sync.WaitGroup
  698. }
  699. var (
  700. ErrQueryArgLength = errors.New("query argument length mismatch")
  701. )