Explorar o código

fix empty request body with spnego

replay request body when needed
Jonathan Turner %!s(int64=7) %!d(string=hai) anos
pai
achega
bae8ea1f6f
Modificáronse 2 ficheiros con 25 adicións e 7 borrados
  1. 22 0
      spnego/http.go
  2. 3 7
      spnego/http_test.go

+ 22 - 0
spnego/http.go

@@ -1,11 +1,13 @@
 package spnego
 
 import (
+	"bytes"
 	"context"
 	"encoding/base64"
 	"errors"
 	"fmt"
 	"io"
+	"io/ioutil"
 	"net"
 	"net/http"
 	"net/http/cookiejar"
@@ -39,6 +41,11 @@ func (e redirectErr) Error() string {
 	return fmt.Sprintf("redirect to %v", e.reqTarget.URL)
 }
 
+type teeReadCloser struct {
+	io.Reader
+	io.Closer
+}
+
 // NewClient returns an SPNEGO enabled HTTP client.
 func NewClient(krb5Cl *client.Client, httpCl *http.Client, spn string) *Client {
 	if httpCl == nil {
@@ -68,6 +75,13 @@ func NewClient(krb5Cl *client.Client, httpCl *http.Client, spn string) *Client {
 
 // Do is the SPNEGO enabled HTTP client's equivalent of the http.Client's Do method.
 func (c *Client) Do(req *http.Request) (resp *http.Response, err error) {
+	var body bytes.Buffer
+	if req.Body != nil {
+		// Use a tee reader to capture any body sent in case we have to replay it again
+		teeR := io.TeeReader(req.Body, &body)
+		teeRC := teeReadCloser{teeR, req.Body}
+		req.Body = teeRC
+	}
 	resp, err = c.Client.Do(req)
 	if err != nil {
 		if ue, ok := err.(*url.Error); ok {
@@ -78,6 +92,10 @@ func (c *Client) Do(req *http.Request) (resp *http.Response, err error) {
 				if len(c.reqs) >= 10 {
 					return resp, errors.New("stopped after 10 redirects")
 				}
+				if req.Body != nil {
+					// Refresh the body reader so the body can be sent again
+					e.reqTarget.Body = ioutil.NopCloser(&body)
+				}
 				return c.Do(e.reqTarget)
 			}
 		}
@@ -88,6 +106,10 @@ func (c *Client) Do(req *http.Request) (resp *http.Response, err error) {
 		if err != nil {
 			return resp, err
 		}
+		if req.Body != nil {
+			// Refresh the body reader so the body can be sent again
+			req.Body = ioutil.NopCloser(&body)
+		}
 		return c.Do(req)
 	}
 	return resp, err

+ 3 - 7
spnego/http_test.go

@@ -271,15 +271,11 @@ func TestService_SPNEGOKRB_Upload(t *testing.T) {
 	bodyWriter.Close()
 
 	r, _ := http.NewRequest("POST", s.URL, bodyBuf)
+	r.Header.Set("Content-Type", bodyWriter.FormDataContentType())
 
 	cl := getClient()
-	err = SetSPNEGOHeader(cl, r, "HTTP/host.test.gokrb5")
-	if err != nil {
-		t.Fatalf("error setting client's SPNEGO header: %v", err)
-	}
-
-	r.Header.Set("Content-Type", bodyWriter.FormDataContentType())
-	httpResp, err := http.DefaultClient.Do(r)
+	spnegoCl := NewClient(cl, nil, "HTTP/host.test.gokrb5")
+	httpResp, err := spnegoCl.Do(r)
 	if err != nil {
 		t.Fatalf("Request error: %v\n", err)
 	}