Browse Source

fix(watchhub.go) add a lock to protect the hashmap

Xiang Li 12 years ago
parent
commit
59ccefee0f
3 changed files with 19 additions and 10 deletions
  1. 2 2
      store/store.go
  2. 14 5
      store/watcher_hub.go
  3. 3 3
      store/watcher_test.go

+ 2 - 2
store/store.go

@@ -352,10 +352,10 @@ func (s *store) NewWatcher(key string, recursive bool, sinceIndex uint64) (*Watc
 	var err *etcdErr.Error
 	var err *etcdErr.Error
 
 
 	if sinceIndex == 0 {
 	if sinceIndex == 0 {
-		w, err = s.WatcherHub.watch(key, recursive, nextIndex)
+		w, err = s.WatcherHub.newWatcher(key, recursive, nextIndex)
 
 
 	} else {
 	} else {
-		w, err = s.WatcherHub.watch(key, recursive, sinceIndex)
+		w, err = s.WatcherHub.newWatcher(key, recursive, sinceIndex)
 	}
 	}
 
 
 	if err != nil {
 	if err != nil {

+ 14 - 5
store/watcher_hub.go

@@ -4,6 +4,7 @@ import (
 	"container/list"
 	"container/list"
 	"path"
 	"path"
 	"strings"
 	"strings"
+	"sync"
 	"sync/atomic"
 	"sync/atomic"
 
 
 	etcdErr "github.com/coreos/etcd/error"
 	etcdErr "github.com/coreos/etcd/error"
@@ -16,6 +17,7 @@ import (
 // event happens between the end of the first watch command and the start
 // event happens between the end of the first watch command and the start
 // of the second command.
 // of the second command.
 type watcherHub struct {
 type watcherHub struct {
+	mutex        sync.Mutex // protect the hash map
 	watchers     map[string]*list.List
 	watchers     map[string]*list.List
 	count        int64 // current number of watchers.
 	count        int64 // current number of watchers.
 	EventHistory *EventHistory
 	EventHistory *EventHistory
@@ -32,11 +34,11 @@ func newWatchHub(capacity int) *watcherHub {
 	}
 	}
 }
 }
 
 
-// watch function returns an Event channel.
-// If recursive is true, the first change after index under key will be sent to the event channel.
-// If recursive is false, the first change after index at key will be sent to the event channel.
+// newWatcher function returns a watcher.
+// If recursive is true, the first change after index under key will be sent to the event channel of the watcher.
+// If recursive is false, the first change after index at key will be sent to the event channel of the watcher.
 // If index is zero, watch will start from the current index + 1.
 // If index is zero, watch will start from the current index + 1.
-func (wh *watcherHub) watch(key string, recursive bool, index uint64) (*Watcher, *etcdErr.Error) {
+func (wh *watcherHub) newWatcher(key string, recursive bool, index uint64) (*Watcher, *etcdErr.Error) {
 	event, err := wh.EventHistory.scan(key, recursive, index)
 	event, err := wh.EventHistory.scan(key, recursive, index)
 
 
 	if err != nil {
 	if err != nil {
@@ -51,10 +53,12 @@ func (wh *watcherHub) watch(key string, recursive bool, index uint64) (*Watcher,
 
 
 	if event != nil {
 	if event != nil {
 		w.EventChan <- event
 		w.EventChan <- event
-
 		return w, nil
 		return w, nil
 	}
 	}
 
 
+	wh.mutex.Lock()
+	defer wh.mutex.Unlock()
+
 	l, ok := wh.watchers[key]
 	l, ok := wh.watchers[key]
 
 
 	var elem *list.Element
 	var elem *list.Element
@@ -69,6 +73,8 @@ func (wh *watcherHub) watch(key string, recursive bool, index uint64) (*Watcher,
 	}
 	}
 
 
 	w.Remove = func() {
 	w.Remove = func() {
+		wh.mutex.Lock()
+		defer wh.mutex.Unlock()
 		l.Remove(elem)
 		l.Remove(elem)
 		if l.Len() == 0 {
 		if l.Len() == 0 {
 			delete(wh.watchers, key)
 			delete(wh.watchers, key)
@@ -100,6 +106,9 @@ func (wh *watcherHub) notify(e *Event) {
 }
 }
 
 
 func (wh *watcherHub) notifyWatchers(e *Event, path string, deleted bool) {
 func (wh *watcherHub) notifyWatchers(e *Event, path string, deleted bool) {
+	wh.mutex.Lock()
+	defer wh.mutex.Unlock()
+
 	l, ok := wh.watchers[path]
 	l, ok := wh.watchers[path]
 	if ok {
 	if ok {
 		curr := l.Front()
 		curr := l.Front()

+ 3 - 3
store/watcher_test.go

@@ -23,7 +23,7 @@ import (
 func TestWatcher(t *testing.T) {
 func TestWatcher(t *testing.T) {
 	s := newStore()
 	s := newStore()
 	wh := s.WatcherHub
 	wh := s.WatcherHub
-	w, err := wh.watch("/foo", true, 1)
+	w, err := wh.newWatcher("/foo", true, 1)
 	if err != nil {
 	if err != nil {
 		t.Fatalf("%v", err)
 		t.Fatalf("%v", err)
 	}
 	}
@@ -46,7 +46,7 @@ func TestWatcher(t *testing.T) {
 		t.Fatal("recv != send")
 		t.Fatal("recv != send")
 	}
 	}
 
 
-	w, _ = wh.watch("/foo", false, 2)
+	w, _ = wh.newWatcher("/foo", false, 2)
 	c = w.EventChan
 	c = w.EventChan
 
 
 	e = newEvent(Create, "/foo/bar", 2, 2)
 	e = newEvent(Create, "/foo/bar", 2, 2)
@@ -71,7 +71,7 @@ func TestWatcher(t *testing.T) {
 	}
 	}
 
 
 	// ensure we are doing exact matching rather than prefix matching
 	// ensure we are doing exact matching rather than prefix matching
-	w, _ = wh.watch("/fo", true, 1)
+	w, _ = wh.newWatcher("/fo", true, 1)
 	c = w.EventChan
 	c = w.EventChan
 
 
 	select {
 	select {