Browse Source

refactor(client): remove useless logic in redirection

Yicheng Qin 11 years ago
parent
commit
ae81f843f1
1 changed files with 16 additions and 66 deletions
  1. 16 66
      server/client.go

+ 16 - 66
server/client.go

@@ -8,7 +8,6 @@ import (
 	"fmt"
 	"io/ioutil"
 	"net/http"
-	"net/url"
 	"strconv"
 
 	etcdErr "github.com/coreos/etcd/error"
@@ -141,87 +140,38 @@ func (c *Client) checkErrorResponse(resp *http.Response) *etcdErr.Error {
 // put sends server side PUT request.
 // It always follows redirects instead of stopping according to RFC 2616.
 func (c *Client) put(urlStr string, body []byte) (*http.Response, error) {
-	req, err := http.NewRequest("PUT", urlStr, bytes.NewBuffer(body))
-	if err != nil {
-		return nil, err
-	}
-	return c.doAlwaysFollowingRedirects(req, body)
+	return c.doAlwaysFollowingRedirects("PUT", urlStr, body)
 }
 
-// doAlwaysFollowingRedirects provides similar functionality as standard one,
-// but it does redirect with the same method for PUT or POST requests.
-// Part of the code is borrowed from pkg/net/http/client.go.
-func (c *Client) doAlwaysFollowingRedirects(ireq *http.Request, body []byte) (resp *http.Response, err error) {
-	var base *url.URL
-	redirectChecker := c.CheckRedirect
-	if redirectChecker == nil {
-		redirectChecker = defaultCheckRedirect
-	}
-	var via []*http.Request
-
-	req := ireq
-	urlStr := "" // next relative or absolute URL to fetch (after first request)
-	for redirect := 0; ; redirect++ {
-		if redirect != 0 {
-			req, err = http.NewRequest(ireq.Method, urlStr, bytes.NewBuffer(body))
-			if err != nil {
-				break
-			}
-			req.URL = base.ResolveReference(req.URL)
-			if len(via) > 0 {
-				// Add the Referer header.
-				lastReq := via[len(via)-1]
-				if lastReq.URL.Scheme != "https" {
-					req.Header.Set("Referer", lastReq.URL.String())
-				}
-
-				err = redirectChecker(req, via)
-				if err != nil {
-					break
-				}
-			}
+func (c *Client) doAlwaysFollowingRedirects(method string, urlStr string, body []byte) (resp *http.Response, err error) {
+	var req *http.Request
+
+	for redirect := 0; redirect < 10; redirect++ {
+		req, err = http.NewRequest(method, urlStr, bytes.NewBuffer(body))
+		if err != nil {
+			return
 		}
 
-		urlStr = req.URL.String()
-		// It uses exported Do method here.
-		// It is more elegant to use unexported send method, but that will
-		// introduce many redundant code.
 		if resp, err = c.Do(req); err != nil {
-			break
+			if resp != nil {
+				resp.Body.Close()
+			}
+			return
 		}
 
-		if shouldExtraRedirectPost(resp.StatusCode) {
+		if resp.StatusCode == http.StatusMovedPermanently || resp.StatusCode == http.StatusTemporaryRedirect {
 			resp.Body.Close()
 			if urlStr = resp.Header.Get("Location"); urlStr == "" {
 				err = errors.New(fmt.Sprintf("%d response missing Location header", resp.StatusCode))
-				break
+				return
 			}
-			base = req.URL
-			via = append(via, req)
 			continue
 		}
 		return
 	}
 
-	if resp != nil {
-		resp.Body.Close()
-	}
-	return nil, err
-}
-
-func shouldExtraRedirectPost(statusCode int) bool {
-	switch statusCode {
-	case http.StatusMovedPermanently, http.StatusTemporaryRedirect:
-		return true
-	}
-	return false
-}
-
-func defaultCheckRedirect(req *http.Request, via []*http.Request) error {
-	if len(via) >= 10 {
-		return errors.New("stopped after 10 redirects")
-	}
-	return nil
+	err = errors.New("stopped after 10 redirects")
+	return
 }
 
 func clientError(err error) *etcdErr.Error {