diff --git a/integration/helloworld/helloworld.pb.go b/integration/helloworld/helloworld.pb.go index 7b3335444..96320f02a 100644 --- a/integration/helloworld/helloworld.pb.go +++ b/integration/helloworld/helloworld.pb.go @@ -16,9 +16,10 @@ It has these top-level messages: package helloworld import ( - proto "github.com/golang/protobuf/proto" fmt "fmt" math "math" + + proto "github.com/golang/protobuf/proto" ) import ( diff --git a/pkg/middlewares/accesslog/logger.go b/pkg/middlewares/accesslog/logger.go index c876efcb3..13af24c9b 100644 --- a/pkg/middlewares/accesslog/logger.go +++ b/pkg/middlewares/accesslog/logger.go @@ -6,6 +6,7 @@ import ( "io" "net" "net/http" + "net/textproto" "net/url" "os" "path/filepath" @@ -100,6 +101,17 @@ func NewHandler(config *types.AccessLog) (*Handler, error) { Level: logrus.InfoLevel, } + // Transform headers names in config to a canonical form, to be used as is without further transformations. + if config.Fields != nil && config.Fields.Headers != nil && len(config.Fields.Headers.Names) > 0 { + fields := map[string]string{} + + for h, v := range config.Fields.Headers.Names { + fields[textproto.CanonicalMIMEHeaderKey(h)] = v + } + + config.Fields.Headers.Names = fields + } + logHandler := &Handler{ config: config, logger: logger, diff --git a/pkg/middlewares/accesslog/logger_test.go b/pkg/middlewares/accesslog/logger_test.go index 6d1567db0..eeeceba7b 100644 --- a/pkg/middlewares/accesslog/logger_test.go +++ b/pkg/middlewares/accesslog/logger_test.go @@ -41,11 +41,7 @@ var ( ) func TestLogRotation(t *testing.T) { - tempDir, err := ioutil.TempDir("", "traefik_") - if err != nil { - t.Fatalf("Error setting up temporary directory: %s", err) - } - defer os.RemoveAll(tempDir) + tempDir := createTempDir(t, "traefik_") fileName := filepath.Join(tempDir, "traefik.log") rotatedFileName := fileName + ".rotated" @@ -119,9 +115,106 @@ func lineCount(t *testing.T, fileName string) int { return count } +func TestLoggerHeaderFields(t *testing.T) { + tmpDir := createTempDir(t, CommonFormat) + + expectedValue := "expectedValue" + + testCases := []struct { + desc string + accessLogFields types.AccessLogFields + header string + expected string + }{ + { + desc: "with default mode", + header: "User-Agent", + expected: types.AccessLogDrop, + accessLogFields: types.AccessLogFields{ + DefaultMode: types.AccessLogDrop, + Headers: &types.FieldHeaders{ + DefaultMode: types.AccessLogDrop, + Names: map[string]string{}, + }, + }, + }, + { + desc: "with exact header name", + header: "User-Agent", + expected: types.AccessLogKeep, + accessLogFields: types.AccessLogFields{ + DefaultMode: types.AccessLogDrop, + Headers: &types.FieldHeaders{ + DefaultMode: types.AccessLogDrop, + Names: map[string]string{ + "User-Agent": types.AccessLogKeep, + }, + }, + }, + }, + { + desc: "with case insensitive match on header name", + header: "User-Agent", + expected: types.AccessLogKeep, + accessLogFields: types.AccessLogFields{ + DefaultMode: types.AccessLogDrop, + Headers: &types.FieldHeaders{ + DefaultMode: types.AccessLogDrop, + Names: map[string]string{ + "user-agent": types.AccessLogKeep, + }, + }, + }, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + logFile, err := ioutil.TempFile(tmpDir, "*.log") + require.NoError(t, err) + + config := &types.AccessLog{ + FilePath: logFile.Name(), + Format: CommonFormat, + Fields: &test.accessLogFields, + } + + logger, err := NewHandler(config) + require.NoError(t, err) + defer logger.Close() + + if config.FilePath != "" { + _, err = os.Stat(config.FilePath) + require.NoError(t, err, fmt.Sprintf("logger should create %s", config.FilePath)) + } + + req := &http.Request{ + Header: map[string][]string{}, + URL: &url.URL{ + Path: testPath, + }, + } + req.Header.Set(test.header, expectedValue) + + logger.ServeHTTP(httptest.NewRecorder(), req, http.HandlerFunc(func(writer http.ResponseWriter, r *http.Request) { + writer.WriteHeader(http.StatusOK) + })) + + logData, err := ioutil.ReadFile(logFile.Name()) + require.NoError(t, err) + + if test.expected == types.AccessLogDrop { + assert.NotContains(t, string(logData), expectedValue) + } else { + assert.Contains(t, string(logData), expectedValue) + } + }) + } +} + func TestLoggerCLF(t *testing.T) { tmpDir := createTempDir(t, CommonFormat) - defer os.RemoveAll(tmpDir) logFilePath := filepath.Join(tmpDir, logFileNameSuffix) config := &types.AccessLog{FilePath: logFilePath, Format: CommonFormat} @@ -136,7 +229,6 @@ func TestLoggerCLF(t *testing.T) { func TestAsyncLoggerCLF(t *testing.T) { tmpDir := createTempDir(t, CommonFormat) - defer os.RemoveAll(tmpDir) logFilePath := filepath.Join(tmpDir, logFileNameSuffix) config := &types.AccessLog{FilePath: logFilePath, Format: CommonFormat, BufferingSize: 1024} @@ -358,7 +450,6 @@ func TestLoggerJSON(t *testing.T) { t.Parallel() tmpDir := createTempDir(t, JSONFormat) - defer os.RemoveAll(tmpDir) logFilePath := filepath.Join(tmpDir, logFileNameSuffix) @@ -642,6 +733,8 @@ func createTempDir(t *testing.T, prefix string) string { tmpDir, err := ioutil.TempDir("", prefix) require.NoError(t, err, "failed to create temp dir") + t.Cleanup(func() { _ = os.RemoveAll(tmpDir) }) + return tmpDir }