瀏覽代碼

More close mechanisms, and some refactoring

namely, needed to move some epsilon decay stuff away from interfaces,
or from being public for that matter. oh well, didn't need to override
it anwyay
Dan Frank 12 年之前
父節點
當前提交
9be2838f29
共有 5 個文件被更改,包括 46 次插入27 次删除
  1. 21 14
      epsilon_decay.go
  2. 1 1
      epsilon_decay_test.go
  3. 16 12
      epsilon_greedy.go
  4. 5 0
      host_entry.go
  5. 3 0
      hostpool.go

+ 21 - 14
epsilon_decay.go

@@ -1,6 +1,7 @@
 package hostpool
 
 import (
+	"sync"
 	"time"
 )
 
@@ -14,13 +15,7 @@ import (
 const epsilonBuckets = 120
 const defaultDecayDuration = time.Duration(5) * time.Minute
 
-type EpsilonDecayStore interface {
-	Record(score float64)
-	GetWeightedAvgScore() float64
-	performDecay() // this is only exposed in the interface for testing
-}
-
-type defEpsDecayStore struct {
+type epsilonDecayStore struct {
 	epsilonCounts []int64
 	epsilonValues []float64
 	epsilonIndex  int
@@ -30,6 +25,9 @@ type defEpsDecayStore struct {
 	// incoming request channels
 	recordReqChan     chan *recordRequest
 	getWAScoreReqChan chan *getWAScoreRequest
+
+	closeChan chan struct{}
+	wg        sync.WaitGroup
 }
 
 type recordRequest struct {
@@ -43,8 +41,8 @@ type getWAScoreRequest struct {
 
 // -- "Constructor" --
 
-func NewDecayStore() EpsilonDecayStore {
-	store := &defEpsDecayStore{
+func newDecayStore() *epsilonDecayStore {
+	store := &epsilonDecayStore{
 		epsilonCounts: make([]int64, epsilonBuckets),
 		epsilonValues: make([]float64, epsilonBuckets),
 		decayDuration: defaultDecayDuration,
@@ -55,13 +53,14 @@ func NewDecayStore() EpsilonDecayStore {
 	var numBuckets int64 = int64(len(store.epsilonCounts))
 	durationPerBucket := time.Duration(int64(store.decayDuration) / numBuckets)
 	ticker := time.Tick(durationPerBucket)
+	store.wg.Add(1)
 	go store.muxRequests(ticker)
 	return store
 }
 
 // -- Public Methods --
 
-func (ds *defEpsDecayStore) Record(score float64) {
+func (ds *epsilonDecayStore) Record(score float64) {
 	req := &recordRequest{
 		score:    score,
 		respChan: make(chan struct{}),
@@ -70,7 +69,7 @@ func (ds *defEpsDecayStore) Record(score float64) {
 	<-req.respChan
 }
 
-func (ds *defEpsDecayStore) GetWeightedAvgScore() float64 {
+func (ds *epsilonDecayStore) GetWeightedAvgScore() float64 {
 	req := &getWAScoreRequest{
 		respChan: make(chan float64),
 	}
@@ -79,9 +78,14 @@ func (ds *defEpsDecayStore) GetWeightedAvgScore() float64 {
 	return avgScore
 }
 
+func (ds *epsilonDecayStore) close() {
+	ds.closeChan <- struct{}{}
+	ds.wg.Wait()
+}
+
 // -- Internal Methods --
 
-func (ds *defEpsDecayStore) muxRequests(decayTicker <-chan time.Time) {
+func (ds *epsilonDecayStore) muxRequests(decayTicker <-chan time.Time) {
 	for {
 		select {
 		case <-decayTicker:
@@ -94,6 +98,9 @@ func (ds *defEpsDecayStore) muxRequests(decayTicker <-chan time.Time) {
 			ds.epsilonCounts[ds.epsilonIndex]++
 			ds.epsilonValues[ds.epsilonIndex] += newScore
 			req.respChan <- struct{}{}
+		case <-ds.closeChan:
+			ds.wg.Done()
+			return
 		}
 
 	}
@@ -101,14 +108,14 @@ func (ds *defEpsDecayStore) muxRequests(decayTicker <-chan time.Time) {
 
 // Methods below should only be called from muxRequests above
 
-func (ds *defEpsDecayStore) performDecay() {
+func (ds *epsilonDecayStore) performDecay() {
 	ds.epsilonIndex += 1
 	ds.epsilonIndex = ds.epsilonIndex % epsilonBuckets
 	ds.epsilonCounts[ds.epsilonIndex] = 0
 	ds.epsilonValues[ds.epsilonIndex] = 0.0
 }
 
-func (ds *defEpsDecayStore) getWeightedAverageScore() float64 {
+func (ds *epsilonDecayStore) getWeightedAverageScore() float64 {
 	var value float64
 	var lastValue float64
 

+ 1 - 1
epsilon_decay_test.go

@@ -6,7 +6,7 @@ import (
 )
 
 func TestEDS(t *testing.T) {
-	eds := NewDecayStore()
+	eds := newDecayStore()
 	eds.Record(1.5)
 	assert.Equal(t, eds.GetWeightedAvgScore(), 1.5)
 }

+ 16 - 12
epsilon_greedy.go

@@ -6,16 +6,23 @@ import (
 	"time"
 )
 
-type epsilonGreedyHostEntry interface {
+type epsilonGreedyHostEntry struct {
 	HostEntry
-	EpsilonDecayStore
+	*epsilonDecayStore
 }
 
+func (egHostEntry *epsilonGreedyHostEntry) Close() {
+	egHostEntry.HostEntry.Close()
+	egHostEntry.epsilonDecayStore.close()
+}
+
+// -------------------------------
+
 type epsilonGreedyHostPool struct {
 	HostPool
-	hosts                  map[string]epsilonGreedyHostEntry // this basically just mirrors the underlying host map
-	epsilon                float32                           // this is our exploration factor
-	EpsilonValueCalculator                                   // embed the epsilonValueCalculator
+	hosts                  map[string]*epsilonGreedyHostEntry // this basically just mirrors the underlying host map
+	epsilon                float32                            // this is our exploration factor
+	EpsilonValueCalculator                                    // embed the epsilonValueCalculator
 	timer
 }
 
@@ -31,13 +38,10 @@ const epsilonDecay = 0.90 // decay the exploration rate
 const minEpsilon = 0.01   // explore one percent of the time
 const initialEpsilon = 0.3
 
-func toEGHostEntry(fromHE HostEntry) epsilonGreedyHostEntry {
-	return &struct {
-		HostEntry
-		EpsilonDecayStore
-	}{
+func toEGHostEntry(fromHE HostEntry) *epsilonGreedyHostEntry {
+	return &epsilonGreedyHostEntry{
 		fromHE,
-		NewDecayStore(),
+		newDecayStore(),
 	}
 }
 
@@ -73,7 +77,7 @@ func ToEpsilonGreedy(pool HostPool, decayDuration time.Duration, calc EpsilonVal
 		timer:                  &realTimer{},
 	}
 
-	p.hosts = make(map[string]epsilonGreedyHostEntry)
+	p.hosts = make(map[string]*epsilonGreedyHostEntry)
 	for _, hostName := range pool.Hosts() {
 		p.hosts[hostName] = toEGHostEntry(pool.lookupHost(hostName))
 	}

+ 5 - 0
host_entry.go

@@ -12,6 +12,7 @@ type HostEntry interface {
 	SetDead(bool)
 	canTryHost(time.Time) bool
 	willRetryHost()
+	Close()
 }
 
 // -- Requests
@@ -145,3 +146,7 @@ func (he *hostEntry) willRetryHost() {
 	he.incomingRequests <- req
 	<-req.respChan
 }
+
+func (he *hostEntry) Close() {
+	close(he.incomingRequests)
+}

+ 3 - 0
hostpool.go

@@ -230,6 +230,9 @@ func (p *standardHostPool) hostList() []HostEntry {
 
 func (p *standardHostPool) Close() {
 	p.closeChan <- struct{}{}
+	for _, he := range p.hosts {
+		he.Close()
+	}
 	p.wg.Wait()
 }