Selaa lähdekoodia

support tracing

Chris Bannister 10 vuotta sitten
vanhempi
commit
050b936a96
2 muutettua tiedostoa jossa 43 lisäystä ja 10 poistoa
  1. 21 7
      conn.go
  2. 22 3
      frame.go

+ 21 - 7
conn.go

@@ -186,7 +186,7 @@ func (c *Conn) startup(cfg *ConnConfig) error {
 		m["COMPRESSION"] = c.compressor.Name()
 	}
 
-	frame, err := c.exec(&writeStartupFrame{opts: m})
+	frame, err := c.exec(&writeStartupFrame{opts: m}, nil)
 	if err != nil {
 		return err
 	}
@@ -216,7 +216,7 @@ func (c *Conn) authenticateHandshake(authFrame *authenticateFrame) error {
 	req := &writeAuthResponseFrame{data: resp}
 
 	for {
-		frame, err := c.exec(req)
+		frame, err := c.exec(req, nil)
 		if err != nil {
 			return err
 		}
@@ -242,7 +242,8 @@ func (c *Conn) authenticateHandshake(authFrame *authenticateFrame) error {
 	}
 }
 
-func (c *Conn) exec(req frameWriter) (frame, error) {
+func (c *Conn) exec(req frameWriter, tracer Tracer) (frame, error) {
+	// TODO: move tracer onto conn
 	stream := <-c.uniq
 
 	call := &c.calls[stream]
@@ -253,6 +254,14 @@ func (c *Conn) exec(req frameWriter) (frame, error) {
 
 	// log.Printf("%v: OUT stream=%d (%T) req=%v\n", c.conn.LocalAddr(), stream, req, req)
 	framer := newFramer(c, c, c.compressor, c.version)
+
+	// unfortunatly this part of the protocol leaks in conn, somehow move this
+	// out into framer. One way to do it would be to use the same framer to send
+	// and recv for a single stream.
+	if tracer != nil {
+		framer.flags |= flagTracing
+	}
+
 	err := req.writeFrame(framer, stream)
 	framerPool.Put(framer)
 
@@ -271,6 +280,10 @@ func (c *Conn) exec(req frameWriter) (frame, error) {
 	if err != nil {
 		return nil, err
 	}
+
+	if len(framer.traceID) > 0 {
+		tracer.Trace(framer.traceID)
+	}
 	// log.Printf("%v: IN stream=%d (%T) resp=%v\n", c.conn.LocalAddr(), stream, frame, frame)
 
 	return frame, nil
@@ -377,7 +390,7 @@ func (c *Conn) prepareStatement(stmt string, trace Tracer) (*resultPreparedFrame
 		statement: stmt,
 	}
 
-	resp, err := c.exec(prep)
+	resp, err := c.exec(prep, trace)
 	if err != nil {
 		flight.err = err
 		flight.wg.Done()
@@ -469,7 +482,7 @@ func (c *Conn) executeQuery(qry *Query) *Iter {
 		}
 	}
 
-	resp, err := c.exec(frame)
+	resp, err := c.exec(frame, qry.trace)
 	if err != nil {
 		return &Iter{err: err}
 	}
@@ -557,7 +570,7 @@ func (c *Conn) UseKeyspace(keyspace string) error {
 	q := &writeQueryFrame{statement: `USE "` + keyspace + `"`}
 	q.params.consistency = Any
 
-	resp, err := c.exec(q)
+	resp, err := c.exec(q, nil)
 	if err != nil {
 		return err
 	}
@@ -636,7 +649,8 @@ func (c *Conn) executeBatch(batch *Batch) error {
 		}
 	}
 
-	resp, err := c.exec(req)
+	// TODO: should batch support tracing?
+	resp, err := c.exec(req, nil)
 	if err != nil {
 		return err
 	}

+ 22 - 3
frame.go

@@ -240,6 +240,9 @@ type framer struct {
 	// if this frame was read then the header will be here
 	header *frameHeader
 
+	// if tracing flag is set this is not nil
+	traceID []byte
+
 	buf []byte
 }
 
@@ -263,6 +266,7 @@ func newFramer(r io.Reader, w io.Writer, compressor Compressor, version byte) *f
 	f.headSize = headSize
 	f.buf = f.buf[:0]
 	f.header = nil
+	f.traceID = nil
 
 	return f
 }
@@ -332,6 +336,14 @@ func (f *framer) readFrame(head *frameHeader) error {
 }
 
 func (f *framer) parseFrame() (frame, error) {
+	if f.header.version.request() {
+		return frameHeader{}, NewErrProtocol("got a request frame from server: %v", f.header.version)
+	}
+
+	if f.header.flags&flagTracing == flagTracing {
+		f.readTrace()
+	}
+
 	// asumes that the frame body has been read into buf
 	switch f.header.op {
 	case opError:
@@ -476,6 +488,10 @@ func (f *framer) finishWrite() error {
 	return nil
 }
 
+func (f *framer) readTrace() {
+	f.traceID = f.readUUID().Bytes()
+}
+
 type readyFrame struct {
 	frameHeader
 }
@@ -1083,15 +1099,18 @@ func (f *framer) readString() (s string) {
 	return
 }
 
-func (f *framer) longString() (s string) {
+func (f *framer) readLongString() (s string) {
 	size := f.readInt()
 	s = string(f.buf[:size])
 	f.buf = f.buf[size:]
 	return
 }
 
-func (f *framer) readUUID() (u *UUID) {
-	return
+func (f *framer) readUUID() *UUID {
+	// TODO: how to handle this error, if it is a uuid, then sureley, problems?
+	u, _ := UUIDFromBytes(f.buf[:16])
+	f.buf = f.buf[16:]
+	return &u
 }
 
 func (f *framer) readStringList() []string {