epsilon_decay.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. package hostpool
  2. import (
  3. "time"
  4. )
  5. // This will implement something that can, somewhat generically, keep track of
  6. // scores over time, with number of trials, so that we can do epsilon-greedy choice
  7. // among options.
  8. // Since this is being designed around response times, higher scores should be "worse"
  9. // Not yet clear to me whether that detail will come into play at this level
  10. const epsilonBuckets = 120
  11. const defaultDecayDuration = time.Duration(5) * time.Minute
  12. type EpsilonDecayStore interface {
  13. Record(score float64)
  14. GetWeightedAvgScore() float64
  15. }
  16. type defEpsDecayStore struct {
  17. epsilonCounts []int64
  18. epsilonValues []float64
  19. epsilonIndex int
  20. decayDuration time.Duration
  21. // incoming request channels
  22. recordReqChan chan *recordRequest
  23. getWAScoreReqChan chan *getWAScoreRequest
  24. }
  25. type recordRequest struct {
  26. score float64
  27. respChan chan struct{}
  28. }
  29. type getWAScoreRequest struct {
  30. respChan chan float64
  31. }
  32. // -- "Constructor" --
  33. func New() EpsilonDecayStore {
  34. store := &defEpsDecayStore{
  35. epsilonCounts: make([]int64, epsilonBuckets),
  36. epsilonValues: make([]float64, epsilonBuckets),
  37. decayDuration: defaultDecayDuration,
  38. recordReqChan: make(chan *recordRequest),
  39. getWAScoreReqChan: make(chan *getWAScoreRequest),
  40. }
  41. var numBuckets int64 = int64(len(store.epsilonCounts))
  42. durationPerBucket := time.Duration(int64(store.decayDuration) / numBuckets)
  43. ticker := time.Tick(durationPerBucket)
  44. go store.muxRequests(ticker)
  45. return store
  46. }
  47. // -- Public Methods --
  48. func (ds *defEpsDecayStore) Record(score float64) {
  49. req := &recordRequest{
  50. score: score,
  51. respChan: make(chan struct{}),
  52. }
  53. ds.recordReqChan <- req
  54. <-req.respChan
  55. }
  56. func (ds *defEpsDecayStore) GetWeightedAvgScore() float64 {
  57. req := &getWAScoreRequest{
  58. respChan: make(chan float64),
  59. }
  60. ds.getWAScoreReqChan <- req
  61. avgScore := <-req.respChan
  62. return avgScore
  63. }
  64. // -- Internal Methods --
  65. func (ds *defEpsDecayStore) muxRequests(decayTicker <-chan time.Time) {
  66. for {
  67. select {
  68. case <-decayTicker:
  69. ds.performDecay()
  70. case req := <-ds.getWAScoreReqChan:
  71. avgScore := ds.GetWeightedAvgScore()
  72. req.respChan <- avgScore
  73. case req := <-ds.recordReqChan:
  74. newScore := req.score
  75. ds.epsilonCounts[ds.epsilonIndex]++
  76. ds.epsilonValues[ds.epsilonIndex] += newScore
  77. req.respChan <- struct{}{}
  78. }
  79. }
  80. }
  81. // Methods below should only be called from muxRequests above
  82. func (ds *defEpsDecayStore) performDecay() {
  83. ds.epsilonIndex += 1
  84. ds.epsilonIndex = ds.epsilonIndex % epsilonBuckets
  85. ds.epsilonCounts[ds.epsilonIndex] = 0
  86. ds.epsilonValues[ds.epsilonIndex] = 0.0
  87. }
  88. func (ds *defEpsDecayStore) getWeightedAverageScore() float64 {
  89. var value float64
  90. var lastValue float64
  91. // start at 1 so we start with the oldest entry
  92. for i := 1; i <= epsilonBuckets; i += 1 {
  93. pos := (ds.epsilonIndex + i) % epsilonBuckets
  94. bucketCount := ds.epsilonCounts[pos]
  95. weight := float64(i) / float64(epsilonBuckets)
  96. if bucketCount > 0 {
  97. currentValue := float64(ds.epsilonValues[pos]) / float64(bucketCount)
  98. value += currentValue * weight
  99. lastValue = currentValue
  100. } else {
  101. value += lastValue * weight
  102. }
  103. }
  104. return value
  105. }