Browse Source

support watch delete

Xiang Li 12 years ago
parent
commit
bd8ec6d67b

+ 12 - 6
file_system/file_system.go

@@ -198,12 +198,6 @@ func (fs *FileSystem) Delete(keyPath string, recurisive bool, index uint64, term
 		return nil, err
 	}
 
-	err = n.Remove(recurisive)
-
-	if err != nil {
-		return nil, err
-	}
-
 	e := newEvent(Delete, keyPath, index, term)
 
 	if n.IsDir() {
@@ -212,6 +206,18 @@ func (fs *FileSystem) Delete(keyPath string, recurisive bool, index uint64, term
 		e.PrevValue = n.Value
 	}
 
+	callback := func(path string) {
+		fs.WatcherHub.notifyWithPath(e, path, true)
+	}
+
+	err = n.Remove(recurisive, callback)
+
+	if err != nil {
+		return nil, err
+	}
+
+	fs.WatcherHub.notify(e)
+
 	return e, nil
 }
 

+ 23 - 0
file_system/file_system_test.go

@@ -230,6 +230,29 @@ func TestTestAndSet(t *testing.T) {
 	}
 }
 
+func TestWatchRemove(t *testing.T) {
+	fs := New()
+	fs.Create("/foo/foo/foo", "bar", Permanent, 1, 1)
+
+	// watch at a deeper path
+	c, _ := fs.WatcherHub.watch("/foo/foo/foo", false, 0)
+	fs.Delete("/foo", true, 2, 1)
+	e := <-c
+	if e.Key != "/foo" {
+		t.Fatal("watch for delete fails")
+	}
+
+	fs.Create("/foo/foo/foo", "bar", Permanent, 3, 1)
+	// watch at a prefix
+	c, _ = fs.WatcherHub.watch("/foo", true, 0)
+	fs.Delete("/foo/foo/foo", false, 4, 1)
+	e = <-c
+	if e.Key != "/foo/foo/foo" {
+		t.Fatal("watch for delete fails")
+	}
+
+}
+
 func createAndGet(fs *FileSystem, path string, t *testing.T) {
 	_, err := fs.Create(path, "bar", Permanent, 1, 1)
 

+ 14 - 4
file_system/node.go

@@ -65,7 +65,7 @@ func newDir(keyPath string, createIndex uint64, createTerm uint64, parent *Node,
 // Remove function remove the node.
 // If the node is a directory and recursive is true, the function will recursively remove
 // add nodes under the receiver node.
-func (n *Node) Remove(recursive bool) error {
+func (n *Node) Remove(recursive bool, callback func(path string)) error {
 	n.mu.Lock()
 	defer n.mu.Unlock()
 
@@ -80,6 +80,11 @@ func (n *Node) Remove(recursive bool) error {
 			// This is the only pointer to Node object
 			// Handled by garbage collector
 			delete(n.Parent.Children, name)
+
+			if callback != nil {
+				callback(n.Path)
+			}
+
 			n.stopExpire <- true
 			n.status = removed
 		}
@@ -92,13 +97,18 @@ func (n *Node) Remove(recursive bool) error {
 	}
 
 	for _, child := range n.Children { // delete all children
-		child.Remove(true)
+		child.Remove(true, callback)
 	}
 
 	// delete self
 	_, name := path.Split(n.Path)
 	if n.Parent.Children[name] == n {
 		delete(n.Parent.Children, name)
+
+		if callback != nil {
+			callback(n.Path)
+		}
+
 		n.stopExpire <- true
 		n.status = removed
 	}
@@ -235,14 +245,14 @@ func (n *Node) IsDir() bool {
 func (n *Node) Expire() {
 	duration := n.ExpireTime.Sub(time.Now())
 	if duration <= 0 {
-		n.Remove(true)
+		n.Remove(true, nil)
 		return
 	}
 
 	select {
 	// if timeout, delete the node
 	case <-time.After(duration):
-		n.Remove(true)
+		n.Remove(true, nil)
 		return
 
 	// if stopped, return

+ 37 - 39
file_system/watcher.go

@@ -28,18 +28,18 @@ func newWatchHub(capacity int) *watcherHub {
 // If recursive is true, the first change after index under prefix will be sent to the event channel.
 // If recursive is false, the first change after index at prefix will be sent to the event channel.
 // If index is zero, watch will start from the current index + 1.
-func (wh *watcherHub) watch(prefix string, recursive bool, index uint64) (error, <-chan *Event) {
+func (wh *watcherHub) watch(prefix string, recursive bool, index uint64) (<-chan *Event, error) {
 	eventChan := make(chan *Event, 1)
 
 	e, err := wh.EventHistory.scan(prefix, index)
 
 	if err != nil {
-		return err, nil
+		return nil, err
 	}
 
 	if e != nil {
 		eventChan <- e
-		return nil, eventChan
+		return eventChan, nil
 	}
 
 	w := &watcher{
@@ -58,57 +58,55 @@ func (wh *watcherHub) watch(prefix string, recursive bool, index uint64) (error,
 		wh.watchers[prefix] = l
 	}
 
-	return nil, eventChan
+	return eventChan, nil
 }
 
-func (wh *watcherHub) notify(e *Event) {
+func (wh *watcherHub) notifyWithPath(e *Event, path string, force bool) {
+	l, ok := wh.watchers[path]
 
-	segments := strings.Split(e.Key, "/")
-	currPath := "/"
+	if ok {
 
-	// walk through all the paths
-	for _, segment := range segments {
-		currPath = path.Join(currPath, segment)
+		curr := l.Front()
+		notifiedAll := true
 
-		l, ok := wh.watchers[currPath]
+		for {
 
-		if ok {
+			if curr == nil { // we have reached the end of the list
 
-			curr := l.Front()
-			notifiedAll := true
+				if notifiedAll {
+					// if we have notified all watcher in the list
+					// we can delete the list
+					delete(wh.watchers, path)
+				}
+				break
+			}
 
-			for {
+			next := curr.Next() // save the next
 
-				if curr == nil { // we have reached the end of the list
+			w, _ := curr.Value.(*watcher)
 
-					if notifiedAll {
-						// if we have notified all watcher in the list
-						// we can delete the list
-						delete(wh.watchers, currPath)
-					}
-					break
-				}
+			if w.recursive || force || e.Key == path {
+				w.eventChan <- e
+				l.Remove(curr)
+			} else {
+				notifiedAll = false
+			}
 
-				next := curr.Next() // save the next
+			curr = next // go to the next one
 
-				w, _ := curr.Value.(*watcher)
+		}
+	}
+}
 
-				if w.recursive {
-					w.eventChan <- e
-					l.Remove(curr)
-				} else {
-					if e.Key == currPath { // only notify the same path
-						w.eventChan <- e
-						l.Remove(curr)
-					} else { // we do not notify all watcher in the list
-						notifiedAll = false
-					}
-				}
+func (wh *watcherHub) notify(e *Event) {
 
-				curr = next // go to the next one
+	segments := strings.Split(e.Key, "/")
 
-			}
-		}
+	currPath := "/"
 
+	// walk through all the paths
+	for _, segment := range segments {
+		currPath = path.Join(currPath, segment)
+		wh.notifyWithPath(e, currPath, false)
 	}
 }

+ 2 - 2
file_system/watcher_test.go

@@ -6,7 +6,7 @@ import (
 
 func TestWatch(t *testing.T) {
 	wh := newWatchHub(100)
-	err, c := wh.watch("/foo", true, 0)
+	c, err := wh.watch("/foo", true, 0)
 
 	if err != nil {
 		t.Fatal("%v", err)
@@ -29,7 +29,7 @@ func TestWatch(t *testing.T) {
 		t.Fatal("recv != send")
 	}
 
-	_, c = wh.watch("/foo", false, 0)
+	c, _ = wh.watch("/foo", false, 0)
 
 	e = newEvent(Set, "/foo/bar", 1, 0)