From e11ff98608b40671dff7a54b3b2811524ce0008e Mon Sep 17 00:00:00 2001 From: Julien Salleyron Date: Tue, 6 Feb 2024 17:34:07 +0100 Subject: [PATCH] Fix NTLM and Kerberos --- pkg/server/server_entrypoint_tcp.go | 11 ++++ pkg/server/service/roundtripper.go | 67 +++++++++++++++++++- pkg/server/service/roundtripper_test.go | 78 ++++++++++++++++++++++++ pkg/server/service/smart_roundtripper.go | 8 ++- 4 files changed, 161 insertions(+), 3 deletions(-) diff --git a/pkg/server/server_entrypoint_tcp.go b/pkg/server/server_entrypoint_tcp.go index 26b417293..16c3a0caa 100644 --- a/pkg/server/server_entrypoint_tcp.go +++ b/pkg/server/server_entrypoint_tcp.go @@ -27,6 +27,7 @@ import ( "github.com/traefik/traefik/v2/pkg/safe" "github.com/traefik/traefik/v2/pkg/server/router" tcprouter "github.com/traefik/traefik/v2/pkg/server/router/tcp" + "github.com/traefik/traefik/v2/pkg/server/service" "github.com/traefik/traefik/v2/pkg/tcp" "github.com/traefik/traefik/v2/pkg/types" "golang.org/x/net/http2" @@ -613,6 +614,16 @@ func createHTTPServer(ctx context.Context, ln net.Listener, configuration *stati } } + prevConnContext := serverHTTP.ConnContext + serverHTTP.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + // This adds an empty struct in order to store a RoundTripper in the ConnContext in case of Kerberos or NTLM. + ctx = service.AddTransportOnContext(ctx) + if prevConnContext != nil { + return prevConnContext(ctx, c) + } + return ctx + } + // ConfigureServer configures HTTP/2 with the MaxConcurrentStreams option for the given server. // Also keeping behavior the same as // https://cs.opensource.google/go/go/+/refs/tags/go1.17.7:src/net/http/server.go;l=3262 diff --git a/pkg/server/service/roundtripper.go b/pkg/server/service/roundtripper.go index 68b3b0bbf..31493bbb9 100644 --- a/pkg/server/service/roundtripper.go +++ b/pkg/server/service/roundtripper.go @@ -1,6 +1,7 @@ package service import ( + "context" "crypto/tls" "crypto/x509" "errors" @@ -8,6 +9,7 @@ import ( "net" "net/http" "reflect" + "strings" "sync" "time" @@ -149,10 +151,71 @@ func createRoundTripper(cfg *dynamic.ServersTransport) (http.RoundTripper, error // Return directly HTTP/1.1 transport when HTTP/2 is disabled if cfg.DisableHTTP2 { - return transport, nil + return &KerberosRoundTripper{ + OriginalRoundTripper: transport, + new: func() http.RoundTripper { + return transport.Clone() + }, + }, nil } - return newSmartRoundTripper(transport, cfg.ForwardingTimeouts) + rt, err := newSmartRoundTripper(transport, cfg.ForwardingTimeouts) + if err != nil { + return nil, err + } + return &KerberosRoundTripper{ + OriginalRoundTripper: rt, + new: func() http.RoundTripper { + return rt.Clone() + }, + }, nil +} + +type KerberosRoundTripper struct { + new func() http.RoundTripper + OriginalRoundTripper http.RoundTripper +} + +type stickyRoundTripper struct { + RoundTripper http.RoundTripper +} + +type transportKeyType string + +var transportKey transportKeyType = "transport" + +func AddTransportOnContext(ctx context.Context) context.Context { + return context.WithValue(ctx, transportKey, &stickyRoundTripper{}) +} + +func (k *KerberosRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) { + value, ok := request.Context().Value(transportKey).(*stickyRoundTripper) + if !ok { + return k.OriginalRoundTripper.RoundTrip(request) + } + + if value.RoundTripper != nil { + return value.RoundTripper.RoundTrip(request) + } + + resp, err := k.OriginalRoundTripper.RoundTrip(request) + + // If we found that we are authenticating with Kerberos (Negotiate) or NTLM. + // We put a dedicated roundTripper in the ConnContext. + // This will stick the next calls to the same connection with the backend. + if err == nil && containsNTLMorNegotiate(resp.Header.Values("WWW-Authenticate")) { + value.RoundTripper = k.new() + } + return resp, err +} + +func containsNTLMorNegotiate(h []string) bool { + for _, s := range h { + if strings.HasPrefix(s, "NTLM") || strings.HasPrefix(s, "Negotiate") { + return true + } + } + return false } func createRootCACertPool(rootCAs []traefiktls.FileOrContent) *x509.CertPool { diff --git a/pkg/server/service/roundtripper_test.go b/pkg/server/service/roundtripper_test.go index a057166ab..460bc9a60 100644 --- a/pkg/server/service/roundtripper_test.go +++ b/pkg/server/service/roundtripper_test.go @@ -1,6 +1,7 @@ package service import ( + "context" "crypto/tls" "crypto/x509" "net" @@ -293,3 +294,80 @@ func TestDisableHTTP2(t *testing.T) { }) } } + +type roundTripperFn func(req *http.Request) (*http.Response, error) + +func (r roundTripperFn) RoundTrip(request *http.Request) (*http.Response, error) { + return r(request) +} + +func TestKerberosRoundTripper(t *testing.T) { + testCases := []struct { + desc string + + originalRoundTripperHeaders map[string][]string + + expectedStatusCode []int + expectedDedicatedCount int + expectedOriginalCount int + }{ + { + desc: "without special header", + expectedStatusCode: []int{http.StatusUnauthorized, http.StatusUnauthorized, http.StatusUnauthorized}, + expectedOriginalCount: 3, + }, + { + desc: "with Negotiate (Kerberos)", + originalRoundTripperHeaders: map[string][]string{"Www-Authenticate": {"Negotiate"}}, + expectedStatusCode: []int{http.StatusUnauthorized, http.StatusOK, http.StatusOK}, + expectedOriginalCount: 1, + expectedDedicatedCount: 2, + }, + { + desc: "with NTLM", + originalRoundTripperHeaders: map[string][]string{"Www-Authenticate": {"NTLM"}}, + expectedStatusCode: []int{http.StatusUnauthorized, http.StatusOK, http.StatusOK}, + expectedOriginalCount: 1, + expectedDedicatedCount: 2, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + origCount := 0 + dedicatedCount := 0 + rt := KerberosRoundTripper{ + new: func() http.RoundTripper { + return roundTripperFn(func(req *http.Request) (*http.Response, error) { + dedicatedCount++ + return &http.Response{ + StatusCode: http.StatusOK, + }, nil + }) + }, + OriginalRoundTripper: roundTripperFn(func(req *http.Request) (*http.Response, error) { + origCount++ + return &http.Response{ + StatusCode: http.StatusUnauthorized, + Header: test.originalRoundTripperHeaders, + }, nil + }), + } + + ctx := AddTransportOnContext(context.Background()) + for _, expected := range test.expectedStatusCode { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://127.0.0.1", http.NoBody) + require.NoError(t, err) + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + require.Equal(t, expected, resp.StatusCode) + } + + require.Equal(t, test.expectedOriginalCount, origCount) + require.Equal(t, test.expectedDedicatedCount, dedicatedCount) + }) + } +} diff --git a/pkg/server/service/smart_roundtripper.go b/pkg/server/service/smart_roundtripper.go index 5643e60ed..0d6c280bc 100644 --- a/pkg/server/service/smart_roundtripper.go +++ b/pkg/server/service/smart_roundtripper.go @@ -11,7 +11,7 @@ import ( "golang.org/x/net/http2" ) -func newSmartRoundTripper(transport *http.Transport, forwardingTimeouts *dynamic.ForwardingTimeouts) (http.RoundTripper, error) { +func newSmartRoundTripper(transport *http.Transport, forwardingTimeouts *dynamic.ForwardingTimeouts) (*smartRoundTripper, error) { transportHTTP1 := transport.Clone() transportHTTP2, err := http2.ConfigureTransports(transport) @@ -53,6 +53,12 @@ type smartRoundTripper struct { http *http.Transport } +func (m *smartRoundTripper) Clone() http.RoundTripper { + h := m.http.Clone() + h2 := m.http2.Clone() + return &smartRoundTripper{http: h, http2: h2} +} + func (m *smartRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { // If we have a connection upgrade, we don't use HTTP/2 if httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") {