Browse Source

only use a single framer for reading and writing on a stream

Chris Bannister 10 years ago
parent
commit
6352de5a0f
1 changed files with 93 additions and 81 deletions
  1. 93 81
      conn.go

+ 93 - 81
conn.go

@@ -177,6 +177,15 @@ func Connect(addr string, cfg ConnConfig, pool ConnectionPool) (*Conn, error) {
 	return c, nil
 	return c, nil
 }
 }
 
 
+func (c *Conn) Write(p []byte) (int, error) {
+	c.conn.SetWriteDeadline(time.Now().Add(c.timeout))
+	return c.conn.Write(p)
+}
+
+func (c *Conn) Read(p []byte) (int, error) {
+	return c.r.Read(p)
+}
+
 func (c *Conn) startup(cfg *ConnConfig) error {
 func (c *Conn) startup(cfg *ConnConfig) error {
 	m := map[string]string{
 	m := map[string]string{
 		"CQL_VERSION": cfg.CQLVersion,
 		"CQL_VERSION": cfg.CQLVersion,
@@ -242,74 +251,28 @@ func (c *Conn) authenticateHandshake(authFrame *authenticateFrame) error {
 	}
 	}
 }
 }
 
 
-func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
-	// TODO: move tracer onto conn
-	stream := <-c.uniq
-
-	call := &c.calls[stream]
-	atomic.StoreInt32(&call.active, 1)
-	defer atomic.StoreInt32(&call.active, 0)
-
-	call.resp = make(chan callResp, 1)
-
-	// log.Printf("%v: OUT stream=%d (%T) req=%v\n", c.conn.LocalAddr(), stream, req, req)
-	framer := newFramer(c, c, c.compressor, c.version)
-
-	// unfortunatly this part of the protocol leaks in conn, somehow move this
-	// out into framer. One way to do it would be to use the same framer to send
-	// and recv for a single stream.
-	if tracer != nil {
-		framer.flags |= flagTracing
-	}
-
-	err := req.writeFrame(framer, stream)
-	framerPool.Put(framer)
-
-	if err != nil {
-		return nil, err
-	}
-
-	resp := <-call.resp
-	if resp.err != nil {
-		return nil, resp.err
-	}
-	defer framerPool.Put(resp.framer)
-
-	frame, err := resp.framer.parseFrame()
-	if err != nil {
-		return nil, err
-	}
-
-	if len(framer.traceID) > 0 {
-		tracer.Trace(framer.traceID)
-	}
-	// log.Printf("%v: IN stream=%d (%T) resp=%v\n", c.conn.LocalAddr(), stream, frame, frame)
-
-	return frame, nil
-}
-
 // Serve starts the stream multiplexer for this connection, which is required
 // Serve starts the stream multiplexer for this connection, which is required
 // to execute any queries. This method runs as long as the connection is
 // to execute any queries. This method runs as long as the connection is
 // open and is therefore usually called in a separate goroutine.
 // open and is therefore usually called in a separate goroutine.
 func (c *Conn) serve() {
 func (c *Conn) serve() {
 	var (
 	var (
 		err    error
 		err    error
-		framer *framer
+		header frameHeader
 	)
 	)
 
 
 	for {
 	for {
-		framer, err = c.recv()
+		header, err = c.recv()
 		if err != nil {
 		if err != nil {
 			break
 			break
 		}
 		}
-		c.dispatch(framer)
+		c.dispatch(header)
 	}
 	}
 
 
 	c.Close()
 	c.Close()
 	for id := 0; id < len(c.calls); id++ {
 	for id := 0; id < len(c.calls); id++ {
 		req := &c.calls[id]
 		req := &c.calls[id]
 		if atomic.CompareAndSwapInt32(&req.active, 1, 0) {
 		if atomic.CompareAndSwapInt32(&req.active, 1, 0) {
-			req.resp <- callResp{nil, err}
+			req.resp <- struct{}{}
 			close(req.resp)
 			close(req.resp)
 		}
 		}
 	}
 	}
@@ -319,52 +282,111 @@ func (c *Conn) serve() {
 	}
 	}
 }
 }
 
 
