Browse Source

Add Sec-WebSocket-Extensions header parser

Also, improve token list header parser.
Gary Burd 9 years ago
parent
commit
8b29b78138
2 changed files with 223 additions and 13 deletions
  1. 183 13
      util.go
  2. 40 0
      util_test.go

+ 183 - 13
util.go

@@ -13,19 +13,6 @@ import (
 	"strings"
 )
 
-// tokenListContainsValue returns true if the 1#token header with the given
-// name contains token.
-func tokenListContainsValue(header http.Header, name string, value string) bool {
-	for _, v := range header[name] {
-		for _, s := range strings.Split(v, ",") {
-			if strings.EqualFold(value, strings.TrimSpace(s)) {
-				return true
-			}
-		}
-	}
-	return false
-}
-
 var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
 
 func computeAcceptKey(challengeKey string) string {
@@ -42,3 +29,186 @@ func generateChallengeKey() (string, error) {
 	}
 	return base64.StdEncoding.EncodeToString(p), nil
 }
+
+// Octet types from RFC 2616.
+var octetTypes [256]byte
+
+const (
+	isTokenOctet = 1 << iota
+	isSpaceOctet
+)
+
+func init() {
+	// From RFC 2616
+	//
+	// OCTET      = <any 8-bit sequence of data>
+	// CHAR       = <any US-ASCII character (octets 0 - 127)>
+	// CTL        = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
+	// CR         = <US-ASCII CR, carriage return (13)>
+	// LF         = <US-ASCII LF, linefeed (10)>
+	// SP         = <US-ASCII SP, space (32)>
+	// HT         = <US-ASCII HT, horizontal-tab (9)>
+	// <">        = <US-ASCII double-quote mark (34)>
+	// CRLF       = CR LF
+	// LWS        = [CRLF] 1*( SP | HT )
+	// TEXT       = <any OCTET except CTLs, but including LWS>
+	// separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <">
+	//              | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT
+	// token      = 1*<any CHAR except CTLs or separators>
+	// qdtext     = <any TEXT except <">>
+
+	for c := 0; c < 256; c++ {
+		var t byte
+		isCtl := c <= 31 || c == 127
+		isChar := 0 <= c && c <= 127
+		isSeparator := strings.IndexRune(" \t\"(),/:;<=>?@[]\\{}", rune(c)) >= 0
+		if strings.IndexRune(" \t\r\n", rune(c)) >= 0 {
+			t |= isSpaceOctet
+		}
+		if isChar && !isCtl && !isSeparator {
+			t |= isTokenOctet
+		}
+		octetTypes[c] = t
+	}
+}
+
+func skipSpace(s string) (rest string) {
+	i := 0
+	for ; i < len(s); i++ {
+		if octetTypes[s[i]]&isSpaceOctet == 0 {
+			break
+		}
+	}
+	return s[i:]
+}
+
+func nextToken(s string) (token, rest string) {
+	i := 0
+	for ; i < len(s); i++ {
+		if octetTypes[s[i]]&isTokenOctet == 0 {
+			break
+		}
+	}
+	return s[:i], s[i:]
+}
+
+func nextTokenOrQuoted(s string) (value string, rest string) {
+	if !strings.HasPrefix(s, "\"") {
+		return nextToken(s)
+	}
+	s = s[1:]
+	for i := 0; i < len(s); i++ {
+		switch s[i] {
+		case '"':
+			return s[:i], s[i+1:]
+		case '\\':
+			p := make([]byte, len(s)-1)
+			j := copy(p, s[:i])
+			escape := true
+			for i = i + 1; i < len(s); i++ {
+				b := s[i]
+				switch {
+				case escape:
+					escape = false
+					p[j] = b
+					j += 1
+				case b == '\\':
+					escape = true
+				case b == '"':
+					return string(p[:j]), s[i+1:]
+				default:
+					p[j] = b
+					j += 1
+				}
+			}
+			return "", ""
+		}
+	}
+	return "", ""
+}
+
+// tokenListContainsValue returns true if the 1#token header with the given
+// name contains token.
+func tokenListContainsValue(header http.Header, name string, value string) bool {
+headers:
+	for _, s := range header[name] {
+		for {
+			var t string
+			t, s = nextToken(skipSpace(s))
+			if t == "" {
+				continue headers
+			}
+			s = skipSpace(s)
+			if s != "" && s[0] != ',' {
+				continue headers
+			}
+			if strings.EqualFold(t, value) {
+				return true
+			}
+			if s == "" {
+				continue headers
+			}
+			s = s[1:]
+		}
+	}
+	return false
+}
+
+// parseExtensiosn parses WebSocket extensions from a header.
+func parseExtensions(header http.Header) []map[string]string {
+
+	// From RFC 6455:
+	//
+	//  Sec-WebSocket-Extensions = extension-list
+	//  extension-list = 1#extension
+	//  extension = extension-token *( ";" extension-param )
+	//  extension-token = registered-token
+	//  registered-token = token
+	//  extension-param = token [ "=" (token | quoted-string) ]
+	//     ;When using the quoted-string syntax variant, the value
+	//     ;after quoted-string unescaping MUST conform to the
+	//     ;'token' ABNF.
+
+	var result []map[string]string
+headers:
+	for _, s := range header["Sec-Websocket-Extensions"] {
+		for {
+			var t string
+			t, s = nextToken(skipSpace(s))
+			if t == "" {
+				continue headers
+			}
+			ext := map[string]string{"": t}
+			for {
+				s = skipSpace(s)
+				if !strings.HasPrefix(s, ";") {
+					break
+				}
+				var k string
+				k, s = nextToken(skipSpace(s[1:]))
+				if k == "" {
+					continue headers
+				}
+				s = skipSpace(s)
+				var v string
+				if strings.HasPrefix(s, "=") {
+					v, s = nextTokenOrQuoted(skipSpace(s[1:]))
+					s = skipSpace(s)
+				}
+				if s != "" && s[0] != ',' && s[0] != ';' {
+					continue headers
+				}
+				ext[k] = v
+			}
+			if s != "" && s[0] != ',' {
+				continue headers
+			}
+			result = append(result, ext)
+			if s == "" {
+				continue headers
+			}
+			s = s[1:]
+		}
+	}
+	return result
+}

