소스 검색

conn: replace streams chan with an ID generator

Replace the bounched chan uniq with a stream ID generator based
very closely on the StreamIDGenerator in the java-driver.
Chris Bannister 10 년 전
부모
커밋
753dd6d05f
3개의 변경된 파일359개의 추가작업 그리고 34개의 파일을 삭제
  1. 18 34
      conn.go
  2. 140 0
      internal/streams/streams.go
  3. 201 0
      internal/streams/streams_test.go

+ 18 - 34
conn.go

@@ -18,6 +18,8 @@ import (
 	"sync"
 	"sync/atomic"
 	"time"
+
+	"github.com/gocql/gocql/internal/streams"
 )
 
 var (
@@ -120,9 +122,9 @@ type Conn struct {
 
 	headerBuf []byte
 
-	uniq  chan int
-	mu    sync.RWMutex
-	calls map[int]*callReq
+	streams *streams.IDGenerator
+	mu      sync.RWMutex
+	calls   map[int]*callReq
 
 	errorHandler    ConnErrorHandler
 	compressor      Compressor
@@ -171,25 +173,14 @@ func Connect(addr string, cfg *ConnConfig, errorHandler ConnErrorHandler, sessio
 	}
 
 	headerSize := 8
-
-	maxStreams := 128
 	if cfg.ProtoVersion > protoVersion2 {
-		maxStreams = 32768
 		headerSize = 9
 	}
 
-	streams := cfg.NumStreams
-	if streams <= 0 || streams >= maxStreams {
-		streams = maxStreams
-	} else {
-		streams++
-	}
-
 	c := &Conn{
 		conn:         conn,
 		r:            bufio.NewReader(conn),
 		cfg:          cfg,
-		uniq:         make(chan int, streams),
 		calls:        make(map[int]*callReq),
 		timeout:      cfg.Timeout,
 		version:      uint8(cfg.ProtoVersion),
@@ -200,19 +191,13 @@ func Connect(addr string, cfg *ConnConfig, errorHandler ConnErrorHandler, sessio
 		headerBuf:    make([]byte, headerSize),
 		quit:         make(chan struct{}),
 		session:      session,
-		numStreams:   streams,
+		streams:      streams.New(cfg.ProtoVersion),
 	}
 
 	if cfg.Keepalive > 0 {
 		c.setKeepalive(cfg.Keepalive)
 	}
 
-	// reserve stream 0 incase cassandra returns an error on it without us sending
-	// a request.
-	for i := 1; i < streams; i++ {
-		c.uniq <- i
-	}
-
 	go c.serve()
 
 	if err := c.startup(); err != nil {
@@ -410,7 +395,7 @@ func (c *Conn) recv() error {
 		return err
 	}
 
-	if head.stream > c.numStreams {
+	if head.stream > c.streams.NumStreams {
 		return fmt.Errorf("gocql: frame header stream is beyond call exepected bounds: %d", head.stream)
 	} else if head.stream == -1 {
 		// TODO: handle cassandra event frames, we shouldnt get any currently
@@ -479,16 +464,14 @@ func (c *Conn) releaseStream(stream int) {
 	call := c.calls[stream]
 	if call != nil && stream != call.streamID {
 		panic(fmt.Sprintf("attempt to release streamID with ivalid stream: %d -> %+v\n", stream, call))
+	} else if call == nil {
+		panic(fmt.Sprintf("releasing a stream not in use: %d", stream))
 	}
 	delete(c.calls, stream)
 	c.mu.Unlock()
 
 	streamPool.Put(call)
-
-	select {
-	case c.uniq <- stream:
-	case <-c.quit:
-	}
+	c.streams.Clear(stream)
 }
 
 func (c *Conn) handleTimeout() {
@@ -509,11 +492,10 @@ var (
 
 func (c *Conn) exec(req frameWriter, tracer Tracer) (*framer, error) {
 	// TODO: move tracer onto conn
-	var stream int
-	select {
-	case stream = <-c.uniq:
-	case <-c.quit:
-		return nil, ErrConnectionClosed
+	stream, ok := c.streams.GetStream()
+	if !ok {
+		fmt.Println(c.streams)
+		return nil, ErrNoStreams
 	}
 
 	// resp is basically a waiting semaphore protecting the framer
@@ -522,7 +504,8 @@ func (c *Conn) exec(req frameWriter, tracer Tracer) (*framer, error) {
 	c.mu.Lock()
 	call := c.calls[stream]
 	if call != nil {
-		panic(fmt.Sprintf("attempting to use stream already in use: %d -> %+v\n", stream, call))
+		c.mu.Unlock()
+		return nil, fmt.Errorf("attempting to use stream already in use: %d -> %d", stream, call.streamID)
 	} else {
 		call = streamPool.Get().(*callReq)
 	}
@@ -799,7 +782,7 @@ func (c *Conn) Address() string {
 }
 
 func (c *Conn) AvailableStreams() int {
-	return len(c.uniq)
+	return c.streams.Available()
 }
 
 func (c *Conn) UseKeyspace(keyspace string) error {
@@ -1011,4 +994,5 @@ var (
 	ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period")
 	ErrTooManyTimeouts   = errors.New("gocql: too many query timeouts on the connection")
 	ErrConnectionClosed  = errors.New("gocql: connection closed waiting for response")
+	ErrNoStreams         = errors.New("gocql: no streams available on connection")
 )

+ 140 - 0
internal/streams/streams.go

@@ -0,0 +1,140 @@
+package streams
+
+import (
+	"math"
+	"strconv"
+	"sync/atomic"
+)
+
+const bucketBits = 64
+
+// IDGenerator tracks and allocates streams which are in use.
+type IDGenerator struct {
+	NumStreams   int
+	inuseStreams int32
+	numBuckets   uint32
+
+	// streams is a bitset where each bit represents a stream, a 1 implies in use
+	streams []uint64
+	offset  uint32
+}
+
+func New(protocol int) *IDGenerator {
+	maxStreams := 128
+	if protocol > 2 {
+		maxStreams = 32768
+	}
+
+	buckets := maxStreams / 64
+	// reserve stream 0
+	streams := make([]uint64, buckets)
+	streams[0] = 1 << 63
+
+	return &IDGenerator{
+		NumStreams: maxStreams,
+		streams:    streams,
+		numBuckets: uint32(buckets),
+		offset:     uint32(buckets) - 1,
+	}
+}
+
+func streamFromBucket(bucket, streamInBucket int) int {
+	return (bucket * bucketBits) + streamInBucket
+}
+
+func (s *IDGenerator) GetStream() (int, bool) {
+	// based closely on the java-driver stream ID generator
+	// avoid false sharing subsequent requests.
+	offset := atomic.LoadUint32(&s.offset)
+	for !atomic.CompareAndSwapUint32(&s.offset, offset, (offset+1)%s.numBuckets) {
+		offset = atomic.LoadUint32(&s.offset)
+	}
+	offset = (offset + 1) % s.numBuckets
+
+	for i := uint32(0); i < s.numBuckets; i++ {
+		pos := int((i + offset) % s.numBuckets)
+
+		bucket := atomic.LoadUint64(&s.streams[pos])
+		if bucket == math.MaxUint64 {
+			// all streams in use
+			continue
+		}
+
+		for j := 0; j < bucketBits; j++ {
+			mask := uint64(1 << streamOffset(j))
+			if bucket&mask == 0 {
+				if atomic.CompareAndSwapUint64(&s.streams[pos], bucket, bucket|mask) {
+					atomic.AddInt32(&s.inuseStreams, 1)
+					return streamFromBucket(int(pos), j), true
+				}
+				bucket = atomic.LoadUint64(&s.streams[offset])
+			}
+		}
+	}
+
+	return 0, false
+}
+
+func bitfmt(b uint64) string {
+	return strconv.FormatUint(b, 16)
+}
+
+// returns the bucket offset of a given stream
+func bucketOffset(i int) int {
+	return i / bucketBits
+}
+
+func streamOffset(stream int) uint64 {
+	return bucketBits - uint64(stream%bucketBits) - 1
+}
+
+func isSet(bits uint64, stream int) bool {
+	return bits>>streamOffset(stream)&1 == 1
+}
+
+func (s *IDGenerator) isSet(stream int) bool {
+	bits := atomic.LoadUint64(&s.streams[bucketOffset(stream)])
+	return isSet(bits, stream)
+}
+
+func (s *IDGenerator) String() string {
+	size := s.numBuckets * (bucketBits + 1)
+	buf := make([]byte, 0, size)
+	for i := 0; i < int(s.numBuckets); i++ {
+		bits := atomic.LoadUint64(&s.streams[i])
+		buf = append(buf, bitfmt(bits)...)
+		buf = append(buf, ' ')
+	}
+	return string(buf[:size-1 : size-1])
+}
+
+func (s *IDGenerator) Clear(stream int) (inuse bool) {
+	offset := bucketOffset(stream)
+	bucket := atomic.LoadUint64(&s.streams[offset])
+
+	mask := uint64(1) << streamOffset(stream)
+	if bucket&mask != mask {
+		// already cleared
+		return false
+	}
+
+	for !atomic.CompareAndSwapUint64(&s.streams[offset], bucket, bucket & ^mask) {
+		bucket = atomic.LoadUint64(&s.streams[offset])
+		if bucket&mask != mask {
+			// already cleared
+			return false
+		}
+	}
+
+	// TODO: make this account for 0 stream being reserved
+	if atomic.AddInt32(&s.inuseStreams, -1) < 0 {
+		// TODO(zariel): remove this
+		panic("negative streams inuse")
+	}
+
+	return true
+}
+
+func (s *IDGenerator) Available() int {
+	return s.NumStreams - int(atomic.LoadInt32(&s.inuseStreams)) - 1
+}

+ 201 - 0
internal/streams/streams_test.go

@@ -0,0 +1,201 @@
+package streams
+
+import (
+	"math"
+	"strconv"
+	"sync/atomic"
+	"testing"
+)
+
+func TestUsesAllStreams(t *testing.T) {
+	streams := New(1)
+
+	got := make(map[int]struct{})
+
+	for i := 1; i < streams.NumStreams; i++ {
+		stream, ok := streams.GetStream()
+		if !ok {
+			t.Fatalf("unable to get stream %d", i)
+		}
+
+		if _, ok = got[stream]; ok {
+			t.Fatalf("got an already allocated stream: %d", stream)
+		}
+		got[stream] = struct{}{}
+
+		if !streams.isSet(stream) {
+			bucket := atomic.LoadUint64(&streams.streams[bucketOffset(stream)])
+			t.Logf("bucket=%d: %s\n", bucket, strconv.FormatUint(bucket, 2))
+			t.Fatalf("stream not set: %d", stream)
+		}
+	}
+
+	for i := 1; i < streams.NumStreams; i++ {
+		if _, ok := got[i]; !ok {
+			t.Errorf("did not use stream %d", i)
+		}
+	}
+	if _, ok := got[0]; ok {
+		t.Fatal("expected to not use stream 0")
+	}
+
+	for i, bucket := range streams.streams {
+		if bucket != math.MaxUint64 {
+			t.Errorf("did not use all streams in offset=%d bucket=%s", i, bitfmt(bucket))
+		}
+	}
+}
+
+func TestFullStreams(t *testing.T) {
+	streams := New(1)
+	for i := range streams.streams {
+		streams.streams[i] = math.MaxUint64
+	}
+
+	stream, ok := streams.GetStream()
+	if ok {
+		t.Fatalf("should not get stream when all in use: stream=%d", stream)
+	}
+}
+
+func TestClearStreams(t *testing.T) {
+	streams := New(1)
+	for i := range streams.streams {
+		streams.streams[i] = math.MaxUint64
+	}
+	streams.inuseStreams = int32(streams.NumStreams)
+
+	for i := 0; i < streams.NumStreams; i++ {
+		streams.Clear(i)
+	}
+
+	for i, bucket := range streams.streams {
+		if bucket != 0 {
+			t.Errorf("did not clear streams in offset=%d bucket=%s", i, bitfmt(bucket))
+		}
+	}
+}
+
+func TestDoubleClear(t *testing.T) {
+	streams := New(1)
+	stream, ok := streams.GetStream()
+	if !ok {
+		t.Fatal("did not get stream")
+	}
+
+	if !streams.Clear(stream) {
+		t.Fatalf("stream not indicated as in use: %d", stream)
+	}
+	if streams.Clear(stream) {
+		t.Fatalf("stream not as in use after clear: %d", stream)
+	}
+}
+
+func BenchmarkConcurrentUse(b *testing.B) {
+	streams := New(2)
+
+	b.RunParallel(func(pb *testing.PB) {
+		for pb.Next() {
+			stream, ok := streams.GetStream()
+			if !ok {
+				b.Error("unable to get stream")
+				return
+			}
+
+			if !streams.Clear(stream) {
+				b.Errorf("stream was already cleared: %d", stream)
+				return
+			}
+		}
+	})
+}
+
+func TestStreamOffset(t *testing.T) {
+	tests := [...]struct {
+		n   int
+		off uint64
+	}{
+		{0, 63},
+		{1, 62},
+		{2, 61},
+		{3, 60},
+		{63, 0},
+		{64, 63},
+
+		{128, 63},
+	}
+
+	for _, test := range tests {
+		if off := streamOffset(test.n); off != test.off {
+			t.Errorf("n=%d expected %d got %d", test.n, off, test.off)
+		}
+	}
+}
+
+func TestIsSet(t *testing.T) {
+	tests := [...]struct {
+		stream int
+		bucket uint64
+		set    bool
+	}{
+		{0, 0, false},
+		{0, 1 << 63, true},
+		{1, 0, false},
+		{1, 1 << 62, true},
+		{63, 1, true},
+		{64, 1 << 63, true},
+		{0, 0x8000000000000000, true},
+	}
+
+	for i, test := range tests {
+		if set := isSet(test.bucket, test.stream); set != test.set {
+			t.Errorf("[%d] stream=%d expected %v got %v", i, test.stream, test.set, set)
+		}
+	}
+
+	for i := 0; i < bucketBits; i++ {
+		if !isSet(math.MaxUint64, i) {
+			var shift uint64 = math.MaxUint64 >> streamOffset(i)
+			t.Errorf("expected isSet for all i=%d got=%d", i, shift)
+		}
+	}
+}
+
+func TestBucketOfset(t *testing.T) {
+	tests := [...]struct {
+		n      int
+		bucket int
+	}{
+		{0, 0},
+		{1, 0},
+		{63, 0},
+		{64, 1},
+	}
+
+	for _, test := range tests {
+		if bucket := bucketOffset(test.n); bucket != test.bucket {
+			t.Errorf("n=%d expected %v got %v", test.n, test.bucket, bucket)
+		}
+	}
+}
+
+func TestStreamFromBucket(t *testing.T) {
+	tests := [...]struct {
+		bucket int
+		pos    int
+		stream int
+	}{
+		{0, 0, 0},
+		{0, 1, 1},
+		{0, 2, 2},
+		{0, 63, 63},
+		{1, 0, 64},
+		{1, 1, 65},
+	}
+
+	for _, test := range tests {
+		if stream := streamFromBucket(test.bucket, test.pos); stream != test.stream {
+			t.Errorf("bucket=%d pos=%d expected %v got %v", test.bucket, test.pos, test.stream, stream)
+		}
+	}
+}