Преглед изворни кода

Merge pull request #166 from MaximeHeckel/add-cookiejar

add cookie jar to dialer
Gary Burd пре 9 година
родитељ
комит
5df680c89f
2 измењених фајлова са 68 додато и 1 уклоњено
  1. 19 1
      client.go
  2. 49 0
      client_server_test.go

+ 19 - 1
client.go

@@ -78,6 +78,11 @@ type Dialer struct {
 	// guarantee that compression will be supported. Currently only "no context
 	// takeover" modes are supported.
 	EnableCompression bool
+
+	// Jar specifies the cookie jar.
+	// If Jar is nil, cookies are not sent in requests and ignored
+	// in responses.
+	Jar http.CookieJar
 }
 
 var errMalformedURL = errors.New("malformed ws or wss URL")
@@ -91,7 +96,6 @@ func parseURL(s string) (*url.URL, error) {
 	//
 	// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
 	// wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
-
 	var u url.URL
 	switch {
 	case strings.HasPrefix(s, "ws://"):
@@ -201,6 +205,13 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
 		Host:       u.Host,
 	}
 
+	// Set the cookies present in the cookie jar of the dialer
+	if d.Jar != nil {
+		for _, cookie := range d.Jar.Cookies(u) {
+			req.AddCookie(cookie)
+		}
+	}
+
 	// Set the request headers using the capitalization for names and values in
 	// RFC examples. Although the capitalization shouldn't matter, there are
 	// servers that depend on it. The Header.Set method is not used because the
@@ -337,6 +348,13 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
 	if err != nil {
 		return nil, nil, err
 	}
+
+	if d.Jar != nil {
+		if rc := resp.Cookies(); len(rc) > 0 {
+			d.Jar.SetCookies(u, rc)
+		}
+	}
+
 	if resp.StatusCode != 101 ||
 		!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
 		!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||

+ 49 - 0
client_server_test.go

@@ -11,6 +11,7 @@ import (
 	"io"
 	"io/ioutil"
 	"net/http"
+	"net/http/cookiejar"
 	"net/http/httptest"
 	"net/url"
 	"reflect"
@@ -228,6 +229,54 @@ func TestDial(t *testing.T) {
 	sendRecv(t, ws)
 }
 
+func TestDialCookieJar(t *testing.T) {
+	s := newServer(t)
+	defer s.Close()
+
+	jar, _ := cookiejar.New(nil)
+	d := cstDialer
+	d.Jar = jar
+
+	u, _ := parseURL(s.URL)
+
+	switch u.Scheme {
+	case "ws":
+		u.Scheme = "http"
+	case "wss":
+		u.Scheme = "https"
+	}
+
+	cookies := []*http.Cookie{&http.Cookie{Name: "gorilla", Value: "ws", Path: "/"}}
+	d.Jar.SetCookies(u, cookies)
+
+	ws, _, err := d.Dial(s.URL, nil)
+	if err != nil {
+		t.Fatalf("Dial: %v", err)
+	}
+	defer ws.Close()
+
+	var gorilla string
+	var sessionID string
+	for _, c := range d.Jar.Cookies(u) {
+		if c.Name == "gorilla" {
+			gorilla = c.Value
+		}
+
+		if c.Name == "sessionID" {
+			sessionID = c.Value
+		}
+	}
+	if gorilla != "ws" {
+		t.Error("Cookie not present in jar.")
+	}
+
+	if sessionID != "1234" {
+		t.Error("Set-Cookie not received from the server.")
+	}
+
+	sendRecv(t, ws)
+}
+
 func TestDialTLS(t *testing.T) {
 	s := newTLSServer(t)
 	defer s.Close()