From 336135c3927f072301578c4d6c159d2ecb162acb Mon Sep 17 00:00:00 2001 From: mpl Date: Tue, 2 Apr 2019 16:56:05 +0200 Subject: [PATCH] Set X-Forwarded-* headers Co-authored-by: Julien Salleyron --- integration/simple_test.go | 28 ++++ .../forwardedheaders/forwarded_header.go | 124 +++++++++++++++++- .../forwardedheaders/forwarded_header_test.go | 111 +++++++++++++++- 3 files changed, 258 insertions(+), 5 deletions(-) diff --git a/integration/simple_test.go b/integration/simple_test.go index 7eb3e3544..e3c4060ca 100644 --- a/integration/simple_test.go +++ b/integration/simple_test.go @@ -404,6 +404,34 @@ func (s *SimpleSuite) TestIPStrategyWhitelist(c *check.C) { } } +func (s *SimpleSuite) TestXForwardedHeaders(c *check.C) { + s.createComposeProject(c, "whitelist") + s.composeProject.Start(c) + + cmd, output := s.traefikCmd(withConfigFile("fixtures/simple_whitelist.toml")) + defer output(c) + + err := cmd.Start() + c.Assert(err, checker.IsNil) + defer cmd.Process.Kill() + + err = try.GetRequest("http://127.0.0.1:8080/api/providers/docker/routers", 2*time.Second, + try.BodyContains("override.remoteaddr.whitelist.docker.local")) + c.Assert(err, checker.IsNil) + + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000", nil) + c.Assert(err, checker.IsNil) + + req.Host = "override.depth.whitelist.docker.local" + req.Header.Set("X-Forwarded-For", "8.8.8.8,10.0.0.1,127.0.0.1") + + err = try.Request(req, 1*time.Second, + try.StatusCodeIs(http.StatusOK), + try.BodyContains("X-Forwarded-Proto", "X-Forwarded-For", "X-Forwarded-Host", + "X-Forwarded-Host", "X-Forwarded-Port", "X-Forwarded-Server", "X-Real-Ip")) + c.Assert(err, checker.IsNil) +} + func (s *SimpleSuite) TestKeepTrailingSlash(c *check.C) { file := s.adaptFile(c, "fixtures/keep_trailing_slash.toml", struct { KeepTrailingSlash bool diff --git a/pkg/middlewares/forwardedheaders/forwarded_header.go b/pkg/middlewares/forwardedheaders/forwarded_header.go index 7512d2da8..74bdc3941 100644 --- a/pkg/middlewares/forwardedheaders/forwarded_header.go +++ b/pkg/middlewares/forwardedheaders/forwarded_header.go @@ -1,19 +1,43 @@ package forwardedheaders import ( + "net" "net/http" + "os" + "strings" "github.com/containous/traefik/pkg/ip" - "github.com/vulcand/oxy/forward" - "github.com/vulcand/oxy/utils" ) -// XForwarded filter for XForwarded headers. +const ( + xForwardedProto = "X-Forwarded-Proto" + xForwardedFor = "X-Forwarded-For" + xForwardedHost = "X-Forwarded-Host" + xForwardedPort = "X-Forwarded-Port" + xForwardedServer = "X-Forwarded-Server" + xRealIP = "X-Real-Ip" + connection = "Connection" + upgrade = "Upgrade" +) + +var xHeaders = []string{ + xForwardedProto, + xForwardedFor, + xForwardedHost, + xForwardedPort, + xForwardedServer, + xRealIP, +} + +// XForwarded is an HTTP handler wrapper that sets the X-Forwarded headers, and other relevant headers for a +// reverse-proxy. Unless insecure is set, it first removes all the existing values for those headers if the remote +// address is not one of the trusted ones. type XForwarded struct { insecure bool trustedIps []string ipChecker *ip.Checker next http.Handler + hostname string } // NewXForwarded creates a new XForwarded. @@ -27,11 +51,17 @@ func NewXForwarded(insecure bool, trustedIps []string, next http.Handler) (*XFor } } + hostname, err := os.Hostname() + if err != nil { + hostname = "localhost" + } + return &XForwarded{ insecure: insecure, trustedIps: trustedIps, ipChecker: ipChecker, next: next, + hostname: hostname, }, nil } @@ -42,10 +72,96 @@ func (x *XForwarded) isTrustedIP(ip string) bool { return x.ipChecker.IsAuthorized(ip) == nil } +// removeIPv6Zone removes the zone if the given IP is an ipv6 address and it has +// {zone} information in it, like "[fe80::d806:a55d:eb1b:49cc%vEthernet (vmxnet3 +// Ethernet Adapter - Virtual Switch)]:64692" +func removeIPv6Zone(clientIP string) string { + return strings.Split(clientIP, "%")[0] +} + +// isWebsocketRequest returns whether the specified HTTP request is a +// websocket handshake request +func isWebsocketRequest(req *http.Request) bool { + containsHeader := func(name, value string) bool { + items := strings.Split(req.Header.Get(name), ",") + for _, item := range items { + if value == strings.ToLower(strings.TrimSpace(item)) { + return true + } + } + return false + } + return containsHeader(connection, "upgrade") && containsHeader(upgrade, "websocket") +} + +func forwardedPort(req *http.Request) string { + if req == nil { + return "" + } + + if _, port, err := net.SplitHostPort(req.Host); err == nil && port != "" { + return port + } + + if req.Header.Get(xForwardedProto) == "https" || req.Header.Get(xForwardedProto) == "wss" { + return "443" + } + + if req.TLS != nil { + return "443" + } + + return "80" +} + +func (x *XForwarded) rewrite(outreq *http.Request) { + if clientIP, _, err := net.SplitHostPort(outreq.RemoteAddr); err == nil { + clientIP = removeIPv6Zone(clientIP) + + if outreq.Header.Get(xRealIP) == "" { + outreq.Header.Set(xRealIP, clientIP) + } + } + + xfProto := outreq.Header.Get(xForwardedProto) + if xfProto == "" { + if outreq.TLS != nil { + outreq.Header.Set(xForwardedProto, "https") + } else { + outreq.Header.Set(xForwardedProto, "http") + } + } + + if isWebsocketRequest(outreq) { + if outreq.Header.Get(xForwardedProto) == "https" { + outreq.Header.Set(xForwardedProto, "wss") + } else { + outreq.Header.Set(xForwardedProto, "ws") + } + } + + if xfPort := outreq.Header.Get(xForwardedPort); xfPort == "" { + outreq.Header.Set(xForwardedPort, forwardedPort(outreq)) + } + + if xfHost := outreq.Header.Get(xForwardedHost); xfHost == "" && outreq.Host != "" { + outreq.Header.Set(xForwardedHost, outreq.Host) + } + + if x.hostname != "" { + outreq.Header.Set(xForwardedServer, x.hostname) + } +} + +// ServeHTTP implements http.Handler func (x *XForwarded) ServeHTTP(w http.ResponseWriter, r *http.Request) { if !x.insecure && !x.isTrustedIP(r.RemoteAddr) { - utils.RemoveHeaders(r.Header, forward.XHeaders...) + for _, h := range xHeaders { + r.Header.Del(h) + } } + x.rewrite(r) + x.next.ServeHTTP(w, r) } diff --git a/pkg/middlewares/forwardedheaders/forwarded_header_test.go b/pkg/middlewares/forwardedheaders/forwarded_header_test.go index 1062c8ea1..dbcf04b41 100644 --- a/pkg/middlewares/forwardedheaders/forwarded_header_test.go +++ b/pkg/middlewares/forwardedheaders/forwarded_header_test.go @@ -1,6 +1,7 @@ package forwardedheaders import ( + "crypto/tls" "net/http" "testing" @@ -16,6 +17,9 @@ func TestServeHTTP(t *testing.T) { incomingHeaders map[string]string remoteAddr string expectedHeaders map[string]string + tls bool + websocket bool + host string }{ { desc: "all Empty", @@ -99,6 +103,93 @@ func TestServeHTTP(t *testing.T) { "X-Forwarded-for": "", }, }, + { + desc: "xRealIP populated from remote address", + remoteAddr: "10.0.1.101:80", + expectedHeaders: map[string]string{ + xRealIP: "10.0.1.101", + }, + }, + { + desc: "xRealIP was already populated from previous headers", + insecure: true, + remoteAddr: "10.0.1.101:80", + incomingHeaders: map[string]string{ + xRealIP: "10.0.1.12", + }, + expectedHeaders: map[string]string{ + xRealIP: "10.0.1.12", + }, + }, + { + desc: "xForwardedProto with no tls", + tls: false, + expectedHeaders: map[string]string{ + xForwardedProto: "http", + }, + }, + { + desc: "xForwardedProto with tls", + tls: true, + expectedHeaders: map[string]string{ + xForwardedProto: "https", + }, + }, + { + desc: "xForwardedProto with websocket", + tls: false, + websocket: true, + expectedHeaders: map[string]string{ + xForwardedProto: "ws", + }, + }, + { + desc: "xForwardedProto with websocket and tls", + tls: true, + websocket: true, + expectedHeaders: map[string]string{ + xForwardedProto: "wss", + }, + }, + { + desc: "xForwardedPort with explicit port", + host: "foo.com:8080", + expectedHeaders: map[string]string{ + xForwardedPort: "8080", + }, + }, + { + desc: "xForwardedPort with implicit tls port from proto header", + // setting insecure just so our initial xForwardedProto does not get cleaned + insecure: true, + incomingHeaders: map[string]string{ + xForwardedProto: "https", + }, + expectedHeaders: map[string]string{ + xForwardedProto: "https", + xForwardedPort: "443", + }, + }, + { + desc: "xForwardedPort with implicit tls port from TLS in req", + tls: true, + expectedHeaders: map[string]string{ + xForwardedPort: "443", + }, + }, + { + desc: "xForwardedHost from req host", + host: "foo.com:8080", + expectedHeaders: map[string]string{ + xForwardedHost: "foo.com:8080", + }, + }, { + desc: "xForwardedServer from req XForwarded", + host: "foo.com:8080", + expectedHeaders: map[string]string{ + xForwardedServer: "foo.com:8080", + }, + }, } for _, test := range testCases { @@ -111,13 +202,31 @@ func TestServeHTTP(t *testing.T) { req.RemoteAddr = test.remoteAddr + if test.tls { + req.TLS = &tls.ConnectionState{} + } + + if test.websocket { + req.Header.Set(connection, "upgrade") + req.Header.Set(upgrade, "websocket") + } + + if test.host != "" { + req.Host = test.host + } + for k, v := range test.incomingHeaders { req.Header.Set(k, v) } - m, err := NewXForwarded(test.insecure, test.trustedIps, http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) + m, err := NewXForwarded(test.insecure, test.trustedIps, + http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) require.NoError(t, err) + if test.host != "" { + m.hostname = test.host + } + m.ServeHTTP(nil, req) for k, v := range test.expectedHeaders {