epsilon_decay.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. package hostpool
  2. import (
  3. "time"
  4. )
  5. // This will implement something that can, somewhat generically, keep track of
  6. // scores over time, with number of trials, so that we can do epsilon-greedy choice
  7. // among options.
  8. // Since this is being designed around response times, higher scores should be "worse"
  9. // Not yet clear to me whether that detail will come into play at this level
  10. const epsilonBuckets = 120
  11. const defaultDecayDuration = time.Duration(5) * time.Minute
  12. type EpsilonDecayStore interface {
  13. Record(score float64)
  14. GetWeightedAvgScore() float64
  15. performDecay() // this is only exposed in the interface for testing
  16. }
  17. type defEpsDecayStore struct {
  18. epsilonCounts []int64
  19. epsilonValues []float64
  20. epsilonIndex int
  21. decayDuration time.Duration
  22. // incoming request channels
  23. recordReqChan chan *recordRequest
  24. getWAScoreReqChan chan *getWAScoreRequest
  25. }
  26. type recordRequest struct {
  27. score float64
  28. respChan chan struct{}
  29. }
  30. type getWAScoreRequest struct {
  31. respChan chan float64
  32. }
  33. // -- "Constructor" --
  34. func NewDecayStore() EpsilonDecayStore {
  35. store := &defEpsDecayStore{
  36. epsilonCounts: make([]int64, epsilonBuckets),
  37. epsilonValues: make([]float64, epsilonBuckets),
  38. decayDuration: defaultDecayDuration,
  39. recordReqChan: make(chan *recordRequest),
  40. getWAScoreReqChan: make(chan *getWAScoreRequest),
  41. }
  42. var numBuckets int64 = int64(len(store.epsilonCounts))
  43. durationPerBucket := time.Duration(int64(store.decayDuration) / numBuckets)
  44. ticker := time.Tick(durationPerBucket)
  45. go store.muxRequests(ticker)
  46. return store
  47. }
  48. // -- Public Methods --
  49. func (ds *defEpsDecayStore) Record(score float64) {
  50. req := &recordRequest{
  51. score: score,
  52. respChan: make(chan struct{}),
  53. }
  54. ds.recordReqChan <- req
  55. <-req.respChan
  56. }
  57. func (ds *defEpsDecayStore) GetWeightedAvgScore() float64 {
  58. req := &getWAScoreRequest{
  59. respChan: make(chan float64),
  60. }
  61. ds.getWAScoreReqChan <- req
  62. avgScore := <-req.respChan
  63. return avgScore
  64. }
  65. // -- Internal Methods --
  66. func (ds *defEpsDecayStore) muxRequests(decayTicker <-chan time.Time) {
  67. for {
  68. select {
  69. case <-decayTicker:
  70. ds.performDecay()
  71. case req := <-ds.getWAScoreReqChan:
  72. avgScore := ds.getWeightedAverageScore()
  73. req.respChan <- avgScore
  74. case req := <-ds.recordReqChan:
  75. newScore := req.score
  76. ds.epsilonCounts[ds.epsilonIndex]++
  77. ds.epsilonValues[ds.epsilonIndex] += newScore
  78. req.respChan <- struct{}{}
  79. }
  80. }
  81. }
  82. // Methods below should only be called from muxRequests above
  83. func (ds *defEpsDecayStore) performDecay() {
  84. ds.epsilonIndex += 1
  85. ds.epsilonIndex = ds.epsilonIndex % epsilonBuckets
  86. ds.epsilonCounts[ds.epsilonIndex] = 0
  87. ds.epsilonValues[ds.epsilonIndex] = 0.0
  88. }
  89. func (ds *defEpsDecayStore) getWeightedAverageScore() float64 {
  90. var value float64
  91. var lastValue float64
  92. // start at 1 so we start with the oldest entry
  93. for i := 1; i <= epsilonBuckets; i += 1 {
  94. pos := (ds.epsilonIndex + i) % epsilonBuckets
  95. bucketCount := ds.epsilonCounts[pos]
  96. weight := float64(i) / float64(epsilonBuckets)
  97. if bucketCount > 0 {
  98. currentValue := float64(ds.epsilonValues[pos]) / float64(bucketCount)
  99. value += currentValue * weight
  100. lastValue = currentValue
  101. } else {
  102. value += lastValue * weight
  103. }
  104. }
  105. return value
  106. }