epsilon_decay.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. package hostpool
  2. import (
  3. "sync"
  4. "time"
  5. )
  6. // This will implement something that can, somewhat generically, keep track of
  7. // scores over time, with number of trials, so that we can do epsilon-greedy choice
  8. // among options.
  9. // Since this is being designed around response times, higher scores should be "worse"
  10. // Not yet clear to me whether that detail will come into play at this level
  11. const epsilonBuckets = 120
  12. const defaultDecayDuration = time.Duration(5) * time.Minute
  13. type epsilonDecayStore struct {
  14. epsilonCounts []int64
  15. epsilonValues []float64
  16. epsilonIndex int
  17. decayDuration time.Duration
  18. // incoming request channels
  19. recordReqChan chan *recordRequest
  20. getWAScoreReqChan chan *getWAScoreRequest
  21. closeChan chan struct{}
  22. wg sync.WaitGroup
  23. }
  24. type recordRequest struct {
  25. score float64
  26. respChan chan struct{}
  27. }
  28. type getWAScoreRequest struct {
  29. respChan chan float64
  30. }
  31. // -- "Constructor" --
  32. func newDecayStore() *epsilonDecayStore {
  33. store := &epsilonDecayStore{
  34. epsilonCounts: make([]int64, epsilonBuckets),
  35. epsilonValues: make([]float64, epsilonBuckets),
  36. decayDuration: defaultDecayDuration,
  37. recordReqChan: make(chan *recordRequest),
  38. getWAScoreReqChan: make(chan *getWAScoreRequest),
  39. }
  40. var numBuckets int64 = int64(len(store.epsilonCounts))
  41. durationPerBucket := time.Duration(int64(store.decayDuration) / numBuckets)
  42. ticker := time.Tick(durationPerBucket)
  43. store.wg.Add(1)
  44. go store.muxRequests(ticker)
  45. return store
  46. }
  47. // -- Public Methods --
  48. func (ds *epsilonDecayStore) Record(score float64) {
  49. req := &recordRequest{
  50. score: score,
  51. respChan: make(chan struct{}),
  52. }
  53. ds.recordReqChan <- req
  54. <-req.respChan
  55. }
  56. func (ds *epsilonDecayStore) GetWeightedAvgScore() float64 {
  57. req := &getWAScoreRequest{
  58. respChan: make(chan float64),
  59. }
  60. ds.getWAScoreReqChan <- req
  61. avgScore := <-req.respChan
  62. return avgScore
  63. }
  64. func (ds *epsilonDecayStore) close() {
  65. ds.closeChan <- struct{}{}
  66. ds.wg.Wait()
  67. }
  68. // -- Internal Methods --
  69. func (ds *epsilonDecayStore) muxRequests(decayTicker <-chan time.Time) {
  70. for {
  71. select {
  72. case <-decayTicker:
  73. ds.performDecay()
  74. case req := <-ds.getWAScoreReqChan:
  75. avgScore := ds.getWeightedAverageScore()
  76. req.respChan <- avgScore
  77. case req := <-ds.recordReqChan:
  78. newScore := req.score
  79. ds.epsilonCounts[ds.epsilonIndex]++
  80. ds.epsilonValues[ds.epsilonIndex] += newScore
  81. req.respChan <- struct{}{}
  82. case <-ds.closeChan:
  83. ds.wg.Done()
  84. return
  85. }
  86. }
  87. }
  88. // Methods below should only be called from muxRequests above
  89. func (ds *epsilonDecayStore) performDecay() {
  90. ds.epsilonIndex += 1
  91. ds.epsilonIndex = ds.epsilonIndex % epsilonBuckets
  92. ds.epsilonCounts[ds.epsilonIndex] = 0
  93. ds.epsilonValues[ds.epsilonIndex] = 0.0
  94. }
  95. func (ds *epsilonDecayStore) getWeightedAverageScore() float64 {
  96. var value float64
  97. var lastValue float64
  98. // start at 1 so we start with the oldest entry
  99. for i := 1; i <= epsilonBuckets; i += 1 {
  100. pos := (ds.epsilonIndex + i) % epsilonBuckets
  101. bucketCount := ds.epsilonCounts[pos]
  102. weight := float64(i) / float64(epsilonBuckets)
  103. if bucketCount > 0 {
  104. currentValue := float64(ds.epsilonValues[pos]) / float64(bucketCount)
  105. value += currentValue * weight
  106. lastValue = currentValue
  107. } else {
  108. value += lastValue * weight
  109. }
  110. }
  111. return value
  112. }