Browse Source

added password authentication support for Cassandra 2

Christoph Hack 12 years ago
parent
commit
788922259a
4 changed files with 133 additions and 42 deletions
  1. 4 0
      cassandra_test.go
  2. 20 18
      cluster.go
  3. 84 24
      conn.go
  4. 25 0
      frame.go

+ 4 - 0
cassandra_test.go

@@ -27,6 +27,10 @@ func createSession(t *testing.T) *Session {
 	cluster := NewCluster(strings.Split(*flagCluster, ",")...)
 	cluster.ProtoVersion = *flagProto
 	cluster.CQLVersion = *flagCQL
+	cluster.Authenticator = PasswordAuthenticator{
+		Username: "cassandra",
+		Password: "cassandra",
+	}
 
 	session, err := cluster.CreateSession()
 	if err != nil {

+ 20 - 18
cluster.go

@@ -18,19 +18,20 @@ import (
 // behavior to fit the most common use cases. Applications that requre a
 // different setup must implement their own cluster.
 type ClusterConfig struct {
-	Hosts        []string      // addresses for the initial connections
-	CQLVersion   string        // CQL version (default: 3.0.0)
-	ProtoVersion int           // version of the native protocol (default: 2)
-	Timeout      time.Duration // connection timeout (default: 200ms)
-	DefaultPort  int           // default port (default: 9042)
-	Keyspace     string        // initial keyspace (optional)
-	NumConns     int           // number of connections per host (default: 2)
-	NumStreams   int           // number of streams per connection (default: 128)
-	DelayMin     time.Duration // minimum reconnection delay (default: 1s)
-	DelayMax     time.Duration // maximum reconnection delay (default: 10min)
-	StartupMin   int           // wait for StartupMin hosts (default: len(Hosts)/2+1)
-	Consistency  Consistency   // default consistency level (default: Quorum)
-	Compressor   Compressor    // compression algorithm (default: nil)
+	Hosts         []string      // addresses for the initial connections
+	CQLVersion    string        // CQL version (default: 3.0.0)
+	ProtoVersion  int           // version of the native protocol (default: 2)
+	Timeout       time.Duration // connection timeout (default: 200ms)
+	DefaultPort   int           // default port (default: 9042)
+	Keyspace      string        // initial keyspace (optional)
+	NumConns      int           // number of connections per host (default: 2)
+	NumStreams    int           // number of streams per connection (default: 128)
+	DelayMin      time.Duration // minimum reconnection delay (default: 1s)
+	DelayMax      time.Duration // maximum reconnection delay (default: 10min)
+	StartupMin    int           // wait for StartupMin hosts (default: len(Hosts)/2+1)
+	Consistency   Consistency   // default consistency level (default: Quorum)
+	Compressor    Compressor    // compression algorithm (default: nil)
+	Authenticator Authenticator // authenticator (default: nil)
 }
 
 // NewCluster generates a new config for the default cluster implementation.
@@ -102,11 +103,12 @@ type clusterImpl struct {
 
 func (c *clusterImpl) connect(addr string) {
 	cfg := ConnConfig{
-		ProtoVersion: c.cfg.ProtoVersion,
-		CQLVersion:   c.cfg.CQLVersion,
-		Timeout:      c.cfg.Timeout,
-		NumStreams:   c.cfg.NumStreams,
-		Compressor:   c.cfg.Compressor,
+		ProtoVersion:  c.cfg.ProtoVersion,
+		CQLVersion:    c.cfg.CQLVersion,
+		Timeout:       c.cfg.Timeout,
+		NumStreams:    c.cfg.NumStreams,
+		Compressor:    c.cfg.Compressor,
+		Authenticator: c.cfg.Authenticator,
 	}
 	delay := c.cfg.DelayMin
 	for {

+ 84 - 24
conn.go

@@ -6,6 +6,7 @@ package gocql
 
 import (
 	"bufio"
+	"fmt"
 	"net"
 	"sync"
 	"sync/atomic"
@@ -19,22 +20,43 @@ const flagResponse = 0x80
 const maskVersion = 0x7F
 
 type Cluster interface {
-	//HandleAuth(addr, method string) ([]byte, Challenger, error)
 	HandleError(conn *Conn, err error, closed bool)
 	HandleKeyspace(conn *Conn, keyspace string)
-	// Authenticate(addr string)
 }
 
-/* type Challenger interface {
-	Challenge(data []byte) ([]byte, error)
-} */
+type Authenticator interface {
+	Challenge(req []byte) (resp []byte, auth Authenticator, err error)
+	Success(data []byte) error
+}
+
+type PasswordAuthenticator struct {
+	Username string
+	Password string
+}
+
+func (p PasswordAuthenticator) Challenge(req []byte) ([]byte, Authenticator, error) {
+	if string(req) != "org.apache.cassandra.auth.PasswordAuthenticator" {
+		return nil, nil, fmt.Errorf("unexpected authenticator %q", req)
+	}
+	resp := make([]byte, 2+len(p.Username)+len(p.Password))
+	resp[0] = 0
+	copy(resp[1:], p.Username)
+	resp[len(p.Username)+1] = 0
+	copy(resp[2+len(p.Username):], p.Password)
+	return resp, nil, nil
+}
+
+func (p PasswordAuthenticator) Success(data []byte) error {
+	return nil
+}
 
 type ConnConfig struct {
-	ProtoVersion int
-	CQLVersion   string
-	Timeout      time.Duration
-	NumStreams   int
-	Compressor   Compressor
+	ProtoVersion  int
+	CQLVersion    string
+	Timeout       time.Duration
+	NumStreams    int
+	Compressor    Compressor
+	Authenticator Authenticator
 }
 
 // Conn is a single connection to a Cassandra node. It can be used to execute
@@ -54,6 +76,7 @@ type Conn struct {
 
 	cluster    Cluster
 	compressor Compressor
+	auth       Authenticator
 	addr       string
 	version    uint8
 }
@@ -82,6 +105,7 @@ func Connect(addr string, cfg ConnConfig, cluster Cluster) (*Conn, error) {
 		addr:       conn.RemoteAddr().String(),
 		cluster:    cluster,
 		compressor: cfg.Compressor,
+		auth:       cfg.Authenticator,
 	}
 	for i := 0; i < cap(c.uniq); i++ {
 		c.uniq <- uint8(i)
@@ -97,24 +121,54 @@ func Connect(addr string, cfg ConnConfig, cluster Cluster) (*Conn, error) {
 }
 
 func (c *Conn) startup(cfg *ConnConfig) error {
-	req := &startupFrame{
-		CQLVersion: cfg.CQLVersion,
-	}
+	compression := ""
 	if c.compressor != nil {
-		req.Compression = c.compressor.Name()
+		compression = c.compressor.Name()
 	}
-	resp, err := c.execSimple(req)
-	if err != nil {
-		return err
+	var req operation = &startupFrame{
+		CQLVersion:  cfg.CQLVersion,
+		Compression: compression,
 	}
-	switch x := resp.(type) {
-	case readyFrame:
-	case error:
-		return x
-	default:
-		return ErrProtocol
+	var challenger Authenticator
+	for {
+		resp, err := c.execSimple(req)
+		if err != nil {
+			return err
+		}
+		switch x := resp.(type) {
+		case readyFrame:
+			return nil
+		case error:
+			return x
+		case authenticateFrame:
+			if c.auth == nil {
+				return fmt.Errorf("authentication required (using %q)", x.Authenticator)
+			}
+			var resp []byte
+			resp, challenger, err = c.auth.Challenge([]byte(x.Authenticator))
+			if err != nil {
+				return err
+			}
+			req = &authResponseFrame{resp}
+		case authChallengeFrame:
+			if challenger == nil {
+				return fmt.Errorf("authentication error (invalid challenge)")
+			}
+			var resp []byte
+			resp, challenger, err = challenger.Challenge(x.Data)
+			if err != nil {
+				return err
+			}
+			req = &authResponseFrame{resp}
+		case authSuccessFrame:
+			if challenger != nil {
+				return challenger.Success(x.Data)
+			}
+			return nil
+		default:
+			return ErrProtocol
+		}
 	}
-	return nil
 }
 
 // Serve starts the stream multiplexer for this connection, which is required
@@ -493,6 +547,12 @@ func (c *Conn) decodeFrame(f frame, trace Tracer) (rval interface{}, err error)
 		default:
 			return nil, ErrProtocol
 		}
+	case opAuthenticate:
+		return authenticateFrame{f.readString()}, nil
+	case opAuthChallenge:
+		return authChallengeFrame{f.readBytes()}, nil
+	case opAuthSuccess:
+		return authSuccessFrame{f.readBytes()}, nil
 	case opSupported:
 		return supportedFrame{}, nil
 	case opError:

+ 25 - 0
frame.go

@@ -432,3 +432,28 @@ func (op *optionsFrame) encodeFrame(version uint8, f frame) (frame, error) {
 	f.setHeader(version, 0, 0, opOptions)
 	return f, nil
 }
+
+type authenticateFrame struct {
+	Authenticator string
+}
+
+type authResponseFrame struct {
+	Data []byte
+}
+
+func (op *authResponseFrame) encodeFrame(version uint8, f frame) (frame, error) {
+	if f == nil {
+		f = make(frame, headerSize, defaultFrameSize)
+	}
+	f.setHeader(version, 0, 0, opAuthResponse)
+	f.writeBytes(op.Data)
+	return f, nil
+}
+
+type authSuccessFrame struct {
+	Data []byte
+}
+
+type authChallengeFrame struct {
+	Data []byte
+}