Browse Source

make hijack more stable (#565)

Kevin Wan 3 năm trước cách đây
mục cha
commit
3c6951577d

+ 5 - 1
rest/handler/authhandler.go

@@ -143,7 +143,11 @@ func (grw *guardedResponseWriter) Header() http.Header {
 // Hijack implements the http.Hijacker interface.
 // This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
 func (grw *guardedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
-	return grw.writer.(http.Hijacker).Hijack()
+	if hijacked, ok := grw.writer.(http.Hijacker); ok {
+		return hijacked.Hijack()
+	}
+
+	return nil, nil, errors.New("server doesn't support hijacking")
 }
 
 func (grw *guardedResponseWriter) Write(body []byte) (int, error) {

+ 30 - 0
rest/handler/authhandler_test.go

@@ -1,6 +1,8 @@
 package handler
 
 import (
+	"bufio"
+	"net"
 	"net/http"
 	"net/http/httptest"
 	"testing"
@@ -87,6 +89,26 @@ func TestAuthHandler_NilError(t *testing.T) {
 	})
 }
 
+func TestAuthHandler_Flush(t *testing.T) {
+	resp := httptest.NewRecorder()
+	handler := newGuardedResponseWriter(resp)
+	handler.Flush()
+	assert.True(t, resp.Flushed)
+}
+
+func TestAuthHandler_Hijack(t *testing.T) {
+	resp := httptest.NewRecorder()
+	writer := newGuardedResponseWriter(resp)
+	assert.NotPanics(t, func() {
+		writer.Hijack()
+	})
+
+	writer = newGuardedResponseWriter(mockedHijackable{resp})
+	assert.NotPanics(t, func() {
+		writer.Hijack()
+	})
+}
+
 func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) {
 	now := time.Now().Unix()
 	claims := make(jwt.MapClaims)
@@ -101,3 +123,11 @@ func buildToken(secretKey string, payloads map[string]interface{}, seconds int64
 
 	return token.SignedString([]byte(secretKey))
 }
+
+type mockedHijackable struct {
+	*httptest.ResponseRecorder
+}
+
+func (m mockedHijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+	return nil, nil, nil
+}

+ 5 - 1
rest/handler/cryptionhandler.go

@@ -99,7 +99,11 @@ func (w *cryptionResponseWriter) Header() http.Header {
 // Hijack implements the http.Hijacker interface.
 // This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
 func (w *cryptionResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
-	return w.ResponseWriter.(http.Hijacker).Hijack()
+	if hijacked, ok := w.ResponseWriter.(http.Hijacker); ok {
+		return hijacked.Hijack()
+	}
+
+	return nil, nil, errors.New("server doesn't support hijacking")
 }
 
 func (w *cryptionResponseWriter) Write(p []byte) (int, error) {

+ 13 - 0
rest/handler/cryptionhandler_test.go

@@ -103,3 +103,16 @@ func TestCryptionHandlerFlush(t *testing.T) {
 	assert.Nil(t, err)
 	assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
 }
+
+func TestCryptionHandler_Hijack(t *testing.T) {
+	resp := httptest.NewRecorder()
+	writer := newCryptionResponseWriter(resp)
+	assert.NotPanics(t, func() {
+		writer.Hijack()
+	})
+
+	writer = newCryptionResponseWriter(mockedHijackable{resp})
+	assert.NotPanics(t, func() {
+		writer.Hijack()
+	})
+}

+ 16 - 1
rest/handler/loghandler.go

@@ -4,6 +4,7 @@ import (
 	"bufio"
 	"bytes"
 	"context"
+	"errors"
 	"fmt"
 	"io"
 	"net"
@@ -40,7 +41,11 @@ func (w *loggedResponseWriter) Header() http.Header {
 // Hijack implements the http.Hijacker interface.
 // This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
 func (w *loggedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
-	return w.w.(http.Hijacker).Hijack()
+	if hijacked, ok := w.w.(http.Hijacker); ok {
+		return hijacked.Hijack()
+	}
+
+	return nil, nil, errors.New("server doesn't support hijacking")
 }
 
 func (w *loggedResponseWriter) Write(bytes []byte) (int, error) {
@@ -91,6 +96,16 @@ func (w *detailLoggedResponseWriter) Header() http.Header {
 	return w.writer.Header()
 }
 
+// Hijack implements the http.Hijacker interface.
+// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
+func (w *detailLoggedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+	if hijacked, ok := w.writer.w.(http.Hijacker); ok {
+		return hijacked.Hijack()
+	}
+
+	return nil, nil, errors.New("server doesn't support hijacking")
+}
+
 func (w *detailLoggedResponseWriter) Write(bs []byte) (int, error) {
 	w.buf.Write(bs)
 	return w.writer.Write(bs)

+ 38 - 0
rest/handler/loghandler_test.go

@@ -62,6 +62,44 @@ func TestLogHandlerSlow(t *testing.T) {
 	}
 }
 
+func TestLogHandler_Hijack(t *testing.T) {
+	resp := httptest.NewRecorder()
+	writer := &loggedResponseWriter{
+		w: resp,
+	}
+	assert.NotPanics(t, func() {
+		writer.Hijack()
+	})
+
+	writer = &loggedResponseWriter{
+		w: mockedHijackable{resp},
+	}
+	assert.NotPanics(t, func() {
+		writer.Hijack()
+	})
+}
+
+func TestDetailedLogHandler_Hijack(t *testing.T) {
+	resp := httptest.NewRecorder()
+	writer := &detailLoggedResponseWriter{
+		writer: &loggedResponseWriter{
+			w: resp,
+		},
+	}
+	assert.NotPanics(t, func() {
+		writer.Hijack()
+	})
+
+	writer = &detailLoggedResponseWriter{
+		writer: &loggedResponseWriter{
+			w: mockedHijackable{resp},
+		},
+	}
+	assert.NotPanics(t, func() {
+		writer.Hijack()
+	})
+}
+
 func BenchmarkLogHandler(b *testing.B) {
 	b.ReportAllocs()