浏览代码

Merge pull request #659 from retailnext/fix-crashes-on-bad-peers

Don't crash on bad peers.
Chris Bannister 9 年之前
父节点
当前提交
a572b53bea
共有 3 个文件被更改,包括 61 次插入3 次删除
  1. 35 0
      cassandra_test.go
  2. 7 1
      conn.go
  3. 19 2
      host_source.go

+ 35 - 0
cassandra_test.go

@@ -77,6 +77,41 @@ func TestEmptyHosts(t *testing.T) {
 	}
 }
 
+func TestInvalidPeerEntry(t *testing.T) {
+	session := createSession(t)
+
+	// rack, release_version, schema_version, tokens are all null
+	query := session.Query("INSERT into system.peers (peer, data_center, host_id, rpc_address) VALUES (?, ?, ?, ?)",
+		"169.254.235.45",
+		"datacenter1",
+		"35c0ec48-5109-40fd-9281-9e9d4add2f1e",
+		"169.254.235.45",
+	)
+
+	// clean up naughty peer
+	defer session.Query("DELETE from system.peers where peer == ?", "169.254.235.45").Exec()
+
+	if err := query.Exec(); err != nil {
+		t.Fatal(err)
+	}
+
+	session.Close()
+
+	cluster := createCluster()
+	cluster.PoolConfig.HostSelectionPolicy = TokenAwareHostPolicy(RoundRobinHostPolicy())
+	session = createSessionFromCluster(cluster, t)
+	defer session.Close()
+
+	// check we can perform a query
+	iter := session.Query("select peer from system.peers").Iter()
+	var peer string
+	for iter.Scan(&peer) {
+	}
+	if err := iter.Close(); err != nil {
+		t.Fatal(err)
+	}
+}
+
 //TestUseStatementError checks to make sure the correct error is returned when the user tries to execute a use statement.
 func TestUseStatementError(t *testing.T) {
 	session := createSession(t)

+ 7 - 1
conn.go

@@ -9,7 +9,6 @@ import (
 	"crypto/tls"
 	"errors"
 	"fmt"
-	"github.com/gocql/gocql/internal/lru"
 	"io"
 	"io/ioutil"
 	"log"
@@ -20,6 +19,8 @@ import (
 	"sync/atomic"
 	"time"
 
+	"github.com/gocql/gocql/internal/lru"
+
 	"github.com/gocql/gocql/internal/streams"
 )
 
@@ -1001,6 +1002,11 @@ func (c *Conn) awaitSchemaAgreement() (err error) {
 
 		var schemaVersion string
 		for iter.Scan(&schemaVersion) {
+			if schemaVersion == "" {
+				log.Println("skipping peer entry with empty schema_version")
+				continue
+			}
+
 			versions[schemaVersion] = struct{}{}
 			schemaVersion = ""
 		}

+ 19 - 2
host_source.go

@@ -2,6 +2,7 @@ package gocql
 
 import (
 	"fmt"
+	"log"
 	"net"
 	"strconv"
 	"strings"
@@ -38,10 +39,18 @@ func (c *cassVersion) Set(v string) error {
 }
 
 func (c *cassVersion) UnmarshalCQL(info TypeInfo, data []byte) error {
+	return c.unmarshal(data)
+}
+
+func (c *cassVersion) unmarshal(data []byte) error {
 	version := strings.TrimSuffix(string(data), "-SNAPSHOT")
 	version = strings.TrimPrefix(version, "v")
 	v := strings.Split(version, ".")
 
+	if len(v) < 2 {
+		return fmt.Errorf("invalid version string: %s", data)
+	}
+
 	var err error
 	c.Major, err = strconv.Atoi(v[0])
 	if err != nil {
@@ -319,8 +328,16 @@ func (r *ringDescriber) GetHosts() (hosts []*HostInfo, partitioner string, err e
 		return r.prevHosts, r.prevPartitioner, nil
 	}
 
-	host := &HostInfo{port: r.session.cfg.Port}
-	for iter.Scan(&host.peer, &host.dataCenter, &host.rack, &host.hostId, &host.tokens, &host.version) {
+	var (
+		host         = &HostInfo{port: r.session.cfg.Port}
+		versionBytes []byte
+	)
+	for iter.Scan(&host.peer, &host.dataCenter, &host.rack, &host.hostId, &host.tokens, &versionBytes) {
+		if err = host.version.unmarshal(versionBytes); err != nil {
+			log.Printf("invalid peer entry: peer=%s host_id=%s tokens=%v version=%s\n", host.peer, host.hostId, host.tokens, versionBytes)
+			continue
+		}
+
 		if r.matchFilter(host) {
 			hosts = append(hosts, host)
 		}