Explorar o código

server: use transporter as raft HTTP handler

Yicheng Qin %!s(int64=11) %!d(string=hai) anos
pai
achega
429b9487f7
Modificáronse 2 ficheiros con 36 adicións e 20 borrados
  1. 12 20
      etcd/etcd.go
  2. 24 0
      etcd/transporter.go

+ 12 - 20
etcd/etcd.go

@@ -62,13 +62,13 @@ type Server struct {
 
 
 	nodes  map[string]bool
 	nodes  map[string]bool
 	client *v2client
 	client *v2client
+	t      *transporter
 
 
 	// participant mode vars
 	// participant mode vars
 	proposal    chan v2Proposal
 	proposal    chan v2Proposal
 	node        *v2Raft
 	node        *v2Raft
 	addNodeC    chan raft.Config
 	addNodeC    chan raft.Config
 	removeNodeC chan raft.Config
 	removeNodeC chan raft.Config
-	t           *transporter
 
 
 	// standby mode vars
 	// standby mode vars
 	leader      int64
 	leader      int64
@@ -107,6 +107,7 @@ func New(c *config.Config, id int64) *Server {
 		raftPubAddr:  c.Peer.Addr,
 		raftPubAddr:  c.Peer.Addr,
 		nodes:        make(map[string]bool),
 		nodes:        make(map[string]bool),
 		client:       newClient(tc),
 		client:       newClient(tc),
+		t:            newTransporter(tc),
 		tickDuration: defaultTickDuration,
 		tickDuration: defaultTickDuration,
 
 
 		Store: store.New(),
 		Store: store.New(),
@@ -118,8 +119,7 @@ func New(c *config.Config, id int64) *Server {
 		Node:   raft.New(id, defaultHeartbeat, defaultElection),
 		Node:   raft.New(id, defaultHeartbeat, defaultElection),
 		result: make(map[wait]chan interface{}),
 		result: make(map[wait]chan interface{}),
 	}
 	}
-	t := newTransporter(tc)
-	s.initParticipant(node, t)
+	s.initParticipant(node)
 
 
 	for _, seed := range c.Peers {
 	for _, seed := range c.Peers {
 		s.nodes[seed] = true
 		s.nodes[seed] = true
@@ -145,7 +145,7 @@ func (s *Server) SetTick(d time.Duration) {
 }
 }
 
 
 func (s *Server) RaftHandler() http.Handler {
 func (s *Server) RaftHandler() http.Handler {
-	return http.HandlerFunc(s.ServeHTTPRaft)
+	return s.t
 }
 }
 
 
 func (s *Server) ClusterConfig() *config.ClusterConfig {
 func (s *Server) ClusterConfig() *config.ClusterConfig {
@@ -171,7 +171,7 @@ func (s *Server) Stop() {
 		return
 		return
 	}
 	}
 	s.mode = stop
 	s.mode = stop
-	s.t.stop()
+	s.t.closeConnections()
 	close(s.stop)
 	close(s.stop)
 }
 }
 
 
@@ -305,23 +305,12 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	}
 	}
 }
 }
 
 
