| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490 |
- package main
- import (
- "bytes"
- "crypto/tls"
- "crypto/x509"
- "encoding/json"
- "encoding/pem"
- "flag"
- "fmt"
- "github.com/xiangli-cmu/go-raft"
- "github.com/xiangli-cmu/raft-etcd/store"
- "github.com/xiangli-cmu/raft-etcd/web"
- //"io"
- "io/ioutil"
- "log"
- "net/http"
- "os"
- //"strconv"
- "strings"
- "time"
- )
- //------------------------------------------------------------------------------
- //
- // Initialization
- //
- //------------------------------------------------------------------------------
- var verbose bool
- var cluster string
- var address string
- var clientPort int
- var serverPort int
- var webPort int
- var serverCertFile string
- var serverKeyFile string
- var serverCAFile string
- var clientCertFile string
- var clientKeyFile string
- var clientCAFile string
- var dirPath string
- var ignore bool
- var maxSize int
- func init() {
- flag.BoolVar(&verbose, "v", false, "verbose logging")
- flag.StringVar(&cluster, "C", "", "the ip address and port of a existing cluster")
- flag.StringVar(&address, "a", "", "the ip address of the local machine")
- flag.IntVar(&clientPort, "c", 4001, "the port to communicate with clients")
- flag.IntVar(&serverPort, "s", 7001, "the port to communicate with servers")
- flag.IntVar(&webPort, "w", -1, "the port of web interface")
- flag.StringVar(&serverCAFile, "serverCAFile", "", "the path of the CAFile")
- flag.StringVar(&serverCertFile, "serverCert", "", "the cert file of the server")
- flag.StringVar(&serverKeyFile, "serverKey", "", "the key file of the server")
- flag.StringVar(&clientCAFile, "clientCAFile", "", "the path of the client CAFile")
- flag.StringVar(&clientCertFile, "clientCert", "", "the cert file of the client")
- flag.StringVar(&clientKeyFile, "clientKey", "", "the key file of the client")
- flag.StringVar(&dirPath, "d", "./", "the directory to store log and snapshot")
- flag.BoolVar(&ignore, "i", false, "ignore the old configuration, create a new node")
- flag.IntVar(&maxSize, "m", 1024, "the max size of result buffer")
- }
- // CONSTANTS
- const (
- HTTP = iota
- HTTPS
- HTTPSANDVERIFY
- )
- const (
- SERVER = iota
- CLIENT
- )
- const (
- ELECTIONTIMTOUT = 200 * time.Millisecond
- HEARTBEATTIMEOUT = 50 * time.Millisecond
- )
- //------------------------------------------------------------------------------
- //
- // Typedefs
- //
- //------------------------------------------------------------------------------
- type Info struct {
- Address string `json:"address"`
- ServerPort int `json:"serverPort"`
- ClientPort int `json:"clientPort"`
- WebPort int `json:"webPort"`
- }
- //------------------------------------------------------------------------------
- //
- // Variables
- //
- //------------------------------------------------------------------------------
- var server *raft.Server
- var serverTransHandler transHandler
- var logger *log.Logger
- var storeMsg chan string
- //------------------------------------------------------------------------------
- //
- // Functions
- //
- //------------------------------------------------------------------------------
- //--------------------------------------
- // Main
- //--------------------------------------
- func main() {
- var err error
- logger = log.New(os.Stdout, "", log.LstdFlags)
- flag.Parse()
- // Setup commands.
- raft.RegisterCommand(&JoinCommand{})
- raft.RegisterCommand(&SetCommand{})
- raft.RegisterCommand(&GetCommand{})
- raft.RegisterCommand(&DeleteCommand{})
- if err := os.MkdirAll(dirPath, 0744); err != nil {
- fatal("Unable to create path: %v", err)
- }
- // Read server info from file or grab it from user.
- var info *Info = getInfo(dirPath)
- name := fmt.Sprintf("%s:%d", info.Address, info.ServerPort)
- fmt.Printf("ServerName: %s\n\n", name)
- // secrity type
- st := securityType(SERVER)
- if st == -1 {
- panic("ERROR type")
- }
- serverTransHandler = createTranHandler(st)
- // Setup new raft server.
- s := store.CreateStore(maxSize)
- // create raft server
- server, err = raft.NewServer(name, dirPath, serverTransHandler, s, nil)
- if err != nil {
- fatal("%v", err)
- }
- err = server.LoadSnapshot()
- if err == nil {
- debug("%s finished load snapshot", server.Name())
- } else {
- fmt.Println(err)
- debug("%s bad snapshot", server.Name())
- }
- server.Initialize()
- debug("%s finished init", server.Name())
- server.SetElectionTimeout(ELECTIONTIMTOUT)
- server.SetHeartbeatTimeout(HEARTBEATTIMEOUT)
- debug("%s finished set timeout", server.Name())
- if server.IsLogEmpty() {
- // start as a leader in a new cluster
- if cluster == "" {
- server.StartLeader()
- // join self as a peer
- command := &JoinCommand{}
- command.Name = server.Name()
- server.Do(command)
- debug("%s start as a leader", server.Name())
- // start as a fellower in a existing cluster
- } else {
- server.StartFollower()
- err := Join(server, cluster)
- if err != nil {
- panic(err)
- }
- fmt.Println("success join")
- }
- // rejoin the previous cluster
- } else {
- server.StartFollower()
- debug("%s start as a follower", server.Name())
- }
- // open the snapshot
- //go server.Snapshot()
- if webPort != -1 {
- // start web
- s.SetMessager(&storeMsg)
- go webHelper()
- go web.Start(server, webPort)
- }
- go startServTransport(info.ServerPort, st)
- startClientTransport(info.ClientPort, securityType(CLIENT))
- }
- func usage() {
- fatal("usage: raftd [PATH]")
- }
- func createTranHandler(st int) transHandler {
- t := transHandler{}
- switch st {
- case HTTP:
- t := transHandler{}
- t.client = nil
- return t
- case HTTPS:
- fallthrough
- case HTTPSANDVERIFY:
- tlsCert, err := tls.LoadX509KeyPair(serverCertFile, serverKeyFile)
- if err != nil {
- panic(err)
- }
- tr := &http.Transport{
- TLSClientConfig: &tls.Config{
- Certificates: []tls.Certificate{tlsCert},
- InsecureSkipVerify: true,
- },
- DisableCompression: true,
- }
- t.client = &http.Client{Transport: tr}
- return t
- }
- // for complier
- return transHandler{}
- }
- func startServTransport(port int, st int) {
- // internal commands
- http.HandleFunc("/join", JoinHttpHandler)
- http.HandleFunc("/vote", VoteHttpHandler)
- http.HandleFunc("/log", GetLogHttpHandler)
- http.HandleFunc("/log/append", AppendEntriesHttpHandler)
- http.HandleFunc("/snapshot", SnapshotHttpHandler)
- http.HandleFunc("/client", clientHttpHandler)
- switch st {
- case HTTP:
- debug("raft server [%s] listen on http", server.Name())
- log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", port), nil))
- case HTTPS:
- http.ListenAndServeTLS(fmt.Sprintf(":%d", port), serverCertFile, serverKeyFile, nil)
- case HTTPSANDVERIFY:
- pemByte, _ := ioutil.ReadFile(serverCAFile)
- block, pemByte := pem.Decode(pemByte)
- cert, err := x509.ParseCertificate(block.Bytes)
- if err != nil {
- fmt.Println(err)
- }
- certPool := x509.NewCertPool()
- certPool.AddCert(cert)
- server := &http.Server{
- TLSConfig: &tls.Config{
- ClientAuth: tls.RequireAndVerifyClientCert,
- ClientCAs: certPool,
- },
- Addr: fmt.Sprintf(":%d", port),
- }
- err = server.ListenAndServeTLS(serverCertFile, serverKeyFile)
- if err != nil {
- log.Fatal(err)
- }
- }
- }
- func startClientTransport(port int, st int) {
- // external commands
- http.HandleFunc("/v1/keys/", Multiplexer)
- http.HandleFunc("/v1/watch/", WatchHttpHandler)
- http.HandleFunc("/v1/list/", ListHttpHandler)
- http.HandleFunc("/master", MasterHttpHandler)
- switch st {
- case HTTP:
- debug("etcd [%s] listen on http", server.Name())
- log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", port), nil))
- case HTTPS:
- http.ListenAndServeTLS(fmt.Sprintf(":%d", port), clientCertFile, clientKeyFile, nil)
- case HTTPSANDVERIFY:
- pemByte, _ := ioutil.ReadFile(clientCAFile)
- block, pemByte := pem.Decode(pemByte)
- cert, err := x509.ParseCertificate(block.Bytes)
- if err != nil {
- fmt.Println(err)
- }
- certPool := x509.NewCertPool()
- certPool.AddCert(cert)
- server := &http.Server{
- TLSConfig: &tls.Config{
- ClientAuth: tls.RequireAndVerifyClientCert,
- ClientCAs: certPool,
- },
- Addr: fmt.Sprintf(":%d", port),
- }
- err = server.ListenAndServeTLS(clientCertFile, clientKeyFile)
- if err != nil {
- log.Fatal(err)
- }
- }
- }
- //--------------------------------------
- // Config
- //--------------------------------------
- func securityType(source int) int {
- var keyFile, certFile, CAFile string
- switch source {
- case SERVER:
- keyFile = serverKeyFile
- certFile = serverCertFile
- CAFile = serverCAFile
- case CLIENT:
- keyFile = clientKeyFile
- certFile = clientCertFile
- CAFile = clientCAFile
- }
- if keyFile == "" && certFile == "" && CAFile == "" {
- return HTTP
- }
- if keyFile != "" && certFile != "" {
- if CAFile != "" {
- return HTTPSANDVERIFY
- }
- return HTTPS
- }
- return -1
- }
- func getInfo(path string) *Info {
- info := &Info{}
- // Read in the server info if available.
- infoPath := fmt.Sprintf("%s/info", path)
- // delete the old configuration if exist
- if ignore {
- logPath := fmt.Sprintf("%s/log", path)
- snapshotPath := fmt.Sprintf("%s/snapshotPath", path)
- os.Remove(infoPath)
- os.Remove(logPath)
- os.RemoveAll(snapshotPath)
- }
- if file, err := os.Open(infoPath); err == nil {
- if content, err := ioutil.ReadAll(file); err != nil {
- fatal("Unable to read info: %v", err)
- } else {
- if err = json.Unmarshal(content, &info); err != nil {
- fatal("Unable to parse info: %v", err)
- }
- }
- file.Close()
- // Otherwise ask user for info and write it to file.
- } else {
- if address == "" {
- fatal("Please give the address of the local machine")
- }
- info.Address = address
- info.Address = strings.TrimSpace(info.Address)
- fmt.Println("address ", info.Address)
- info.ServerPort = serverPort
- info.ClientPort = clientPort
- info.WebPort = webPort
- // Write to file.
- content, _ := json.Marshal(info)
- content = []byte(string(content) + "\n")
- if err := ioutil.WriteFile(infoPath, content, 0644); err != nil {
- fatal("Unable to write info to file: %v", err)
- }
- }
- return info
- }
- //--------------------------------------
- // Handlers
- //--------------------------------------
- // Send join requests to the leader.
- func Join(s *raft.Server, serverName string) error {
- var b bytes.Buffer
- command := &JoinCommand{}
- command.Name = s.Name()
- json.NewEncoder(&b).Encode(command)
- // t must be ok
- t, _ := server.Transporter().(transHandler)
- debug("Send Join Request to %s", serverName)
- resp, err := Post(&t, fmt.Sprintf("%s/join", serverName), &b)
- for {
- if resp != nil {
- defer resp.Body.Close()
- if resp.StatusCode == http.StatusOK {
- return nil
- }
- if resp.StatusCode == http.StatusServiceUnavailable {
- address, err := ioutil.ReadAll(resp.Body)
- if err != nil {
- warn("Cannot Read Leader info: %v", err)
- }
- debug("Leader is %s", address)
- debug("Send Join Request to %s", address)
- json.NewEncoder(&b).Encode(command)
- resp, err = Post(&t, fmt.Sprintf("%s/join", address), &b)
- }
- }
- }
- return fmt.Errorf("Unable to join: %v", err)
- }
|