-func (c *Conn) Write(p []byte) (int, error) {
-	c.conn.SetWriteDeadline(time.Now().Add(c.timeout))
-	return c.conn.Write(p)
-}
-
-func (c *Conn) Read(p []byte) (int, error) {
-	return c.r.Read(p)
-}
+func (c *Conn) recv() (frameHeader, error) {
+	// not safe for concurrent reads
 
 
-func (c *Conn) recv() (*framer, error) {
 	// read a full header, ignore timeouts, as this is being ran in a loop
 	// read a full header, ignore timeouts, as this is being ran in a loop
-	// TODO: TCP level deadlines? or just query level dealines?
+	// TODO: TCP level deadlines? or just query level deadlines?
 
 
 	// were just reading headers over and over and copy bodies
 	// were just reading headers over and over and copy bodies
 	head, err := readHeader(c.r, c.headerBuf)
 	head, err := readHeader(c.r, c.headerBuf)
 	if err != nil {
 	if err != nil {
-		return nil, err
+		return frameHeader{}, err
 	}
 	}
 
 
-	// log.Printf("header=%v\n", head)
-	if head.version.version() != c.version {
-		return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", head.version.version(), c.version)
+	call := &c.calls[head.stream]
+
+	call.mu.Lock()
+	log.Printf("readframe stream=%v\n", head.stream)
+	err = call.framer.readFrame(&head)
+	call.mu.Unlock()
+	if err != nil {
+		return frameHeader{}, err
 	}
 	}
 
 
-	framer := newFramer(c.r, c, c.compressor, c.version)
-	if err := framer.readFrame(&head); err != nil {
-		return nil, err
+	if head.version.version() != c.version {
+		return frameHeader{}, NewErrProtocol("unexpected protocol version in response: got %d expected %d", head.version.version(), c.version)
 	}
 	}
 
 
-	return framer, nil
+	return head, nil
 }
 }
 
 
-func (c *Conn) dispatch(f *framer) {
-	id := f.header.stream
+func (c *Conn) dispatch(header frameHeader) {
+	id := header.stream
 	if id >= len(c.calls) {
 	if id >= len(c.calls) {
+		// should this panic?
 		return
 		return
 	}
 	}
 
 
 	// TODO: replace this with a sparse map
 	// TODO: replace this with a sparse map
 	call := &c.calls[id]
 	call := &c.calls[id]
 
 
-	call.resp <- callResp{f, nil}
+	call.resp <- struct{}{}
 	atomic.AddInt32(&c.nwait, -1)
 	atomic.AddInt32(&c.nwait, -1)
+
 	c.uniq <- id
 	c.uniq <- id
 }
 }
 
 
+type callReq struct {
+	active int32
+	// could use a waitgroup but this allows us to do timeouts on the read/send
+	resp   chan struct{}
+	mu     sync.Mutex
+	framer *framer
+	err    error
+}
+
+type callResp struct {
+	framer *framer
+	err    error
+}
+
+func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
+	// TODO: move tracer onto conn
+	stream := <-c.uniq
+
+	call := &c.calls[stream]
+	atomic.StoreInt32(&call.active, 1)
+	defer atomic.StoreInt32(&call.active, 0)
+
+	// resp is basically a waiting semafore protecting the framer
+	call.resp = make(chan struct{}, 1)
+
+	// log.Printf("%v: OUT stream=%d (%T) req=%v\n", c.conn.LocalAddr(), stream, req, req)
+	framer := newFramer(c, c, c.compressor, c.version)
+	defer framerPool.Put(framer)
+	call.framer = framer
+
+	if tracer != nil {
+		framer.trace()
+	}
+
+	log.Printf("writing frame stream=%v\n", stream)
+	call.mu.Lock()
+	err := req.writeFrame(framer, stream)
+	call.mu.Unlock()
+	if err != nil {
+		return nil, err
+	}
+
+	<-call.resp
+	if call.err != nil {
+		return nil, call.err
+	}
+
+	frame, err := framer.parseFrame()
+	if err != nil {
+		return nil, err
+	}
+
+	if len(framer.traceID) > 0 {
+		tracer.Trace(framer.traceID)
+	}
+	// log.Printf("%v: IN stream=%d (%T) resp=%v\n", c.conn.LocalAddr(), stream, frame, frame)
+
+	return frame, nil
+}
+
 func (c *Conn) prepareStatement(stmt string, trace Tracer) (*resultPreparedFrame, error) {
 func (c *Conn) prepareStatement(stmt string, trace Tracer) (*resultPreparedFrame, error) {
 	stmtsLRU.Lock()
 	stmtsLRU.Lock()
 	if stmtsLRU.lru == nil {
 	if stmtsLRU.lru == nil {
@@ -688,16 +710,6 @@ func (c *Conn) setKeepalive(d time.Duration) error {
 	return nil
 	return nil
 }
 }
 
 
-type callReq struct {
-	active int32
-	resp   chan callResp
-}
-
-type callResp struct {
-	framer *framer
-	err    error
-}
-
 type inflightPrepare struct {
 type inflightPrepare struct {
 	info *resultPreparedFrame
 	info *resultPreparedFrame
 	err  error
 	err  error