Browse Source

refactor(server): add Client struct

This is used to send request to web API.
It will do this behavior a lot in standby mode, so I abstract this
struct first.
Yicheng Qin 11 years ago
parent
commit
4e14604e5c
3 changed files with 233 additions and 84 deletions
  1. 6 1
      error/error.go
  2. 203 0
      server/client.go
  3. 24 83
      server/peer_server.go

+ 6 - 1
error/error.go

@@ -58,6 +58,9 @@ var errors = map[int]string{
 	EcodeInvalidActiveSize:   "Invalid active size",
 	EcodeInvalidPromoteDelay: "Standby promote delay",
 	EcodePromoteError:        "Standby promotion error",
+
+	// client related errors
+	EcodeClientInternal: "Client Internal Error",
 }
 
 const (
@@ -92,6 +95,8 @@ const (
 	EcodeInvalidActiveSize   = 403
 	EcodeInvalidPromoteDelay = 404
 	EcodePromoteError        = 405
+
+	EcodeClientInternal = 500
 )
 
 type Error struct {
@@ -116,7 +121,7 @@ func Message(code int) string {
 
 // Only for error interface
 func (e Error) Error() string {
-	return e.Message
+	return e.Message + " (" + e.Cause + ")"
 }
 
 func (e Error) toJsonString() string {

+ 203 - 0
server/client.go

@@ -0,0 +1,203 @@
+package server
+
+import (
+	"bytes"
+	"encoding/binary"
+	"encoding/json"
+	"errors"
+	"fmt"
+	"io/ioutil"
+	"net/http"
+	"net/url"
+	"strconv"
+
+	etcdErr "github.com/coreos/etcd/error"
+	"github.com/coreos/etcd/log"
+)
+
+// Client sends various requests using HTTP API.
+// It is different from raft communication, and doesn't record anything in the log.
+// Public functions return "etcd/error".Error intentionally to figure out
+// etcd error code easily.
+// TODO(yichengq): It is similar to go-etcd. But it could have many efforts
+// to integrate the two. Leave it for further discussion.
+type Client struct {
+	http.Client
+}
+
+func NewClient(transport http.RoundTripper) *Client {
+	return &Client{http.Client{Transport: transport}}
+}
+
+// CheckVersion checks whether the version is available.
+func (c *Client) CheckVersion(url string, version int) (bool, *etcdErr.Error) {
+	resp, err := c.Get(url + fmt.Sprintf("/version/%d/check", version))
+	if err != nil {
+		return false, clientError(err)
+	}
+
+	defer resp.Body.Close()
+
+	return resp.StatusCode == 200, nil
+}
+
+// GetVersion fetches the peer version of a cluster.
+func (c *Client) GetVersion(url string) (int, *etcdErr.Error) {
+	resp, err := c.Get(url + "/version")
+	if err != nil {
+		return 0, clientError(err)
+	}
+
+	defer resp.Body.Close()
+
+	body, err := ioutil.ReadAll(resp.Body)
+	if err != nil {
+		return 0, clientError(err)
+	}
+
+	// Parse version number.
+	version, _ := strconv.Atoi(string(body))
+	return version, nil
+}
+
+// AddMachine adds machine to the cluster.
+// The first return value is the commit index of join command.
+func (c *Client) AddMachine(url string, cmd *JoinCommand) (uint64, *etcdErr.Error) {
+	b, _ := json.Marshal(cmd)
+	url = url + "/join"
+
+	log.Infof("Send Join Request to %s", url)
+	resp, err := c.put(url, b)
+	if err != nil {
+		return 0, clientError(err)
+	}
+	defer resp.Body.Close()
+
+	if err := c.checkErrorResponse(resp); err != nil {
+		return 0, err
+	}
+	b, err = ioutil.ReadAll(resp.Body)
+	if err != nil {
+		return 0, clientError(err)
+	}
+	index, numRead := binary.Uvarint(b)
+	if numRead < 0 {
+		return 0, clientError(fmt.Errorf("buf too small, or value too large"))
+	}
+	return index, nil
+}
+
+func (c *Client) parseJSONResponse(resp *http.Response, val interface{}) *etcdErr.Error {
+	defer resp.Body.Close()
+
+	if err := c.checkErrorResponse(resp); err != nil {
+		return err
+	}
+	if err := json.NewDecoder(resp.Body).Decode(val); err != nil {
+		log.Debugf("Error parsing join response: %v", err)
+		return clientError(err)
+	}
+	return nil
+}
+
+func (c *Client) checkErrorResponse(resp *http.Response) *etcdErr.Error {
+	if resp.StatusCode != http.StatusOK {
+		uerr := &etcdErr.Error{}
+		if err := json.NewDecoder(resp.Body).Decode(uerr); err != nil {
+			log.Debugf("Error parsing response to etcd error: %v", err)
+			return clientError(err)
+		}
+		return uerr
+	}
+	return nil
+}
+
+// put sends server side PUT request.
+// It always follows redirects instead of stopping according to RFC 2616.
+func (c *Client) put(urlStr string, body []byte) (*http.Response, error) {
+	req, err := http.NewRequest("PUT", urlStr, bytes.NewBuffer(body))
+	if err != nil {
+		return nil, err
+	}
+	return c.doAlwaysFollowingRedirects(req, body)
+}
+
+// doAlwaysFollowingRedirects provides similar functionality as standard one,
+// but it does redirect with the same method for PUT or POST requests.
+// Part of the code is borrowed from pkg/net/http/client.go.
+func (c *Client) doAlwaysFollowingRedirects(ireq *http.Request, body []byte) (resp *http.Response, err error) {
+	var base *url.URL
+	redirectChecker := c.CheckRedirect
+	if redirectChecker == nil {
+		redirectChecker = defaultCheckRedirect
+	}
+	var via []*http.Request
+
+	req := ireq
+	urlStr := "" // next relative or absolute URL to fetch (after first request)
+	for redirect := 0; ; redirect++ {
+		if redirect != 0 {
+			req, err = http.NewRequest(ireq.Method, urlStr, bytes.NewBuffer(body))
+			if err != nil {
+				break
+			}
+			req.URL = base.ResolveReference(req.URL)
+			if len(via) > 0 {
+				// Add the Referer header.
+				lastReq := via[len(via)-1]
+				if lastReq.URL.Scheme != "https" {
+					req.Header.Set("Referer", lastReq.URL.String())
+				}
+
+				err = redirectChecker(req, via)
+				if err != nil {
+					break
+				}
+			}
+		}
+
+		urlStr = req.URL.String()
+		// It uses exported Do method here.
+		// It is more elegant to use unexported send method, but that will
+		// introduce many redundant code.
+		if resp, err = c.Do(req); err != nil {
+			break
+		}
+
+		if shouldExtraRedirectPost(resp.StatusCode) {
+			resp.Body.Close()
+			if urlStr = resp.Header.Get("Location"); urlStr == "" {
+				err = errors.New(fmt.Sprintf("%d response missing Location header", resp.StatusCode))
+				break
+			}
+			base = req.URL
+			via = append(via, req)
+			continue
+		}
+		return
+	}
+
+	if resp != nil {
+		resp.Body.Close()
+	}
+	return nil, err
+}
+
+func shouldExtraRedirectPost(statusCode int) bool {
+	switch statusCode {
+	case http.StatusMovedPermanently, http.StatusTemporaryRedirect:
+		return true
+	}
+	return false
+}
+
+func defaultCheckRedirect(req *http.Request, via []*http.Request) error {
+	if len(via) >= 10 {
+		return errors.New("stopped after 10 redirects")
+	}
+	return nil
+}
+
+func clientError(err error) *etcdErr.Error {
+	return etcdErr.NewError(etcdErr.EcodeClientInternal, err.Error(), 0)
+}

+ 24 - 83
server/peer_server.go

@@ -1,16 +1,12 @@
 package server
 
 import (
-	"bytes"
-	"encoding/binary"
 	"encoding/json"
 	"fmt"
-	"io/ioutil"
 	"math/rand"
 	"net/http"
 	"net/url"
 	"sort"
-	"strconv"
 	"strings"
 	"sync"
 	"time"
@@ -52,6 +48,7 @@ type PeerServerConfig struct {
 
 type PeerServer struct {
 	Config         PeerServerConfig
+	client         *Client
 	clusterConfig  *ClusterConfig
 	raftServer     raft.Server
 	server         *Server
@@ -250,6 +247,11 @@ func (s *PeerServer) Start(snapshot bool, discoverURL string, peers []string) er
 		}
 	}
 
+	// TODO(yichengq): client for HTTP API usage could use transport other
+	// than the raft one. The transport should have longer timeout because
+	// it doesn't have fault tolerance of raft protocol.
+	s.client = NewClient(s.raftServer.Transporter().(*transporter).transport)
+
 	s.raftServer.Init()
 
 	// Set NOCOW for data directory in btrfs
@@ -359,24 +361,6 @@ func (s *PeerServer) startAsFollower(cluster []string, retryTimes int) error {
 	return nil
 }
 
-// getVersion fetches the peer version of a cluster.
-func getVersion(t *transporter, versionURL url.URL) (int, error) {
-	resp, _, err := t.Get(versionURL.String())
-	if err != nil {
-		return 0, err
-	}
-	defer resp.Body.Close()
-
-	body, err := ioutil.ReadAll(resp.Body)
-	if err != nil {
-		return 0, err
-	}
-
-	// Parse version number.
-	version, _ := strconv.Atoi(string(body))
-	return version, nil
-}
-
 // Upgradable checks whether all peers in a cluster support an upgrade to the next store version.
 func (s *PeerServer) Upgradable() error {
 	nextVersion := s.store.Version() + 1
@@ -386,13 +370,12 @@ func (s *PeerServer) Upgradable() error {
 			return fmt.Errorf("PeerServer: Cannot parse URL: '%s' (%s)", peerURL, err)
 		}
 
-		t, _ := s.raftServer.Transporter().(*transporter)
-		checkURL := (&url.URL{Host: u.Host, Scheme: s.Config.Scheme, Path: fmt.Sprintf("/version/%d/check", nextVersion)}).String()
-		resp, _, err := t.Get(checkURL)
+		url := (&url.URL{Host: u.Host, Scheme: s.Config.Scheme}).String()
+		ok, err := s.client.CheckVersion(url, nextVersion)
 		if err != nil {
-			return fmt.Errorf("PeerServer: Cannot check version compatibility: %s", u.Host)
+			return err
 		}
-		if resp.StatusCode != 200 {
+		if !ok {
 			return fmt.Errorf("PeerServer: Version %d is not compatible with peer: %s", nextVersion, u.Host)
 		}
 	}
@@ -501,12 +484,10 @@ func (s *PeerServer) joinCluster(cluster []string) bool {
 
 // Send join requests to peer.
 func (s *PeerServer) joinByPeer(server raft.Server, peer string, scheme string) error {
-	// t must be ok
-	t, _ := server.Transporter().(*transporter)
+	u := (&url.URL{Host: peer, Scheme: scheme}).String()
 
 	// Our version must match the leaders version
-	versionURL := url.URL{Host: peer, Scheme: scheme, Path: "/version"}
-	version, err := getVersion(t, versionURL)
+	version, err := s.client.GetVersion(u)
 	if err != nil {
 		return fmt.Errorf("Error during join version check: %v", err)
 	}
@@ -514,60 +495,20 @@ func (s *PeerServer) joinByPeer(server raft.Server, peer string, scheme string)
 		return fmt.Errorf("Unable to join: cluster version is %d; version compatibility is %d - %d", version, store.MinVersion(), store.MaxVersion())
 	}
 
-	var b bytes.Buffer
-	c := &JoinCommand{
-		MinVersion: store.MinVersion(),
-		MaxVersion: store.MaxVersion(),
-		Name:       server.Name(),
-		RaftURL:    s.Config.URL,
-		EtcdURL:    s.server.URL(),
+	joinIndex, err := s.client.AddMachine(u,
+		&JoinCommand{
+			MinVersion: store.MinVersion(),
+			MaxVersion: store.MaxVersion(),
+			Name:       server.Name(),
+			RaftURL:    s.Config.URL,
+			EtcdURL:    s.server.URL(),
+		})
+	if err != nil {
+		return err
 	}
-	json.NewEncoder(&b).Encode(c)
 
-	joinURL := url.URL{Host: peer, Scheme: scheme, Path: "/join"}
-	log.Infof("Send Join Request to %s", joinURL.String())
-
-	req, _ := http.NewRequest("PUT", joinURL.String(), &b)
-	resp, err := t.client.Do(req)
-
-	for {
-		if err != nil {
-			return fmt.Errorf("Unable to join: %v", err)
-		}
-		if resp != nil {
-			defer resp.Body.Close()
-
-			log.Infof("»»»» %d", resp.StatusCode)
-			if resp.StatusCode == http.StatusOK {
-				b, _ := ioutil.ReadAll(resp.Body)
-				s.joinIndex, _ = binary.Uvarint(b)
-				return nil
-			}
-			if resp.StatusCode == http.StatusTemporaryRedirect {
-				address := resp.Header.Get("Location")
-				log.Debugf("Send Join Request to %s", address)
-				c := &JoinCommand{
-					MinVersion: store.MinVersion(),
-					MaxVersion: store.MaxVersion(),
-					Name:       server.Name(),
-					RaftURL:    s.Config.URL,
-					EtcdURL:    s.server.URL(),
-				}
-				json.NewEncoder(&b).Encode(c)
-				resp, _, err = t.Put(address, &b)
-
-			} else if resp.StatusCode == http.StatusBadRequest {
-				log.Debug("Reach max number peers in the cluster")
-				decoder := json.NewDecoder(resp.Body)
-				err := &etcdErr.Error{}
-				decoder.Decode(err)
-				return *err
-			} else {
-				return fmt.Errorf("Unable to join")
-			}
-		}
-
-	}
+	s.joinIndex = joinIndex
+	return nil
 }
 
 func (s *PeerServer) Stats() []byte {