diff --git a/pkg/middlewares/auth/forward.go b/pkg/middlewares/auth/forward.go index ed9e7d2ef..df1895f09 100644 --- a/pkg/middlewares/auth/forward.go +++ b/pkg/middlewares/auth/forward.go @@ -15,6 +15,7 @@ import ( "github.com/traefik/traefik/v2/pkg/config/dynamic" "github.com/traefik/traefik/v2/pkg/log" "github.com/traefik/traefik/v2/pkg/middlewares" + "github.com/traefik/traefik/v2/pkg/middlewares/connectionheader" "github.com/traefik/traefik/v2/pkg/tracing" "github.com/vulcand/oxy/forward" "github.com/vulcand/oxy/utils" @@ -89,7 +90,7 @@ func NewForward(ctx context.Context, next http.Handler, config dynamic.ForwardAu fa.authResponseHeadersRegex = re } - return fa, nil + return connectionheader.Remover(fa), nil } func (fa *forwardAuth) GetTracingInformation() (string, ext.SpanKindEnum) { diff --git a/pkg/middlewares/connectionheader/connectionheader.go b/pkg/middlewares/connectionheader/connectionheader.go new file mode 100644 index 000000000..b7a64910a --- /dev/null +++ b/pkg/middlewares/connectionheader/connectionheader.go @@ -0,0 +1,46 @@ +package connectionheader + +import ( + "net/http" + "net/textproto" + "strings" + + "golang.org/x/net/http/httpguts" +) + +const ( + connectionHeader = "Connection" + upgradeHeader = "Upgrade" +) + +// Remover removes hop-by-hop headers listed in the "Connection" header. +// See RFC 7230, section 6.1. +func Remover(next http.Handler) http.HandlerFunc { + return func(rw http.ResponseWriter, req *http.Request) { + var reqUpType string + if httpguts.HeaderValuesContainsToken(req.Header[connectionHeader], upgradeHeader) { + reqUpType = req.Header.Get(upgradeHeader) + } + + removeConnectionHeaders(req.Header) + + if reqUpType != "" { + req.Header.Set(connectionHeader, upgradeHeader) + req.Header.Set(upgradeHeader, reqUpType) + } else { + req.Header.Del(connectionHeader) + } + + next.ServeHTTP(rw, req) + } +} + +func removeConnectionHeaders(h http.Header) { + for _, f := range h[connectionHeader] { + for _, sf := range strings.Split(f, ",") { + if sf = textproto.TrimString(sf); sf != "" { + h.Del(sf) + } + } + } +} diff --git a/pkg/middlewares/connectionheader/connectionheader_test.go b/pkg/middlewares/connectionheader/connectionheader_test.go new file mode 100644 index 000000000..7ee047bd0 --- /dev/null +++ b/pkg/middlewares/connectionheader/connectionheader_test.go @@ -0,0 +1,71 @@ +package connectionheader + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRemover(t *testing.T) { + testCases := []struct { + desc string + reqHeaders map[string]string + expected http.Header + }{ + { + desc: "simple remove", + reqHeaders: map[string]string{ + "Foo": "bar", + connectionHeader: "foo", + }, + expected: http.Header{}, + }, + { + desc: "remove and Upgrade", + reqHeaders: map[string]string{ + upgradeHeader: "test", + "Foo": "bar", + connectionHeader: "Upgrade,foo", + }, + expected: http.Header{ + upgradeHeader: []string{"test"}, + connectionHeader: []string{"Upgrade"}, + }, + }, + { + desc: "no remove", + reqHeaders: map[string]string{ + "Foo": "bar", + connectionHeader: "fii", + }, + expected: http.Header{ + "Foo": []string{"bar"}, + }, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}) + + h := Remover(next) + + req := httptest.NewRequest(http.MethodGet, "https://localhost", nil) + + for k, v := range test.reqHeaders { + req.Header.Set(k, v) + } + + rw := httptest.NewRecorder() + + h.ServeHTTP(rw, req) + + assert.Equal(t, test.expected, req.Header) + }) + } +} diff --git a/pkg/middlewares/headers/headers.go b/pkg/middlewares/headers/headers.go index 39594494a..ece7fece2 100644 --- a/pkg/middlewares/headers/headers.go +++ b/pkg/middlewares/headers/headers.go @@ -10,6 +10,7 @@ import ( "github.com/traefik/traefik/v2/pkg/config/dynamic" "github.com/traefik/traefik/v2/pkg/log" "github.com/traefik/traefik/v2/pkg/middlewares" + "github.com/traefik/traefik/v2/pkg/middlewares/connectionheader" "github.com/traefik/traefik/v2/pkg/tracing" ) @@ -58,11 +59,12 @@ func New(ctx context.Context, next http.Handler, cfg dynamic.Headers, name strin if hasCustomHeaders || hasCorsHeaders { logger.Debugf("Setting up customHeaders/Cors from %v", cfg) - var err error - handler, err = NewHeader(nextHandler, cfg) + h, err := NewHeader(nextHandler, cfg) if err != nil { return nil, err } + + handler = connectionheader.Remover(h) } return &headers{