-func (s *Server) ServeHTTPRaft(w http.ResponseWriter, r *http.Request) {
-	switch s.mode {
-	case participant:
-		s.t.ServeHTTP(w, r)
-	case standby:
-		http.NotFound(w, r)
-	case stop:
-		http.Error(w, "server is stopped", http.StatusInternalServerError)
-	}
-}
-
-func (s *Server) initParticipant(node *v2Raft, t *transporter) {
+func (s *Server) initParticipant(node *v2Raft) {
 	s.proposal = make(chan v2Proposal, maxBufferedProposal)
 	s.proposal = make(chan v2Proposal, maxBufferedProposal)
 	s.node = node
 	s.node = node
 	s.addNodeC = make(chan raft.Config, 1)
 	s.addNodeC = make(chan raft.Config, 1)
 	s.removeNodeC = make(chan raft.Config, 1)
 	s.removeNodeC = make(chan raft.Config, 1)
-	s.t = t
+	s.t.start()
 }
 }
 
 
 func (s *Server) initStandby(leader int64, leaderAddr string, conf *config.ClusterConfig) {
 func (s *Server) initStandby(leader int64, leaderAddr string, conf *config.ClusterConfig) {
@@ -351,11 +340,14 @@ func (s *Server) run() {
 }
 }
 
 
 func (s *Server) runParticipant() {
 func (s *Server) runParticipant() {
+	defer func() {
+		s.node.StopProposalWaiters()
+		s.t.stop()
+	}()
 	node := s.node
 	node := s.node
 	recv := s.t.recv
 	recv := s.t.recv
 	ticker := time.NewTicker(s.tickDuration)
 	ticker := time.NewTicker(s.tickDuration)
 	v2SyncTicker := time.NewTicker(time.Millisecond * 500)
 	v2SyncTicker := time.NewTicker(time.Millisecond * 500)
-	defer s.node.StopProposalWaiters()
 
 
 	var proposal chan v2Proposal
 	var proposal chan v2Proposal
 	var addNodeC, removeNodeC chan raft.Config
 	var addNodeC, removeNodeC chan raft.Config
@@ -434,7 +426,7 @@ func (s *Server) runStandby() {
 		Node:   raft.New(s.id, defaultHeartbeat, defaultElection),
 		Node:   raft.New(s.id, defaultHeartbeat, defaultElection),
 		result: make(map[wait]chan interface{}),
 		result: make(map[wait]chan interface{}),
 	}
 	}
-	s.initParticipant(node, s.t)
+	s.initParticipant(node)
 	s.mode = participant
 	s.mode = participant
 	return
 	return
 }
 }

+ 24 - 0
etcd/transporter.go

@@ -48,11 +48,19 @@ func newTransporter(tc *tls.Config) *transporter {
 	return t
 	return t
 }
 }
 
 
+func (t *transporter) start() {
+	t.mu.Lock()
+	t.stopped = false
+	t.mu.Unlock()
+}
+
 func (t *transporter) stop() {
 func (t *transporter) stop() {
 	t.mu.Lock()
 	t.mu.Lock()
 	t.stopped = true
 	t.stopped = true
 	t.mu.Unlock()
 	t.mu.Unlock()
+}
 
 
+func (t *transporter) closeConnections() {
 	t.wg.Wait()
 	t.wg.Wait()
 	tr := t.client.Transport.(*http.Transport)
 	tr := t.client.Transport.(*http.Transport)
 	tr.CloseIdleConnections()
 	tr.CloseIdleConnections()
@@ -125,6 +133,14 @@ func (t *transporter) fetchAddr(seedurl string, id int64) error {
 }
 }
 
 
 func (t *transporter) serveRaft(w http.ResponseWriter, r *http.Request) {
 func (t *transporter) serveRaft(w http.ResponseWriter, r *http.Request) {
+	t.mu.RLock()
+	if t.stopped {
+		t.mu.RUnlock()
+		http.Error(w, "404 page not found", http.StatusNotFound)
+		return
+	}
+	t.mu.RUnlock()
+
 	msg := new(raft.Message)
 	msg := new(raft.Message)
 	if err := json.NewDecoder(r.Body).Decode(msg); err != nil {
 	if err := json.NewDecoder(r.Body).Decode(msg); err != nil {
 		log.Println(err)
 		log.Println(err)
@@ -143,6 +159,14 @@ func (t *transporter) serveRaft(w http.ResponseWriter, r *http.Request) {
 }
 }
 
 
 func (t *transporter) serveCfg(w http.ResponseWriter, r *http.Request) {
 func (t *transporter) serveCfg(w http.ResponseWriter, r *http.Request) {
+	t.mu.RLock()
+	if t.stopped {
+		t.mu.RUnlock()
+		http.Error(w, "404 page not found", http.StatusNotFound)
+		return
+	}
+	t.mu.RUnlock()
+
 	id, err := strconv.ParseInt(r.URL.Path[len("/raft/cfg/"):], 10, 64)
 	id, err := strconv.ParseInt(r.URL.Path[len("/raft/cfg/"):], 10, 64)
 	if err != nil {
 	if err != nil {
 		http.Error(w, err.Error(), http.StatusBadRequest)
 		http.Error(w, err.Error(), http.StatusBadRequest)