Browse Source

Merge pull request #626 from xiangli-cmu/refactor_listener

refactor(listener) refactor listener related code
Xiang Li 11 years ago
parent
commit
79e4c838f4
4 changed files with 61 additions and 75 deletions
  1. 4 4
      config/config.go
  2. 7 41
      etcd.go
  3. 28 2
      server/listener.go
  4. 22 28
      tests/server_utils.go

+ 4 - 4
config/config.go

@@ -390,8 +390,8 @@ func (c *Config) Sanitize() error {
 }
 
 // EtcdTLSInfo retrieves a TLSInfo object for the etcd server
-func (c *Config) EtcdTLSInfo() server.TLSInfo {
-	return server.TLSInfo{
+func (c *Config) EtcdTLSInfo() *server.TLSInfo {
+	return &server.TLSInfo{
 		CAFile:   c.CAFile,
 		CertFile: c.CertFile,
 		KeyFile:  c.KeyFile,
@@ -399,8 +399,8 @@ func (c *Config) EtcdTLSInfo() server.TLSInfo {
 }
 
 // PeerRaftInfo retrieves a TLSInfo object for the peer server.
-func (c *Config) PeerTLSInfo() server.TLSInfo {
-	return server.TLSInfo{
+func (c *Config) PeerTLSInfo() *server.TLSInfo {
+	return &server.TLSInfo{
 		CAFile:   c.Peer.CAFile,
 		CertFile: c.Peer.CertFile,
 		KeyFile:  c.Peer.KeyFile,

+ 7 - 41
etcd.go

@@ -18,7 +18,6 @@ package main
 
 import (
 	"fmt"
-	"net"
 	"net/http"
 	"os"
 	"path/filepath"
@@ -126,24 +125,6 @@ func main() {
 	}
 	ps := server.NewPeerServer(psConfig, registry, store, &mb, followersStats, serverStats)
 
-	var psListener net.Listener
-	if psConfig.Scheme == "https" {
-		peerServerTLSConfig, err := config.PeerTLSInfo().ServerConfig()
-		if err != nil {
-			log.Fatal("peer server TLS error: ", err)
-		}
-
-		psListener, err = server.NewTLSListener(config.Peer.BindAddr, peerServerTLSConfig)
-		if err != nil {
-			log.Fatal("Failed to create peer listener: ", err)
-		}
-	} else {
-		psListener, err = server.NewListener(config.Peer.BindAddr)
-		if err != nil {
-			log.Fatal("Failed to create peer listener: ", err)
-		}
-	}
-
 	// Create raft transporter and server
 	raftTransporter := server.NewTransporter(followersStats, serverStats, registry, heartbeatInterval, dialTimeout, responseHeaderTimeout)
 	if psConfig.Scheme == "https" {
@@ -168,34 +149,19 @@ func main() {
 		s.EnableTracing()
 	}
 
-	var sListener net.Listener
-	if config.EtcdTLSInfo().Scheme() == "https" {
-		etcdServerTLSConfig, err := config.EtcdTLSInfo().ServerConfig()
-		if err != nil {
-			log.Fatal("etcd TLS error: ", err)
-		}
-
-		sListener, err = server.NewTLSListener(config.BindAddr, etcdServerTLSConfig)
-		if err != nil {
-			log.Fatal("Failed to create TLS etcd listener: ", err)
-		}
-	} else {
-		sListener, err = server.NewListener(config.BindAddr)
-		if err != nil {
-			log.Fatal("Failed to create etcd listener: ", err)
-		}
-	}
-
 	ps.SetServer(s)
 	ps.Start(config.Snapshot, config.Discovery, config.Peers)
 
 	go func() {
-		log.Infof("peer server [name %s, listen on %s, advertised url %s]", ps.Config.Name, psListener.Addr(), ps.Config.URL)
+		log.Infof("peer server [name %s, listen on %s, advertised url %s]", ps.Config.Name, config.Peer.BindAddr, ps.Config.URL)
+		l := server.NewListener(psConfig.Scheme, config.Peer.BindAddr, config.PeerTLSInfo())
+
 		sHTTP := &ehttp.CORSHandler{ps.HTTPHandler(), corsInfo}
-		log.Fatal(http.Serve(psListener, sHTTP))
+		log.Fatal(http.Serve(l, sHTTP))
 	}()
 
-	log.Infof("etcd server [name %s, listen on %s, advertised url %s]", s.Name, sListener.Addr(), s.URL())
+	log.Infof("etcd server [name %s, listen on %s, advertised url %s]", s.Name, config.BindAddr, s.URL())
+	l := server.NewListener(config.EtcdTLSInfo().Scheme(), config.BindAddr, config.EtcdTLSInfo())
 	sHTTP := &ehttp.CORSHandler{s.HTTPHandler(), corsInfo}
-	log.Fatal(http.Serve(sListener, sHTTP))
+	log.Fatal(http.Serve(l, sHTTP))
 }

+ 28 - 2
server/listener.go

@@ -3,9 +3,35 @@ package server
 import (
 	"crypto/tls"
 	"net"
+
+	"github.com/coreos/etcd/log"
 )
 
-func NewListener(addr string) (net.Listener, error) {
+// NewListener creates a net.Listener
+// If the given scheme is "https", it will generate TLS configuration based on TLSInfo.
+// If any error happens, this function will call log.Fatal
+func NewListener(scheme, addr string, tlsInfo *TLSInfo) net.Listener {
+	if scheme == "https" {
+		cfg, err := tlsInfo.ServerConfig()
+		if err != nil {
+			log.Fatal("TLS info error: ", err)
+		}
+
+		l, err := newTLSListener(addr, cfg)
+		if err != nil {
+			log.Fatal("Failed to create TLS listener: ", err)
+		}
+		return l
+	}
+
+	l, err := newListener(addr)
+	if err != nil {
+		log.Fatal("Failed to create listener: ", err)
+	}
+	return l
+}
+
+func newListener(addr string) (net.Listener, error) {
 	if addr == "" {
 		addr = ":http"
 	}
@@ -16,7 +42,7 @@ func NewListener(addr string) (net.Listener, error) {
 	return l, nil
 }
 
-func NewTLSListener(addr string, cfg *tls.Config) (net.Listener, error) {
+func newTLSListener(addr string, cfg *tls.Config) (net.Listener, error) {
 	if addr == "" {
 		addr = ":https"
 	}

+ 22 - 28
tests/server_utils.go

@@ -15,12 +15,12 @@ import (
 )
 
 const (
-	testName		= "ETCDTEST"
-	testClientURL		= "localhost:4401"
-	testRaftURL		= "localhost:7701"
-	testSnapshotCount	= 10000
-	testHeartbeatInterval	= time.Duration(50) * time.Millisecond
-	testElectionTimeout	= time.Duration(200) * time.Millisecond
+	testName              = "ETCDTEST"
+	testClientURL         = "localhost:4401"
+	testRaftURL           = "localhost:7701"
+	testSnapshotCount     = 10000
+	testHeartbeatInterval = time.Duration(50) * time.Millisecond
+	testElectionTimeout   = time.Duration(200) * time.Millisecond
 )
 
 // Starts a server in a temporary directory.
@@ -35,20 +35,17 @@ func RunServer(f func(*server.Server)) {
 	followersStats := server.NewRaftFollowersStats(testName)
 
 	psConfig := server.PeerServerConfig{
-		Name:		testName,
-		URL:		"http://" + testRaftURL,
-		Scheme:		"http",
-		SnapshotCount:	testSnapshotCount,
-		MaxClusterSize:	9,
+		Name:           testName,
+		URL:            "http://" + testRaftURL,
+		Scheme:         "http",
+		SnapshotCount:  testSnapshotCount,
+		MaxClusterSize: 9,
 	}
 
 	mb := metrics.NewBucket("")
 
 	ps := server.NewPeerServer(psConfig, registry, store, &mb, followersStats, serverStats)
-	psListener, err := server.NewListener(testRaftURL)
-	if err != nil {
-		panic(err)
-	}
+	psListener := server.NewListener("http", testRaftURL, nil)
 
 	// Create Raft transporter and server
 	dialTimeout := (3 * testHeartbeatInterval) + testElectionTimeout
@@ -63,10 +60,7 @@ func RunServer(f func(*server.Server)) {
 	ps.SetRaftServer(raftServer)
 
 	s := server.New(testName, "http://"+testClientURL, ps, registry, store, nil)
-	sListener, err := server.NewListener(testClientURL)
-	if err != nil {
-		panic(err)
-	}
+	sListener := server.NewListener("http", testClientURL, nil)
 
 	ps.SetServer(s)
 
@@ -104,16 +98,16 @@ func RunServer(f func(*server.Server)) {
 }
 
 type waitHandler struct {
-        wg *sync.WaitGroup
-        handler http.Handler
+	wg      *sync.WaitGroup
+	handler http.Handler
 }
 
-func (h *waitHandler) ServeHTTP(w http.ResponseWriter, r *http.Request){
-        h.wg.Add(1)
-        defer h.wg.Done()
-        h.handler.ServeHTTP(w, r)
+func (h *waitHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	h.wg.Add(1)
+	defer h.wg.Done()
+	h.handler.ServeHTTP(w, r)
 
-        //important to flush before decrementing the wait group.
-        //we won't get a chance to once main() ends.
-        w.(http.Flusher).Flush()
+	//important to flush before decrementing the wait group.
+	//we won't get a chance to once main() ends.
+	w.(http.Flusher).Flush()
 }