Browse Source

use type inheritance

Xiang Li 12 years ago
parent
commit
981351c9d9
9 changed files with 60 additions and 57 deletions
  1. 2 2
      etcd.go
  2. 6 6
      etcd_handlers.go
  3. 18 15
      etcd_server.go
  4. 1 1
      etcd_test.go
  5. 2 2
      machines.go
  6. 6 6
      raft_handlers.go
  7. 22 22
      raft_server.go
  8. 1 1
      snapshot.go
  9. 2 2
      util.go

+ 2 - 2
etcd.go

@@ -197,7 +197,7 @@ func main() {
 	r = newRaftServer(info.Name, info.RaftURL, &raftTLSConfig, &info.RaftTLS)
 
 	startWebInterface()
-	r.start()
-	e.start()
+	r.run()
+	e.run()
 
 }

+ 6 - 6
etcd_handlers.go

@@ -122,8 +122,8 @@ func DeleteHttpHandler(w *http.ResponseWriter, req *http.Request) {
 // Dispatch the command to leader
 func dispatch(c Command, w *http.ResponseWriter, req *http.Request, etcd bool) {
 
-	if r.server.State() == raft.Leader {
-		if body, err := r.server.Do(c); err != nil {
+	if r.State() == raft.Leader {
+		if body, err := r.Do(c); err != nil {
 
 			if _, ok := err.(store.NotFoundError); ok {
 				(*w).WriteHeader(http.StatusNotFound)
@@ -167,7 +167,7 @@ func dispatch(c Command, w *http.ResponseWriter, req *http.Request, etcd bool) {
 			return
 		}
 	} else {
-		leader := r.server.Leader()
+		leader := r.Leader()
 		// current no leader
 		if leader == "" {
 			(*w).WriteHeader(http.StatusInternalServerError)
@@ -211,7 +211,7 @@ func dispatch(c Command, w *http.ResponseWriter, req *http.Request, etcd bool) {
 
 // Handler to return the current leader's raft address
 func LeaderHttpHandler(w http.ResponseWriter, req *http.Request) {
-	leader := r.server.Leader()
+	leader := r.Leader()
 
 	if leader != "" {
 		w.WriteHeader(http.StatusOK)
@@ -256,7 +256,7 @@ func GetHttpHandler(w *http.ResponseWriter, req *http.Request) {
 		Key: key,
 	}
 
-	if body, err := command.Apply(r.server); err != nil {
+	if body, err := command.Apply(r.Server); err != nil {
 
 		if _, ok := err.(store.NotFoundError); ok {
 			(*w).WriteHeader(http.StatusNotFound)
@@ -310,7 +310,7 @@ func WatchHttpHandler(w http.ResponseWriter, req *http.Request) {
 		return
 	}
 
-	if body, err := command.Apply(r.server); err != nil {
+	if body, err := command.Apply(r.Server); err != nil {
 		w.WriteHeader(http.StatusInternalServerError)
 		w.Write(newJsonError(500, key))
 	} else {

+ 18 - 15
etcd_server.go

@@ -6,6 +6,7 @@ import (
 )
 
 type etcdServer struct {
+	http.Server
 	name    string
 	url     string
 	tlsConf *TLSConfig
@@ -14,32 +15,34 @@ type etcdServer struct {
 
 var e *etcdServer
 
-func newEtcdServer(name string, url string, tlsConf *TLSConfig, tlsInfo *TLSInfo) *etcdServer {
+func newEtcdServer(name string, urlStr string, tlsConf *TLSConfig, tlsInfo *TLSInfo) *etcdServer {
+	u, err := url.Parse(urlStr)
+
+	if err != nil {
+		fatalf("invalid url '%s': %s", e.url, err)
+	}
+
 	return &etcdServer{
+		Server: http.Server{
+			Handler:   NewEtcdMuxer(),
+			TLSConfig: &tlsConf.Server,
+			Addr:      u.Host,
+		},
 		name:    name,
-		url:     url,
+		url:     urlStr,
 		tlsConf: tlsConf,
 		tlsInfo: tlsInfo,
 	}
 }
 
 // Start to listen and response etcd client command
-func (e *etcdServer) start() {
-	u, err := url.Parse(e.url)
-	if err != nil {
-		fatalf("invalid url '%s': %s", e.url, err)
-	}
-	infof("etcd server [%s:%s]", e.name, u)
+func (e *etcdServer) run() {
 
-	server := http.Server{
-		Handler:   NewEtcdMuxer(),
-		TLSConfig: &e.tlsConf.Server,
-		Addr:      u.Host,
-	}
+	infof("etcd server [%s:%s]", e.name, e.url)
 
 	if e.tlsConf.Scheme == "http" {
-		fatal(server.ListenAndServe())
+		fatal(e.ListenAndServe())
 	} else {
-		fatal(server.ListenAndServeTLS(e.tlsInfo.CertFile, e.tlsInfo.KeyFile))
+		fatal(e.ListenAndServeTLS(e.tlsInfo.CertFile, e.tlsInfo.KeyFile))
 	}
 }

+ 1 - 1
etcd_test.go

@@ -291,7 +291,7 @@ func TestKillRandom(t *testing.T) {
 
 	toKill := make(map[int]bool)
 
-	for i := 0; i < 200; i++ {
+	for i := 0; i < 20; i++ {
 		fmt.Printf("TestKillRandom Round[%d/200]\n", i)
 
 		j := 0

+ 2 - 2
machines.go

@@ -10,11 +10,11 @@ func machineNum() int {
 // getMachines gets the current machines in the cluster
 func getMachines() []string {
 
-	peers := r.server.Peers()
+	peers := r.Peers()
 
 	machines := make([]string, len(peers)+1)
 
-	leader, ok := nameToEtcdURL(r.server.Leader())
+	leader, ok := nameToEtcdURL(r.Leader())
 	self := e.url
 	i := 1
 

+ 6 - 6
raft_handlers.go

@@ -15,7 +15,7 @@ func GetLogHttpHandler(w http.ResponseWriter, req *http.Request) {
 	debugf("[recv] GET %s/log", r.url)
 	w.Header().Set("Content-Type", "application/json")
 	w.WriteHeader(http.StatusOK)
-	json.NewEncoder(w).Encode(r.server.LogEntries())
+	json.NewEncoder(w).Encode(r.LogEntries())
 }
 
 // Response to vote request
@@ -24,7 +24,7 @@ func VoteHttpHandler(w http.ResponseWriter, req *http.Request) {
 	err := decodeJsonRequest(req, rvreq)
 	if err == nil {
 		debugf("[recv] POST %s/vote [%s]", r.url, rvreq.CandidateName)
-		if resp := r.server.RequestVote(rvreq); resp != nil {
+		if resp := r.RequestVote(rvreq); resp != nil {
 			w.WriteHeader(http.StatusOK)
 			json.NewEncoder(w).Encode(resp)
 			return
@@ -41,7 +41,7 @@ func AppendEntriesHttpHandler(w http.ResponseWriter, req *http.Request) {
 
 	if err == nil {
 		debugf("[recv] POST %s/log/append [%d]", r.url, len(aereq.Entries))
-		if resp := r.server.AppendEntries(aereq); resp != nil {
+		if resp := r.AppendEntries(aereq); resp != nil {
 			w.WriteHeader(http.StatusOK)
 			json.NewEncoder(w).Encode(resp)
 			if !resp.Success {
@@ -60,7 +60,7 @@ func SnapshotHttpHandler(w http.ResponseWriter, req *http.Request) {
 	err := decodeJsonRequest(req, aereq)
 	if err == nil {
 		debugf("[recv] POST %s/snapshot/ ", r.url)
-		if resp := r.server.RequestSnapshot(aereq); resp != nil {
+		if resp := r.RequestSnapshot(aereq); resp != nil {
 			w.WriteHeader(http.StatusOK)
 			json.NewEncoder(w).Encode(resp)
 			return
@@ -76,7 +76,7 @@ func SnapshotRecoveryHttpHandler(w http.ResponseWriter, req *http.Request) {
 	err := decodeJsonRequest(req, aereq)
 	if err == nil {
 		debugf("[recv] POST %s/snapshotRecovery/ ", r.url)
-		if resp := r.server.SnapshotRecoveryRequest(aereq); resp != nil {
+		if resp := r.SnapshotRecoveryRequest(aereq); resp != nil {
 			w.WriteHeader(http.StatusOK)
 			json.NewEncoder(w).Encode(resp)
 			return
@@ -111,5 +111,5 @@ func JoinHttpHandler(w http.ResponseWriter, req *http.Request) {
 func NameHttpHandler(w http.ResponseWriter, req *http.Request) {
 	debugf("[recv] Get %s/name/ ", r.url)
 	w.WriteHeader(http.StatusOK)
-	w.Write([]byte(r.server.Name()))
+	w.Write([]byte(r.name))
 }

+ 22 - 22
raft_server.go

@@ -13,17 +13,27 @@ import (
 )
 
 type raftServer struct {
+	*raft.Server
 	name    string
 	url     string
 	tlsConf *TLSConfig
 	tlsInfo *TLSInfo
-	server  *raft.Server
 }
 
 var r *raftServer
 
 func newRaftServer(name string, url string, tlsConf *TLSConfig, tlsInfo *TLSInfo) *raftServer {
+
+	// Create transporter for raft
+	raftTransporter := newTransporter(tlsConf.Scheme, tlsConf.Client)
+
+	// Create raft server
+	server, err := raft.NewServer(name, dirPath, raftTransporter, etcdStore, nil)
+
+	check(err)
+
 	return &raftServer{
+		Server:  server,
 		name:    name,
 		url:     url,
 		tlsConf: tlsConf,
@@ -32,26 +42,14 @@ func newRaftServer(name string, url string, tlsConf *TLSConfig, tlsInfo *TLSInfo
 }
 
 // Start the raft server
-func (r *raftServer) start() {
+func (r *raftServer) run() {
 
 	// Setup commands.
 	registerCommands()
 
-	// Create transporter for raft
-	raftTransporter := newTransporter(r.tlsConf.Scheme, r.tlsConf.Client)
-
-	// Create raft server
-	server, err := raft.NewServer(r.name, dirPath, raftTransporter, etcdStore, nil)
-
-	if err != nil {
-		fatal(err)
-	}
-
-	r.server = server
-
 	// LoadSnapshot
 	if snapshot {
-		err = server.LoadSnapshot()
+		err := r.LoadSnapshot()
 
 		if err == nil {
 			debugf("%s finished load snapshot", r.name)
@@ -60,12 +58,12 @@ func (r *raftServer) start() {
 		}
 	}
 
-	server.SetElectionTimeout(ElectionTimeout)
-	server.SetHeartbeatTimeout(HeartbeatTimeout)
+	r.SetElectionTimeout(ElectionTimeout)
+	r.SetHeartbeatTimeout(HeartbeatTimeout)
 
-	server.Start()
+	r.Start()
 
-	if server.IsLogEmpty() {
+	if r.IsLogEmpty() {
 
 		// start as a leader in a new cluster
 		if len(cluster) == 0 {
@@ -74,7 +72,7 @@ func (r *raftServer) start() {
 
 			// leader need to join self as a peer
 			for {
-				_, err := server.Do(newJoinCommand())
+				_, err := r.Do(newJoinCommand())
 				if err == nil {
 					break
 				}
@@ -86,6 +84,8 @@ func (r *raftServer) start() {
 
 			time.Sleep(time.Millisecond * 20)
 
+			var err error
+
 			for i := 0; i < retryTimes; i++ {
 
 				success := false
@@ -93,7 +93,7 @@ func (r *raftServer) start() {
 					if len(machine) == 0 {
 						continue
 					}
-					err = joinCluster(server, machine, r.tlsConf.Scheme)
+					err = joinCluster(r.Server, machine, r.tlsConf.Scheme)
 					if err != nil {
 						if err.Error() == errors[103] {
 							fatal(err)
@@ -171,7 +171,7 @@ func joinCluster(s *raft.Server, raftURL string, scheme string) error {
 	json.NewEncoder(&b).Encode(newJoinCommand())
 
 	// t must be ok
-	t, ok := r.server.Transporter().(transporter)
+	t, ok := r.Transporter().(transporter)
 
 	if !ok {
 		panic("wrong type")

+ 1 - 1
snapshot.go

@@ -29,7 +29,7 @@ func monitorSnapshot() {
 		currentWrites := etcdStore.TotalWrites() - snapConf.lastWrites
 
 		if currentWrites > snapConf.writesThr {
-			r.server.TakeSnapshot()
+			r.TakeSnapshot()
 			snapConf.lastWrites = etcdStore.TotalWrites()
 		}
 	}

+ 2 - 2
util.go

@@ -55,7 +55,7 @@ func startWebInterface() {
 	if argInfo.WebURL != "" {
 		// start web
 		go webHelper()
-		go web.Start(r.server, argInfo.WebURL)
+		go web.Start(r.Server, argInfo.WebURL)
 	}
 }
 
@@ -198,7 +198,7 @@ func send(c chan bool) {
 		command.Key = "foo"
 		command.Value = "bar"
 		command.ExpireTime = time.Unix(0, 0)
-		r.server.Do(command)
+		r.Do(command)
 	}
 	c <- true
 }