Browse Source

Resolve first code review issues

- In routingKeyInfo(string), refactored branching of normal path so that it is mostly on the first indent.
- Added error return value to routing key functions as appropriate and documented instances where errors
are returned and when nil may be returned when no error occurs.
- Added testing of type information returned in routingKeyInfo structs
Justin Corpron 10 years ago
parent
commit
5abf8446a1
2 changed files with 144 additions and 66 deletions
  1. 66 10
      cassandra_test.go
  2. 78 56
      session.go

+ 66 - 10
cassandra_test.go

@@ -1712,11 +1712,14 @@ func TestRoutingKey(t *testing.T) {
 	if err := createTable(session, "CREATE TABLE test_single_routing_key (first_id int, second_id int, PRIMARY KEY (first_id, second_id))"); err != nil {
 		t.Fatalf("failed to create table with error '%v'", err)
 	}
-	if err := createTable(session, "CREATE TABLE test_composite_routing_key (first_id int, second_id int, PRIMARY KEY ((first_id,second_id)))"); err != nil {
+	if err := createTable(session, "CREATE TABLE test_composite_routing_key (first_id int, second_id int, PRIMARY KEY ((first_id, second_id)))"); err != nil {
 		t.Fatalf("failed to create table with error '%v'", err)
 	}
 
-	routingKeyInfo := session.routingKeyInfo("SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?")
+	routingKeyInfo, err := session.routingKeyInfo("SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?")
+	if err != nil {
+		t.Fatalf("failed to get routing key info due to error: %v", err)
+	}
 	if routingKeyInfo == nil {
 		t.Fatal("Expected routing key info, but was nil")
 	}
@@ -1726,21 +1729,55 @@ func TestRoutingKey(t *testing.T) {
 	if routingKeyInfo.indexes[0] != 1 {
 		t.Errorf("Expected routing key index[0] to be 1 but was %d", routingKeyInfo.indexes[0])
 	}
-	query := session.Query("SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", 1, 2)
-	routingKey := query.GetRoutingKey()
-	expectedRoutingKey := []byte{0, 0, 0, 2}
-	if !reflect.DeepEqual(expectedRoutingKey, routingKey) {
-		t.Errorf("Expected routing key %v but was %v", expectedRoutingKey, routingKey)
+	if len(routingKeyInfo.types) != 1 {
+		t.Fatalf("Expected routing key types length to be 1 but was %d", len(routingKeyInfo.types))
+	}
+	if routingKeyInfo.types[0] == nil {
+		t.Fatal("Expected routing key types[0] to be non-nil")
+	}
+	if routingKeyInfo.types[0].Type != TypeInt {
+		t.Fatalf("Expected routing key types[0].Type to be %v but was %v", TypeInt, routingKeyInfo.types[0])
 	}
 
 	// verify the cache is working
-	session.routingKeyInfo("SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?")
+	routingKeyInfo, err = session.routingKeyInfo("SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?")
+	if err != nil {
+		t.Fatalf("failed to get routing key info due to error: %v", err)
+	}
+	if len(routingKeyInfo.indexes) != 1 {
+		t.Fatalf("Expected routing key indexes length to be 1 but was %d", len(routingKeyInfo.indexes))
+	}
+	if routingKeyInfo.indexes[0] != 1 {
+		t.Errorf("Expected routing key index[0] to be 1 but was %d", routingKeyInfo.indexes[0])
+	}
+	if len(routingKeyInfo.types) != 1 {
+		t.Fatalf("Expected routing key types length to be 1 but was %d", len(routingKeyInfo.types))
+	}
+	if routingKeyInfo.types[0] == nil {
+		t.Fatal("Expected routing key types[0] to be non-nil")
+	}
+	if routingKeyInfo.types[0].Type != TypeInt {
+		t.Fatalf("Expected routing key types[0] to be %v but was %v", TypeInt, routingKeyInfo.types[0])
+	}
 	cacheSize := session.routingKeyInfoCache.lru.Len()
 	if cacheSize != 1 {
 		t.Errorf("Expected cache size to be 1 but was %d", cacheSize)
 	}
 
-	routingKeyInfo = session.routingKeyInfo("SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?")
+	query := session.Query("SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", 1, 2)
+	routingKey, err := query.GetRoutingKey()
+	if err != nil {
+		t.Fatalf("Failed to get routing key due to error: %v", err)
+	}
+	expectedRoutingKey := []byte{0, 0, 0, 2}
+	if !reflect.DeepEqual(expectedRoutingKey, routingKey) {
+		t.Errorf("Expected routing key %v but was %v", expectedRoutingKey, routingKey)
+	}
+
+	routingKeyInfo, err = session.routingKeyInfo("SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?")
+	if err != nil {
+		t.Fatalf("failed to get routing key info due to error: %v", err)
+	}
 	if routingKeyInfo == nil {
 		t.Fatal("Expected routing key info, but was nil")
 	}
@@ -1753,8 +1790,27 @@ func TestRoutingKey(t *testing.T) {
 	if routingKeyInfo.indexes[1] != 0 {
 		t.Errorf("Expected routing key index[1] to be 0 but was %d", routingKeyInfo.indexes[1])
 	}
+	if len(routingKeyInfo.types) != 2 {
+		t.Fatalf("Expected routing key types length to be 1 but was %d", len(routingKeyInfo.types))
+	}
+	if routingKeyInfo.types[0] == nil {
+		t.Fatal("Expected routing key types[0] to be non-nil")
+	}
+	if routingKeyInfo.types[0].Type != TypeInt {
+		t.Fatalf("Expected routing key types[0] to be %v but was %v", TypeInt, routingKeyInfo.types[0])
+	}
+	if routingKeyInfo.types[1] == nil {
+		t.Fatal("Expected routing key types[1] to be non-nil")
+	}
+	if routingKeyInfo.types[1].Type != TypeInt {
+		t.Fatalf("Expected routing key types[0] to be %v but was %v", TypeInt, routingKeyInfo.types[1])
+	}
+
 	query = session.Query("SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?", 1, 2)
-	routingKey = query.GetRoutingKey()
+	routingKey, err = query.GetRoutingKey()
+	if err != nil {
+		t.Fatalf("Failed to get routing key due to error: %v", err)
+	}
 	expectedRoutingKey = []byte{0, 4, 0, 0, 0, 2, 0, 0, 4, 0, 0, 0, 1, 0}
 	if !reflect.DeepEqual(expectedRoutingKey, routingKey) {
 		t.Errorf("Expected routing key %v but was %v", expectedRoutingKey, routingKey)

+ 78 - 56
session.go

@@ -196,9 +196,10 @@ func (s *Session) KeyspaceMetadata(keyspace string) (*KeyspaceMetadata, error) {
 }
 
 // returns routing key indexes and type info
-func (s *Session) routingKeyInfo(stmt string) *routingKeyInfo {
+func (s *Session) routingKeyInfo(stmt string) (*routingKeyInfo, error) {
 	s.routingKeyInfoCache.mu.Lock()
 	cacheKey := s.cfg.Keyspace + stmt
+
 	entry, cached := s.routingKeyInfoCache.lru.Get(cacheKey)
 	if cached {
 		// done accessing the cache
@@ -206,15 +207,15 @@ func (s *Session) routingKeyInfo(stmt string) *routingKeyInfo {
 		// the entry is an inflight struct similiar to that used by
 		// Conn to prepare statements
 		inflight := entry.(*inflightCachedEntry)
+
 		// wait for any inflight work
 		inflight.wg.Wait()
 
 		if inflight.err != nil {
-			// return nil for any error
-			return nil
+			return nil, inflight.err
 		}
 
-		return inflight.value.(*routingKeyInfo)
+		return inflight.value.(*routingKeyInfo), nil
 	}
 
 	// create a new inflight entry while the data is created
@@ -229,68 +230,78 @@ func (s *Session) routingKeyInfo(stmt string) *routingKeyInfo {
 
 	// get the query info for the statement
 	conn := s.Pool.Pick(nil)
-	if conn != nil {
-		queryInfo, inflight.err = conn.prepareStatement(stmt, s.trace)
-		if inflight.err == nil {
-			if len(queryInfo.Args) == 0 {
-				// no arguments, no routing key, and no error
-				return nil
-			}
-
-			// get the table metadata
-			table := queryInfo.Args[0].Table
-			var keyspaceMetadata *KeyspaceMetadata
-			keyspaceMetadata, inflight.err = s.KeyspaceMetadata(s.cfg.Keyspace)
-			if inflight.err == nil {
-				tableMetadata, found := keyspaceMetadata.Tables[table]
-				if !found {
-					inflight.err = ErrNoMetadata
-				}
-
-				partitionKey = tableMetadata.PartitionKey
-			}
-		}
-	} else {
+	if conn == nil {
 		// no connections
 		inflight.err = ErrNoConnections
+		// don't cache this error
+		s.routingKeyInfoCache.Remove(cacheKey)
+		return nil, inflight.err
 	}
 
+	queryInfo, inflight.err = conn.prepareStatement(stmt, nil)
 	if inflight.err != nil {
-		// remove from the cache
-		s.routingKeyInfoCache.mu.Lock()
-		s.routingKeyInfoCache.lru.Remove(cacheKey)
-		s.routingKeyInfoCache.mu.Unlock()
-		return nil
+		// don't cache this error
+		s.routingKeyInfoCache.Remove(cacheKey)
+		return nil, inflight.err
+	}
+	if len(queryInfo.Args) == 0 {
+		// no arguments, no routing key, and no error
+		return nil, nil
 	}
 
+	// get the table metadata
+	table := queryInfo.Args[0].Table
+	var keyspaceMetadata *KeyspaceMetadata
+	keyspaceMetadata, inflight.err = s.KeyspaceMetadata(s.cfg.Keyspace)
+	if inflight.err != nil {
+		// don't cache this error
+		s.routingKeyInfoCache.Remove(cacheKey)
+		return nil, inflight.err
+	}
+
+	tableMetadata, found := keyspaceMetadata.Tables[table]
+	if !found {
+		// unlikely that the statement could be prepared and the metadata for
+		// the table couldn't be found, but this may indicate either a bug
+		// in the metadata code, or that the table was just dropped.
+		inflight.err = ErrNoMetadata
+		// don't cache this error
+		s.routingKeyInfoCache.Remove(cacheKey)
+		return nil, inflight.err
+	}
+
+	partitionKey = tableMetadata.PartitionKey
+
 	size := len(partitionKey)
 	routingKeyInfo := &routingKeyInfo{
 		indexes: make([]int, size),
 		types:   make([]*TypeInfo, size),
 	}
-	for i, keyColumn := range partitionKey {
-		routingKeyInfo.indexes[i] = -1
+	for keyIndex, keyColumn := range partitionKey {
+		// set an indicator for checking if the mapping is missing
+		routingKeyInfo.indexes[keyIndex] = -1
+
 		// find the column in the query info
-		for j, boundColumn := range queryInfo.Args {
+		for argIndex, boundColumn := range queryInfo.Args {
 			if keyColumn.Name == boundColumn.Name {
-				// there may be many such columns, pick the first
-				routingKeyInfo.indexes[i] = j
-				routingKeyInfo.types[i] = boundColumn.TypeInfo
+				// there may be many such bound columns, pick the first
+				routingKeyInfo.indexes[keyIndex] = argIndex
+				routingKeyInfo.types[keyIndex] = boundColumn.TypeInfo
 				break
 			}
 		}
 
-		if routingKeyInfo.indexes[i] == -1 {
+		if routingKeyInfo.indexes[keyIndex] == -1 {
 			// missing a routing key column mapping
-			// no error, but cache a nil result
-			return nil
+			// no routing key, and no error
+			return nil, nil
 		}
 	}
 
 	// cache this result
 	inflight.value = routingKeyInfo
 
-	return routingKeyInfo
+	return routingKeyInfo, nil
 }
 
 // ExecuteBatch executes a batch operation and returns nil if successful
@@ -406,16 +417,21 @@ func (q *Query) RoutingKey(routingKey []byte) *Query {
 // GetRoutingKey gets the routing key to use for routing this query. If
 // a routing key has not been explicitly set, then the routing key will
 // be constructed if possible using the keyspace's schema and the query
-// info for this query statement.
-func (q *Query) GetRoutingKey() []byte {
+// info for this query statement. If the routing key cannot be determined
+// then nil will be returned with no error. On any error condition,
+// an error description will be returned.
+func (q *Query) GetRoutingKey() ([]byte, error) {
 	if q.routingKey != nil {
-		return q.routingKey
+		return q.routingKey, nil
 	}
 
 	// try to determine the routing key
-	routingKeyInfo := q.session.routingKeyInfo(q.stmt)
+	routingKeyInfo, err := q.session.routingKeyInfo(q.stmt)
+	if err != nil {
+		return nil, err
+	}
 	if routingKeyInfo == nil {
-		return nil
+		return nil, nil
 	}
 
 	if len(routingKeyInfo.indexes) == 1 {
@@ -425,9 +441,9 @@ func (q *Query) GetRoutingKey() []byte {
 			q.values[routingKeyInfo.indexes[0]],
 		)
 		if err != nil {
-			return nil
+			return nil, err
 		}
-		return routingKey
+		return routingKey, nil
 	}
 
 	// composite routing key
@@ -438,14 +454,14 @@ func (q *Query) GetRoutingKey() []byte {
 			q.values[routingKeyInfo.indexes[i]],
 		)
 		if err != nil {
-			return nil
+			return nil, err
 		}
 		binary.Write(buf, binary.BigEndian, int16(len(encoded)))
 		buf.Write(encoded)
 		buf.WriteByte(0x00)
 	}
 	routingKey := buf.Bytes()
-	return routingKey
+	return routingKey, nil
 }
 
 func (q *Query) shouldPrepare() bool {
@@ -782,15 +798,21 @@ type routingKeyInfo struct {
 	types   []*TypeInfo
 }
 
+func (r *routingKeyInfoLRU) Remove(key string) {
+	r.mu.Lock()
+	r.lru.Remove(key)
+	r.mu.Unlock()
+}
+
 //Max adjusts the maximum size of the cache and cleans up the oldest records if
 //the new max is lower than the previous value. Not concurrency safe.
-func (q *routingKeyInfoLRU) Max(max int) {
-	q.mu.Lock()
-	for q.lru.Len() > max {
-		q.lru.RemoveOldest()
+func (r *routingKeyInfoLRU) Max(max int) {
+	r.mu.Lock()
+	for r.lru.Len() > max {
+		r.lru.RemoveOldest()
 	}
-	q.lru.MaxEntries = max
-	q.mu.Unlock()
+	r.lru.MaxEntries = max
+	r.mu.Unlock()
 }
 
 type inflightCachedEntry struct {