From 3f6ea04048239cae6fde10e6bce413e5bcfb72d5 Mon Sep 17 00:00:00 2001 From: Daniel Tomcej Date: Fri, 12 Jul 2019 03:46:04 -0600 Subject: [PATCH] Properly add response headers for CORS --- integration/fixtures/headers/basic.toml | 15 -- integration/fixtures/headers/cors.toml | 12 +- integration/headers_test.go | 76 ++++----- integration/resources/compose/headers.yml | 4 - integration/try/condition.go | 13 +- pkg/middlewares/headers/headers.go | 183 +++++++++++++--------- pkg/middlewares/headers/headers_test.go | 31 +++- pkg/responsemodifiers/headers.go | 2 +- 8 files changed, 198 insertions(+), 138 deletions(-) delete mode 100644 integration/resources/compose/headers.yml diff --git a/integration/fixtures/headers/basic.toml b/integration/fixtures/headers/basic.toml index 5ed0ab4ba..29b1eb79e 100644 --- a/integration/fixtures/headers/basic.toml +++ b/integration/fixtures/headers/basic.toml @@ -8,18 +8,3 @@ [entryPoints] [entryPoints.web] address = ":8000" - -[providers] - [providers.file] - -## dynamic configuration ## - -[http.routers] - [http.routers.router1] - rule = "Host(`test.localhost`)" - service = "service1" - -[http.services] - [http.services.service1.loadBalancer] - [[http.services.service1.loadBalancer.servers]] - url = "http://172.17.0.2:80" diff --git a/integration/fixtures/headers/cors.toml b/integration/fixtures/headers/cors.toml index 7cc707cab..32b8d2995 100644 --- a/integration/fixtures/headers/cors.toml +++ b/integration/fixtures/headers/cors.toml @@ -17,6 +17,12 @@ [http.routers] [http.routers.router1] rule = "Host(`test.localhost`)" + middlewares = ["cors"] + service = "service1" + + [http.routers.router2] + rule = "Host(`test2.localhost`)" + middlewares = ["nocors"] service = "service1" [http.middlewares] @@ -26,7 +32,11 @@ accessControlMaxAge = 100 addVaryHeader = true + [http.middlewares.nocors.Headers] + [http.middlewares.nocors.Headers.CustomResponseHeaders] + X-Custom-Response-Header = "True" + [http.services] [http.services.service1.loadBalancer] [[http.services.service1.loadBalancer.servers]] - url = "http://172.17.0.2:80" + url = "http://127.0.0.1:9000" diff --git a/integration/headers_test.go b/integration/headers_test.go index 8ddc4f7a4..dd60a5876 100644 --- a/integration/headers_test.go +++ b/integration/headers_test.go @@ -12,12 +12,6 @@ import ( // Headers test suites type HeadersSuite struct{ BaseSuite } -func (s *HeadersSuite) SetUpSuite(c *check.C) { - s.createComposeProject(c, "headers") - - s.composeProject.Start(c) -} - func (s *HeadersSuite) TestSimpleConfiguration(c *check.C) { cmd, display := s.traefikCmd(withConfigFile("fixtures/headers/basic.toml")) defer display(c) @@ -38,10 +32,18 @@ func (s *HeadersSuite) TestCorsResponses(c *check.C) { c.Assert(err, checker.IsNil) defer cmd.Process.Kill() + backend := startTestServer("9000", http.StatusOK) + defer backend.Close() + + err = try.GetRequest(backend.URL, 500*time.Millisecond, try.StatusCodeIs(http.StatusOK)) + c.Assert(err, checker.IsNil) + testCase := []struct { desc string requestHeaders http.Header expected http.Header + reqHost string + method string }{ { desc: "simple access control allow origin", @@ -52,33 +54,9 @@ func (s *HeadersSuite) TestCorsResponses(c *check.C) { "Access-Control-Allow-Origin": {"https://foo.bar.org"}, "Vary": {"Origin"}, }, + reqHost: "test.localhost", + method: http.MethodGet, }, - } - - for _, test := range testCase { - req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/", nil) - c.Assert(err, checker.IsNil) - req.Host = "test.localhost" - req.Header = test.requestHeaders - - err = try.Request(req, 500*time.Millisecond, try.HasBody(), try.HasHeaderStruct(test.expected)) - c.Assert(err, checker.IsNil) - } -} - -func (s *HeadersSuite) TestCorsPreflightResponses(c *check.C) { - cmd, display := s.traefikCmd(withConfigFile("fixtures/headers/cors.toml")) - defer display(c) - - err := cmd.Start() - c.Assert(err, checker.IsNil) - defer cmd.Process.Kill() - - testCase := []struct { - desc string - requestHeaders http.Header - expected http.Header - }{ { desc: "simple preflight request", requestHeaders: http.Header{ @@ -91,16 +69,44 @@ func (s *HeadersSuite) TestCorsPreflightResponses(c *check.C) { "Access-Control-Max-Age": {"100"}, "Access-Control-Allow-Methods": {"GET,OPTIONS,PUT"}, }, + reqHost: "test.localhost", + method: http.MethodOptions, + }, + { + desc: "preflight Options request with no cors configured", + requestHeaders: http.Header{ + "Access-Control-Request-Headers": {"origin"}, + "Access-Control-Request-Method": {"GET", "OPTIONS"}, + "Origin": {"https://foo.bar.org"}, + }, + expected: http.Header{ + "X-Custom-Response-Header": {"True"}, + }, + reqHost: "test2.localhost", + method: http.MethodOptions, + }, + { + desc: "preflight Get request with no cors configured", + requestHeaders: http.Header{ + "Access-Control-Request-Headers": {"origin"}, + "Access-Control-Request-Method": {"GET", "OPTIONS"}, + "Origin": {"https://foo.bar.org"}, + }, + expected: http.Header{ + "X-Custom-Response-Header": {"True"}, + }, + reqHost: "test2.localhost", + method: http.MethodGet, }, } for _, test := range testCase { - req, err := http.NewRequest(http.MethodOptions, "http://127.0.0.1:8000/", nil) + req, err := http.NewRequest(test.method, "http://127.0.0.1:8000/", nil) c.Assert(err, checker.IsNil) - req.Host = "test.localhost" + req.Host = test.reqHost req.Header = test.requestHeaders - err = try.Request(req, 500*time.Millisecond, try.HasBody(), try.HasHeaderStruct(test.expected)) + err = try.Request(req, 500*time.Millisecond, try.HasHeaderStruct(test.expected)) c.Assert(err, checker.IsNil) } } diff --git a/integration/resources/compose/headers.yml b/integration/resources/compose/headers.yml deleted file mode 100644 index fba4da55d..000000000 --- a/integration/resources/compose/headers.yml +++ /dev/null @@ -1,4 +0,0 @@ -whoami1: - image: containous/whoami - ports: - - "8881:80" diff --git a/integration/try/condition.go b/integration/try/condition.go index ce3db4ca2..d7d94c4ee 100644 --- a/integration/try/condition.go +++ b/integration/try/condition.go @@ -168,18 +168,17 @@ func HasHeaderValue(header, value string, exactMatch bool) ResponseCondition { func HasHeaderStruct(header http.Header) ResponseCondition { return func(res *http.Response) error { for key := range header { - if _, ok := res.Header[key]; ok { - // Header exists in the response, test it. - eq := reflect.DeepEqual(header[key], res.Header[key]) - if !eq { - return fmt.Errorf("for header %s got values %v, wanted %v", key, res.Header[key], header[key]) - } + if _, ok := res.Header[key]; !ok { + return fmt.Errorf("header %s not present in the response. Expected headers: %v Got response headers: %v", key, header, res.Header) + } + // Header exists in the response, test it. + if !reflect.DeepEqual(header[key], res.Header[key]) { + return fmt.Errorf("for header %s got values %v, wanted %v", key, res.Header[key], header[key]) } } return nil } - } // DoCondition is a retry condition function. diff --git a/pkg/middlewares/headers/headers.go b/pkg/middlewares/headers/headers.go index 526bc8b8e..a4d7de259 100644 --- a/pkg/middlewares/headers/headers.go +++ b/pkg/middlewares/headers/headers.go @@ -16,8 +16,7 @@ import ( ) const ( - typeName = "Headers" - originHeaderKey = "X-Request-Origin" + typeName = "Headers" ) type headers struct { @@ -107,29 +106,127 @@ func (s secureHeader) ServeHTTP(rw http.ResponseWriter, req *http.Request) { s.secure.HandlerFuncWithNextForRequestOnly(rw, req, s.next.ServeHTTP) } -// Header is a middleware that helps setup a few basic security features. A single headerOptions struct can be -// provided to configure which features should be enabled, and the ability to override a few of the default values. +// Header is a middleware that helps setup a few basic security features. +// A single headerOptions struct can be provided to configure which features should be enabled, +// and the ability to override a few of the default values. type Header struct { - next http.Handler - headers *dynamic.Headers + next http.Handler + hasCustomHeaders bool + hasCorsHeaders bool + headers *dynamic.Headers } // NewHeader constructs a new header instance from supplied frontend header struct. func NewHeader(next http.Handler, headers dynamic.Headers) *Header { + hasCustomHeaders := headers.HasCustomHeadersDefined() + hasCorsHeaders := headers.HasCorsHeadersDefined() + return &Header{ - next: next, - headers: &headers, + next: next, + headers: &headers, + hasCustomHeaders: hasCustomHeaders, + hasCorsHeaders: hasCorsHeaders, } } func (s *Header) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + // Handle Cors headers and preflight if configured. + if isPreflight := s.processCorsHeaders(rw, req); isPreflight { + return + } + + if s.hasCustomHeaders { + s.modifyCustomRequestHeaders(req) + } + + // If there is a next, call it. + if s.next != nil { + s.next.ServeHTTP(rw, req) + } +} + +// modifyCustomRequestHeaders sets or deletes custom request headers. +func (s *Header) modifyCustomRequestHeaders(req *http.Request) { + // Loop through Custom request headers + for header, value := range s.headers.CustomRequestHeaders { + if value == "" { + req.Header.Del(header) + } else { + req.Header.Set(header, value) + } + } +} + +// preRequestModifyCorsResponseHeaders sets during request processing time, +// all the CORS response headers that we already know that are supposed to be set, +// and which do not depend on a later state of the response. +// One notable example of a header that can only be modified later on is "Vary", +// And this is set in the post-response response modifier method +func (s *Header) preRequestModifyCorsResponseHeaders(rw http.ResponseWriter, req *http.Request) { + originHeader := req.Header.Get("Origin") + allowOrigin := s.getAllowOrigin(originHeader) + + if allowOrigin != "" { + rw.Header().Set("Access-Control-Allow-Origin", allowOrigin) + } + + if s.headers.AccessControlAllowCredentials { + rw.Header().Set("Access-Control-Allow-Credentials", "true") + } + + if len(s.headers.AccessControlExposeHeaders) > 0 { + exposeHeaders := strings.Join(s.headers.AccessControlExposeHeaders, ",") + rw.Header().Set("Access-Control-Expose-Headers", exposeHeaders) + } +} + +// PostRequestModifyResponseHeaders set or delete response headers. +// This method is called AFTER the response is generated from the backend +// and can merge/override headers from the backend response. +func (s *Header) PostRequestModifyResponseHeaders(res *http.Response) error { + // Loop through Custom response headers + for header, value := range s.headers.CustomResponseHeaders { + if value == "" { + res.Header.Del(header) + } else { + res.Header.Set(header, value) + } + } + if !s.headers.AddVaryHeader { + return nil + } + + varyHeader := res.Header.Get("Vary") + if varyHeader == "Origin" { + return nil + } + + if varyHeader != "" { + varyHeader += "," + } + varyHeader += "Origin" + + res.Header.Set("Vary", varyHeader) + return nil +} + +// processCorsHeaders processes the incoming request, +// and returns if it is a preflight request. +// If not a preflight, it handles the preRequestModifyCorsResponseHeaders. +func (s *Header) processCorsHeaders(rw http.ResponseWriter, req *http.Request) bool { + if !s.hasCorsHeaders { + return false + } + reqAcMethod := req.Header.Get("Access-Control-Request-Method") reqAcHeaders := req.Header.Get("Access-Control-Request-Headers") originHeader := req.Header.Get("Origin") if reqAcMethod != "" && reqAcHeaders != "" && originHeader != "" && req.Method == http.MethodOptions { - // If the request is an OPTIONS request with an Access-Control-Request-Method header, and Access-Control-Request-Headers headers, - // and Origin headers, then it is a CORS preflight request, and we need to build a custom response: https://www.w3.org/TR/cors/#preflight-request + // If the request is an OPTIONS request with an Access-Control-Request-Method header, + // and Access-Control-Request-Headers headers, and Origin headers, + // then it is a CORS preflight request, + // and we need to build a custom response: https://www.w3.org/TR/cors/#preflight-request if s.headers.AccessControlAllowCredentials { rw.Header().Set("Access-Control-Allow-Credentials", "true") } @@ -151,71 +248,11 @@ func (s *Header) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } rw.Header().Set("Access-Control-Max-Age", strconv.Itoa(int(s.headers.AccessControlMaxAge))) - - return + return true } - if len(originHeader) > 0 { - rw.Header().Set(originHeaderKey, originHeader) - } - - s.modifyRequestHeaders(req) - // If there is a next, call it. - if s.next != nil { - s.next.ServeHTTP(rw, req) - } -} - -// modifyRequestHeaders sets or deletes request headers. -func (s *Header) modifyRequestHeaders(req *http.Request) { - // Loop through Custom request headers - for header, value := range s.headers.CustomRequestHeaders { - if value == "" { - req.Header.Del(header) - } else { - req.Header.Set(header, value) - } - } -} - -// ModifyResponseHeaders set or delete response headers -func (s *Header) ModifyResponseHeaders(res *http.Response) error { - // Loop through Custom response headers - for header, value := range s.headers.CustomResponseHeaders { - if value == "" { - res.Header.Del(header) - } else { - res.Header.Set(header, value) - } - } - originHeader := res.Header.Get(originHeaderKey) - allowOrigin := s.getAllowOrigin(originHeader) - // Delete the origin header key, since it is only used to pass data from the request for response handling - res.Header.Del(originHeaderKey) - if allowOrigin != "" { - res.Header.Set("Access-Control-Allow-Origin", allowOrigin) - - if s.headers.AddVaryHeader { - varyHeader := res.Header.Get("Vary") - if varyHeader != "" { - varyHeader += "," - } - varyHeader += "Origin" - - res.Header.Set("Vary", varyHeader) - } - } - - if s.headers.AccessControlAllowCredentials { - res.Header.Set("Access-Control-Allow-Credentials", "true") - } - - exposeHeaders := strings.Join(s.headers.AccessControlExposeHeaders, ",") - if exposeHeaders != "" { - res.Header.Set("Access-Control-Expose-Headers", exposeHeaders) - } - - return nil + s.preRequestModifyCorsResponseHeaders(rw, req) + return false } func (s *Header) getAllowOrigin(header string) string { diff --git a/pkg/middlewares/headers/headers_test.go b/pkg/middlewares/headers/headers_test.go index 6f637a6c0..d4ec57e73 100644 --- a/pkg/middlewares/headers/headers_test.go +++ b/pkg/middlewares/headers/headers_test.go @@ -333,6 +333,7 @@ func TestGetTracingInformation(t *testing.T) { func TestCORSResponses(t *testing.T) { emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) nonEmptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Vary", "Testing") }) + existingOriginHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Vary", "Origin") }) testCases := []struct { desc string @@ -436,6 +437,32 @@ func TestCORSResponses(t *testing.T) { "Vary": {"Testing,Origin"}, }, }, + { + desc: "Test Simple Request with Vary Headers and existing vary:origin response", + header: NewHeader(existingOriginHandler, dynamic.Headers{ + AccessControlAllowOrigin: "origin-list-or-null", + AddVaryHeader: true, + }), + requestHeaders: map[string][]string{ + "Origin": {"https://foo.bar.org"}, + }, + expected: map[string][]string{ + "Access-Control-Allow-Origin": {"https://foo.bar.org"}, + "Vary": {"Origin"}, + }, + }, + { + desc: "Test Simple CustomRequestHeaders Not Hijacked by CORS", + header: NewHeader(emptyHandler, dynamic.Headers{ + CustomRequestHeaders: map[string]string{"foo": "bar"}, + }), + requestHeaders: map[string][]string{ + "Access-Control-Request-Headers": {"origin"}, + "Access-Control-Request-Method": {"GET", "OPTIONS"}, + "Origin": {"https://foo.bar.org"}, + }, + expected: map[string][]string{}, + }, } for _, test := range testCases { @@ -445,7 +472,7 @@ func TestCORSResponses(t *testing.T) { rw := httptest.NewRecorder() test.header.ServeHTTP(rw, req) - err := test.header.ModifyResponseHeaders(rw.Result()) + err := test.header.PostRequestModifyResponseHeaders(rw.Result()) require.NoError(t, err) assert.Equal(t, test.expected, rw.Result().Header) }) @@ -492,7 +519,7 @@ func TestCustomResponseHeaders(t *testing.T) { req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil) rw := httptest.NewRecorder() test.header.ServeHTTP(rw, req) - err := test.header.ModifyResponseHeaders(rw.Result()) + err := test.header.PostRequestModifyResponseHeaders(rw.Result()) require.NoError(t, err) assert.Equal(t, test.expected, rw.Result().Header) }) diff --git a/pkg/responsemodifiers/headers.go b/pkg/responsemodifiers/headers.go index 5c104729a..aa209ad0a 100644 --- a/pkg/responsemodifiers/headers.go +++ b/pkg/responsemodifiers/headers.go @@ -34,7 +34,7 @@ func buildHeaders(hdrs *dynamic.Headers) func(*http.Response) error { return func(resp *http.Response) error { if hdrs.HasCustomHeadersDefined() || hdrs.HasCorsHeadersDefined() { - err := headers.NewHeader(nil, *hdrs).ModifyResponseHeaders(resp) + err := headers.NewHeader(nil, *hdrs).PostRequestModifyResponseHeaders(resp) if err != nil { return err }