+ 40 - 0
util_test.go

@@ -6,6 +6,7 @@ package websocket
 
 import (
 	"net/http"
+	"reflect"
 	"testing"
 )
 
@@ -32,3 +33,42 @@ func TestTokenListContainsValue(t *testing.T) {
 		}
 	}
 }
+
+var parseExtensionTests = []struct {
+	value      string
+	extensions []map[string]string
+}{
+	{`foo`, []map[string]string{map[string]string{"": "foo"}}},
+	{`foo, bar; baz=2`, []map[string]string{
+		map[string]string{"": "foo"},
+		map[string]string{"": "bar", "baz": "2"}}},
+	{`foo; bar="b,a;z"`, []map[string]string{
+		map[string]string{"": "foo", "bar": "b,a;z"}}},
+	{`foo , bar; baz = 2`, []map[string]string{
+		map[string]string{"": "foo"},
+		map[string]string{"": "bar", "baz": "2"}}},
+	{`foo, bar; baz=2 junk`, []map[string]string{
+		map[string]string{"": "foo"}}},
+	{`foo junk, bar; baz=2 junk`, nil},
+	{`mux; max-channels=4; flow-control, deflate-stream`, []map[string]string{
+		map[string]string{"": "mux", "max-channels": "4", "flow-control": ""},
+		map[string]string{"": "deflate-stream"}}},
+	{`permessage-foo; x="10"`, []map[string]string{
+		map[string]string{"": "permessage-foo", "x": "10"}}},
+	{`permessage-foo; use_y, permessage-foo`, []map[string]string{
+		map[string]string{"": "permessage-foo", "use_y": ""},
+		map[string]string{"": "permessage-foo"}}},
+	{`permessage-deflate; client_max_window_bits; server_max_window_bits=10 , permessage-deflate; client_max_window_bits`, []map[string]string{
+		map[string]string{"": "permessage-deflate", "client_max_window_bits": "", "server_max_window_bits": "10"},
+		map[string]string{"": "permessage-deflate", "client_max_window_bits": ""}}},
+}
+
+func TestParseExtensions(t *testing.T) {
+	for _, tt := range parseExtensionTests {
+		h := http.Header{http.CanonicalHeaderKey("Sec-WebSocket-Extensions"): {tt.value}}
+		extensions := parseExtensions(h)
+		if !reflect.DeepEqual(extensions, tt.extensions) {
+			t.Errorf("parseExtensions(%q)\n    = %v,\nwant %v", tt.value, extensions, tt.extensions)
+		}
+	}
+}