|
@@ -34,9 +34,10 @@ type Upgrader struct {
|
|
|
// default values will be used.
|
|
// default values will be used.
|
|
|
ReadBufferSize, WriteBufferSize int
|
|
ReadBufferSize, WriteBufferSize int
|
|
|
|
|
|
|
|
- // Subprotocols specifies the server's supported protocols. If Subprotocols
|
|
|
|
|
- // is nil, then Upgrade does not negotiate a subprotocol.
|
|
|
|
|
- Subprotocols []string
|
|
|
|
|
|
|
+ // NegotiateSubprotocol specifies the function to negotiate a subprotocol
|
|
|
|
|
+ // based on a request. If NegotiateSubprotocol is nil, then no subprotocol
|
|
|
|
|
+ // will be used.
|
|
|
|
|
+ NegotiateSubprotocol func(r *http.Request) (string, error)
|
|
|
|
|
|
|
|
// Error specifies the function for generating HTTP error responses. If Error
|
|
// Error specifies the function for generating HTTP error responses. If Error
|
|
|
// is nil, then http.Error is used to generate the HTTP response.
|
|
// is nil, then http.Error is used to generate the HTTP response.
|
|
@@ -59,21 +60,6 @@ func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status in
|
|
|
return nil, err
|
|
return nil, err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// Check if the passed subprotocol is supported by the server
|
|
|
|
|
-func (u *Upgrader) hasSubprotocol(subprotocol string) bool {
|
|
|
|
|
- if u.Subprotocols == nil {
|
|
|
|
|
- return false
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- for _, s := range u.Subprotocols {
|
|
|
|
|
- if s == subprotocol {
|
|
|
|
|
- return true
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- return false
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
// Check if host in Origin header matches host of request
|
|
// Check if host in Origin header matches host of request
|
|
|
func (u *Upgrader) checkSameOrigin(r *http.Request) bool {
|
|
func (u *Upgrader) checkSameOrigin(r *http.Request) bool {
|
|
|
origin := r.Header.Get("Origin")
|
|
origin := r.Header.Get("Origin")
|
|
@@ -155,12 +141,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
|
|
}
|
|
}
|
|
|
c := newConn(netConn, true, readBufSize, writeBufSize)
|
|
c := newConn(netConn, true, readBufSize, writeBufSize)
|
|
|
|
|
|
|
|
- if u.Subprotocols != nil {
|
|
|
|
|
- for _, proto := range Subprotocols(r) {
|
|
|
|
|
- if u.hasSubprotocol(proto) {
|
|
|
|
|
- c.subprotocol = proto
|
|
|
|
|
- break
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ if u.NegotiateSubprotocol != nil {
|
|
|
|
|
+ c.subprotocol, err = u.NegotiateSubprotocol(r)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ netConn.Close()
|
|
|
|
|
+ return nil, err
|
|
|
}
|
|
}
|
|
|
} else if responseHeader != nil {
|
|
} else if responseHeader != nil {
|
|
|
c.subprotocol = responseHeader.Get("Sec-Websocket-Protocol")
|
|
c.subprotocol = responseHeader.Get("Sec-Websocket-Protocol")
|