diff --git a/pkg/middlewares/auth/forward.go b/pkg/middlewares/auth/forward.go index f1b4421bf..4bd5bbaf6 100644 --- a/pkg/middlewares/auth/forward.go +++ b/pkg/middlewares/auth/forward.go @@ -88,9 +88,11 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } - writeHeader(req, forwardReq, fa.trustForwardHeader) + // Ensure tracing headers are in the request before we copy the headers to the + // forwardReq. + tracing.InjectRequestHeaders(req) - tracing.InjectRequestHeaders(forwardReq) + writeHeader(req, forwardReq, fa.trustForwardHeader) forwardResponse, forwardErr := httpClient.Do(forwardReq) if forwardErr != nil { diff --git a/pkg/middlewares/auth/forward_test.go b/pkg/middlewares/auth/forward_test.go index 20dfa6608..707e4f4ed 100644 --- a/pkg/middlewares/auth/forward_test.go +++ b/pkg/middlewares/auth/forward_test.go @@ -3,13 +3,18 @@ package auth import ( "context" "fmt" + "io" "io/ioutil" "net/http" "net/http/httptest" "testing" "github.com/containous/traefik/v2/pkg/config/dynamic" + tracingMiddleware "github.com/containous/traefik/v2/pkg/middlewares/tracing" "github.com/containous/traefik/v2/pkg/testhelpers" + "github.com/containous/traefik/v2/pkg/tracing" + "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/mocktracer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vulcand/oxy/forward" @@ -394,3 +399,44 @@ func Test_writeHeader(t *testing.T) { }) } } + +func TestForwardAuthUsesTracing(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Mockpfx-Ids-Traceid") == "" { + t.Errorf("expected Mockpfx-Ids-Traceid header to be present in request") + } + })) + defer server.Close() + + next := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + + auth := dynamic.ForwardAuth{ + Address: server.URL, + } + + tracer := mocktracer.New() + opentracing.SetGlobalTracer(tracer) + + tr, _ := tracing.NewTracing("testApp", 100, &mockBackend{tracer}) + + next, err := NewForward(context.Background(), next, auth, "authTest") + require.NoError(t, err) + + next = tracingMiddleware.NewEntryPoint(context.Background(), tr, "tracingTest", next) + + ts := httptest.NewServer(next) + defer ts.Close() + + req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil) + res, err := http.DefaultClient.Do(req) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) +} + +type mockBackend struct { + opentracing.Tracer +} + +func (b *mockBackend) Setup(componentName string) (opentracing.Tracer, io.Closer, error) { + return b.Tracer, ioutil.NopCloser(nil), nil +}