etcd.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. package main
  2. import (
  3. "bytes"
  4. "crypto/tls"
  5. "crypto/x509"
  6. "encoding/json"
  7. "encoding/pem"
  8. "flag"
  9. "fmt"
  10. "github.com/xiangli-cmu/go-raft"
  11. "github.com/xiangli-cmu/raft-etcd/store"
  12. "github.com/xiangli-cmu/raft-etcd/web"
  13. //"io"
  14. "io/ioutil"
  15. "log"
  16. "net/http"
  17. "os"
  18. //"strconv"
  19. "strings"
  20. "time"
  21. )
  22. //------------------------------------------------------------------------------
  23. //
  24. // Initialization
  25. //
  26. //------------------------------------------------------------------------------
  27. var verbose bool
  28. var cluster string
  29. var address string
  30. var clientPort int
  31. var serverPort int
  32. var webPort int
  33. var serverCertFile string
  34. var serverKeyFile string
  35. var serverCAFile string
  36. var clientCertFile string
  37. var clientKeyFile string
  38. var clientCAFile string
  39. var dirPath string
  40. var ignore bool
  41. var maxSize int
  42. func init() {
  43. flag.BoolVar(&verbose, "v", false, "verbose logging")
  44. flag.StringVar(&cluster, "C", "", "the ip address and port of a existing cluster")
  45. flag.StringVar(&address, "a", "", "the ip address of the local machine")
  46. flag.IntVar(&clientPort, "c", 4001, "the port to communicate with clients")
  47. flag.IntVar(&serverPort, "s", 7001, "the port to communicate with servers")
  48. flag.IntVar(&webPort, "w", -1, "the port of web interface")
  49. flag.StringVar(&serverCAFile, "serverCAFile", "", "the path of the CAFile")
  50. flag.StringVar(&serverCertFile, "serverCert", "", "the cert file of the server")
  51. flag.StringVar(&serverKeyFile, "serverKey", "", "the key file of the server")
  52. flag.StringVar(&clientCAFile, "clientCAFile", "", "the path of the client CAFile")
  53. flag.StringVar(&clientCertFile, "clientCert", "", "the cert file of the client")
  54. flag.StringVar(&clientKeyFile, "clientKey", "", "the key file of the client")
  55. flag.StringVar(&dirPath, "d", "./", "the directory to store log and snapshot")
  56. flag.BoolVar(&ignore, "i", false, "ignore the old configuration, create a new node")
  57. flag.IntVar(&maxSize, "m", 1024, "the max size of result buffer")
  58. }
  59. // CONSTANTS
  60. const (
  61. HTTP = iota
  62. HTTPS
  63. HTTPSANDVERIFY
  64. )
  65. const (
  66. SERVER = iota
  67. CLIENT
  68. )
  69. const (
  70. ELECTIONTIMTOUT = 200 * time.Millisecond
  71. HEARTBEATTIMEOUT = 50 * time.Millisecond
  72. )
  73. //------------------------------------------------------------------------------
  74. //
  75. // Typedefs
  76. //
  77. //------------------------------------------------------------------------------
  78. type Info struct {
  79. Address string `json:"address"`
  80. ServerPort int `json:"serverPort"`
  81. ClientPort int `json:"clientPort"`
  82. WebPort int `json:"webPort"`
  83. }
  84. //------------------------------------------------------------------------------
  85. //
  86. // Variables
  87. //
  88. //------------------------------------------------------------------------------
  89. var server *raft.Server
  90. var serverTransHandler transHandler
  91. var logger *log.Logger
  92. var storeMsg chan string
  93. //------------------------------------------------------------------------------
  94. //
  95. // Functions
  96. //
  97. //------------------------------------------------------------------------------
  98. //--------------------------------------
  99. // Main
  100. //--------------------------------------
  101. func main() {
  102. var err error
  103. logger = log.New(os.Stdout, "", log.LstdFlags)
  104. flag.Parse()
  105. // Setup commands.
  106. raft.RegisterCommand(&JoinCommand{})
  107. raft.RegisterCommand(&SetCommand{})
  108. raft.RegisterCommand(&GetCommand{})
  109. raft.RegisterCommand(&DeleteCommand{})
  110. if err := os.MkdirAll(dirPath, 0744); err != nil {
  111. fatal("Unable to create path: %v", err)
  112. }
  113. // Read server info from file or grab it from user.
  114. var info *Info = getInfo(dirPath)
  115. name := fmt.Sprintf("%s:%d", info.Address, info.ServerPort)
  116. fmt.Printf("ServerName: %s\n\n", name)
  117. // secrity type
  118. st := securityType(SERVER)
  119. if st == -1 {
  120. panic("ERROR type")
  121. }
  122. serverTransHandler = createTranHandler(st)
  123. // Setup new raft server.
  124. s := store.CreateStore(maxSize)
  125. // create raft server
  126. server, err = raft.NewServer(name, dirPath, serverTransHandler, s, nil)
  127. if err != nil {
  128. fatal("%v", err)
  129. }
  130. err = server.LoadSnapshot()
  131. if err == nil {
  132. debug("%s finished load snapshot", server.Name())
  133. } else {
  134. fmt.Println(err)
  135. debug("%s bad snapshot", server.Name())
  136. }
  137. server.Initialize()
  138. debug("%s finished init", server.Name())
  139. server.SetElectionTimeout(ELECTIONTIMTOUT)
  140. server.SetHeartbeatTimeout(HEARTBEATTIMEOUT)
  141. debug("%s finished set timeout", server.Name())
  142. if server.IsLogEmpty() {
  143. // start as a leader in a new cluster
  144. if cluster == "" {
  145. server.StartLeader()
  146. // join self as a peer
  147. command := &JoinCommand{}
  148. command.Name = server.Name()
  149. server.Do(command)
  150. debug("%s start as a leader", server.Name())
  151. // start as a fellower in a existing cluster
  152. } else {
  153. server.StartFollower()
  154. err := Join(server, cluster)
  155. if err != nil {
  156. panic(err)
  157. }
  158. fmt.Println("success join")
  159. }
  160. // rejoin the previous cluster
  161. } else {
  162. server.StartFollower()
  163. debug("%s start as a follower", server.Name())
  164. }
  165. // open the snapshot
  166. //go server.Snapshot()
  167. if webPort != -1 {
  168. // start web
  169. s.SetMessager(&storeMsg)
  170. go webHelper()
  171. go web.Start(server, webPort)
  172. }
  173. go startServTransport(info.ServerPort, st)
  174. startClientTransport(info.ClientPort, securityType(CLIENT))
  175. }
  176. func usage() {
  177. fatal("usage: raftd [PATH]")
  178. }
  179. func createTranHandler(st int) transHandler {
  180. t := transHandler{}
  181. switch st {
  182. case HTTP:
  183. t := transHandler{}
  184. t.client = nil
  185. return t
  186. case HTTPS:
  187. fallthrough
  188. case HTTPSANDVERIFY:
  189. tlsCert, err := tls.LoadX509KeyPair(serverCertFile, serverKeyFile)
  190. if err != nil {
  191. panic(err)
  192. }
  193. tr := &http.Transport{
  194. TLSClientConfig: &tls.Config{
  195. Certificates: []tls.Certificate{tlsCert},
  196. InsecureSkipVerify: true,
  197. },
  198. DisableCompression: true,
  199. }
  200. t.client = &http.Client{Transport: tr}
  201. return t
  202. }
  203. // for complier
  204. return transHandler{}
  205. }
  206. func startServTransport(port int, st int) {
  207. // internal commands
  208. http.HandleFunc("/join", JoinHttpHandler)
  209. http.HandleFunc("/vote", VoteHttpHandler)
  210. http.HandleFunc("/log", GetLogHttpHandler)
  211. http.HandleFunc("/log/append", AppendEntriesHttpHandler)
  212. http.HandleFunc("/snapshot", SnapshotHttpHandler)
  213. http.HandleFunc("/client", clientHttpHandler)
  214. switch st {
  215. case HTTP:
  216. debug("raft server [%s] listen on http", server.Name())
  217. log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", port), nil))
  218. case HTTPS:
  219. http.ListenAndServeTLS(fmt.Sprintf(":%d", port), serverCertFile, serverKeyFile, nil)
  220. case HTTPSANDVERIFY:
  221. pemByte, _ := ioutil.ReadFile(serverCAFile)
  222. block, pemByte := pem.Decode(pemByte)
  223. cert, err := x509.ParseCertificate(block.Bytes)
  224. if err != nil {
  225. fmt.Println(err)
  226. }
  227. certPool := x509.NewCertPool()
  228. certPool.AddCert(cert)
  229. server := &http.Server{
  230. TLSConfig: &tls.Config{
  231. ClientAuth: tls.RequireAndVerifyClientCert,
  232. ClientCAs: certPool,
  233. },
  234. Addr: fmt.Sprintf(":%d", port),
  235. }
  236. err = server.ListenAndServeTLS(serverCertFile, serverKeyFile)
  237. if err != nil {
  238. log.Fatal(err)
  239. }
  240. }
  241. }
  242. func startClientTransport(port int, st int) {
  243. // external commands
  244. http.HandleFunc("/v1/keys/", Multiplexer)
  245. http.HandleFunc("/v1/watch/", WatchHttpHandler)
  246. http.HandleFunc("/master", MasterHttpHandler)
  247. switch st {
  248. case HTTP:
  249. debug("etcd [%s] listen on http", server.Name())
  250. log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", port), nil))
  251. case HTTPS:
  252. http.ListenAndServeTLS(fmt.Sprintf(":%d", port), clientCertFile, clientKeyFile, nil)
  253. case HTTPSANDVERIFY:
  254. pemByte, _ := ioutil.ReadFile(clientCAFile)
  255. block, pemByte := pem.Decode(pemByte)
  256. cert, err := x509.ParseCertificate(block.Bytes)
  257. if err != nil {
  258. fmt.Println(err)
  259. }
  260. certPool := x509.NewCertPool()
  261. certPool.AddCert(cert)
  262. server := &http.Server{
  263. TLSConfig: &tls.Config{
  264. ClientAuth: tls.RequireAndVerifyClientCert,
  265. ClientCAs: certPool,
  266. },
  267. Addr: fmt.Sprintf(":%d", port),
  268. }
  269. err = server.ListenAndServeTLS(clientCertFile, clientKeyFile)
  270. if err != nil {
  271. log.Fatal(err)
  272. }
  273. }
  274. }
  275. //--------------------------------------
  276. // Config
  277. //--------------------------------------
  278. func securityType(source int) int {
  279. var keyFile, certFile, CAFile string
  280. switch source {
  281. case SERVER:
  282. keyFile = serverKeyFile
  283. certFile = serverCertFile
  284. CAFile = serverCAFile
  285. case CLIENT:
  286. keyFile = clientKeyFile
  287. certFile = clientCertFile
  288. CAFile = clientCAFile
  289. }
  290. if keyFile == "" && certFile == "" && CAFile == "" {
  291. return HTTP
  292. }
  293. if keyFile != "" && certFile != "" {
  294. if CAFile != "" {
  295. return HTTPSANDVERIFY
  296. }
  297. return HTTPS
  298. }
  299. return -1
  300. }
  301. func getInfo(path string) *Info {
  302. info := &Info{}
  303. // Read in the server info if available.
  304. infoPath := fmt.Sprintf("%s/info", path)
  305. // delete the old configuration if exist
  306. if ignore {
  307. logPath := fmt.Sprintf("%s/log", path)
  308. snapshotPath := fmt.Sprintf("%s/snapshotPath", path)
  309. os.Remove(infoPath)
  310. os.Remove(logPath)
  311. os.RemoveAll(snapshotPath)
  312. }
  313. if file, err := os.Open(infoPath); err == nil {
  314. if content, err := ioutil.ReadAll(file); err != nil {
  315. fatal("Unable to read info: %v", err)
  316. } else {
  317. if err = json.Unmarshal(content, &info); err != nil {
  318. fatal("Unable to parse info: %v", err)
  319. }
  320. }
  321. file.Close()
  322. // Otherwise ask user for info and write it to file.
  323. } else {
  324. if address == "" {
  325. fatal("Please give the address of the local machine")
  326. }
  327. info.Address = address
  328. info.Address = strings.TrimSpace(info.Address)
  329. fmt.Println("address ", info.Address)
  330. info.ServerPort = serverPort
  331. info.ClientPort = clientPort
  332. info.WebPort = webPort
  333. // Write to file.
  334. content, _ := json.Marshal(info)
  335. content = []byte(string(content) + "\n")
  336. if err := ioutil.WriteFile(infoPath, content, 0644); err != nil {
  337. fatal("Unable to write info to file: %v", err)
  338. }
  339. }
  340. return info
  341. }
  342. //--------------------------------------
  343. // Handlers
  344. //--------------------------------------
  345. // Send join requests to the leader.
  346. func Join(s *raft.Server, serverName string) error {
  347. var b bytes.Buffer
  348. command := &JoinCommand{}
  349. command.Name = s.Name()
  350. json.NewEncoder(&b).Encode(command)
  351. // t must be ok
  352. t, _ := server.Transporter().(transHandler)
  353. debug("Send Join Request to %s", serverName)
  354. resp, err := Post(&t, fmt.Sprintf("%s/join", serverName), &b)
  355. for {
  356. if resp != nil {
  357. defer resp.Body.Close()
  358. if resp.StatusCode == http.StatusOK {
  359. return nil
  360. }
  361. if resp.StatusCode == http.StatusServiceUnavailable {
  362. address, err := ioutil.ReadAll(resp.Body)
  363. if err != nil {
  364. warn("Cannot Read Leader info: %v", err)
  365. }
  366. debug("Leader is %s", address)
  367. debug("Send Join Request to %s", address)
  368. json.NewEncoder(&b).Encode(command)
  369. resp, err = Post(&t, fmt.Sprintf("%s/join", address), &b)
  370. }
  371. }
  372. }
  373. return fmt.Errorf("Unable to join: %v", err)
  374. }