Browse Source

Merge pull request #352 from xiangli-cmu/fix_watch_prefix

fix(event_history.go) should not scan prefix
Xiang Li 12 years ago
parent
commit
577d08ea7d
5 changed files with 50 additions and 18 deletions
  1. 16 3
      store/event_history.go
  2. 4 4
      store/event_test.go
  3. 4 5
      store/store.go
  4. 6 6
      store/watcher_hub.go
  5. 20 0
      store/watcher_test.go

+ 16 - 3
store/event_history.go

@@ -2,6 +2,7 @@ package store
 
 
 import (
 import (
 	"fmt"
 	"fmt"
+	"path"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 
 
@@ -39,8 +40,8 @@ func (eh *EventHistory) addEvent(e *Event) *Event {
 }
 }
 
 
 // scan function is enumerating events from the index in history and
 // scan function is enumerating events from the index in history and
-// stops till the first point where the key has identified prefix
-func (eh *EventHistory) scan(prefix string, index uint64) (*Event, *etcdErr.Error) {
+// stops till the first point where the key has identified key
+func (eh *EventHistory) scan(key string, recursive bool, index uint64) (*Event, *etcdErr.Error) {
 	eh.rwl.RLock()
 	eh.rwl.RLock()
 	defer eh.rwl.RUnlock()
 	defer eh.rwl.RUnlock()
 
 
@@ -62,7 +63,19 @@ func (eh *EventHistory) scan(prefix string, index uint64) (*Event, *etcdErr.Erro
 	for {
 	for {
 		e := eh.Queue.Events[i]
 		e := eh.Queue.Events[i]
 
 
-		if strings.HasPrefix(e.Key, prefix) && index <= e.Index() { // make sure we bypass the smaller one
+		ok := (e.Key == key)
+
+		if recursive {
+			// add tailing slash
+			key := path.Clean(key)
+			if key[len(key)-1] != '/' {
+				key = key + "/"
+			}
+
+			ok = ok || strings.HasPrefix(e.Key, key)
+		}
+
+		if ok && index <= e.Index() { // make sure we bypass the smaller one
 			return e, nil
 			return e, nil
 		}
 		}
 
 

+ 4 - 4
store/event_test.go

@@ -41,24 +41,24 @@ func TestScanHistory(t *testing.T) {
 	eh.addEvent(newEvent(Create, "/foo/bar/bar", 4))
 	eh.addEvent(newEvent(Create, "/foo/bar/bar", 4))
 	eh.addEvent(newEvent(Create, "/foo/foo/foo", 5))
 	eh.addEvent(newEvent(Create, "/foo/foo/foo", 5))
 
 
-	e, err := eh.scan("/foo", 1)
+	e, err := eh.scan("/foo", false, 1)
 	if err != nil || e.Index() != 1 {
 	if err != nil || e.Index() != 1 {
 		t.Fatalf("scan error [/foo] [1] %v", e.Index)
 		t.Fatalf("scan error [/foo] [1] %v", e.Index)
 	}
 	}
 
 
-	e, err = eh.scan("/foo/bar", 1)
+	e, err = eh.scan("/foo/bar", false, 1)
 
 
 	if err != nil || e.Index() != 2 {
 	if err != nil || e.Index() != 2 {
 		t.Fatalf("scan error [/foo/bar] [2] %v", e.Index)
 		t.Fatalf("scan error [/foo/bar] [2] %v", e.Index)
 	}
 	}
 
 
-	e, err = eh.scan("/foo/bar", 3)
+	e, err = eh.scan("/foo/bar", true, 3)
 
 
 	if err != nil || e.Index() != 4 {
 	if err != nil || e.Index() != 4 {
 		t.Fatalf("scan error [/foo/bar/bar] [4] %v", e.Index)
 		t.Fatalf("scan error [/foo/bar/bar] [4] %v", e.Index)
 	}
 	}
 
 
-	e, err = eh.scan("/foo/bar", 6)
+	e, err = eh.scan("/foo/bar", true, 6)
 
 
 	if e != nil {
 	if e != nil {
 		t.Fatalf("bad index shoud reuturn nil")
 		t.Fatalf("bad index shoud reuturn nil")

+ 4 - 5
store/store.go

@@ -280,8 +280,8 @@ func (s *store) Delete(nodePath string, recursive bool) (*Event, error) {
 	return e, nil
 	return e, nil
 }
 }
 
 
-func (s *store) Watch(prefix string, recursive bool, sinceIndex uint64) (<-chan *Event, error) {
-	prefix = path.Clean(path.Join("/", prefix))
+func (s *store) Watch(key string, recursive bool, sinceIndex uint64) (<-chan *Event, error) {
+	key = path.Clean(path.Join("/", key))
 
 
 	nextIndex := s.CurrentIndex + 1
 	nextIndex := s.CurrentIndex + 1
 
 
@@ -292,10 +292,10 @@ func (s *store) Watch(prefix string, recursive bool, sinceIndex uint64) (<-chan
 	var err *etcdErr.Error
 	var err *etcdErr.Error
 
 
 	if sinceIndex == 0 {
 	if sinceIndex == 0 {
-		c, err = s.WatcherHub.watch(prefix, recursive, nextIndex)
+		c, err = s.WatcherHub.watch(key, recursive, nextIndex)
 
 
 	} else {
 	} else {
-		c, err = s.WatcherHub.watch(prefix, recursive, sinceIndex)
+		c, err = s.WatcherHub.watch(key, recursive, sinceIndex)
 	}
 	}
 
 
 	if err != nil {
 	if err != nil {
@@ -396,7 +396,6 @@ func (s *store) internalCreate(nodePath string, value string, unique bool, repla
 		expireTime = Permanent
 		expireTime = Permanent
 	}
 	}
 
 
-
 	dir, newNodeName := path.Split(nodePath)
 	dir, newNodeName := path.Split(nodePath)
 
 
 	// walk through the nodePath, create dirs and get the last directory node
 	// walk through the nodePath, create dirs and get the last directory node

+ 6 - 6
store/watcher_hub.go

@@ -33,11 +33,11 @@ func newWatchHub(capacity int) *watcherHub {
 }
 }
 
 
 // watch function returns an Event channel.
 // watch function returns an Event channel.
-// If recursive is true, the first change after index under prefix will be sent to the event channel.
-// If recursive is false, the first change after index at prefix will be sent to the event channel.
+// If recursive is true, the first change after index under key will be sent to the event channel.
+// If recursive is false, the first change after index at key will be sent to the event channel.
 // If index is zero, watch will start from the current index + 1.
 // If index is zero, watch will start from the current index + 1.
-func (wh *watcherHub) watch(prefix string, recursive bool, index uint64) (<-chan *Event, *etcdErr.Error) {
-	event, err := wh.EventHistory.scan(prefix, index)
+func (wh *watcherHub) watch(key string, recursive bool, index uint64) (<-chan *Event, *etcdErr.Error) {
+	event, err := wh.EventHistory.scan(key, recursive, index)
 
 
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -57,7 +57,7 @@ func (wh *watcherHub) watch(prefix string, recursive bool, index uint64) (<-chan
 		sinceIndex: index,
 		sinceIndex: index,
 	}
 	}
 
 
-	l, ok := wh.watchers[prefix]
+	l, ok := wh.watchers[key]
 
 
 	if ok { // add the new watcher to the back of the list
 	if ok { // add the new watcher to the back of the list
 		l.PushBack(w)
 		l.PushBack(w)
@@ -65,7 +65,7 @@ func (wh *watcherHub) watch(prefix string, recursive bool, index uint64) (<-chan
 	} else { // create a new list and add the new watcher
 	} else { // create a new list and add the new watcher
 		l := list.New()
 		l := list.New()
 		l.PushBack(w)
 		l.PushBack(w)
-		wh.watchers[prefix] = l
+		wh.watchers[key] = l
 	}
 	}
 
 
 	atomic.AddInt64(&wh.count, 1)
 	atomic.AddInt64(&wh.count, 1)

+ 20 - 0
store/watcher_test.go

@@ -68,4 +68,24 @@ func TestWatcher(t *testing.T) {
 		t.Fatal("recv != send")
 		t.Fatal("recv != send")
 	}
 	}
 
 
+	// ensure we are doing exact matching rather than prefix matching
+	c, _ = wh.watch("/fo", true, 1)
+
+	select {
+	case re = <-c:
+		t.Fatal("should not receive from channel:", re)
+	default:
+		// do nothing
+	}
+
+	e = newEvent(Create, "/fo/bar", 3)
+
+	wh.notify(e)
+
+	re = <-c
+
+	if e != re {
+		t.Fatal("recv != send")
+	}
+
 }
 }