diff --git a/middlewares/replace_path.go b/middlewares/replace_path.go index 208415a51..dcaece1ff 100644 --- a/middlewares/replace_path.go +++ b/middlewares/replace_path.go @@ -4,15 +4,15 @@ import ( "net/http" ) +// ReplacedPathHeader is the default header to set the old path to +const ReplacedPathHeader = "X-Replaced-Path" + // ReplacePath is a middleware used to replace the path of a URL request type ReplacePath struct { Handler http.Handler Path string } -// ReplacedPathHeader is the default header to set the old path to -const ReplacedPathHeader = "X-Replaced-Path" - func (s *ReplacePath) ServeHTTP(w http.ResponseWriter, r *http.Request) { r.Header.Add(ReplacedPathHeader, r.URL.Path) r.URL.Path = s.Path diff --git a/middlewares/replace_path_test.go b/middlewares/replace_path_test.go index 179ebdabb..7c1204f8a 100644 --- a/middlewares/replace_path_test.go +++ b/middlewares/replace_path_test.go @@ -1,10 +1,10 @@ -package middlewares_test +package middlewares import ( "net/http" "testing" - "github.com/containous/traefik/middlewares" + "github.com/stretchr/testify/assert" ) func TestReplacePath(t *testing.T) { @@ -17,28 +17,22 @@ func TestReplacePath(t *testing.T) { for _, path := range paths { t.Run(path, func(t *testing.T) { - var newPath, oldPath string - handler := &middlewares.ReplacePath{ + + var expectedPath, actualHeader string + handler := &ReplacePath{ Path: replacementPath, Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - newPath = r.URL.Path - oldPath = r.Header.Get("X-Replaced-Path") + expectedPath = r.URL.Path + actualHeader = r.Header.Get(ReplacedPathHeader) }), } req, err := http.NewRequest("GET", "http://localhost"+path, nil) - if err != nil { - t.Error(err) - } + assert.NoError(t, err, "%s: unexpected error.", path) handler.ServeHTTP(nil, req) - if newPath != replacementPath { - t.Fatalf("new path should be '%s'", replacementPath) - } - - if oldPath != path { - t.Fatalf("old path should be '%s'", path) - } + assert.Equal(t, expectedPath, replacementPath, "%s: unexpected path.", path) + assert.Equal(t, path, actualHeader, "%s: unexpected '%s' header.", path, ReplacedPathHeader) }) } } diff --git a/middlewares/stripPrefix.go b/middlewares/stripPrefix.go index 291dda40d..2f2529796 100644 --- a/middlewares/stripPrefix.go +++ b/middlewares/stripPrefix.go @@ -5,9 +5,8 @@ import ( "strings" ) -const ( - forwardedPrefixHeader = "X-Forwarded-Prefix" -) +// ForwardedPrefixHeader is the default header to set prefix +const ForwardedPrefixHeader = "X-Forwarded-Prefix" // StripPrefix is a middleware used to strip prefix from an URL request type StripPrefix struct { @@ -35,7 +34,7 @@ func (s *StripPrefix) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (s *StripPrefix) serveRequest(w http.ResponseWriter, r *http.Request, prefix string) { - r.Header[forwardedPrefixHeader] = []string{prefix} + r.Header.Add(ForwardedPrefixHeader, prefix) r.RequestURI = r.URL.RequestURI() s.Handler.ServeHTTP(w, r) } diff --git a/middlewares/stripPrefixRegex.go b/middlewares/stripPrefixRegex.go index f2a85834e..bf2130ca8 100644 --- a/middlewares/stripPrefixRegex.go +++ b/middlewares/stripPrefixRegex.go @@ -40,7 +40,7 @@ func (s *StripPrefixRegex) ServeHTTP(w http.ResponseWriter, r *http.Request) { } r.URL.Path = r.URL.Path[len(prefix.Path):] - r.Header[forwardedPrefixHeader] = []string{prefix.Path} + r.Header.Add(ForwardedPrefixHeader, prefix.Path) r.RequestURI = r.URL.RequestURI() s.Handler.ServeHTTP(w, r) return diff --git a/middlewares/stripPrefixRegex_test.go b/middlewares/stripPrefixRegex_test.go index cb1bc1eb1..03c8c8b44 100644 --- a/middlewares/stripPrefixRegex_test.go +++ b/middlewares/stripPrefixRegex_test.go @@ -1,55 +1,90 @@ package middlewares import ( - "fmt" - "io/ioutil" "net/http" "net/http/httptest" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestStripPrefixRegex(t *testing.T) { - handlerPath := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, r.URL.Path) - }) - - handler := NewStripPrefixRegex(handlerPath, []string{"/a/api/", "/b/{regex}/", "/c/{category}/{id:[0-9]+}/"}) - server := httptest.NewServer(handler) - defer server.Close() + testPrefixRegex := []string{"/a/api/", "/b/{regex}/", "/c/{category}/{id:[0-9]+}/"} tests := []struct { - expectedCode int - expectedResponse string - url string + path string + expectedStatusCode int + expectedPath string + expectedHeader string }{ - {url: "/a/test", expectedCode: 404, expectedResponse: "404 page not found\n"}, - {url: "/a/api/test", expectedCode: 200, expectedResponse: "test"}, - - {url: "/b/api/", expectedCode: 200, expectedResponse: ""}, - {url: "/b/api/test1", expectedCode: 200, expectedResponse: "test1"}, - {url: "/b/api2/test2", expectedCode: 200, expectedResponse: "test2"}, - - {url: "/c/api/123/", expectedCode: 200, expectedResponse: ""}, - {url: "/c/api/123/test3", expectedCode: 200, expectedResponse: "test3"}, - {url: "/c/api/abc/test4", expectedCode: 404, expectedResponse: "404 page not found\n"}, + { + path: "/a/test", + expectedStatusCode: 404, + }, + { + path: "/a/api/test", + expectedStatusCode: 200, + expectedPath: "test", + expectedHeader: "/a/api/", + }, + { + path: "/b/api/", + expectedStatusCode: 200, + expectedHeader: "/b/api/", + }, + { + path: "/b/api/test1", + expectedStatusCode: 200, + expectedPath: "test1", + expectedHeader: "/b/api/", + }, + { + path: "/b/api2/test2", + expectedStatusCode: 200, + expectedPath: "test2", + expectedHeader: "/b/api2/", + }, + { + path: "/c/api/123/", + expectedStatusCode: 200, + expectedHeader: "/c/api/123/", + }, + { + path: "/c/api/123/test3", + expectedStatusCode: 200, + expectedPath: "test3", + expectedHeader: "/c/api/123/", + }, + { + path: "/c/api/abc/test4", + expectedStatusCode: 404, + }, } for _, test := range tests { - resp, err := http.Get(server.URL + test.url) - if err != nil { - t.Fatal(err) - } - if resp.StatusCode != test.expectedCode { - t.Fatalf("Received non-%d response: %d\n", test.expectedCode, resp.StatusCode) - } - response, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - if test.expectedResponse != string(response) { - t.Errorf("Expected '%s' : '%s'\n", test.expectedResponse, response) - } + test := test + t.Run(test.path, func(t *testing.T) { + t.Parallel() + + var actualPath, actualHeader string + handlerPath := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + actualPath = r.URL.Path + actualHeader = r.Header.Get(ForwardedPrefixHeader) + }) + + handler := NewStripPrefixRegex(handlerPath, testPrefixRegex) + server := httptest.NewServer(handler) + defer server.Close() + + resp, err := http.Get(server.URL + test.path) + require.NoError(t, err, "%s: unexpected error.", test.path) + + assert.Equal(t, test.expectedStatusCode, resp.StatusCode, "%s: unexpected status code.", test.path) + assert.Equal(t, test.expectedPath, actualPath, "%s: unexpected path.", test.path) + assert.Equal(t, test.expectedHeader, actualHeader, "%s: unexpected '%s' header.", test.path, ForwardedPrefixHeader) + }) } } diff --git a/middlewares/stripPrefix_test.go b/middlewares/stripPrefix_test.go index d7fe348f5..807de926a 100644 --- a/middlewares/stripPrefix_test.go +++ b/middlewares/stripPrefix_test.go @@ -16,6 +16,7 @@ func TestStripPrefix(t *testing.T) { path string expectedStatusCode int expectedPath string + expectedHeader string }{ { desc: "no prefixes configured", @@ -29,6 +30,7 @@ func TestStripPrefix(t *testing.T) { path: "/", expectedStatusCode: http.StatusOK, expectedPath: "/", + expectedHeader: "/", }, { desc: "prefix and path matching", @@ -36,6 +38,7 @@ func TestStripPrefix(t *testing.T) { path: "/stat", expectedStatusCode: http.StatusOK, expectedPath: "/", + expectedHeader: "/stat", }, { desc: "path prefix on exactly matching path", @@ -43,6 +46,7 @@ func TestStripPrefix(t *testing.T) { path: "/stat/", expectedStatusCode: http.StatusOK, expectedPath: "/", + expectedHeader: "/stat/", }, { desc: "path prefix on matching longer path", @@ -50,6 +54,7 @@ func TestStripPrefix(t *testing.T) { path: "/stat/us", expectedStatusCode: http.StatusOK, expectedPath: "/us", + expectedHeader: "/stat/", }, { desc: "path prefix on mismatching path", @@ -63,6 +68,7 @@ func TestStripPrefix(t *testing.T) { path: "/stat/", expectedStatusCode: http.StatusOK, expectedPath: "/", + expectedHeader: "/stat", }, { desc: "earlier prefix matching", @@ -70,6 +76,7 @@ func TestStripPrefix(t *testing.T) { path: "/stat/us", expectedStatusCode: http.StatusOK, expectedPath: "/us", + expectedHeader: "/stat", }, { desc: "later prefix matching", @@ -77,6 +84,7 @@ func TestStripPrefix(t *testing.T) { path: "/stat", expectedStatusCode: http.StatusOK, expectedPath: "/", + expectedHeader: "/stat", }, } @@ -84,20 +92,23 @@ func TestStripPrefix(t *testing.T) { test := test t.Run(test.desc, func(t *testing.T) { t.Parallel() - var gotPath string + + var actualPath, actualHeader string server := httptest.NewServer(&StripPrefix{ Prefixes: test.prefixes, Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotPath = r.URL.Path + actualPath = r.URL.Path + actualHeader = r.Header.Get(ForwardedPrefixHeader) }), }) defer server.Close() resp, err := http.Get(server.URL + test.path) - require.NoError(t, err, "Failed to send GET request") - assert.Equal(t, test.expectedStatusCode, resp.StatusCode, "Unexpected status code") + require.NoError(t, err, "%s: failed to send GET request.", test.desc) - assert.Equal(t, test.expectedPath, gotPath, "Unexpected path") + assert.Equal(t, test.expectedStatusCode, resp.StatusCode, "%s: unexpected status code.", test.desc) + assert.Equal(t, test.expectedPath, actualPath, "%s: unexpected path.", test.desc) + assert.Equal(t, test.expectedHeader, actualHeader, "%s: unexpected '%s' header.", test.desc, ForwardedPrefixHeader) }) } } diff --git a/server/rules_test.go b/server/rules_test.go index 2568910fa..2ac0c3a2f 100644 --- a/server/rules_test.go +++ b/server/rules_test.go @@ -3,10 +3,11 @@ package server import ( "net/http" "net/url" - "reflect" "testing" "github.com/containous/mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestParseOneRule(t *testing.T) { @@ -17,17 +18,12 @@ func TestParseOneRule(t *testing.T) { expression := "Host:foo.bar" routeResult, err := rules.Parse(expression) - - if err != nil { - t.Fatalf("Error while building route for Host:foo.bar: %s", err) - } + require.NoError(t, err, "Error while building route for %s", expression) request, err := http.NewRequest("GET", "http://foo.bar", nil) routeMatch := routeResult.Match(request, &mux.RouteMatch{Route: routeResult}) - if !routeMatch { - t.Fatalf("Rule Host:foo.bar don't match: %s", err) - } + assert.True(t, routeMatch, "Rule %s don't match.", expression) } func TestParseTwoRules(t *testing.T) { @@ -39,47 +35,54 @@ func TestParseTwoRules(t *testing.T) { expression := "Host: Foo.Bar ; Path:/FOObar" routeResult, err := rules.Parse(expression) - if err != nil { - t.Fatalf("Error while building route for Host:foo.bar;Path:/FOObar: %s", err) - } + require.NoError(t, err, "Error while building route for %s.", expression) - request, err := http.NewRequest("GET", "http://foo.bar/foobar", nil) + request, _ := http.NewRequest("GET", "http://foo.bar/foobar", nil) routeMatch := routeResult.Match(request, &mux.RouteMatch{Route: routeResult}) - if routeMatch { - t.Fatalf("Rule Host:foo.bar;Path:/FOObar don't match: %s", err) - } + assert.False(t, routeMatch, "Rule %s don't match.", expression) - request, err = http.NewRequest("GET", "http://foo.bar/FOObar", nil) + request, _ = http.NewRequest("GET", "http://foo.bar/FOObar", nil) routeMatch = routeResult.Match(request, &mux.RouteMatch{Route: routeResult}) - if !routeMatch { - t.Fatalf("Rule Host:foo.bar;Path:/FOObar don't match: %s", err) - } + assert.True(t, routeMatch, "Rule %s don't match.", expression) } func TestParseDomains(t *testing.T) { rules := &Rules{} - expressionsSlice := []string{ - "Host:foo.bar,test.bar", - "Path:/test", - "Host:foo.bar;Path:/test", - "Host: Foo.Bar ;Path:/test", + + tests := []struct { + expression string + domain []string + }{ + { + expression: "Host:foo.bar,test.bar", + domain: []string{"foo.bar", "test.bar"}, + }, + { + expression: "Path:/test", + domain: []string{}, + }, + { + expression: "Host:foo.bar;Path:/test", + domain: []string{"foo.bar"}, + }, + { + expression: "Host: Foo.Bar ;Path:/test", + domain: []string{"foo.bar"}, + }, } - domainsSlice := [][]string{ - {"foo.bar", "test.bar"}, - {}, - {"foo.bar"}, - {"foo.bar"}, - } - for i, expression := range expressionsSlice { - domains, err := rules.ParseDomains(expression) - if err != nil { - t.Fatalf("Error while parsing domains: %v", err) - } - if !reflect.DeepEqual(domains, domainsSlice[i]) { - t.Fatalf("Error parsing domains: expected %+v, got %+v", domainsSlice[i], domains) - } + + for _, test := range tests { + test := test + t.Run(test.expression, func(t *testing.T) { + t.Parallel() + + domains, err := rules.ParseDomains(test.expression) + require.NoError(t, err, "%s: Error while parsing domain.", test.expression) + + assert.EqualValues(t, test.domain, domains, "%s: Error parsing domains from expression.", test.expression) + }) } } @@ -87,70 +90,49 @@ func TestPriorites(t *testing.T) { router := mux.NewRouter() router.StrictSlash(true) rules := &Rules{route: &serverRoute{route: router.NewRoute()}} - routeFoo, err := rules.Parse("PathPrefix:/foo") - if err != nil { - t.Fatalf("Error while building route for PathPrefix:/foo: %s", err) - } + expression01 := "PathPrefix:/foo" + + routeFoo, err := rules.Parse(expression01) + require.NoError(t, err, "Error while building route for %s", expression01) + fooHandler := &fakeHandler{name: "fooHandler"} routeFoo.Handler(fooHandler) - if !router.Match(&http.Request{URL: &url.URL{ - Path: "/foo", - }}, &mux.RouteMatch{}) { - t.Fatal("Error matching route") - } + routeMatch := router.Match(&http.Request{URL: &url.URL{Path: "/foo"}}, &mux.RouteMatch{}) + assert.True(t, routeMatch, "Error matching route") - if router.Match(&http.Request{URL: &url.URL{ - Path: "/fo", - }}, &mux.RouteMatch{}) { - t.Fatal("Error matching route") - } + routeMatch = router.Match(&http.Request{URL: &url.URL{Path: "/fo"}}, &mux.RouteMatch{}) + assert.False(t, routeMatch, "Error matching route") multipleRules := &Rules{route: &serverRoute{route: router.NewRoute()}} - routeFoobar, err := multipleRules.Parse("PathPrefix:/foobar") - if err != nil { - t.Fatalf("Error while building route for PathPrefix:/foobar: %s", err) - } + expression02 := "PathPrefix:/foobar" + + routeFoobar, err := multipleRules.Parse(expression02) + require.NoError(t, err, "Error while building route for %s", expression02) + foobarHandler := &fakeHandler{name: "foobarHandler"} routeFoobar.Handler(foobarHandler) - if !router.Match(&http.Request{URL: &url.URL{ - Path: "/foo", - }}, &mux.RouteMatch{}) { - t.Fatal("Error matching route") - } + routeMatch = router.Match(&http.Request{URL: &url.URL{Path: "/foo"}}, &mux.RouteMatch{}) + + assert.True(t, routeMatch, "Error matching route") + fooMatcher := &mux.RouteMatch{} - if !router.Match(&http.Request{URL: &url.URL{ - Path: "/foobar", - }}, fooMatcher) { - t.Fatal("Error matching route") - } + routeMatch = router.Match(&http.Request{URL: &url.URL{Path: "/foobar"}}, fooMatcher) - if fooMatcher.Handler == foobarHandler { - t.Fatal("Error matching priority") - } - - if fooMatcher.Handler != fooHandler { - t.Fatal("Error matching priority") - } + assert.True(t, routeMatch, "Error matching route") + assert.NotEqual(t, fooMatcher.Handler, foobarHandler, "Error matching priority") + assert.Equal(t, fooMatcher.Handler, fooHandler, "Error matching priority") routeFoo.Priority(1) routeFoobar.Priority(10) router.SortRoutes() foobarMatcher := &mux.RouteMatch{} - if !router.Match(&http.Request{URL: &url.URL{ - Path: "/foobar", - }}, foobarMatcher) { - t.Fatal("Error matching route") - } + routeMatch = router.Match(&http.Request{URL: &url.URL{Path: "/foobar"}}, foobarMatcher) - if foobarMatcher.Handler != foobarHandler { - t.Fatal("Error matching priority") - } - - if foobarMatcher.Handler == fooHandler { - t.Fatal("Error matching priority") - } + assert.True(t, routeMatch, "Error matching route") + assert.Equal(t, foobarMatcher.Handler, foobarHandler, "Error matching priority") + assert.NotEqual(t, foobarMatcher.Handler, fooHandler, "Error matching priority") } type fakeHandler struct {