diff --git a/middlewares/auth/forward.go b/middlewares/auth/forward.go index 5e9de62ca..233d97fb0 100644 --- a/middlewares/auth/forward.go +++ b/middlewares/auth/forward.go @@ -12,6 +12,10 @@ import ( "github.com/vulcand/oxy/utils" ) +const ( + xForwardedURI = "X-Forwarded-Uri" +) + // Forward the authentication to a external server func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { @@ -122,4 +126,12 @@ func writeHeader(req *http.Request, forwardReq *http.Request, trustForwardHeader } else { forwardReq.Header.Del(forward.XForwardedHost) } + + if xfURI := req.Header.Get(xForwardedURI); xfURI != "" && trustForwardHeader { + forwardReq.Header.Set(xForwardedURI, xfURI) + } else if req.URL.RequestURI() != "" { + forwardReq.Header.Set(xForwardedURI, req.URL.RequestURI()) + } else { + forwardReq.Header.Del(xForwardedURI) + } } diff --git a/middlewares/auth/forward_test.go b/middlewares/auth/forward_test.go index 7f6534674..6c6147ace 100644 --- a/middlewares/auth/forward_test.go +++ b/middlewares/auth/forward_test.go @@ -215,12 +215,40 @@ func Test_writeHeader(t *testing.T) { "X-Forwarded-Host": "", }, }, + { + name: "trust Forward Header with forwarded URI", + headers: map[string]string{ + "Accept": "application/json", + "X-Forwarded-Host": "fii.bir", + "X-Forwarded-Uri": "/forward?q=1", + }, + trustForwardHeader: true, + expectedHeaders: map[string]string{ + "Accept": "application/json", + "X-Forwarded-Host": "fii.bir", + "X-Forwarded-Uri": "/forward?q=1", + }, + }, + { + name: "not trust Forward Header with forward requested URI", + headers: map[string]string{ + "Accept": "application/json", + "X-Forwarded-Host": "fii.bir", + "X-Forwarded-Uri": "/forward?q=1", + }, + trustForwardHeader: false, + expectedHeaders: map[string]string{ + "Accept": "application/json", + "X-Forwarded-Host": "foo.bar", + "X-Forwarded-Uri": "/path?q=1", + }, + }, } for _, test := range testCases { t.Run(test.name, func(t *testing.T) { - req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar", nil) + req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/path?q=1", nil) for key, value := range test.headers { req.Header.Set(key, value) } @@ -229,7 +257,7 @@ func Test_writeHeader(t *testing.T) { req.Host = "" } - forwardReq := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar", nil) + forwardReq := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/path?q=1", nil) writeHeader(req, forwardReq, test.trustForwardHeader)