diff --git a/pkg/middlewares/forwardedheaders/forwarded_header.go b/pkg/middlewares/forwardedheaders/forwarded_header.go index 5d2355dfd..2a36859c4 100644 --- a/pkg/middlewares/forwardedheaders/forwarded_header.go +++ b/pkg/middlewares/forwardedheaders/forwarded_header.go @@ -104,9 +104,10 @@ func isWebsocketRequest(req *http.Request) bool { return true } - h = h[pos:] + h = h[pos+1:] } } + return containsHeader(connection, "upgrade") && containsHeader(upgrade, "websocket") } diff --git a/pkg/middlewares/forwardedheaders/forwarded_header_test.go b/pkg/middlewares/forwardedheaders/forwarded_header_test.go index 8fc290f63..40e2c323a 100644 --- a/pkg/middlewares/forwardedheaders/forwarded_header_test.go +++ b/pkg/middlewares/forwardedheaders/forwarded_header_test.go @@ -3,6 +3,7 @@ package forwardedheaders import ( "crypto/tls" "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/assert" @@ -299,3 +300,71 @@ func TestServeHTTP(t *testing.T) { }) } } + +func Test_isWebsocketRequest(t *testing.T) { + testCases := []struct { + desc string + connectionHeader string + upgradeHeader string + assert assert.BoolAssertionFunc + }{ + { + desc: "connection Header multiple values middle", + connectionHeader: "foo,upgrade,bar", + upgradeHeader: "websocket", + assert: assert.True, + }, + { + desc: "connection Header multiple values end", + connectionHeader: "foo,bar,upgrade", + upgradeHeader: "websocket", + assert: assert.True, + }, + { + desc: "connection Header multiple values begin", + connectionHeader: "upgrade,foo,bar", + upgradeHeader: "websocket", + assert: assert.True, + }, + { + desc: "connection Header no upgrade", + connectionHeader: "foo,bar", + upgradeHeader: "websocket", + assert: assert.False, + }, + { + desc: "connection Header empty", + connectionHeader: "", + upgradeHeader: "websocket", + assert: assert.False, + }, + { + desc: "no header values", + connectionHeader: "foo,bar", + upgradeHeader: "foo,bar", + assert: assert.False, + }, + { + desc: "upgrade header multiple values", + connectionHeader: "upgrade", + upgradeHeader: "foo,bar,websocket", + assert: assert.True, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + + req.Header.Set(connection, test.connectionHeader) + req.Header.Set(upgrade, test.upgradeHeader) + + ok := isWebsocketRequest(req) + + test.assert(t, ok) + }) + } +}