diff --git a/pkg/middlewares/auth/forward.go b/pkg/middlewares/auth/forward.go index 381c687b4..65722316a 100644 --- a/pkg/middlewares/auth/forward.go +++ b/pkg/middlewares/auth/forward.go @@ -26,6 +26,18 @@ const ( forwardedTypeName = "ForwardedAuthType" ) +// hopHeaders Hop-by-hop headers to be removed in the authentication request. +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html +// Proxy-Authorization header is forwarded to the authentication server (see https://tools.ietf.org/html/rfc7235#section-4.4). +var hopHeaders = []string{ + forward.Connection, + forward.KeepAlive, + forward.Te, // canonicalized version of "TE" + forward.Trailers, + forward.TransferEncoding, + forward.Upgrade, +} + type forwardAuth struct { address string authResponseHeaders []string @@ -131,7 +143,7 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) { logger.Debugf("Remote error %s. StatusCode: %d", fa.address, forwardResponse.StatusCode) utils.CopyHeaders(rw.Header(), forwardResponse.Header) - utils.RemoveHeaders(rw.Header(), forward.HopHeaders...) + utils.RemoveHeaders(rw.Header(), hopHeaders...) // Grab the location header, if any. redirectURL, err := forwardResponse.Location() @@ -187,7 +199,7 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func writeHeader(req, forwardReq *http.Request, trustForwardHeader bool, allowedHeaders []string) { utils.CopyHeaders(forwardReq.Header, req.Header) - utils.RemoveHeaders(forwardReq.Header, forward.HopHeaders...) + utils.RemoveHeaders(forwardReq.Header, hopHeaders...) forwardReq.Header = filterForwardRequestHeaders(forwardReq.Header, allowedHeaders) diff --git a/pkg/middlewares/auth/forward_test.go b/pkg/middlewares/auth/forward_test.go index 65e6ad9cd..13a8f5a5f 100644 --- a/pkg/middlewares/auth/forward_test.go +++ b/pkg/middlewares/auth/forward_test.go @@ -26,6 +26,7 @@ func TestForwardAuthFail(t *testing.T) { }) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set(forward.ProxyAuthenticate, "test") http.Error(w, "Forbidden", http.StatusForbidden) })) t.Cleanup(server.Close) @@ -48,6 +49,7 @@ func TestForwardAuthFail(t *testing.T) { err = res.Body.Close() require.NoError(t, err) + assert.Equal(t, "test", res.Header.Get(forward.ProxyAuthenticate)) assert.Equal(t, "Forbidden\n", string(body)) } @@ -142,7 +144,7 @@ func TestForwardAuthRedirect(t *testing.T) { func TestForwardAuthRemoveHopByHopHeaders(t *testing.T) { authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { headers := w.Header() - for _, header := range forward.HopHeaders { + for _, header := range hopHeaders { if header == forward.TransferEncoding { headers.Set(header, "chunked") } else { @@ -367,11 +369,13 @@ func Test_writeHeader(t *testing.T) { }, trustForwardHeader: false, expectedHeaders: map[string]string{ - "X-CustomHeader": "CustomHeader", - "X-Forwarded-Proto": "http", - "X-Forwarded-Host": "foo.bar", - "X-Forwarded-Uri": "/path?q=1", - "X-Forwarded-Method": "GET", + "X-CustomHeader": "CustomHeader", + "X-Forwarded-Proto": "http", + "X-Forwarded-Host": "foo.bar", + "X-Forwarded-Uri": "/path?q=1", + "X-Forwarded-Method": "GET", + forward.ProxyAuthenticate: "ProxyAuthenticate", + forward.ProxyAuthorization: "ProxyAuthorization", }, checkForUnexpectedHeaders: true, },