From 2b4d33e9198296b8abe8d2f419b34edd9a9c15fa Mon Sep 17 00:00:00 2001 From: Kendrick Erickson Date: Thu, 2 Nov 2017 05:06:03 -0500 Subject: [PATCH] Pass through certain forward auth negative response headers --- middlewares/auth/forward.go | 30 +++++++++++- middlewares/auth/forward_test.go | 79 ++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 1 deletion(-) diff --git a/middlewares/auth/forward.go b/middlewares/auth/forward.go index 30bcd6452..5e9de62ca 100644 --- a/middlewares/auth/forward.go +++ b/middlewares/auth/forward.go @@ -14,7 +14,13 @@ import ( // Forward the authentication to a external server func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - httpClient := http.Client{} + + // Ensure our request client does not follow redirects + httpClient := http.Client{ + CheckRedirect: func(r *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } if config.TLS != nil { tlsConfig, err := config.TLS.CreateTLSConfig() @@ -52,8 +58,30 @@ func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next } defer forwardResponse.Body.Close() + // Pass the forward response's body and selected headers if it + // didn't return a response within the range of [200, 300). if forwardResponse.StatusCode < http.StatusOK || forwardResponse.StatusCode >= http.StatusMultipleChoices { log.Debugf("Remote error %s. StatusCode: %d", config.Address, forwardResponse.StatusCode) + + // Grab the location header, if any. + redirectURL, err := forwardResponse.Location() + + if err != nil { + if err != http.ErrNoLocation { + log.Debugf("Error reading response location header %s. Cause: %s", config.Address, err) + w.WriteHeader(http.StatusInternalServerError) + return + } + } else if redirectURL.String() != "" { + // Set the location in our response if one was sent back. + w.Header().Add("Location", redirectURL.String()) + } + + // Pass any Set-Cookie headers the forward auth server provides + for _, cookie := range forwardResponse.Cookies() { + w.Header().Add("Set-Cookie", cookie.String()) + } + w.WriteHeader(forwardResponse.StatusCode) w.Write(body) return diff --git a/middlewares/auth/forward_test.go b/middlewares/auth/forward_test.go index 44b619689..7f6534674 100644 --- a/middlewares/auth/forward_test.go +++ b/middlewares/auth/forward_test.go @@ -77,6 +77,85 @@ func TestForwardAuthSuccess(t *testing.T) { assert.Equal(t, "traefik\n", string(body), "they should be equal") } +func TestForwardAuthRedirect(t *testing.T) { + authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "http://example.com/redirect-test", http.StatusFound) + })) + defer authTs.Close() + + authMiddleware, err := NewAuthenticator(&types.Auth{ + Forward: &types.Forward{ + Address: authTs.URL, + }, + }) + assert.NoError(t, err, "there should be no error") + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "traefik") + }) + n := negroni.New(authMiddleware) + n.UseHandler(handler) + ts := httptest.NewServer(n) + defer ts.Close() + + client := &http.Client{ + CheckRedirect: func(r *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil) + res, err := client.Do(req) + assert.NoError(t, err, "there should be no error") + assert.Equal(t, http.StatusFound, res.StatusCode, "they should be equal") + + location, err := res.Location() + + assert.NoError(t, err, "there should be no error") + assert.Equal(t, "http://example.com/redirect-test", location.String(), "they should be equal") + + body, err := ioutil.ReadAll(res.Body) + assert.NoError(t, err, "there should be no error") + assert.NotEmpty(t, string(body), "there should be something in the body") +} + +func TestForwardAuthCookie(t *testing.T) { + authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookie := &http.Cookie{Name: "example", Value: "testing", Path: "/"} + http.SetCookie(w, cookie) + http.Error(w, "Forbidden", http.StatusForbidden) + })) + defer authTs.Close() + + authMiddleware, err := NewAuthenticator(&types.Auth{ + Forward: &types.Forward{ + Address: authTs.URL, + }, + }) + assert.NoError(t, err, "there should be no error") + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "traefik") + }) + n := negroni.New(authMiddleware) + n.UseHandler(handler) + ts := httptest.NewServer(n) + defer ts.Close() + + client := &http.Client{} + req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil) + res, err := client.Do(req) + assert.NoError(t, err, "there should be no error") + assert.Equal(t, http.StatusForbidden, res.StatusCode, "they should be equal") + + for _, cookie := range res.Cookies() { + assert.Equal(t, "testing", cookie.Value, "they should be equal") + } + + body, err := ioutil.ReadAll(res.Body) + assert.NoError(t, err, "there should be no error") + assert.Equal(t, "Forbidden\n", string(body), "they should be equal") +} + func Test_writeHeader(t *testing.T) { testCases := []struct {