Browse Source

Merge pull request #819 from Zariel/protocol-discovery

Add support to discovery protocol version
Chris Bannister 9 năm trước cách đây
mục cha
commit
64566e11bb
9 tập tin đã thay đổi với 211 bổ sung88 xóa
  1. 14 0
      cassandra_test.go
  2. 11 4
      cluster.go
  3. 13 5
      conn.go
  4. 2 1
      conn_test.go
  5. 66 27
      control.go
  6. 35 0
      control_test.go
  7. 1 1
      frame.go
  8. 9 1
      ring.go
  9. 60 49
      session.go

+ 14 - 0
cassandra_test.go

@@ -2502,5 +2502,19 @@ func TestCreateSession_DontSwallowError(t *testing.T) {
 			t.Fatalf(`expcted to get error "unsupported response version" got: %q`, err)
 		}
 	}
+}
+
+func TestControl_DiscoverProtocol(t *testing.T) {
+	cluster := createCluster()
+	cluster.ProtoVersion = 0
+
+	session, err := cluster.CreateSession()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer session.Close()
 
+	if session.cfg.ProtoVersion == 0 {
+		t.Fatal("did not discovery protocol")
+	}
 }

+ 11 - 4
cluster.go

@@ -33,9 +33,17 @@ type ClusterConfig struct {
 	// address, which is used to index connected hosts. If the domain name specified
 	// resolves to more than 1 IP address then the driver may connect multiple times to
 	// the same host, and will not mark the node being down or up from events.
-	Hosts             []string
-	CQLVersion        string            // CQL version (default: 3.0.0)
-	ProtoVersion      int               // version of the native protocol (default: 2)
+	Hosts      []string
+	CQLVersion string // CQL version (default: 3.0.0)
+
+	// ProtoVersion sets the version of the native protocol to use, this will
+	// enable features in the driver for specific protocol versions, generally this
+	// should be set to a known version (2,3,4) for the cluster being connected to.
+	//
+	// If it is 0 or unset (the default) then the driver will attempt to discover the
+	// highest supported protocol for the cluster. In clusters with nodes of different
+	// versions the protocol selected is not defined (ie, it can be any of the supported in the cluster)
+	ProtoVersion      int
 	Timeout           time.Duration     // connection timeout (default: 600ms)
 	Port              int               // port (default: 9042)
 	Keyspace          string            // initial keyspace (optional)
@@ -118,7 +126,6 @@ func NewCluster(hosts ...string) *ClusterConfig {
 	cfg := &ClusterConfig{
 		Hosts:                  hosts,
 		CQLVersion:             "3.0.0",
-		ProtoVersion:           2,
 		Timeout:                600 * time.Millisecond,
 		Port:                   9042,
 		NumConns:               2,

+ 13 - 5
conn.go

@@ -440,6 +440,17 @@ func (c *Conn) discardFrame(head frameHeader) error {
 	return nil
 }
 
+type protocolError struct {
+	frame frame
+}
+
+func (p *protocolError) Error() string {
+	if err, ok := p.frame.(error); ok {
+		return err.Error()
+	}
+	return fmt.Sprintf("gocql: received unexpected frame on stream %d: %v", p.frame.Header().stream, p.frame)
+}
+
 func (c *Conn) recv() error {
 	// not safe for concurrent reads
 
@@ -479,11 +490,8 @@ func (c *Conn) recv() error {
 			return err
 		}
 
-		switch v := frame.(type) {
-		case error:
-			return fmt.Errorf("gocql: error on stream %d: %v", head.stream, v)
-		default:
-			return fmt.Errorf("gocql: received frame on stream %d: %v", head.stream, frame)
+		return &protocolError{
+			frame: frame,
 		}
 	}
 

+ 2 - 1
conn_test.go

@@ -454,7 +454,8 @@ func TestQueryTimeoutClose(t *testing.T) {
 }
 
 func TestStream0(t *testing.T) {
-	const expErr = "gocql: received frame on stream 0"
+	// TODO: replace this with type check
+	const expErr = "gocql: received unexpected frame on stream 0"
 
 	ctx, cancel := context.WithCancel(context.Background())
 	defer cancel()

+ 66 - 27
control.go

@@ -7,6 +7,7 @@ import (
 	"log"
 	"math/rand"
 	"net"
+	"regexp"
 	"strconv"
 	"sync/atomic"
 	"time"
@@ -134,53 +135,91 @@ func hostInfo(addr string, defaultPort int) (*HostInfo, error) {
 	return &HostInfo{peer: ip, port: port}, nil
 }
 
-func (c *controlConn) shuffleDial(endpoints []string) (conn *Conn, err error) {
-	// TODO: accept a []*HostInfo
-	perm := randr.Perm(len(endpoints))
-	shuffled := make([]string, len(endpoints))
+func shuffleHosts(hosts []*HostInfo) []*HostInfo {
+	perm := randr.Perm(len(hosts))
+	shuffled := make([]*HostInfo, len(hosts))
 
-	for i, endpoint := range endpoints {
-		shuffled[perm[i]] = endpoint
+	for i, host := range hosts {
+		shuffled[perm[i]] = host
 	}
 
+	return shuffled
+}
+
+func (c *controlConn) shuffleDial(endpoints []*HostInfo) (*Conn, error) {
 	// shuffle endpoints so not all drivers will connect to the same initial
 	// node.
-	for _, addr := range shuffled {
-		if addr == "" {
-			return nil, fmt.Errorf("invalid address: %q", addr)
+	shuffled := shuffleHosts(endpoints)
+
+	var err error
+	for _, host := range shuffled {
+		var conn *Conn
+		conn, err = c.session.connect(host, c)
+		if err == nil {
+			return conn, nil
 		}
 
-		port := c.session.cfg.Port
-		addr = JoinHostPort(addr, port)
+		log.Printf("gocql: unable to dial control conn %v: %v\n", host.Peer(), err)
+	}
 
-		var host *HostInfo
-		host, err = hostInfo(addr, port)
-		if err != nil {
-			return nil, fmt.Errorf("invalid address: %q: %v", addr, err)
-		}
+	return nil, err
+}
 
-		hostInfo, _ := c.session.ring.addHostIfMissing(host)
-		conn, err = c.session.connect(hostInfo, c)
-		if err == nil {
-			return conn, err
-		}
+// this is going to be version dependant and a nightmare to maintain :(
+var protocolSupportRe = regexp.MustCompile(`the lowest supported version is \d+ and the greatest is (\d+)$`)
 
-		log.Printf("gocql: unable to dial control conn %v: %v\n", addr, err)
+func parseProtocolFromError(err error) int {
+	// I really wish this had the actual info in the error frame...
+	matches := protocolSupportRe.FindAllStringSubmatch(err.Error(), -1)
+	if len(matches) != 1 || len(matches[0]) != 2 {
+		if verr, ok := err.(*protocolError); ok {
+			return int(verr.frame.Header().version.version())
+		}
+		return 0
 	}
 
+	max, err := strconv.Atoi(matches[0][1])
 	if err != nil {
-		return nil, err
+		return 0
+	}
+
+	return max
+}
+
+func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) {
+	hosts = shuffleHosts(hosts)
+
+	connCfg := *c.session.connCfg
+	connCfg.ProtoVersion = 4 // TODO: define maxProtocol
+
+	handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) {
+		// we should never get here, but if we do it means we connected to a
+		// host successfully which means our attempted protocol version worked
+	})
+
+	var err error
+	for _, host := range hosts {
+		var conn *Conn
+		conn, err = Connect(host, &connCfg, handler, c.session)
+		if err == nil {
+			conn.Close()
+			return connCfg.ProtoVersion, nil
+		}
+
+		if proto := parseProtocolFromError(err); proto > 0 {
+			return proto, nil
+		}
 	}
 
-	return conn, nil
+	return 0, err
 }
 
-func (c *controlConn) connect(endpoints []string) error {
-	if len(endpoints) == 0 {
+func (c *controlConn) connect(hosts []*HostInfo) error {
+	if len(hosts) == 0 {
 		return errors.New("control: no endpoints specified")
 	}
 
-	conn, err := c.shuffleDial(endpoints)
+	conn, err := c.shuffleDial(hosts)
 	if err != nil {
 		return fmt.Errorf("control: unable to connect to initial hosts: %v", err)
 	}

+ 35 - 0
control_test.go

@@ -29,3 +29,38 @@ func TestHostInfo_Lookup(t *testing.T) {
 		}
 	}
 }
+
+func TestParseProtocol(t *testing.T) {
+	tests := [...]struct {
+		err   error
+		proto int
+	}{
+		{
+			err: &protocolError{
+				frame: errorFrame{
+					code:    0x10,
+					message: "Invalid or unsupported protocol version (5); the lowest supported version is 3 and the greatest is 4",
+				},
+			},
+			proto: 4,
+		},
+		{
+			err: &protocolError{
+				frame: errorFrame{
+					frameHeader: frameHeader{
+						version: 0x83,
+					},
+					code:    0x10,
+					message: "Invalid or unsupported protocol version: 5",
+				},
+			},
+			proto: 3,
+		},
+	}
+
+	for i, test := range tests {
+		if proto := parseProtocolFromError(test.err); proto != test.proto {
+			t.Errorf("%d: exepcted proto %d got %d", i, test.proto, proto)
+		}
+	}
+}

+ 1 - 1
frame.go

@@ -365,7 +365,7 @@ func readHeader(r io.Reader, p []byte) (head frameHeader, err error) {
 	version := p[0] & protoVersionMask
 
 	if version < protoVersion1 || version > protoVersion4 {
-		return frameHeader{}, fmt.Errorf("gocql: unsupported response version: %d", version)
+		return frameHeader{}, fmt.Errorf("gocql: unsupported protocol response version: %d", version)
 	}
 
 	headSize := 9

+ 9 - 1
ring.go

@@ -10,7 +10,7 @@ type ring struct {
 	// endpoints are the set of endpoints which the driver will attempt to connect
 	// to in the case it can not reach any of its hosts. They are also used to boot
 	// strap the initial connection.
-	endpoints []string
+	endpoints []*HostInfo
 
 	// hosts are the set of all hosts in the cassandra ring that we know of
 	mu    sync.RWMutex
@@ -70,6 +70,14 @@ func (r *ring) addHost(host *HostInfo) bool {
 	return ok
 }
 
+func (r *ring) addOrUpdate(host *HostInfo) *HostInfo {
+	if existingHost, ok := r.addHostIfMissing(host); ok {
+		existingHost.update(host)
+		host = existingHost
+	}
+	return host
+}
+
 func (r *ring) addHostIfMissing(host *HostInfo) (*HostInfo, bool) {
 	ip := host.Peer().String()
 

+ 60 - 49
session.go

@@ -92,7 +92,7 @@ func addrsToHosts(addrs []string, defaultPort int) ([]*HostInfo, error) {
 
 // NewSession wraps an existing Node.
 func NewSession(cfg ClusterConfig) (*Session, error) {
-	//Check that hosts in the ClusterConfig is not empty
+	// Check that hosts in the ClusterConfig is not empty
 	if len(cfg.Hosts) < 1 {
 		return nil, ErrNoHosts
 	}
@@ -103,106 +103,117 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
 		cfg:      cfg,
 		pageSize: cfg.PageSize,
 		stmtsLRU: &preparedLRU{lru: lru.New(cfg.MaxPreparedStmts)},
+		quit:     make(chan struct{}),
 	}
 
-	connCfg, err := connConfig(s)
-	if err != nil {
-		s.Close()
-		return nil, fmt.Errorf("gocql: unable to create session: %v", err)
-	}
-	s.connCfg = connCfg
-
 	s.nodeEvents = newEventDebouncer("NodeEvents", s.handleNodeEvent)
 	s.schemaEvents = newEventDebouncer("SchemaEvents", s.handleSchemaEvent)
 
 	s.routingKeyInfoCache.lru = lru.New(cfg.MaxRoutingKeyInfo)
 
-	// I think it might be a good idea to simplify this and make it always discover
-	// hosts, maybe with more filters.
 	s.hostSource = &ringDescriber{
 		session:   s,
 		closeChan: make(chan bool),
 	}
 
-	pool := cfg.PoolConfig.buildPool(s)
 	if cfg.PoolConfig.HostSelectionPolicy == nil {
 		cfg.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy()
 	}
+	s.pool = cfg.PoolConfig.buildPool(s)
 
-	s.pool = pool
 	s.policy = cfg.PoolConfig.HostSelectionPolicy
 	s.executor = &queryExecutor{
-		pool:   pool,
+		pool:   s.pool,
 		policy: cfg.PoolConfig.HostSelectionPolicy,
 	}
 
-	var hosts []*HostInfo
-	if !cfg.disableControlConn {
+	if err := s.init(); err != nil {
+		// TODO(zariel): dont wrap this error in fmt.Errorf, return a typed error
+		s.Close()
+		return nil, fmt.Errorf("gocql: unable to create session: %v", err)
+	}
+
+	if s.pool.Size() == 0 {
+		// TODO(zariel): move this to init
+		s.Close()
+		return nil, ErrNoConnectionsStarted
+	}
+
+	return s, nil
+}
+
+func (s *Session) init() error {
+	hosts, err := addrsToHosts(s.cfg.Hosts, s.cfg.Port)
+	if err != nil {
+		return err
+	}
+
+	connCfg, err := connConfig(s)
+	if err != nil {
+		return err
+	}
+	s.connCfg = connCfg
+
+	if !s.cfg.disableControlConn {
 		s.control = createControlConn(s)
-		if err := s.control.connect(cfg.Hosts); err != nil {
-			s.Close()
-			return nil, fmt.Errorf("gocql: unable to create session: %v", err)
+		if s.cfg.ProtoVersion == 0 {
+			proto, err := s.control.discoverProtocol(hosts)
+			if err != nil {
+				return fmt.Errorf("unable to discover protocol version: %v", err)
+			} else if proto == 0 {
+				return errors.New("unable to discovery protocol version")
+			}
+
+			// TODO(zariel): we really only need this in 1 place
+			s.cfg.ProtoVersion = proto
+			connCfg.ProtoVersion = proto
+		}
+
+		if err := s.control.connect(hosts); err != nil {
+			return err
 		}
 
 		// need to setup host source to check for broadcast_address in system.local
 		localHasRPCAddr, _ := checkSystemLocal(s.control)
 		s.hostSource.localHasRpcAddr = localHasRPCAddr
 
-		if cfg.DisableInitialHostLookup {
-			// TODO: we could look at system.local to get token and other metadata
-			// in this case.
-			hosts, err = addrsToHosts(cfg.Hosts, cfg.Port)
-		} else {
-			hosts, _, err = s.hostSource.GetHosts()
+		if !s.cfg.DisableInitialHostLookup {
+			// TODO(zariel): we need to get the partitioner from here
+			var p string
+			hosts, p, err = s.hostSource.GetHosts()
+			if err != nil {
+				return err
+			}
+			s.policy.SetPartitioner(p)
 		}
-
-	} else {
-		// we dont get host info
-		hosts, err = addrsToHosts(cfg.Hosts, cfg.Port)
-	}
-
-	if err != nil {
-		s.Close()
-		return nil, fmt.Errorf("gocql: unable to create session: %v", err)
 	}
 
 	for _, host := range hosts {
 		if s.cfg.HostFilter == nil || s.cfg.HostFilter.Accept(host) {
-			if existingHost, ok := s.ring.addHostIfMissing(host); ok {
-				existingHost.update(host)
-			}
-
+			host = s.ring.addOrUpdate(host)
 			s.handleNodeUp(host.Peer(), host.Port(), false)
 		}
 	}
 
-	s.quit = make(chan struct{})
-
-	if cfg.ReconnectInterval > 0 {
-		go s.reconnectDownedHosts(cfg.ReconnectInterval)
-	}
-
 	// TODO(zariel): we probably dont need this any more as we verify that we
 	// can connect to one of the endpoints supplied by using the control conn.
 	// See if there are any connections in the pool
-	if s.pool.Size() == 0 {
-		s.Close()
-		return nil, ErrNoConnectionsStarted
+	if s.cfg.ReconnectInterval > 0 {
+		go s.reconnectDownedHosts(s.cfg.ReconnectInterval)
 	}
 
 	// If we disable the initial host lookup, we need to still check if the
 	// cluster is using the newer system schema or not... however, if control
 	// connection is disable, we really have no choice, so we just make our
 	// best guess...
-	if !cfg.disableControlConn && cfg.DisableInitialHostLookup {
-		// TODO(zariel): we dont need to do this twice
+	if !s.cfg.disableControlConn && s.cfg.DisableInitialHostLookup {
 		newer, _ := checkSystemSchema(s.control)
 		s.useSystemSchema = newer
 	} else {
 		s.useSystemSchema = hosts[0].Version().Major >= 3
 	}
 
-	return s, nil
+	return nil
 }
 
 func (s *Session) reconnectDownedHosts(intv time.Duration) {