Browse Source

Add Store interface.

Standard storage is now known as memoryStore (created with
NewMemoryStore exported function). It's still default, however there's
now an option to replace it with a custom store by implementing Store
interface and calling SetCustomStore.
Dmitry Chestnykh 14 years ago
parent
commit
dab967324b
4 changed files with 77 additions and 54 deletions
  1. 16 9
      captcha.go
  2. 3 3
      captcha_test.go
  3. 45 29
      store.go
  4. 13 13
      store_test.go

+ 16 - 9
captcha.go

@@ -1,3 +1,5 @@
+// Package captcha implements generation and verification of image and audio
+// CAPTCHAs.
 package captcha
 package captcha
 
 
 import (
 import (
@@ -24,7 +26,12 @@ const (
 var ErrNotFound = os.NewError("captcha with the given id not found")
 var ErrNotFound = os.NewError("captcha with the given id not found")
 
 
 // globalStore is a shared storage for captchas, generated by New function.
 // globalStore is a shared storage for captchas, generated by New function.
-var globalStore = newStore(StdCollectNum, StdExpiration)
+var globalStore = NewMemoryStore(StdCollectNum, StdExpiration)
+
+// SetCustomStore sets custom storage for captchas.
+func SetCustomStore(s Store) {
+	globalStore = s
+}
 
 
 // RandomDigits returns a byte slice of the given length containing random
 // RandomDigits returns a byte slice of the given length containing random
 // digits in range 0-9.
 // digits in range 0-9.
@@ -43,7 +50,7 @@ func RandomDigits(length int) []byte {
 // storage, and returns its id.
 // storage, and returns its id.
 func New(length int) (id string) {
 func New(length int) (id string) {
 	id = uniuri.New()
 	id = uniuri.New()
-	globalStore.saveCaptcha(id, RandomDigits(length))
+	globalStore.Set(id, RandomDigits(length))
 	return
 	return
 }
 }
 
 
@@ -54,18 +61,18 @@ func New(length int) (id string) {
 // refreshed to show the new captcha representation (WriteImage and WriteAudio
 // refreshed to show the new captcha representation (WriteImage and WriteAudio
 // will write the new one).
 // will write the new one).
 func Reload(id string) bool {
 func Reload(id string) bool {
-	old := globalStore.getDigits(id)
+	old := globalStore.Get(id, false)
 	if old == nil {
 	if old == nil {
 		return false
 		return false
 	}
 	}
-	globalStore.saveCaptcha(id, RandomDigits(len(old)))
+	globalStore.Set(id, RandomDigits(len(old)))
 	return true
 	return true
 }
 }
 
 
 // WriteImage writes PNG-encoded image representation of the captcha with the
 // WriteImage writes PNG-encoded image representation of the captcha with the
 // given id. The image will have the given width and height.
 // given id. The image will have the given width and height.
 func WriteImage(w io.Writer, id string, width, height int) os.Error {
 func WriteImage(w io.Writer, id string, width, height int) os.Error {
-	d := globalStore.getDigits(id)
+	d := globalStore.Get(id, false)
 	if d == nil {
 	if d == nil {
 		return ErrNotFound
 		return ErrNotFound
 	}
 	}
@@ -76,7 +83,7 @@ func WriteImage(w io.Writer, id string, width, height int) os.Error {
 // WriteAudio writes WAV-encoded audio representation of the captcha with the
 // WriteAudio writes WAV-encoded audio representation of the captcha with the
 // given id.
 // given id.
 func WriteAudio(w io.Writer, id string) os.Error {
 func WriteAudio(w io.Writer, id string) os.Error {
-	d := globalStore.getDigits(id)
+	d := globalStore.Get(id, false)
 	if d == nil {
 	if d == nil {
 		return ErrNotFound
 		return ErrNotFound
 	}
 	}
@@ -93,7 +100,7 @@ func Verify(id string, digits []byte) bool {
 	if digits == nil || len(digits) == 0 {
 	if digits == nil || len(digits) == 0 {
 		return false
 		return false
 	}
 	}
-	reald := globalStore.getDigitsClear(id)
+	reald := globalStore.Get(id, true)
 	if reald == nil {
 	if reald == nil {
 		return false
 		return false
 	}
 	}
@@ -128,7 +135,7 @@ func VerifyString(id string, digits string) bool {
 //
 //
 // Collection is launched in a new goroutine.
 // Collection is launched in a new goroutine.
 func Collect() {
 func Collect() {
-	go globalStore.collect()
+	go globalStore.Collect()
 }
 }
 
 
 type captchaHandler struct {
 type captchaHandler struct {
@@ -171,7 +178,7 @@ func (h *captchaHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		//err = WriteAudio(buf, id)
 		//err = WriteAudio(buf, id)
 		//XXX(dchest) Workaround for Chrome: it wants content-length,
 		//XXX(dchest) Workaround for Chrome: it wants content-length,
 		//or else will start playing NOT from the beginning.
 		//or else will start playing NOT from the beginning.
-		d := globalStore.getDigits(id)
+		d := globalStore.Get(id, false)
 		if d == nil {
 		if d == nil {
 			err = ErrNotFound
 			err = ErrNotFound
 		} else {
 		} else {

+ 3 - 3
captcha_test.go

@@ -18,7 +18,7 @@ func TestVerify(t *testing.T) {
 		t.Errorf("verified wrong captcha")
 		t.Errorf("verified wrong captcha")
 	}
 	}
 	id = New(StdLength)
 	id = New(StdLength)
-	d := globalStore.getDigits(id) // cheating
+	d := globalStore.Get(id, false) // cheating
 	if !Verify(id, d) {
 	if !Verify(id, d) {
 		t.Errorf("proper captcha not verified")
 		t.Errorf("proper captcha not verified")
 	}
 	}
@@ -26,9 +26,9 @@ func TestVerify(t *testing.T) {
 
 
 func TestReload(t *testing.T) {
 func TestReload(t *testing.T) {
 	id := New(StdLength)
 	id := New(StdLength)
-	d1 := globalStore.getDigits(id) // cheating
+	d1 := globalStore.Get(id, false) // cheating
 	Reload(id)
 	Reload(id)
-	d2 := globalStore.getDigits(id) // cheating again
+	d2 := globalStore.Get(id, false) // cheating again
 	if bytes.Equal(d1, d2) {
 	if bytes.Equal(d1, d2) {
 		t.Errorf("reload didn't work: %v = %v", d1, d2)
 		t.Errorf("reload didn't work: %v = %v", d1, d2)
 	}
 	}

+ 45 - 29
store.go

@@ -6,16 +6,34 @@ import (
 	"time"
 	"time"
 )
 )
 
 
-// expValue stores timestamp and id of captchas. It is used in a list inside
-// store for indexing generated captchas by timestamp to enable garbage
+// An object implementing Store interface can be registered with SetCustomStore
+// function to handle storage and retrieval of captcha ids and solutions for
+// them, replacing the default memory store.
+type Store interface {
+	// Set sets the digits for the captcha id.
+	Set(id string, digits []byte)
+
+	// Get returns stored digits for the captcha id. Clear indicates
+	// whether the captcha must be deleted from the store.
+	Get(id string, clear bool) (digits []byte)
+
+	// Collect deletes expired captchas from the store.  For custom stores
+	// this method is not called automatically, it is only wired to the
+	// package's Collect function.  Custom stores must implement their own
+	// procedure for calling Collect, for example, in Set method.
+	Collect()
+}
+
+// expValue stores timestamp and id of captchas. It is used in the list inside
+// memoryStore for indexing generated captchas by timestamp to enable garbage
 // collection of expired captchas.
 // collection of expired captchas.
 type expValue struct {
 type expValue struct {
 	timestamp int64
 	timestamp int64
 	id        string
 	id        string
 }
 }
 
 
-// store is an internal store for captcha ids and their values.
-type store struct {
+// memoryStore is an internal store for captcha ids and their values.
+type memoryStore struct {
 	mu  sync.RWMutex
 	mu  sync.RWMutex
 	ids map[string][]byte
 	ids map[string][]byte
 	exp *list.List
 	exp *list.List
@@ -27,9 +45,11 @@ type store struct {
 	expiration int64
 	expiration int64
 }
 }
 
 
-// newStore initializes and returns a new store.
-func newStore(collectNum int, expiration int64) *store {
-	s := new(store)
+// NewMemoryStore returns a new standard memory store for captchas with the
+// given collection threshold and expiration time in seconds. The returned
+// store must be registered with SetCustomStore to replace the default one.
+func NewMemoryStore(collectNum int, expiration int64) Store {
+	s := new(memoryStore)
 	s.ids = make(map[string][]byte)
 	s.ids = make(map[string][]byte)
 	s.exp = list.New()
 	s.exp = list.New()
 	s.collectNum = collectNum
 	s.collectNum = collectNum
@@ -37,44 +57,40 @@ func newStore(collectNum int, expiration int64) *store {
 	return s
 	return s
 }
 }
 
 
-// saveCaptcha saves the captcha id and the corresponding digits.
-func (s *store) saveCaptcha(id string, digits []byte) {
+func (s *memoryStore) Set(id string, digits []byte) {
 	s.mu.Lock()
 	s.mu.Lock()
 	s.ids[id] = digits
 	s.ids[id] = digits
 	s.exp.PushBack(expValue{time.Seconds(), id})
 	s.exp.PushBack(expValue{time.Seconds(), id})
 	s.numStored++
 	s.numStored++
 	s.mu.Unlock()
 	s.mu.Unlock()
 	if s.numStored > s.collectNum {
 	if s.numStored > s.collectNum {
-		go s.collect()
+		go s.Collect()
 	}
 	}
 }
 }
 
 
-// getDigits returns the digits for the given id.
-func (s *store) getDigits(id string) (digits []byte) {
-	s.mu.RLock()
-	defer s.mu.RUnlock()
-	digits, _ = s.ids[id]
-	return
-}
-
-// getDigitsClear returns the digits for the given id, and removes them from
-// the store.
-func (s *store) getDigitsClear(id string) (digits []byte) {
-	s.mu.Lock()
-	defer s.mu.Unlock()
+func (s *memoryStore) Get(id string, clear bool) (digits []byte) {
+	if !clear {
+		// When we don't need to clear captcha, acquire read lock.
+		s.mu.RLock()
+		defer s.mu.RUnlock()
+	} else {
+		s.mu.Lock()
+		defer s.mu.Unlock()
+	}
 	digits, ok := s.ids[id]
 	digits, ok := s.ids[id]
 	if !ok {
 	if !ok {
 		return
 		return
 	}
 	}
-	s.ids[id] = nil, false
-	// XXX(dchest) Index (s.exp) will be cleaned when collecting expired
-	// captchas.  Can't clean it here, because we don't store reference to
-	// expValue in the map. Maybe store it?
+	if clear {
+		s.ids[id] = nil, false
+		// XXX(dchest) Index (s.exp) will be cleaned when collecting expired
+		// captchas.  Can't clean it here, because we don't store reference to
+		// expValue in the map. Maybe store it?
+	}
 	return
 	return
 }
 }
 
 
-// collect deletes expired captchas from the store.
-func (s *store) collect() {
+func (s *memoryStore) Collect() {
 	now := time.Seconds()
 	now := time.Seconds()
 	s.mu.Lock()
 	s.mu.Lock()
 	defer s.mu.Unlock()
 	defer s.mu.Unlock()

+ 13 - 13
store_test.go

@@ -6,27 +6,27 @@ import (
 	"testing"
 	"testing"
 )
 )
 
 
-func TestSaveAndGetDigits(t *testing.T) {
-	s := newStore(StdCollectNum, StdExpiration)
+func TestSetGet(t *testing.T) {
+	s := NewMemoryStore(StdCollectNum, StdExpiration)
 	id := "captcha id"
 	id := "captcha id"
 	d := RandomDigits(10)
 	d := RandomDigits(10)
-	s.saveCaptcha(id, d)
-	d2 := s.getDigits(id)
+	s.Set(id, d)
+	d2 := s.Get(id, false)
 	if d2 == nil || !bytes.Equal(d, d2) {
 	if d2 == nil || !bytes.Equal(d, d2) {
 		t.Errorf("saved %v, getDigits returned got %v", d, d2)
 		t.Errorf("saved %v, getDigits returned got %v", d, d2)
 	}
 	}
 }
 }
 
 
-func TestGetDigitsClear(t *testing.T) {
-	s := newStore(StdCollectNum, StdExpiration)
+func TestGetClear(t *testing.T) {
+	s := NewMemoryStore(StdCollectNum, StdExpiration)
 	id := "captcha id"
 	id := "captcha id"
 	d := RandomDigits(10)
 	d := RandomDigits(10)
-	s.saveCaptcha(id, d)
-	d2 := s.getDigitsClear(id)
+	s.Set(id, d)
+	d2 := s.Get(id, true)
 	if d2 == nil || !bytes.Equal(d, d2) {
 	if d2 == nil || !bytes.Equal(d, d2) {
 		t.Errorf("saved %v, getDigitsClear returned got %v", d, d2)
 		t.Errorf("saved %v, getDigitsClear returned got %v", d, d2)
 	}
 	}
-	d2 = s.getDigits(id)
+	d2 = s.Get(id, false)
 	if d2 != nil {
 	if d2 != nil {
 		t.Errorf("getDigitClear didn't clear (%q=%v)", id, d2)
 		t.Errorf("getDigitClear didn't clear (%q=%v)", id, d2)
 	}
 	}
@@ -35,19 +35,19 @@ func TestGetDigitsClear(t *testing.T) {
 func TestCollect(t *testing.T) {
 func TestCollect(t *testing.T) {
 	//TODO(dchest): can't test automatic collection when saving, because
 	//TODO(dchest): can't test automatic collection when saving, because
 	//it's currently launched in a different goroutine.
 	//it's currently launched in a different goroutine.
-	s := newStore(10, -1)
+	s := NewMemoryStore(10, -1)
 	// create 10 ids
 	// create 10 ids
 	ids := make([]string, 10)
 	ids := make([]string, 10)
 	d := RandomDigits(10)
 	d := RandomDigits(10)
 	for i := range ids {
 	for i := range ids {
 		ids[i] = uniuri.New()
 		ids[i] = uniuri.New()
-		s.saveCaptcha(ids[i], d)
+		s.Set(ids[i], d)
 	}
 	}
-	s.collect()
+	s.Collect()
 	// Must be already collected
 	// Must be already collected
 	nc := 0
 	nc := 0
 	for i := range ids {
 	for i := range ids {
-		d2 := s.getDigits(ids[i])
+		d2 := s.Get(ids[i], false)
 		if d2 != nil {
 		if d2 != nil {
 			t.Errorf("%d: not collected", i)
 			t.Errorf("%d: not collected", i)
 			nc++
 			nc++