diff --git a/auth/forward.go b/auth/forward.go deleted file mode 100644 index 5c4262257..000000000 --- a/auth/forward.go +++ /dev/null @@ -1,60 +0,0 @@ -package auth - -import ( - "io/ioutil" - "net/http" - - "github.com/containous/traefik/log" - "github.com/containous/traefik/types" -) - -// Forward the authentication to a external server -func Forward(forward *types.Forward, w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - httpClient := http.Client{} - - if forward.TLS != nil { - tlsConfig, err := forward.TLS.CreateTLSConfig() - if err != nil { - log.Debugf("Impossible to configure TLS to call %s. Cause %s", forward.Address, err) - w.WriteHeader(http.StatusInternalServerError) - return - } - httpClient.Transport = &http.Transport{ - TLSClientConfig: tlsConfig, - } - - } - - forwardReq, err := http.NewRequest(http.MethodGet, forward.Address, nil) - if err != nil { - log.Debugf("Error calling %s. Cause %s", forward.Address, err) - w.WriteHeader(http.StatusInternalServerError) - return - } - forwardReq.Header = r.Header - - forwardResponse, forwardErr := httpClient.Do(forwardReq) - if forwardErr != nil { - log.Debugf("Error calling %s. Cause: %s", forward.Address, forwardErr) - w.WriteHeader(http.StatusInternalServerError) - return - } - - body, readError := ioutil.ReadAll(forwardResponse.Body) - if readError != nil { - log.Debugf("Error reading body %s. Cause: %s", forward.Address, readError) - w.WriteHeader(http.StatusInternalServerError) - return - } - defer forwardResponse.Body.Close() - - if forwardResponse.StatusCode < http.StatusOK || forwardResponse.StatusCode >= http.StatusMultipleChoices { - log.Debugf("Remote error %s. StatusCode: %d", forward.Address, forwardResponse.StatusCode) - w.WriteHeader(forwardResponse.StatusCode) - w.Write(body) - return - } - - r.RequestURI = r.URL.RequestURI() - next(w, r) -} diff --git a/docs/configuration/entrypoints.md b/docs/configuration/entrypoints.md index 7ed721258..ed8a08ede 100644 --- a/docs/configuration/entrypoints.md +++ b/docs/configuration/entrypoints.md @@ -116,28 +116,31 @@ This configuration will first forward the request to `http://authserver.com/auth If the response code is 2XX, access is granted and the original request is performed. Otherwise, the response from the auth server is returned. -```toml -[entryPoints] - [entryPoints.http] - # ... - # To enable forward auth on an entrypoint - [entryPoints.http.auth.forward] - address = "http://authserver.com/auth" -``` - ```toml [entryPoints] [entrypoints.http] # ... - # To enable forward auth on an entrypoint (HTTPS) + # To enable forward auth on an entrypoint [entrypoints.http.auth.forward] address = "https://authserver.com/auth" + + # Trust existing X-Forwarded-* headers. + # Useful with another reverse proxy in front of Traefik. + # + # Optional + # Default: false + # + trustForwardHeader = true + + # Enable forward auth TLS connection. + # + # Optional + # [entrypoints.http.auth.forward.tls] cert = "authserver.crt" key = "authserver.key" ``` - ## Specify Minimum TLS Version To specify an https entry point with a minimum TLS version, and specifying an array of cipher suites (from crypto/tls). diff --git a/middlewares/authenticator.go b/middlewares/auth/authenticator.go similarity index 71% rename from middlewares/authenticator.go rename to middlewares/auth/authenticator.go index cdd4203f5..5482948d7 100644 --- a/middlewares/authenticator.go +++ b/middlewares/auth/authenticator.go @@ -1,4 +1,4 @@ -package middlewares +package auth import ( "fmt" @@ -7,7 +7,6 @@ import ( "strings" goauth "github.com/abbot/go-http-auth" - "github.com/containous/traefik/auth" "github.com/containous/traefik/log" "github.com/containous/traefik/types" "github.com/urfave/negroni" @@ -64,52 +63,12 @@ func NewAuthenticator(authConfig *types.Auth) (*Authenticator, error) { }) } else if authConfig.Forward != nil { authenticator.handler = negroni.HandlerFunc(func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - auth.Forward(authConfig.Forward, w, r, next) + Forward(authConfig.Forward, w, r, next) }) } return &authenticator, nil } -func parserBasicUsers(basic *types.Basic) (map[string]string, error) { - var userStrs []string - if basic.UsersFile != "" { - var err error - if userStrs, err = getLinesFromFile(basic.UsersFile); err != nil { - return nil, err - } - } - userStrs = append(basic.Users, userStrs...) - userMap := make(map[string]string) - for _, user := range userStrs { - split := strings.Split(user, ":") - if len(split) != 2 { - return nil, fmt.Errorf("Error parsing Authenticator user: %v", user) - } - userMap[split[0]] = split[1] - } - return userMap, nil -} - -func parserDigestUsers(digest *types.Digest) (map[string]string, error) { - var userStrs []string - if digest.UsersFile != "" { - var err error - if userStrs, err = getLinesFromFile(digest.UsersFile); err != nil { - return nil, err - } - } - userStrs = append(digest.Users, userStrs...) - userMap := make(map[string]string) - for _, user := range userStrs { - split := strings.Split(user, ":") - if len(split) != 3 { - return nil, fmt.Errorf("Error parsing Authenticator user: %v", user) - } - userMap[split[0]+":"+split[1]] = split[2] - } - return userMap, nil -} - func getLinesFromFile(filename string) ([]string, error) { dat, err := ioutil.ReadFile(filename) if err != nil { diff --git a/middlewares/authenticator_test.go b/middlewares/auth/authenticator_test.go similarity index 72% rename from middlewares/authenticator_test.go rename to middlewares/auth/authenticator_test.go index d27e5d1f3..e27834a6c 100644 --- a/middlewares/authenticator_test.go +++ b/middlewares/auth/authenticator_test.go @@ -1,4 +1,4 @@ -package middlewares +package auth import ( "fmt" @@ -157,7 +157,7 @@ func TestDigestAuthFail(t *testing.T) { } func TestBasicAuthUserHeader(t *testing.T) { - authMiddleware, err := NewAuthenticator(&types.Auth{ + middleware, err := NewAuthenticator(&types.Auth{ Basic: &types.Basic{ Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"}, }, @@ -169,7 +169,7 @@ func TestBasicAuthUserHeader(t *testing.T) { assert.Equal(t, "test", r.Header["X-Webauth-User"][0], "auth user should be set") fmt.Fprintln(w, "traefik") }) - n := negroni.New(authMiddleware) + n := negroni.New(middleware) n.UseHandler(handler) ts := httptest.NewServer(n) defer ts.Close() @@ -186,67 +186,3 @@ func TestBasicAuthUserHeader(t *testing.T) { assert.NoError(t, err, "there should be no error") assert.Equal(t, "traefik\n", string(body), "they should be equal") } - -func TestForwardAuthFail(t *testing.T) { - authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "Forbidden", http.StatusForbidden) - })) - defer authTs.Close() - - authMiddleware, err := NewAuthenticator(&types.Auth{ - Forward: &types.Forward{ - Address: authTs.URL, - }, - }) - assert.NoError(t, err, "there should be no error") - - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, "traefik") - }) - n := negroni.New(authMiddleware) - n.UseHandler(handler) - ts := httptest.NewServer(n) - defer ts.Close() - - client := &http.Client{} - req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil) - res, err := client.Do(req) - assert.NoError(t, err, "there should be no error") - assert.Equal(t, http.StatusForbidden, res.StatusCode, "they should be equal") - - body, err := ioutil.ReadAll(res.Body) - assert.NoError(t, err, "there should be no error") - assert.Equal(t, "Forbidden\n", string(body), "they should be equal") -} - -func TestForwardAuthSuccess(t *testing.T) { - authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, "Success") - })) - defer authTs.Close() - - authMiddleware, err := NewAuthenticator(&types.Auth{ - Forward: &types.Forward{ - Address: authTs.URL, - }, - }) - assert.NoError(t, err, "there should be no error") - - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, "traefik") - }) - n := negroni.New(authMiddleware) - n.UseHandler(handler) - ts := httptest.NewServer(n) - defer ts.Close() - - client := &http.Client{} - req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil) - res, err := client.Do(req) - assert.NoError(t, err, "there should be no error") - assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal") - - body, err := ioutil.ReadAll(res.Body) - assert.NoError(t, err, "there should be no error") - assert.Equal(t, "traefik\n", string(body), "they should be equal") -} diff --git a/middlewares/auth/forward.go b/middlewares/auth/forward.go new file mode 100644 index 000000000..30bcd6452 --- /dev/null +++ b/middlewares/auth/forward.go @@ -0,0 +1,97 @@ +package auth + +import ( + "io/ioutil" + "net" + "net/http" + "strings" + + "github.com/containous/traefik/log" + "github.com/containous/traefik/types" + "github.com/vulcand/oxy/forward" + "github.com/vulcand/oxy/utils" +) + +// Forward the authentication to a external server +func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + httpClient := http.Client{} + + if config.TLS != nil { + tlsConfig, err := config.TLS.CreateTLSConfig() + if err != nil { + log.Debugf("Impossible to configure TLS to call %s. Cause %s", config.Address, err) + w.WriteHeader(http.StatusInternalServerError) + return + } + httpClient.Transport = &http.Transport{ + TLSClientConfig: tlsConfig, + } + } + + forwardReq, err := http.NewRequest(http.MethodGet, config.Address, nil) + if err != nil { + log.Debugf("Error calling %s. Cause %s", config.Address, err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + writeHeader(r, forwardReq, config.TrustForwardHeader) + + forwardResponse, forwardErr := httpClient.Do(forwardReq) + if forwardErr != nil { + log.Debugf("Error calling %s. Cause: %s", config.Address, forwardErr) + w.WriteHeader(http.StatusInternalServerError) + return + } + + body, readError := ioutil.ReadAll(forwardResponse.Body) + if readError != nil { + log.Debugf("Error reading body %s. Cause: %s", config.Address, readError) + w.WriteHeader(http.StatusInternalServerError) + return + } + defer forwardResponse.Body.Close() + + if forwardResponse.StatusCode < http.StatusOK || forwardResponse.StatusCode >= http.StatusMultipleChoices { + log.Debugf("Remote error %s. StatusCode: %d", config.Address, forwardResponse.StatusCode) + w.WriteHeader(forwardResponse.StatusCode) + w.Write(body) + return + } + + r.RequestURI = r.URL.RequestURI() + next(w, r) +} + +func writeHeader(req *http.Request, forwardReq *http.Request, trustForwardHeader bool) { + utils.CopyHeaders(forwardReq.Header, req.Header) + + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + if trustForwardHeader { + if prior, ok := req.Header[forward.XForwardedFor]; ok { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + } + forwardReq.Header.Set(forward.XForwardedFor, clientIP) + } + + if xfp := req.Header.Get(forward.XForwardedProto); xfp != "" && trustForwardHeader { + forwardReq.Header.Set(forward.XForwardedProto, xfp) + } else if req.TLS != nil { + forwardReq.Header.Set(forward.XForwardedProto, "https") + } else { + forwardReq.Header.Set(forward.XForwardedProto, "http") + } + + if xfp := req.Header.Get(forward.XForwardedPort); xfp != "" && trustForwardHeader { + forwardReq.Header.Set(forward.XForwardedPort, xfp) + } + + if xfh := req.Header.Get(forward.XForwardedHost); xfh != "" && trustForwardHeader { + forwardReq.Header.Set(forward.XForwardedHost, xfh) + } else if req.Host != "" { + forwardReq.Header.Set(forward.XForwardedHost, req.Host) + } else { + forwardReq.Header.Del(forward.XForwardedHost) + } +} diff --git a/middlewares/auth/forward_test.go b/middlewares/auth/forward_test.go new file mode 100644 index 000000000..44b619689 --- /dev/null +++ b/middlewares/auth/forward_test.go @@ -0,0 +1,162 @@ +package auth + +import ( + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "github.com/containous/traefik/testhelpers" + "github.com/containous/traefik/types" + "github.com/stretchr/testify/assert" + "github.com/urfave/negroni" +) + +func TestForwardAuthFail(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "Forbidden", http.StatusForbidden) + })) + defer server.Close() + + middleware, err := NewAuthenticator(&types.Auth{ + Forward: &types.Forward{ + Address: server.URL, + }, + }) + assert.NoError(t, err, "there should be no error") + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "traefik") + }) + n := negroni.New(middleware) + n.UseHandler(handler) + ts := httptest.NewServer(n) + defer ts.Close() + + client := &http.Client{} + req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil) + res, err := client.Do(req) + assert.NoError(t, err, "there should be no error") + assert.Equal(t, http.StatusForbidden, res.StatusCode, "they should be equal") + + body, err := ioutil.ReadAll(res.Body) + assert.NoError(t, err, "there should be no error") + assert.Equal(t, "Forbidden\n", string(body), "they should be equal") +} + +func TestForwardAuthSuccess(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Success") + })) + defer server.Close() + + middleware, err := NewAuthenticator(&types.Auth{ + Forward: &types.Forward{ + Address: server.URL, + }, + }) + assert.NoError(t, err, "there should be no error") + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "traefik") + }) + n := negroni.New(middleware) + n.UseHandler(handler) + ts := httptest.NewServer(n) + defer ts.Close() + + client := &http.Client{} + req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil) + res, err := client.Do(req) + assert.NoError(t, err, "there should be no error") + assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal") + + body, err := ioutil.ReadAll(res.Body) + assert.NoError(t, err, "there should be no error") + assert.Equal(t, "traefik\n", string(body), "they should be equal") +} + +func Test_writeHeader(t *testing.T) { + + testCases := []struct { + name string + headers map[string]string + trustForwardHeader bool + emptyHost bool + expectedHeaders map[string]string + }{ + { + name: "trust Forward Header", + headers: map[string]string{ + "Accept": "application/json", + "X-Forwarded-Host": "fii.bir", + }, + trustForwardHeader: true, + expectedHeaders: map[string]string{ + "Accept": "application/json", + "X-Forwarded-Host": "fii.bir", + }, + }, + { + name: "not trust Forward Header", + headers: map[string]string{ + "Accept": "application/json", + "X-Forwarded-Host": "fii.bir", + }, + trustForwardHeader: false, + expectedHeaders: map[string]string{ + "Accept": "application/json", + "X-Forwarded-Host": "foo.bar", + }, + }, + { + name: "trust Forward Header with empty Host", + headers: map[string]string{ + "Accept": "application/json", + "X-Forwarded-Host": "fii.bir", + }, + trustForwardHeader: true, + emptyHost: true, + expectedHeaders: map[string]string{ + "Accept": "application/json", + "X-Forwarded-Host": "fii.bir", + }, + }, + { + name: "not trust Forward Header with empty Host", + headers: map[string]string{ + "Accept": "application/json", + "X-Forwarded-Host": "fii.bir", + }, + trustForwardHeader: false, + emptyHost: true, + expectedHeaders: map[string]string{ + "Accept": "application/json", + "X-Forwarded-Host": "", + }, + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + + req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar", nil) + for key, value := range test.headers { + req.Header.Set(key, value) + } + + if test.emptyHost { + req.Host = "" + } + + forwardReq := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar", nil) + + writeHeader(req, forwardReq, test.trustForwardHeader) + + for key, value := range test.expectedHeaders { + assert.Equal(t, value, forwardReq.Header.Get(key)) + } + }) + } +} diff --git a/middlewares/auth/parser.go b/middlewares/auth/parser.go new file mode 100644 index 000000000..9bf82e4a7 --- /dev/null +++ b/middlewares/auth/parser.go @@ -0,0 +1,48 @@ +package auth + +import ( + "fmt" + "strings" + + "github.com/containous/traefik/types" +) + +func parserBasicUsers(basic *types.Basic) (map[string]string, error) { + var userStrs []string + if basic.UsersFile != "" { + var err error + if userStrs, err = getLinesFromFile(basic.UsersFile); err != nil { + return nil, err + } + } + userStrs = append(basic.Users, userStrs...) + userMap := make(map[string]string) + for _, user := range userStrs { + split := strings.Split(user, ":") + if len(split) != 2 { + return nil, fmt.Errorf("Error parsing Authenticator user: %v", user) + } + userMap[split[0]] = split[1] + } + return userMap, nil +} + +func parserDigestUsers(digest *types.Digest) (map[string]string, error) { + var userStrs []string + if digest.UsersFile != "" { + var err error + if userStrs, err = getLinesFromFile(digest.UsersFile); err != nil { + return nil, err + } + } + userStrs = append(digest.Users, userStrs...) + userMap := make(map[string]string) + for _, user := range userStrs { + split := strings.Split(user, ":") + if len(split) != 3 { + return nil, fmt.Errorf("Error parsing Authenticator user: %v", user) + } + userMap[split[0]+":"+split[1]] = split[2] + } + return userMap, nil +} diff --git a/provider/web/web.go b/provider/web/web.go index 035052d6f..0e2d655e1 100644 --- a/provider/web/web.go +++ b/provider/web/web.go @@ -12,6 +12,7 @@ import ( "github.com/containous/traefik/autogen" "github.com/containous/traefik/log" "github.com/containous/traefik/middlewares" + mauth "github.com/containous/traefik/middlewares/auth" "github.com/containous/traefik/safe" "github.com/containous/traefik/types" "github.com/containous/traefik/version" @@ -135,7 +136,7 @@ func (provider *Provider) Provide(configurationChan chan<- types.ConfigMessage, var err error var negroniInstance = negroni.New() if provider.Auth != nil { - authMiddleware, err := middlewares.NewAuthenticator(provider.Auth) + authMiddleware, err := mauth.NewAuthenticator(provider.Auth) if err != nil { log.Fatal("Error creating Auth: ", err) } diff --git a/server/server.go b/server/server.go index 7bc45acb9..a188dd8cb 100644 --- a/server/server.go +++ b/server/server.go @@ -27,6 +27,7 @@ import ( "github.com/containous/traefik/metrics" "github.com/containous/traefik/middlewares" "github.com/containous/traefik/middlewares/accesslog" + mauth "github.com/containous/traefik/middlewares/auth" "github.com/containous/traefik/provider" "github.com/containous/traefik/safe" "github.com/containous/traefik/types" @@ -282,7 +283,7 @@ func (server *Server) setupServerEntryPoint(newServerEntryPointName string, newS } } if server.globalConfiguration.EntryPoints[newServerEntryPointName].Auth != nil { - authMiddleware, err := middlewares.NewAuthenticator(server.globalConfiguration.EntryPoints[newServerEntryPointName].Auth) + authMiddleware, err := mauth.NewAuthenticator(server.globalConfiguration.EntryPoints[newServerEntryPointName].Auth) if err != nil { log.Fatal("Error starting server: ", err) } @@ -935,7 +936,7 @@ func (server *Server) loadConfig(configurations types.Configurations, globalConf auth.Basic = &types.Basic{ Users: users, } - authMiddleware, err := middlewares.NewAuthenticator(auth) + authMiddleware, err := mauth.NewAuthenticator(auth) if err != nil { log.Errorf("Error creating Auth: %s", err) } else { diff --git a/types/types.go b/types/types.go index 1732e5f4d..9f593b4e2 100644 --- a/types/types.go +++ b/types/types.go @@ -326,8 +326,9 @@ type Digest struct { // Forward authentication type Forward struct { - Address string `description:"Authentication server address"` - TLS *ClientTLS `description:"Enable TLS support"` + Address string `description:"Authentication server address"` + TLS *ClientTLS `description:"Enable TLS support"` + TrustForwardHeader bool `description:"Trust X-Forwarded-* headers"` } // CanonicalDomain returns a lower case domain with trim space