浏览代码

test: move Stream0 test to conn_test and mock

mock out the stream0 test
Chris Bannister 10 年之前
父节点
当前提交
fe0ee93162
共有 4 个文件被更改,包括 61 次插入46 次删除
  1. 0 45
      cassandra_test.go
  2. 12 1
      conn.go
  3. 43 0
      conn_test.go
  4. 6 0
      frame.go

+ 0 - 45
cassandra_test.go

@@ -1906,51 +1906,6 @@ func TestTokenAwareConnPool(t *testing.T) {
 	// TODO add verification that the query went to the correct host
 }
 
-type frameWriterFunc func(framer *framer, streamID int) error
-
-func (f frameWriterFunc) writeFrame(framer *framer, streamID int) error {
-	return f(framer, streamID)
-}
-
-func TestStream0(t *testing.T) {
-	session := createSession(t)
-	defer session.Close()
-
-	var conn *Conn
-	for i := 0; i < 5; i++ {
-		if conn != nil {
-			break
-		}
-
-		_, conn = session.pool.Pick(nil)
-	}
-
-	if conn == nil {
-		t.Fatal("no connections available in the pool")
-	}
-
-	writer := frameWriterFunc(func(f *framer, streamID int) error {
-		if streamID == 0 {
-			t.Fatal("should not use stream 0 for requests")
-		}
-		f.writeHeader(0, opError, streamID)
-		f.writeString("i am a bad frame")
-		f.wbuf[0] = 0xFF
-		return f.finishWrite()
-	})
-
-	const expErr = "gocql: error on stream 0:"
-	// need to write out an invalid frame, which we need a connection to do
-	frame, err := conn.exec(writer, nil)
-	if err == nil {
-		t.Fatal("expected to get an error on stream 0")
-	} else if !strings.HasPrefix(err.Error(), expErr) {
-		t.Fatalf("expected to get error prefix %q got %q", expErr, err.Error())
-	} else if frame != nil {
-		t.Fatalf("expected to get nil frame got %+v", frame)
-	}
-}
-
 func TestNegativeStream(t *testing.T) {
 	session := createSession(t)
 	defer session.Close()

+ 12 - 1
conn.go

@@ -103,6 +103,12 @@ type ConnErrorHandler interface {
 	HandleError(conn *Conn, err error, closed bool)
 }
 
+type connErrorHandlerFn func(conn *Conn, err error, closed bool)
+
+func (fn connErrorHandlerFn) HandleError(conn *Conn, err error, closed bool) {
+	fn(conn, err, closed)
+}
+
 // How many timeouts we will allow to occur before the connection is closed
 // and restarted. This is to prevent a single query timeout from killing a connection
 // which may be serving more queries just fine.
@@ -533,6 +539,11 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (*framer, error) {
 		return nil, err
 	}
 
+	var timeoutCh <-chan time.Time
+	if c.timeout > 0 {
+		timeoutCh = time.After(c.timeout)
+	}
+
 	select {
 	case err := <-call.resp:
 		if err != nil {
@@ -545,7 +556,7 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (*framer, error) {
 			}
 			return nil, err
 		}
-	case <-time.After(c.timeout):
+	case <-timeoutCh:
 		close(call.timeout)
 		c.handleTimeout()
 		return nil, ErrTimeoutNoResponse

+ 43 - 0
conn_test.go

@@ -404,6 +404,46 @@ func TestQueryTimeoutClose(t *testing.T) {
 	}
 }
 
+func TestStream0(t *testing.T) {
+	const expErr = "gocql: error on stream 0:"
+
+	srv := NewTestServer(t, defaultProto)
+	defer srv.Stop()
+
+	errorHandler := connErrorHandlerFn(func(conn *Conn, err error, closed bool) {
+		if !strings.HasPrefix(err.Error(), expErr) {
+			t.Errorf("expected to get error prefix %q got %q", expErr, err.Error())
+		}
+	})
+
+	conn, err := Connect(srv.Address, &ConnConfig{ProtoVersion: int(srv.protocol)}, errorHandler, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	writer := frameWriterFunc(func(f *framer, streamID int) error {
+		f.writeHeader(0, opError, 0)
+		f.writeInt(0)
+		f.writeString("i am a bad frame")
+		// f.wbuf[0] = 2
+		return f.finishWrite()
+	})
+
+	// need to write out an invalid frame, which we need a connection to do
+	framer, err := conn.exec(writer, nil)
+	if err == nil {
+		t.Fatal("expected to get an error on stream 0")
+	} else if !strings.HasPrefix(err.Error(), expErr) {
+		t.Fatalf("expected to get error prefix %q got %q", expErr, err.Error())
+	} else if framer != nil {
+		frame, err := framer.parseFrame()
+		if err != nil {
+			t.Fatal(err)
+		}
+		t.Fatalf("got frame %v", frame)
+	}
+}
+
 func NewTestServer(t testing.TB, protocol uint8) *TestServer {
 	laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
 	if err != nil {
@@ -567,6 +607,9 @@ func (srv *TestServer) process(f *framer) {
 			f.writeHeader(0, opResult, head.stream)
 			f.writeInt(resultKindVoid)
 		}
+	case opError:
+		f.writeHeader(0, opError, head.stream)
+		f.wbuf = append(f.wbuf, f.rbuf...)
 	default:
 		f.writeHeader(0, opError, head.stream)
 		f.writeInt(0)

+ 6 - 0
frame.go

@@ -1341,6 +1341,12 @@ type frameWriter interface {
 	writeFrame(framer *framer, streamID int) error
 }
 
+type frameWriterFunc func(framer *framer, streamID int) error
+
+func (f frameWriterFunc) writeFrame(framer *framer, streamID int) error {
+	return f(framer, streamID)
+}
+
 type writeExecuteFrame struct {
 	preparedID []byte
 	params     queryParams