Browse Source

refactor(listener) refactor listener related code
Remove duplicate code around creating http listener.
Start to listen incoming http requests just before serving them.

Xiang Li 11 years ago
parent
commit
3ae792b159
4 changed files with 61 additions and 75 deletions
  1. 4 4
      config/config.go
  2. 7 41
      etcd.go
  3. 28 2
      server/listener.go
  4. 22 28
      tests/server_utils.go

+ 4 - 4
config/config.go

@@ -390,8 +390,8 @@ func (c *Config) Sanitize() error {
 }
 
 // EtcdTLSInfo retrieves a TLSInfo object for the etcd server
-func (c *Config) EtcdTLSInfo() server.TLSInfo {
-	return server.TLSInfo{
+func (c *Config) EtcdTLSInfo() *server.TLSInfo {
+	return &server.TLSInfo{
 		CAFile:   c.CAFile,
 		CertFile: c.CertFile,
 		KeyFile:  c.KeyFile,
@@ -399,8 +399,8 @@ func (c *Config) EtcdTLSInfo() server.TLSInfo {
 }
 
 // PeerRaftInfo retrieves a TLSInfo object for the peer server.
-func (c *Config) PeerTLSInfo() server.TLSInfo {
-	return server.TLSInfo{
+func (c *Config) PeerTLSInfo() *server.TLSInfo {
+	return &server.TLSInfo{
 		CAFile:   c.Peer.CAFile,
 		CertFile: c.Peer.CertFile,
 		KeyFile:  c.Peer.KeyFile,

+ 7 - 41
etcd.go

@@ -18,7 +18,6 @@ package main
 
 import (
 	"fmt"
-	"net"
 	"net/http"
 	"os"
 	"path/filepath"
@@ -126,24 +125,6 @@ func main() {
 	}
 	ps := server.NewPeerServer(psConfig, registry, store, &mb, followersStats, serverStats)
 
-	var psListener net.Listener
-	if psConfig.Scheme == "https" {
-		peerServerTLSConfig, err := config.PeerTLSInfo().ServerConfig()
-		if err != nil {
-			log.Fatal("peer server TLS error: ", err)
-		}
-
-		psListener, err = server.NewTLSListener(config.Peer.BindAddr, peerServerTLSConfig)
-		if err != nil {
-			log.Fatal("Failed to create peer listener: ", err)
-		}
-	} else {
-		psListener, err = server.NewListener(config.Peer.BindAddr)
-		if err != nil {
-			log.Fatal("Failed to create peer listener: ", err)
-		}
-	}
-
 	// Create raft transporter and server
 	raftTransporter := server.NewTransporter(followersStats, serverStats, registry, heartbeatInterval, dialTimeout, responseHeaderTimeout)
 	if psConfig.Scheme == "https" {
@@ -168,34 +149,19 @@ func main() {
 		s.EnableTracing()
 	}
 
-	var sListener net.Listener
-	if config.EtcdTLSInfo().Scheme() == "https" {
-		etcdServerTLSConfig, err := config.EtcdTLSInfo().ServerConfig()
-		if err != nil {
-			log.Fatal("etcd TLS error: ", err)
-		}
-
-		sListener, err = server.NewTLSListener(config.BindAddr, etcdServerTLSConfig)
-		if err != nil {
-			log.Fatal("Failed to create TLS etcd listener: ", err)
-		}
-	} else {
-		sListener, err = server.NewListener(config.BindAddr)
-		if err != nil {
-			log.Fatal("Failed to create etcd listener: ", err)
-		}
-	}
-
 	ps.SetServer(s)
 	ps.Start(config.Snapshot, config.Discovery, config.Peers)
 
 	go func() {
-		log.Infof("peer server [name %s, listen on %s, advertised url %s]", ps.Config.Name, psListener.Addr(), ps.Config.URL)
+		log.Infof("peer server [name %s, listen on %s, advertised url %s]", ps.Config.Name, config.Peer.BindAddr, ps.Config.URL)
+		l := server.NewListener(psConfig.Scheme, config.Peer.BindAddr, config.PeerTLSInfo())
+
 		sHTTP := &ehttp.CORSHandler{ps.HTTPHandler(), corsInfo}
-		log.Fatal(http.Serve(psListener, sHTTP))
+		log.Fatal(http.Serve(l, sHTTP))
 	}()
 
-	log.Infof("etcd server [name %s, listen on %s, advertised url %s]", s.Name, sListener.Addr(), s.URL())
+	log.Infof("etcd server [name %s, listen on %s, advertised url %s]", s.Name, config.BindAddr, s.URL())
+	l := server.NewListener(config.EtcdTLSInfo().Scheme(), config.BindAddr, config.EtcdTLSInfo())
 	sHTTP := &ehttp.CORSHandler{s.HTTPHandler(), corsInfo}
-	log.Fatal(http.Serve(sListener, sHTTP))
+	log.Fatal(http.Serve(l, sHTTP))
 }

+ 28 - 2
server/listener.go

@@ -3,9 +3,35 @@ package server
 import (
 	"crypto/tls"
 	"net"
+
+	"github.com/coreos/etcd/log"
 )
 
-func NewListener(addr string) (net.Listener, error) {
+// NewListener creates a net.Listener
+// If the given scheme is "https", it will generate TLS configuration based on TLSInfo.
+// If any error happens, this function will call log.Fatal
+func NewListener(scheme, addr string, tlsInfo *TLSInfo) net.Listener {
+	if scheme == "https" {
+		cfg, err := tlsInfo.ServerConfig()
+		if err != nil {
+			log.Fatal("TLS info error: ", err)
+		}
+
+		l, err := newTLSListener(addr, cfg)
+		if err != nil {
+			log.Fatal("Failed to create TLS listener: ", err)
+		}
+		return l
+	}
+
+	l, err := newListener(addr)
+	if err != nil {
+		log.Fatal("Failed to create listener: ", err)
+	}
+	return l
+}
+
+func newListener(addr string) (net.Listener, error) {
 	if addr == "" {
 		addr = ":http"
 	}
@@ -16,7 +42,7 @@ func NewListener(addr string) (net.Listener, error) {
 	return l, nil
 }
 
-func NewTLSListener(addr string, cfg *tls.Config) (net.Listener, error) {
+func newTLSListener(addr string, cfg *tls.Config) (net.Listener, error) {
 	if addr == "" {
 		addr = ":https"
 	}

+ 22 - 28
tests/server_utils.go

@@ -15,12 +15,12 @@ import (
 )
 
 const (
-	testName		= "ETCDTEST"
-	testClientURL		= "localhost:4401"
-	testRaftURL		= "localhost:7701"
-	testSnapshotCount	= 10000
-	testHeartbeatInterval	= time.Duration(50) * time.Millisecond
-	testElectionTimeout	= time.Duration(200) * time.Millisecond
+	testName              = "ETCDTEST"
+	testClientURL         = "localhost:4401"
+	testRaftURL           = "localhost:7701"
+	testSnapshotCount     = 10000
+	testHeartbeatInterval = time.Duration(50) * time.Millisecond
+	testElectionTimeout   = time.Duration(200) * time.Millisecond
 )
 
 // Starts a server in a temporary directory.
@@ -35,20 +35,17 @@ func RunServer(f func(*server.Server)) {
 	followersStats := server.NewRaftFollowersStats(testName)
 
 	psConfig := server.PeerServerConfig{
-		Name:		testName,
-		URL:		"http://" + testRaftURL,
-		Scheme:		"http",
-		SnapshotCount:	testSnapshotCount,
-		MaxClusterSize:	9,
+		Name:           testName,
+		URL:            "http://" + testRaftURL,
+		Scheme:         "http",
+		SnapshotCount:  testSnapshotCount,
+		MaxClusterSize: 9,
 	}
 
 	mb := metrics.NewBucket("")
 
 	ps := server.NewPeerServer(psConfig, registry, store, &mb, followersStats, serverStats)
-	psListener, err := server.NewListener(testRaftURL)
-	if err != nil {
-		panic(err)
-	}
+	psListener := server.NewListener("http", testRaftURL, nil)
 
 	// Create Raft transporter and server
 	dialTimeout := (3 * testHeartbeatInterval) + testElectionTimeout
@@ -63,10 +60,7 @@ func RunServer(f func(*server.Server)) {
 	ps.SetRaftServer(raftServer)
 
 	s := server.New(testName, "http://"+testClientURL, ps, registry, store, nil)
-	sListener, err := server.NewListener(testClientURL)
-	if err != nil {
-		panic(err)
-	}
+	sListener := server.NewListener("http", testClientURL, nil)
 
 	ps.SetServer(s)
 
@@ -104,16 +98,16 @@ func RunServer(f func(*server.Server)) {
 }
 
 type waitHandler struct {
-        wg *sync.WaitGroup
-        handler http.Handler
+	wg      *sync.WaitGroup
+	handler http.Handler
 }
 
-func (h *waitHandler) ServeHTTP(w http.ResponseWriter, r *http.Request){
-        h.wg.Add(1)
-        defer h.wg.Done()
-        h.handler.ServeHTTP(w, r)
+func (h *waitHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	h.wg.Add(1)
+	defer h.wg.Done()
+	h.handler.ServeHTTP(w, r)
 
-        //important to flush before decrementing the wait group.
-        //we won't get a chance to once main() ends.
-        w.(http.Flusher).Flush()
+	//important to flush before decrementing the wait group.
+	//we won't get a chance to once main() ends.
+	w.(http.Flusher).Flush()
 }