Browse Source

simplify lock

Xiang Li 12 years ago
parent
commit
558d30f33f
3 changed files with 14 additions and 36 deletions
  1. 2 24
      store/node.go
  2. 9 9
      store/store.go
  3. 3 3
      store/store_test.go

+ 2 - 24
store/node.go

@@ -3,7 +3,6 @@ package store
 import (
 	"path"
 	"sort"
-	"sync"
 	"time"
 
 	etcdErr "github.com/coreos/etcd/error"
@@ -30,7 +29,6 @@ type Node struct {
 	Value         string           // for key-value pair
 	Children      map[string]*Node // for directory
 	status        int
-	mu            sync.Mutex
 	stopExpire    chan bool // stop expire routine channel
 }
 
@@ -66,10 +64,7 @@ func newDir(nodePath string, createIndex uint64, createTerm uint64, parent *Node
 // If the node is a directory and recursive is true, the function will recursively remove
 // add nodes under the receiver node.
 func (n *Node) Remove(recursive bool, callback func(path string)) error {
-	n.mu.Lock()
-	defer n.mu.Unlock()
-
-	if n.status == removed {
+	if n.status == removed { // check race between remove and expire
 		return nil
 	}
 
@@ -144,8 +139,6 @@ func (n *Node) Write(value string, index uint64, term uint64) error {
 // List function return a slice of nodes under the receiver node.
 // If the receiver node is not a directory, a "Not A Directory" error will be returned.
 func (n *Node) List() ([]*Node, error) {
-	n.mu.Lock()
-	defer n.mu.Unlock()
 	if !n.IsDir() {
 		return nil, etcdErr.NewError(etcdErr.EcodeNotDir, "")
 	}
@@ -168,9 +161,6 @@ func (n *Node) List() ([]*Node, error) {
 // If the node corresponding to the name string is not file, it returns
 // Not File Error
 func (n *Node) GetFile(name string) (*Node, error) {
-	n.mu.Lock()
-	defer n.mu.Unlock()
-
 	if !n.IsDir() {
 		return nil, etcdErr.NewError(etcdErr.EcodeNotDir, n.Path)
 	}
@@ -193,12 +183,6 @@ func (n *Node) GetFile(name string) (*Node, error) {
 // If there is a existing node with the same name under the directory, a "Already Exist"
 // error will be returned
 func (n *Node) Add(child *Node) error {
-	n.mu.Lock()
-	defer n.mu.Unlock()
-	if n.status == removed {
-		return etcdErr.NewError(etcdErr.EcodeKeyNotFound, "")
-	}
-
 	if !n.IsDir() {
 		return etcdErr.NewError(etcdErr.EcodeNotDir, "")
 	}
@@ -220,8 +204,6 @@ func (n *Node) Add(child *Node) error {
 // If the node is a directory, it will clone all the content under this directory.
 // If the node is a key-value pair, it will clone the pair.
 func (n *Node) Clone() *Node {
-	n.mu.Lock()
-	defer n.mu.Unlock()
 	if !n.IsDir() {
 		return newFile(n.Path, n.Value, n.CreateIndex, n.CreateTerm, n.Parent, n.ACL, n.ExpireTime)
 	}
@@ -256,7 +238,6 @@ func (n *Node) Expire(s *Store) {
 	expired, duration := n.IsExpired()
 
 	if expired { // has been expired
-
 		// since the parent function of Expire() runs serially,
 		// there is no need for lock here
 		e := newEvent(Expire, n.Path, UndefIndex, UndefTerm)
@@ -277,9 +258,8 @@ func (n *Node) Expire(s *Store) {
 		// if timeout, delete the node
 		case <-time.After(duration):
 
-			// Lock the worldLock to avoid race on s.WatchHub,
-			// and the race with other slibling nodes on their common parent.
 			s.worldLock.Lock()
+			defer s.worldLock.Unlock()
 
 			e := newEvent(Expire, n.Path, UndefIndex, UndefTerm)
 			s.WatcherHub.notify(e)
@@ -287,8 +267,6 @@ func (n *Node) Expire(s *Store) {
 			n.Remove(true, nil)
 			s.Stats.Inc(ExpireCount)
 
-			s.worldLock.Unlock()
-
 			return
 
 		// if stopped, return

+ 9 - 9
store/store.go

@@ -18,7 +18,7 @@ type Store struct {
 	Index      uint64
 	Term       uint64
 	Stats      *Stats
-	worldLock  sync.RWMutex // stop the world lock. Used to do snapshot
+	worldLock  sync.RWMutex // stop the world lock
 }
 
 func New() *Store {
@@ -95,8 +95,8 @@ func (s *Store) Get(nodePath string, recursive, sorted bool, index uint64, term
 // If the node has already existed, create will fail.
 // If any node on the path is a file, create will fail.
 func (s *Store) Create(nodePath string, value string, expireTime time.Time, index uint64, term uint64) (*Event, error) {
-	s.worldLock.RLock()
-	defer s.worldLock.RUnlock()
+	s.worldLock.Lock()
+	defer s.worldLock.Unlock()
 
 	nodePath = path.Clean(path.Join("/", nodePath))
 
@@ -164,8 +164,8 @@ func (s *Store) Create(nodePath string, value string, expireTime time.Time, inde
 // If the node is a file, the value and the ttl can be updated.
 // If the node is a directory, only the ttl can be updated.
 func (s *Store) Update(nodePath string, value string, expireTime time.Time, index uint64, term uint64) (*Event, error) {
-	s.worldLock.RLock()
-	defer s.worldLock.RUnlock()
+	s.worldLock.Lock()
+	defer s.worldLock.Unlock()
 
 	n, err := s.internalGet(nodePath, index, term)
 
@@ -209,8 +209,8 @@ func (s *Store) Update(nodePath string, value string, expireTime time.Time, inde
 func (s *Store) TestAndSet(nodePath string, prevValue string, prevIndex uint64,
 	value string, expireTime time.Time, index uint64, term uint64) (*Event, error) {
 
-	s.worldLock.RLock()
-	defer s.worldLock.RUnlock()
+	s.worldLock.Lock()
+	defer s.worldLock.Unlock()
 
 	n, err := s.internalGet(nodePath, index, term)
 
@@ -246,8 +246,8 @@ func (s *Store) TestAndSet(nodePath string, prevValue string, prevIndex uint64,
 // Delete function deletes the node at the given path.
 // If the node is a directory, recursive must be true to delete it.
 func (s *Store) Delete(nodePath string, recursive bool, index uint64, term uint64) (*Event, error) {
-	s.worldLock.RLock()
-	defer s.worldLock.RUnlock()
+	s.worldLock.Lock()
+	defer s.worldLock.Unlock()
 
 	n, err := s.internalGet(nodePath, index, term)
 

+ 3 - 3
store/store_test.go

@@ -244,7 +244,7 @@ func TestExpire(t *testing.T) {
 
 	s.Create("/foo", "bar", expire, 1, 1)
 
-	_, err := s.internalGet("/foo", 1, 1)
+	_, err := s.Get("/foo", false, false, 1, 1)
 
 	if err != nil {
 		t.Fatalf("can not get the node")
@@ -252,7 +252,7 @@ func TestExpire(t *testing.T) {
 
 	time.Sleep(time.Second * 2)
 
-	_, err = s.internalGet("/foo", 1, 1)
+	_, err = s.Get("/foo", false, false, 1, 1)
 
 	if err == nil {
 		t.Fatalf("can get the node after expiration time")
@@ -263,7 +263,7 @@ func TestExpire(t *testing.T) {
 	s.Create("/foo", "bar", expire, 1, 1)
 
 	time.Sleep(time.Millisecond * 50)
-	_, err = s.internalGet("/foo", 1, 1)
+	_, err = s.Get("/foo", false, false, 1, 1)
 
 	if err != nil {
 		t.Fatalf("cannot get the node before expiration", err.Error())