Selaa lähdekoodia

update conn_test to support framer

Chris Bannister 10 vuotta sitten
vanhempi
commit
4e87ecfa05
2 muutettua tiedostoa jossa 63 lisäystä ja 62 poistoa
  1. 55 53
      conn_test.go
  2. 8 9
      frame.go

+ 55 - 53
conn_test.go

@@ -438,13 +438,14 @@ func NewSSLTestServer(t testing.TB, protocol uint8) *TestServer {
 }
 
 type TestServer struct {
-	Address  string
-	t        testing.TB
-	nreq     uint64
-	listen   net.Listener
-	nKillReq int64
-
-	protocol   uint8
+	Address    string
+	t          testing.TB
+	nreq       uint64
+	listen     net.Listener
+	nKillReq   int64
+	compressor Compressor
+
+	protocol   byte
 	headerSize int
 }
 
@@ -458,16 +459,19 @@ func (srv *TestServer) serve() {
 		go func(conn net.Conn) {
 			defer conn.Close()
 			for {
-				frame, err := srv.readFrame(conn)
-				if err == io.EOF {
-					return
-				} else if err != nil {
+				framer, err := srv.readFrame(conn)
+				if err != nil {
+					if err == io.EOF {
+						return
+					}
+
 					srv.t.Error(err)
-					continue
+					return
 				}
 
 				atomic.AddUint64(&srv.nreq, 1)
-				go srv.process(frame, conn)
+
+				go srv.process(framer)
 			}
 		}(conn)
 	}
@@ -477,24 +481,21 @@ func (srv *TestServer) Stop() {
 	srv.listen.Close()
 }
 
-func (srv *TestServer) process(f frame, conn net.Conn) {
-	headerSize := headerProtoSize[srv.protocol]
-	stream := f.Stream(srv.protocol)
+func (srv *TestServer) process(f *framer) {
+	head := f.header
+	if head == nil {
+		srv.t.Error("process frame with a nil header")
+		return
+	}
 
-	switch f.Op(srv.protocol) {
+	switch head.op {
 	case opStartup:
-		f = f[:headerSize]
-		f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opReady)
+		f.writeHeader(0, opReady, head.stream)
 	case opOptions:
-		f = f[:headerSize]
-		f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opSupported)
+		f.writeHeader(0, opSupported, head.stream)
 		f.writeShort(0)
 	case opQuery:
-		input := f
-		input.skipHeader(srv.protocol)
-		query := strings.TrimSpace(input.readLongString())
-		f = f[:headerSize]
-		f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opResult)
+		query := f.readLongString()
 		first := query
 		if n := strings.Index(query, " "); n > 0 {
 			first = first[:n]
@@ -502,62 +503,63 @@ func (srv *TestServer) process(f frame, conn net.Conn) {
 		switch strings.ToLower(first) {
 		case "kill":
 			atomic.AddInt64(&srv.nKillReq, 1)
-			f = f[:headerSize]
-			f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opError)
+			f.writeHeader(0, opError, head.stream)
 			f.writeInt(0x1001)
 			f.writeString("query killed")
 		case "slow":
 			go func() {
 				<-time.After(1 * time.Second)
+				f.writeHeader(0, opResult, head.stream)
+				f.buf[0] = srv.protocol | 0x80
 				f.writeInt(resultKindVoid)
-				f.setLength(len(f)-headerSize, srv.protocol)
-				if _, err := conn.Write(f); err != nil {
-					return
+				if err := f.finishWrite(); err != nil {
+					srv.t.Error(err)
 				}
 			}()
+
 			return
 		case "use":
-			f.writeInt(3)
+			f.writeInt(resultKindKeyspace)
 			f.writeString(strings.TrimSpace(query[3:]))
 		case "void":
+			f.writeHeader(0, opResult, head.stream)
 			f.writeInt(resultKindVoid)
 		default:
+			f.writeHeader(0, opResult, head.stream)
 			f.writeInt(resultKindVoid)
 		}
 	default:
-		f = f[:headerSize]
-		f.setHeader(protoDirectionMask|srv.protocol, 0, stream, opError)
+		f.writeHeader(0, opError, head.stream)
 		f.writeInt(0)
 		f.writeString("not supported")
 	}
 
-	f.setLength(len(f)-headerSize, srv.protocol)
-	if _, err := conn.Write(f); err != nil {
-		srv.t.Log(err)
-		return
+	f.buf[0] = srv.protocol | 0x80
+
+	if err := f.finishWrite(); err != nil {
+		srv.t.Error(err)
 	}
 }
 
-func (srv *TestServer) readFrame(conn net.Conn) (frame, error) {
-	frame := make(frame, srv.headerSize, srv.headerSize+512)
-	if _, err := io.ReadFull(conn, frame); err != nil {
+func (srv *TestServer) readFrame(conn net.Conn) (*framer, error) {
+	buf := make([]byte, srv.headerSize)
+	head, err := readHeader(conn, buf)
+	if err != nil {
 		return nil, err
 	}
+	framer := newFramer(conn, conn, nil, srv.protocol)
 
-	// should be a request frame
-	if frame[0]&protoDirectionMask != 0 {
-		return nil, fmt.Errorf("expected to read a request frame got version: 0x%x", frame[0])
-	}
-	if v := frame[0] & protoVersionMask; v != srv.protocol {
-		return nil, fmt.Errorf("expected to read protocol version 0x%x got 0x%x", srv.protocol, v)
+	err = framer.readFrame(&head)
+	if err != nil {
+		return nil, err
 	}
 
-	if n := frame.Length(srv.protocol); n > 0 {
-		frame.grow(n)
-		if _, err := io.ReadFull(conn, frame[srv.headerSize:]); err != nil {
-			return nil, err
-		}
+	// should be a request frame
+	if head.version.response() {
+		return nil, fmt.Errorf("expected to read a request frame got version: %v", head.version)
+	} else if head.version.version() != srv.protocol {
+		return nil, fmt.Errorf("expected to read protocol version 0x%x got 0x%x", srv.protocol, head.version.version())
 	}
 
-	return frame, nil
+	return framer, nil
 }

+ 8 - 9
frame.go

@@ -285,11 +285,7 @@ func readHeader(r io.Reader, p []byte) (frameHeader, error) {
 
 	head := frameHeader{}
 	version := p[0] & protoVersionMask
-	direction := p[0] & protoDirectionMask
 	head.version = protoVersion(p[0])
-	if direction == protoVersionRequest {
-		return frameHeader{}, NewErrProtocol("got a request frame from server: %v", head.version)
-	}
 
 	head.flags = p[1]
 	if version > protoVersion2 {
@@ -319,9 +315,6 @@ func (f *framer) readFrame(head *frameHeader) error {
 		return err
 	}
 
-	// TODO: move frame processing out of framer and onto the requesting callers
-	// this means that we will not be able to reuse buffers between streams, which
-	// may end up being slower than parsing on the IO thread.
 	if head.flags&flagCompress == flagCompress {
 		if f.compres == nil {
 			return NewErrProtocol("no compressor available with compressed frame body")
@@ -427,7 +420,8 @@ func (f *framer) parseErrorFrame() (frame, error) {
 }
 
 func (f *framer) writeHeader(flags byte, op frameOp, stream int) {
-	f.buf = append(f.buf[0:],
+	f.buf = f.buf[:0]
+	f.buf = append(f.buf,
 		f.proto,
 		flags,
 	)
@@ -467,7 +461,12 @@ func (f *framer) setLength(length int) {
 
 func (f *framer) finishWrite() error {
 	length := len(f.buf) - f.headSize
-	if f.flags&flagCompress == flagCompress && f.compres != nil {
+	if f.buf[1]&flagCompress == flagCompress {
+		if f.compres == nil {
+			panic("compress flag set with no compressor")
+		}
+
+		// TODO: only compress frames which are big enough
 		compressed, err := f.compres.Encode(f.buf[f.headSize:])
 		if err != nil {
 			return err