Browse Source

feat(standby_server): make atomic move for file

to avoid the risk of writing out a corrupted file.
Yicheng Qin 11 years ago
parent
commit
71679bcf56
2 changed files with 27 additions and 35 deletions
  1. 25 35
      server/standby_server.go
  2. 2 0
      tests/functional/remove_node_test.go

+ 25 - 35
server/standby_server.go

@@ -3,6 +3,7 @@ package server
 import (
 import (
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
+	"io/ioutil"
 	"net/http"
 	"net/http"
 	"net/url"
 	"net/url"
 	"os"
 	"os"
@@ -41,8 +42,6 @@ type StandbyServer struct {
 	standbyInfo
 	standbyInfo
 	joinIndex uint64
 	joinIndex uint64
 
 
-	file *os.File
-
 	removeNotify chan bool
 	removeNotify chan bool
 	started      bool
 	started      bool
 	closeChan    chan bool
 	closeChan    chan bool
@@ -57,10 +56,9 @@ func NewStandbyServer(config StandbyServerConfig, client *Client) (*StandbyServe
 		client:      client,
 		client:      client,
 		standbyInfo: standbyInfo{SyncInterval: DefaultSyncInterval},
 		standbyInfo: standbyInfo{SyncInterval: DefaultSyncInterval},
 	}
 	}
-	if err := s.openStandbyInfo(); err != nil {
-		return nil, fmt.Errorf("error open/create cluster info file: %v", err)
+	if err := s.loadInfo(); err != nil {
+		return nil, fmt.Errorf("error load standby info file: %v", err)
 	}
 	}
-	s.loadStandbyInfo()
 	return s, nil
 	return s, nil
 }
 }
 
 
@@ -95,8 +93,8 @@ func (s *StandbyServer) Stop() {
 	close(s.closeChan)
 	close(s.closeChan)
 	s.routineGroup.Wait()
 	s.routineGroup.Wait()
 
 
-	if err := s.clearStandbyInfo(); err != nil {
-		log.Warnf("error clearing cluster info for standby")
+	if err := s.saveInfo(); err != nil {
+		log.Warnf("error saving cluster info for standby")
 	}
 	}
 	s.Running = false
 	s.Running = false
 }
 }
@@ -228,7 +226,7 @@ func (s *StandbyServer) syncCluster(peerURLs []string) error {
 
 
 		s.setCluster(machines)
 		s.setCluster(machines)
 		s.SetSyncInterval(config.SyncInterval)
 		s.SetSyncInterval(config.SyncInterval)
-		if err := s.saveStandbyInfo(); err != nil {
+		if err := s.saveInfo(); err != nil {
 			log.Warnf("fail saving cluster info into disk: %v", err)
 			log.Warnf("fail saving cluster info into disk: %v", err)
 		}
 		}
 		return nil
 		return nil
@@ -286,47 +284,39 @@ func (s *StandbyServer) fullPeerURL(urlStr string) string {
 	return u.String()
 	return u.String()
 }
 }
 
 
-func (s *StandbyServer) openStandbyInfo() error {
-	var err error
+func (s *StandbyServer) loadInfo() error {
+	var info standbyInfo
+
 	path := filepath.Join(s.Config.DataDir, standbyInfoName)
 	path := filepath.Join(s.Config.DataDir, standbyInfoName)
-	s.file, err = os.OpenFile(path, os.O_RDWR, 0600)
+	file, err := os.OpenFile(path, os.O_RDONLY, 0600)
 	if err != nil {
 	if err != nil {
 		if os.IsNotExist(err) {
 		if os.IsNotExist(err) {
-			s.file, err = os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0600)
+			return nil
 		}
 		}
 		return err
 		return err
 	}
 	}
-	return nil
-}
-
-func (s *StandbyServer) loadStandbyInfo() ([]*machineMessage, error) {
-	if _, err := s.file.Seek(0, os.SEEK_SET); err != nil {
-		return nil, err
-	}
-	if err := json.NewDecoder(s.file).Decode(&s.standbyInfo); err != nil {
-		return nil, err
+	defer file.Close()
+	if err = json.NewDecoder(file).Decode(&info); err != nil {
+		return err
 	}
 	}
-	return s.standbyInfo.Cluster, nil
+	s.standbyInfo = info
+	return nil
 }
 }
 
 
-func (s *StandbyServer) saveStandbyInfo() error {
-	if err := s.clearStandbyInfo(); err != nil {
-		return nil
-	}
-	if err := json.NewEncoder(s.file).Encode(s.standbyInfo); err != nil {
+func (s *StandbyServer) saveInfo() error {
+	tmpFile, err := ioutil.TempFile(s.Config.DataDir, standbyInfoName)
+	if err != nil {
 		return err
 		return err
 	}
 	}
-	if err := s.file.Sync(); err != nil {
+	if err = json.NewEncoder(tmpFile).Encode(s.standbyInfo); err != nil {
+		tmpFile.Close()
+		os.Remove(tmpFile.Name())
 		return err
 		return err
 	}
 	}
-	return nil
-}
+	tmpFile.Close()
 
 
-func (s *StandbyServer) clearStandbyInfo() error {
-	if _, err := s.file.Seek(0, os.SEEK_SET); err != nil {
-		return err
-	}
-	if err := s.file.Truncate(0); err != nil {
+	path := filepath.Join(s.Config.DataDir, standbyInfoName)
+	if err = os.Rename(tmpFile.Name(), path); err != nil {
 		return err
 		return err
 	}
 	}
 	return nil
 	return nil

+ 2 - 0
tests/functional/remove_node_test.go

@@ -105,6 +105,8 @@ func TestRemoveNode(t *testing.T) {
 
 
 			client.Do(rmReq)
 			client.Do(rmReq)
 
 
+			time.Sleep(100 * time.Millisecond)
+
 			resp, err := c.Get("_etcd/machines", false, false)
 			resp, err := c.Get("_etcd/machines", false, false)
 
 
 			if err != nil {
 			if err != nil {