Browse Source

refactor(peer-server): move listener init out of peer_server.go

Brian Waldon 12 years ago
parent
commit
7bd4d05a38
3 changed files with 43 additions and 93 deletions
  1. 12 2
      etcd.go
  2. 25 89
      server/peer_server.go
  3. 6 2
      tests/server_utils.go

+ 12 - 2
etcd.go

@@ -114,8 +114,8 @@ func main() {
 	psConfig := server.PeerServerConfig{
 	psConfig := server.PeerServerConfig{
 		Name:             info.Name,
 		Name:             info.Name,
 		Path:             config.DataDir,
 		Path:             config.DataDir,
+		Scheme:           peerTLSConfig.Scheme,
 		URL:              info.RaftURL,
 		URL:              info.RaftURL,
-		BindAddr:         info.RaftListenHost,
 		SnapshotCount:    config.SnapshotCount,
 		SnapshotCount:    config.SnapshotCount,
 		HeartbeatTimeout: time.Duration(config.Peer.HeartbeatTimeout) * time.Millisecond,
 		HeartbeatTimeout: time.Duration(config.Peer.HeartbeatTimeout) * time.Millisecond,
 		ElectionTimeout:  time.Duration(config.Peer.ElectionTimeout) * time.Millisecond,
 		ElectionTimeout:  time.Duration(config.Peer.ElectionTimeout) * time.Millisecond,
@@ -125,6 +125,16 @@ func main() {
 	}
 	}
 	ps := server.NewPeerServer(psConfig, &peerTLSConfig, &info.RaftTLS, registry, store, &mb)
 	ps := server.NewPeerServer(psConfig, &peerTLSConfig, &info.RaftTLS, registry, store, &mb)
 
 
+	var psListener net.Listener
+	if psConfig.Scheme == "https" {
+		psListener, err = server.NewTLSListener(info.RaftListenHost, info.RaftTLS.CertFile, info.RaftTLS.KeyFile)
+	} else {
+		psListener, err = server.NewListener(info.RaftListenHost)
+	}
+	if err != nil {
+		panic(err)
+	}
+
 	// Create client server.
 	// Create client server.
 	sConfig := server.ServerConfig{
 	sConfig := server.ServerConfig{
 		Name:     info.Name,
 		Name:     info.Name,
@@ -151,7 +161,7 @@ func main() {
 
 
 	// Run peer server in separate thread while the client server blocks.
 	// Run peer server in separate thread while the client server blocks.
 	go func() {
 	go func() {
-		log.Fatal(ps.ListenAndServe(config.Snapshot, config.Peers))
+		log.Fatal(ps.Serve(psListener, config.Snapshot, config.Peers))
 	}()
 	}()
 	log.Fatal(s.Serve(sListener))
 	log.Fatal(s.Serve(sListener))
 }
 }

+ 25 - 89
server/peer_server.go

@@ -2,7 +2,6 @@ package server
 
 
 import (
 import (
 	"bytes"
 	"bytes"
-	"crypto/tls"
 	"encoding/binary"
 	"encoding/binary"
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
@@ -29,8 +28,8 @@ const ThresholdMonitorTimeout = 5 * time.Second
 type PeerServerConfig struct {
 type PeerServerConfig struct {
 	Name             string
 	Name             string
 	Path             string
 	Path             string
+	Scheme           string
 	URL              string
 	URL              string
-	BindAddr         string
 	SnapshotCount    int
 	SnapshotCount    int
 	HeartbeatTimeout time.Duration
 	HeartbeatTimeout time.Duration
 	ElectionTimeout  time.Duration
 	ElectionTimeout  time.Duration
@@ -43,8 +42,6 @@ type PeerServer struct {
 	Config         PeerServerConfig
 	Config         PeerServerConfig
 	raftServer     raft.Server
 	raftServer     raft.Server
 	server         *Server
 	server         *Server
-	httpServer     *http.Server
-	listener       net.Listener
 	joinIndex      uint64
 	joinIndex      uint64
 	tlsConf        *TLSConfig
 	tlsConf        *TLSConfig
 	tlsInfo        *TLSInfo
 	tlsInfo        *TLSInfo
@@ -54,6 +51,8 @@ type PeerServer struct {
 	store          store.Store
 	store          store.Store
 	snapConf       *snapshotConf
 	snapConf       *snapshotConf
 
 
+	listener net.Listener
+
 	closeChan            chan bool
 	closeChan            chan bool
 	timeoutThresholdChan chan interface{}
 	timeoutThresholdChan chan interface{}
 
 
@@ -77,8 +76,6 @@ func NewPeerServer(psConfig PeerServerConfig, tlsConf *TLSConfig, tlsInfo *TLSIn
 	s := &PeerServer{
 	s := &PeerServer{
 		Config: psConfig,
 		Config: psConfig,
 
 
-		tlsConf:  tlsConf,
-		tlsInfo:  tlsInfo,
 		registry: registry,
 		registry: registry,
 		store:    store,
 		store:    store,
 		followersStats: &raftFollowersStats{
 		followersStats: &raftFollowersStats{
@@ -132,7 +129,7 @@ func NewPeerServer(psConfig PeerServerConfig, tlsConf *TLSConfig, tlsInfo *TLSIn
 }
 }
 
 
 // Start the raft server
 // Start the raft server
-func (s *PeerServer) ListenAndServe(snapshot bool, cluster []string) error {
+func (s *PeerServer) Serve(listener net.Listener, snapshot bool, cluster []string) error {
 	// LoadSnapshot
 	// LoadSnapshot
 	if snapshot {
 	if snapshot {
 		err := s.raftServer.LoadSnapshot()
 		err := s.raftServer.LoadSnapshot()
@@ -185,56 +182,29 @@ func (s *PeerServer) ListenAndServe(snapshot bool, cluster []string) error {
 		go s.monitorSnapshot()
 		go s.monitorSnapshot()
 	}
 	}
 
 
-	// start to response to raft requests
-	return s.startTransport(s.tlsConf.Scheme, s.tlsConf.Server)
-}
-
-// Overridden version of net/http added so we can manage the listener.
-func (s *PeerServer) listenAndServe() error {
-	addr := s.httpServer.Addr
-	if addr == "" {
-		addr = ":http"
-	}
-	l, e := net.Listen("tcp", addr)
-	if e != nil {
-		return e
-	}
-	s.listener = l
-	return s.httpServer.Serve(l)
-}
-
-// Overridden version of net/http added so we can manage the listener.
-func (s *PeerServer) listenAndServeTLS(certFile, keyFile string) error {
-	addr := s.httpServer.Addr
-	if addr == "" {
-		addr = ":https"
-	}
-	config := &tls.Config{}
-	if s.httpServer.TLSConfig != nil {
-		*config = *s.httpServer.TLSConfig
-	}
-	if config.NextProtos == nil {
-		config.NextProtos = []string{"http/1.1"}
-	}
-
-	var err error
-	config.Certificates = make([]tls.Certificate, 1)
-	config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
-	if err != nil {
-		return err
-	}
+	router := mux.NewRouter()
+	httpServer := &http.Server{Handler: router}
 
 
-	conn, err := net.Listen("tcp", addr)
-	if err != nil {
-		return err
-	}
+	// internal commands
+	router.HandleFunc("/name", s.NameHttpHandler)
+	router.HandleFunc("/version", s.VersionHttpHandler)
+	router.HandleFunc("/version/{version:[0-9]+}/check", s.VersionCheckHttpHandler)
+	router.HandleFunc("/upgrade", s.UpgradeHttpHandler)
+	router.HandleFunc("/join", s.JoinHttpHandler)
+	router.HandleFunc("/remove/{name:.+}", s.RemoveHttpHandler)
+	router.HandleFunc("/vote", s.VoteHttpHandler)
+	router.HandleFunc("/log", s.GetLogHttpHandler)
+	router.HandleFunc("/log/append", s.AppendEntriesHttpHandler)
+	router.HandleFunc("/snapshot", s.SnapshotHttpHandler)
+	router.HandleFunc("/snapshotRecovery", s.SnapshotRecoveryHttpHandler)
+	router.HandleFunc("/etcdURL", s.EtcdURLHttpHandler)
 
 
-	tlsListener := tls.NewListener(conn, config)
-	s.listener = tlsListener
-	return s.httpServer.Serve(tlsListener)
+	s.listener = listener
+	log.Infof("raft server [name %s, listen on %s, advertised url %s]", s.Config.Name, listener.Addr(), s.Config.URL)
+	httpServer.Serve(listener)
+	return nil
 }
 }
 
 
-// Stops the server.
 func (s *PeerServer) Close() {
 func (s *PeerServer) Close() {
 	if s.closeChan != nil {
 	if s.closeChan != nil {
 		close(s.closeChan)
 		close(s.closeChan)
@@ -281,40 +251,6 @@ func (s *PeerServer) startAsFollower(cluster []string) {
 	log.Fatalf("Cannot join the cluster via given peers after %x retries", s.Config.RetryTimes)
 	log.Fatalf("Cannot join the cluster via given peers after %x retries", s.Config.RetryTimes)
 }
 }
 
 
-// Start to listen and response raft command
-func (s *PeerServer) startTransport(scheme string, tlsConf tls.Config) error {
-	log.Infof("raft server [name %s, listen on %s, advertised url %s]", s.Config.Name, s.Config.BindAddr, s.Config.URL)
-
-	router := mux.NewRouter()
-
-	s.httpServer = &http.Server{
-		Handler:   router,
-		TLSConfig: &tlsConf,
-		Addr:      s.Config.BindAddr,
-	}
-
-	// internal commands
-	router.HandleFunc("/name", s.NameHttpHandler)
-	router.HandleFunc("/version", s.VersionHttpHandler)
-	router.HandleFunc("/version/{version:[0-9]+}/check", s.VersionCheckHttpHandler)
-	router.HandleFunc("/upgrade", s.UpgradeHttpHandler)
-	router.HandleFunc("/join", s.JoinHttpHandler)
-	router.HandleFunc("/remove/{name:.+}", s.RemoveHttpHandler)
-	router.HandleFunc("/vote", s.VoteHttpHandler)
-	router.HandleFunc("/log", s.GetLogHttpHandler)
-	router.HandleFunc("/log/append", s.AppendEntriesHttpHandler)
-	router.HandleFunc("/snapshot", s.SnapshotHttpHandler)
-	router.HandleFunc("/snapshotRecovery", s.SnapshotRecoveryHttpHandler)
-	router.HandleFunc("/etcdURL", s.EtcdURLHttpHandler)
-
-	if scheme == "http" {
-		return s.listenAndServe()
-	} else {
-		return s.listenAndServeTLS(s.tlsInfo.CertFile, s.tlsInfo.KeyFile)
-	}
-
-}
-
 // getVersion fetches the peer version of a cluster.
 // getVersion fetches the peer version of a cluster.
 func getVersion(t *transporter, versionURL url.URL) (int, error) {
 func getVersion(t *transporter, versionURL url.URL) (int, error) {
 	resp, req, err := t.Get(versionURL.String())
 	resp, req, err := t.Get(versionURL.String())
@@ -344,7 +280,7 @@ func (s *PeerServer) Upgradable() error {
 		}
 		}
 
 
 		t, _ := s.raftServer.Transporter().(*transporter)
 		t, _ := s.raftServer.Transporter().(*transporter)
-		checkURL := (&url.URL{Host: u.Host, Scheme: s.tlsConf.Scheme, Path: fmt.Sprintf("/version/%d/check", nextVersion)}).String()
+		checkURL := (&url.URL{Host: u.Host, Scheme: s.Config.Scheme, Path: fmt.Sprintf("/version/%d/check", nextVersion)}).String()
 		resp, _, err := t.Get(checkURL)
 		resp, _, err := t.Get(checkURL)
 		if err != nil {
 		if err != nil {
 			return fmt.Errorf("PeerServer: Cannot check version compatibility: %s", u.Host)
 			return fmt.Errorf("PeerServer: Cannot check version compatibility: %s", u.Host)
@@ -363,7 +299,7 @@ func (s *PeerServer) joinCluster(cluster []string) bool {
 			continue
 			continue
 		}
 		}
 
 
-		err := s.joinByPeer(s.raftServer, peer, s.tlsConf.Scheme)
+		err := s.joinByPeer(s.raftServer, peer, s.Config.Scheme)
 		if err == nil {
 		if err == nil {
 			log.Debugf("%s success join to the cluster via peer %s", s.Config.Name, peer)
 			log.Debugf("%s success join to the cluster via peer %s", s.Config.Name, peer)
 			return true
 			return true

+ 6 - 2
tests/server_utils.go

@@ -31,7 +31,7 @@ func RunServer(f func(*server.Server)) {
 		Name: testName,
 		Name: testName,
 		Path: path,
 		Path: path,
 		URL: "http://"+testRaftURL,
 		URL: "http://"+testRaftURL,
-		BindAddr: testRaftURL,
+		Scheme: "http",
 		SnapshotCount: testSnapshotCount,
 		SnapshotCount: testSnapshotCount,
 		HeartbeatTimeout: testHeartbeatTimeout,
 		HeartbeatTimeout: testHeartbeatTimeout,
 		ElectionTimeout: testElectionTimeout,
 		ElectionTimeout: testElectionTimeout,
@@ -39,6 +39,10 @@ func RunServer(f func(*server.Server)) {
 		CORS: corsInfo,
 		CORS: corsInfo,
 	}
 	}
 	ps := server.NewPeerServer(psConfig, &server.TLSConfig{Scheme: "http"}, &server.TLSInfo{}, registry, store, nil)
 	ps := server.NewPeerServer(psConfig, &server.TLSConfig{Scheme: "http"}, &server.TLSInfo{}, registry, store, nil)
+	psListener, err := server.NewListener(testRaftURL)
+	if err != nil {
+		panic(err)
+	}
 
 
 	sConfig := server.ServerConfig{
 	sConfig := server.ServerConfig{
 		Name: testName,
 		Name: testName,
@@ -57,7 +61,7 @@ func RunServer(f func(*server.Server)) {
 	c := make(chan bool)
 	c := make(chan bool)
 	go func() {
 	go func() {
 		c <- true
 		c <- true
-		ps.ListenAndServe(false, []string{})
+		ps.Serve(psListener, false, []string{})
 	}()
 	}()
 	<-c
 	<-c