Преглед изворни кода

Merge pull request #168 from ChannelMeter/reprepare_cass2_batchv2

Reprepare cassandra 2.x style batch if batch fails due to unprepared statement
Ben Hood пре 11 година
родитељ
комит
3de4b29e54
2 измењених фајлова са 81 додато и 11 уклоњено
  1. 57 0
      cassandra_test.go
  2. 24 11
      conn.go

+ 57 - 0
cassandra_test.go

@@ -572,3 +572,60 @@ func TestScanCASWithNilArguments(t *testing.T) {
 		t.Fatalf("expected %v but got %v", foo, cas)
 	}
 }
+
+func injectInvalidPreparedStatement(t *testing.T, session *Session, table string) (string, *Conn) {
+	if err := session.Query(`CREATE TABLE ` + table + ` (
+			foo   varchar,
+			bar   int,
+			PRIMARY KEY (foo, bar)
+	)`).Exec(); err != nil {
+		t.Fatal("create:", err)
+	}
+	stmt := "INSERT INTO " + table + " (foo, bar) VALUES (?, 7)"
+	conn := session.Pool.Pick(nil)
+	conn.prepMu.Lock()
+	flight := new(inflightPrepare)
+	conn.prep[stmt] = flight
+	flight.info = &queryInfo{
+		id: []byte{'f', 'o', 'o', 'b', 'a', 'r'},
+		args: []ColumnInfo{ColumnInfo{
+			Keyspace: "gocql_test",
+			Table:    table,
+			Name:     "foo",
+			TypeInfo: &TypeInfo{
+				Type: TypeVarchar,
+			},
+		}, ColumnInfo{
+			Keyspace: "gocql_test",
+			Table:    table,
+			Name:     "bar",
+			TypeInfo: &TypeInfo{
+				Type: TypeInt,
+			},
+		}},
+	}
+	conn.prepMu.Unlock()
+	return stmt, conn
+}
+
+func TestReprepareStatement(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+	stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement")
+	query := session.Query(stmt, "bar")
+	if err := conn.executeQuery(query).Close(); err != nil {
+		t.Fatalf("Failed to execute query for reprepare statement: %v", err)
+	}
+}
+
+func TestReprepareBatch(t *testing.T) {
+	session := createSession(t)
+	defer session.Close()
+	stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement_batch")
+	batch := session.NewBatch(UnloggedBatch)
+	batch.Query(stmt, "bar")
+	if err := conn.executeBatch(batch); err != nil {
+		t.Fatalf("Failed to execute query for reprepare statement: %v", err)
+	}
+
+}

+ 24 - 11
conn.go

@@ -6,6 +6,7 @@ package gocql
 
 import (
 	"bufio"
+	"bytes"
 	"fmt"
 	"net"
 	"sync"
@@ -398,19 +399,15 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		return iter
 	case resultKeyspaceFrame:
 		return &Iter{}
-	case errorFrame:
-		if x.Code() == errUnprepared && len(qry.values) > 0 {
-			c.prepMu.Lock()
-			if val, ok := c.prep[qry.stmt]; ok && val != nil {
-				delete(c.prep, qry.stmt)
-				c.prepMu.Unlock()
-				return c.executeQuery(qry)
-			}
+	case RequestErrUnprepared:
+		c.prepMu.Lock()
+		if val, ok := c.prep[qry.stmt]; ok && val != nil {
+			delete(c.prep, qry.stmt)
 			c.prepMu.Unlock()
-			return &Iter{err: x}
-		} else {
-			return &Iter{err: x}
+			return c.executeQuery(qry)
 		}
+		c.prepMu.Unlock()
+		return &Iter{err: x}
 	case error:
 		return &Iter{err: x}
 	default:
@@ -504,6 +501,22 @@ func (c *Conn) executeBatch(batch *Batch) error {
 	switch x := resp.(type) {
 	case resultVoidFrame:
 		return nil
+	case RequestErrUnprepared:
+		c.prepMu.Lock()
+		found := false
+		for stmt, flight := range c.prep {
+			if bytes.Equal(flight.info.id, x.StatementId) {
+				found = true
+				delete(c.prep, stmt)
+				break
+			}
+		}
+		c.prepMu.Unlock()
+		if found {
+			return c.executeBatch(batch)
+		} else {
+			return x
+		}
 	case error:
 		return x
 	default: