From 39a3cefc216e6c30b625de17119f73b39fd984c6 Mon Sep 17 00:00:00 2001 From: Ludovic Fernandez Date: Mon, 9 Dec 2019 12:20:06 +0100 Subject: [PATCH] fix: PassClientTLSCert middleware separators and formatting --- docs/content/middlewares/passtlsclientcert.md | 12 +- .../passtlsclientcert/pass_tls_client_cert.go | 251 +++++++------ .../pass_tls_client_cert_test.go | 333 +++++++++--------- 3 files changed, 303 insertions(+), 293 deletions(-) diff --git a/docs/content/middlewares/passtlsclientcert.md b/docs/content/middlewares/passtlsclientcert.md index 9da4dbed0..137336ba1 100644 --- a/docs/content/middlewares/passtlsclientcert.md +++ b/docs/content/middlewares/passtlsclientcert.md @@ -380,7 +380,7 @@ In the example, it is the part between `-----BEGIN CERTIFICATE-----` and `-----E !!! info "Extracted data" The delimiters and `\n` will be removed. - If there are more than one certificate, they are separated by a "`;`". + If there are more than one certificate, they are separated by a "`,`". !!! warning "`X-Forwarded-Tls-Client-Cert` value could exceed the web server header size limit" @@ -395,12 +395,12 @@ The value of the header will be an escaped concatenation of all the selected cer The following example shows an unescaped result that uses all the available fields: ```text -Subject="DC=org,DC=cheese,C=FR,C=US,ST=Cheese org state,ST=Cheese com state,L=TOULOUSE,L=LYON,O=Cheese,O=Cheese 2,CN=*.cheese.com",Issuer="DC=org,DC=cheese,C=FR,C=US,ST=Signing State,ST=Signing State 2,L=TOULOUSE,L=LYON,O=Cheese,O=Cheese 2,CN=Simple Signing CA 2",NB=1544094616,NA=1607166616,SAN=*.cheese.org,*.cheese.net,*.cheese.com,test@cheese.org,test@cheese.net,10.0.1.0,10.0.1.2 +Subject="DC=org,DC=cheese,C=FR,C=US,ST=Cheese org state,ST=Cheese com state,L=TOULOUSE,L=LYON,O=Cheese,O=Cheese 2,CN=*.cheese.com";Issuer="DC=org,DC=cheese,C=FR,C=US,ST=Signing State,ST=Signing State 2,L=TOULOUSE,L=LYON,O=Cheese,O=Cheese 2,CN=Simple Signing CA 2";NB="1544094616";NA="1607166616";SAN="*.cheese.org,*.cheese.net,*.cheese.com,test@cheese.org,test@cheese.net,10.0.1.0,10.0.1.2" ``` !!! info "Multiple certificates" - If there are more than one certificate, they are separated by a `;`. + If there are more than one certificate, they are separated by a `,`. #### `info.notAfter` @@ -416,7 +416,7 @@ The data are taken from the following certificate part: The escape `notAfter` info part will be like: ```text -NA=1607166616 +NA="1607166616" ``` #### `info.notBefore` @@ -433,7 +433,7 @@ Validity The escape `notBefore` info part will be like: ```text -NB=1544094616 +NB="1544094616" ``` #### `info.sans` @@ -450,7 +450,7 @@ The data are taken from the following certificate part: The escape SANs info part will be like: ```text -SAN=*.cheese.org,*.cheese.net,*.cheese.com,test@cheese.org,test@cheese.net,10.0.1.0,10.0.1.2 +SAN="*.cheese.org,*.cheese.net,*.cheese.com,test@cheese.org,test@cheese.net,10.0.1.0,10.0.1.2" ``` !!! info "multiple values" diff --git a/pkg/middlewares/passtlsclientcert/pass_tls_client_cert.go b/pkg/middlewares/passtlsclientcert/pass_tls_client_cert.go index a07bb9a73..2b5e29dd2 100644 --- a/pkg/middlewares/passtlsclientcert/pass_tls_client_cert.go +++ b/pkg/middlewares/passtlsclientcert/pass_tls_client_cert.go @@ -18,10 +18,17 @@ import ( "github.com/opentracing/opentracing-go/ext" ) +const typeName = "PassClientTLSCert" + const ( xForwardedTLSClientCert = "X-Forwarded-Tls-Client-Cert" xForwardedTLSClientCertInfo = "X-Forwarded-Tls-Client-Cert-Info" - typeName = "PassClientTLSCert" +) + +const ( + certSeparator = "," + fieldSeparator = ";" + subFieldSeparator = "," ) var attributeTypeNames = map[string]string{ @@ -55,6 +62,29 @@ func newDistinguishedNameOptions(info *dynamic.TLSCLientCertificateDNInfo) *Dist } } +// tlsClientCertificateInfo is a struct for specifying the configuration for the passTLSClientCert middleware. +type tlsClientCertificateInfo struct { + notAfter bool + notBefore bool + sans bool + subject *DistinguishedNameOptions + issuer *DistinguishedNameOptions +} + +func newTLSClientCertificateInfo(info *dynamic.TLSClientCertificateInfo) *tlsClientCertificateInfo { + if info == nil { + return nil + } + + return &tlsClientCertificateInfo{ + issuer: newDistinguishedNameOptions(info.Issuer), + notAfter: info.NotAfter, + notBefore: info.NotBefore, + subject: newDistinguishedNameOptions(info.Subject), + sans: info.Sans, + } +} + // passTLSClientCert is a middleware that helps setup a few tls info features. type passTLSClientCert struct { next http.Handler @@ -71,45 +101,84 @@ func New(ctx context.Context, next http.Handler, config dynamic.PassTLSClientCer next: next, name: name, pem: config.PEM, - info: newTLSClientInfo(config.Info), + info: newTLSClientCertificateInfo(config.Info), }, nil } -// tlsClientCertificateInfo is a struct for specifying the configuration for the passTLSClientCert middleware. -type tlsClientCertificateInfo struct { - notAfter bool - notBefore bool - sans bool - subject *DistinguishedNameOptions - issuer *DistinguishedNameOptions -} - -func newTLSClientInfo(info *dynamic.TLSClientCertificateInfo) *tlsClientCertificateInfo { - if info == nil { - return nil - } - - return &tlsClientCertificateInfo{ - issuer: newDistinguishedNameOptions(info.Issuer), - notAfter: info.NotAfter, - notBefore: info.NotBefore, - subject: newDistinguishedNameOptions(info.Subject), - sans: info.Sans, - } -} - func (p *passTLSClientCert) GetTracingInformation() (string, ext.SpanKindEnum) { return p.name, tracing.SpanKindNoneEnum } func (p *passTLSClientCert) ServeHTTP(rw http.ResponseWriter, req *http.Request) { ctx := middlewares.GetLoggerCtx(req.Context(), p.name, typeName) + logger := log.FromContext(ctx) + + if p.pem { + if req.TLS != nil && len(req.TLS.PeerCertificates) > 0 { + req.Header.Set(xForwardedTLSClientCert, getCertificates(ctx, req.TLS.PeerCertificates)) + } else { + logger.Warn("Tried to extract a certificate on a request without mutual TLS") + } + } + + if p.info != nil { + if req.TLS != nil && len(req.TLS.PeerCertificates) > 0 { + headerContent := p.getCertInfo(ctx, req.TLS.PeerCertificates) + req.Header.Set(xForwardedTLSClientCertInfo, url.QueryEscape(headerContent)) + } else { + logger.Warn("Tried to extract a certificate on a request without mutual TLS") + } + } - p.modifyRequestHeaders(ctx, req) p.next.ServeHTTP(rw, req) } -func getDNInfo(ctx context.Context, prefix string, options *DistinguishedNameOptions, cs *pkix.Name) string { +// getCertInfo Build a string with the wanted client certificates information +// - the `,` is used to separate certificates +// - the `;` is used to separate root fields +// - the value of root fields is always wrapped by double quote +// - if a field is empty, the field is ignored +func (p *passTLSClientCert) getCertInfo(ctx context.Context, certs []*x509.Certificate) string { + var headerValues []string + + for _, peerCert := range certs { + var values []string + + if p.info != nil { + subject := getDNInfo(ctx, p.info.subject, &peerCert.Subject) + if subject != "" { + values = append(values, fmt.Sprintf(`Subject="%s"`, strings.TrimSuffix(subject, subFieldSeparator))) + } + + issuer := getDNInfo(ctx, p.info.issuer, &peerCert.Issuer) + if issuer != "" { + values = append(values, fmt.Sprintf(`Issuer="%s"`, strings.TrimSuffix(issuer, subFieldSeparator))) + } + + if p.info.notBefore { + values = append(values, fmt.Sprintf(`NB="%d"`, uint64(peerCert.NotBefore.Unix()))) + } + + if p.info.notAfter { + values = append(values, fmt.Sprintf(`NA="%d"`, uint64(peerCert.NotAfter.Unix()))) + } + + if p.info.sans { + sans := getSANs(peerCert) + if len(sans) > 0 { + values = append(values, fmt.Sprintf(`SAN="%s"`, strings.Join(sans, subFieldSeparator))) + } + } + } + + value := strings.Join(values, fieldSeparator) + headerValues = append(headerValues, value) + } + + return strings.Join(headerValues, certSeparator) +} + +func getDNInfo(ctx context.Context, options *DistinguishedNameOptions, cs *pkix.Name) string { if options == nil { return "" } @@ -120,7 +189,7 @@ func getDNInfo(ctx context.Context, prefix string, options *DistinguishedNameOpt for _, name := range cs.Names { // Domain Component - RFC 2247 if options.DomainComponent && attributeTypeNames[name.Type.String()] == "DC" { - content.WriteString(fmt.Sprintf("DC=%s,", name.Value)) + content.WriteString(fmt.Sprintf("DC=%s%s", name.Value, subFieldSeparator)) } } @@ -148,11 +217,7 @@ func getDNInfo(ctx context.Context, prefix string, options *DistinguishedNameOpt writePart(ctx, content, cs.CommonName, "CN") } - if content.Len() > 0 { - return prefix + `="` + strings.TrimSuffix(content.String(), ",") + `"` - } - - return "" + return content.String() } func writeParts(ctx context.Context, content io.StringWriter, entries []string, prefix string) { @@ -163,135 +228,63 @@ func writeParts(ctx context.Context, content io.StringWriter, entries []string, func writePart(ctx context.Context, content io.StringWriter, entry string, prefix string) { if len(entry) > 0 { - _, err := content.WriteString(fmt.Sprintf("%s=%s,", prefix, entry)) + _, err := content.WriteString(fmt.Sprintf("%s=%s%s", prefix, entry, subFieldSeparator)) if err != nil { log.FromContext(ctx).Error(err) } } } -// getXForwardedTLSClientCertInfo Build a string with the wanted client certificates information -// like Subject="C=%s,ST=%s,L=%s,O=%s,CN=%s",NB=%d,NA=%d,SAN=%s; -func (p *passTLSClientCert) getXForwardedTLSClientCertInfo(ctx context.Context, certs []*x509.Certificate) string { - var headerValues []string - - for _, peerCert := range certs { - var values []string - var sans string - var nb string - var na string - - if p.info != nil { - subject := getDNInfo(ctx, "Subject", p.info.subject, &peerCert.Subject) - if len(subject) > 0 { - values = append(values, subject) - } - - issuer := getDNInfo(ctx, "Issuer", p.info.issuer, &peerCert.Issuer) - if len(issuer) > 0 { - values = append(values, issuer) - } - } - - ci := p.info - if ci != nil { - if ci.notBefore { - nb = fmt.Sprintf("NB=%d", uint64(peerCert.NotBefore.Unix())) - values = append(values, nb) - } - if ci.notAfter { - na = fmt.Sprintf("NA=%d", uint64(peerCert.NotAfter.Unix())) - values = append(values, na) - } - - if ci.sans { - sans = fmt.Sprintf("SAN=%s", strings.Join(getSANs(peerCert), ",")) - values = append(values, sans) - } - } - - value := strings.Join(values, ",") - headerValues = append(headerValues, value) - } - - return strings.Join(headerValues, ";") -} - -// modifyRequestHeaders set the wanted headers with the certificates information. -func (p *passTLSClientCert) modifyRequestHeaders(ctx context.Context, r *http.Request) { - logger := log.FromContext(ctx) - - if p.pem { - if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { - r.Header.Set(xForwardedTLSClientCert, getXForwardedTLSClientCert(ctx, r.TLS.PeerCertificates)) - } else { - logger.Warn("Tried to extract a certificate on a request without mutual TLS") - } - } - - if p.info != nil { - if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { - headerContent := p.getXForwardedTLSClientCertInfo(ctx, r.TLS.PeerCertificates) - r.Header.Set(xForwardedTLSClientCertInfo, url.QueryEscape(headerContent)) - } else { - logger.Warn("Tried to extract a certificate on a request without mutual TLS") - } - } -} - // sanitize As we pass the raw certificates, remove the useless data and make it http request compliant. func sanitize(cert []byte) string { - s := string(cert) - r := strings.NewReplacer("-----BEGIN CERTIFICATE-----", "", + cleaned := strings.NewReplacer( + "-----BEGIN CERTIFICATE-----", "", "-----END CERTIFICATE-----", "", - "\n", "") - cleaned := r.Replace(s) + "\n", "", + ).Replace(string(cert)) return url.QueryEscape(cleaned) } -// extractCertificate extract the certificate from the request. -func extractCertificate(ctx context.Context, cert *x509.Certificate) string { - b := pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw} - certPEM := pem.EncodeToMemory(&b) - if certPEM == nil { - log.FromContext(ctx).Error("Cannot extract the certificate content") - return "" - } - return sanitize(certPEM) -} - -// getXForwardedTLSClientCert Build a string with the client certificates. -func getXForwardedTLSClientCert(ctx context.Context, certs []*x509.Certificate) string { +// getCertificates Build a string with the client certificates. +func getCertificates(ctx context.Context, certs []*x509.Certificate) string { var headerValues []string for _, peerCert := range certs { headerValues = append(headerValues, extractCertificate(ctx, peerCert)) } - return strings.Join(headerValues, ",") + return strings.Join(headerValues, certSeparator) +} + +// extractCertificate extract the certificate from the request. +func extractCertificate(ctx context.Context, cert *x509.Certificate) string { + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}) + if certPEM == nil { + log.FromContext(ctx).Error("Cannot extract the certificate content") + return "" + } + + return sanitize(certPEM) } // getSANs get the Subject Alternate Name values. func getSANs(cert *x509.Certificate) []string { - var sans []string if cert == nil { - return sans + return nil } + var sans []string sans = append(sans, cert.DNSNames...) sans = append(sans, cert.EmailAddresses...) - var ips []string for _, ip := range cert.IPAddresses { - ips = append(ips, ip.String()) + sans = append(sans, ip.String()) } - sans = append(sans, ips...) - var uris []string for _, uri := range cert.URIs { - uris = append(uris, uri.String()) + sans = append(sans, uri.String()) } - return append(sans, uris...) + return sans } diff --git a/pkg/middlewares/passtlsclientcert/pass_tls_client_cert_test.go b/pkg/middlewares/passtlsclientcert/pass_tls_client_cert_test.go index 8bf5b2e50..49497080e 100644 --- a/pkg/middlewares/passtlsclientcert/pass_tls_client_cert_test.go +++ b/pkg/middlewares/passtlsclientcert/pass_tls_client_cert_test.go @@ -15,6 +15,7 @@ import ( "github.com/containous/traefik/v2/pkg/config/dynamic" "github.com/containous/traefik/v2/pkg/testhelpers" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -113,6 +114,7 @@ Cg+XKmHzexmTnKaKac2w9ZECpRsQ9IBdQq9OghIwPtOnERTOUJEEgNcqA+9xELjb pQ== -----END CERTIFICATE----- ` + minimalCheeseCrt = `-----BEGIN CERTIFICATE----- MIIEQDCCAygCFFRY0OBk/L5Se0IZRj3CMljawL2UMA0GCSqGSIb3DQEBCwUAMIIB hDETMBEGCgmSJomT8ixkARkWA29yZzEWMBQGCgmSJomT8ixkARkWBmNoZWVzZTEP @@ -262,47 +264,6 @@ jECvgAY7Nfd9mZ1KtyNaW31is+kag7NsvjxU/kM= -----END CERTIFICATE-----` ) -func getCleanCertContents(certContents []string) string { - var re = regexp.MustCompile("-----BEGIN CERTIFICATE-----(?s)(.*)") - - var cleanedCertContent []string - for _, certContent := range certContents { - cert := re.FindString(certContent) - cleanedCertContent = append(cleanedCertContent, sanitize([]byte(cert))) - } - - return strings.Join(cleanedCertContent, ",") -} - -func getCertificate(certContent string) *x509.Certificate { - roots := x509.NewCertPool() - ok := roots.AppendCertsFromPEM([]byte(signingCA)) - if !ok { - panic("failed to parse root certificate") - } - - block, _ := pem.Decode([]byte(certContent)) - if block == nil { - panic("failed to parse certificate PEM") - } - cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - panic("failed to parse certificate: " + err.Error()) - } - - return cert -} - -func buildTLSWith(certContents []string) *tls.ConnectionState { - var peerCertificates []*x509.Certificate - - for _, certContent := range certContents { - peerCertificates = append(peerCertificates, getCertificate(certContent)) - } - - return &tls.ConnectionState{PeerCertificates: peerCertificates} -} - var next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := w.Write([]byte("bar")) if err != nil { @@ -310,59 +271,7 @@ var next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { } }) -func getExpectedSanitized(s string) string { - return url.QueryEscape(strings.Replace(s, "\n", "", -1)) -} - -func TestSanitize(t *testing.T) { - testCases := []struct { - desc string - toSanitize []byte - expected string - }{ - { - desc: "Empty", - }, - { - desc: "With a minimal cert", - toSanitize: []byte(minimalCheeseCrt), - expected: getExpectedSanitized(`MIIEQDCCAygCFFRY0OBk/L5Se0IZRj3CMljawL2UMA0GCSqGSIb3DQEBCwUAMIIB -hDETMBEGCgmSJomT8ixkARkWA29yZzEWMBQGCgmSJomT8ixkARkWBmNoZWVzZTEP -MA0GA1UECgwGQ2hlZXNlMREwDwYDVQQKDAhDaGVlc2UgMjEfMB0GA1UECwwWU2lt -cGxlIFNpZ25pbmcgU2VjdGlvbjEhMB8GA1UECwwYU2ltcGxlIFNpZ25pbmcgU2Vj -dGlvbiAyMRowGAYDVQQDDBFTaW1wbGUgU2lnbmluZyBDQTEcMBoGA1UEAwwTU2lt -cGxlIFNpZ25pbmcgQ0EgMjELMAkGA1UEBhMCRlIxCzAJBgNVBAYTAlVTMREwDwYD -VQQHDAhUT1VMT1VTRTENMAsGA1UEBwwETFlPTjEWMBQGA1UECAwNU2lnbmluZyBT -dGF0ZTEYMBYGA1UECAwPU2lnbmluZyBTdGF0ZSAyMSEwHwYJKoZIhvcNAQkBFhJz -aW1wbGVAc2lnbmluZy5jb20xIjAgBgkqhkiG9w0BCQEWE3NpbXBsZTJAc2lnbmlu -Zy5jb20wHhcNMTgxMjA2MTExMDM2WhcNMjEwOTI1MTExMDM2WjAzMQswCQYDVQQG -EwJGUjETMBEGA1UECAwKU29tZS1TdGF0ZTEPMA0GA1UECgwGQ2hlZXNlMIIBIjAN -BgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAskX/bUtwFo1gF2BTPNaNcTUMaRFu -FMZozK8IgLjccZ4kZ0R9oFO6Yp8Zl/IvPaf7tE26PI7XP7eHriUdhnQzX7iioDd0 -RZa68waIhAGc+xPzRFrP3b3yj3S2a9Rve3c0K+SCV+EtKAwsxMqQDhoo9PcBfo5B -RHfht07uD5MncUcGirwN+/pxHV5xzAGPcc7On0/5L7bq/G+63nhu78zw9XyuLaHC -PM5VbOUvpyIESJHbMMzTdFGL8ob9VKO+Kr1kVGdEA9i8FLGl3xz/GBKuW/JD0xyW -DrU29mri5vYWHmkuv7ZWHGXnpXjTtPHwveE9/0/ArnmpMyR9JtqFr1oEvQIDAQAB -MA0GCSqGSIb3DQEBCwUAA4IBAQBHta+NWXI08UHeOkGzOTGRiWXsOH2dqdX6gTe9 -xF1AIjyoQ0gvpoGVvlnChSzmlUj+vnx/nOYGIt1poE3hZA3ZHZD/awsvGyp3GwWD -IfXrEViSCIyF+8tNNKYyUcEO3xdAsAUGgfUwwF/mZ6MBV5+A/ZEEILlTq8zFt9dV -vdKzIt7fZYxYBBHFSarl1x8pDgWXlf3hAufevGJXip9xGYmznF0T5cq1RbWJ4be3 -/9K7yuWhuBYC3sbTbCneHBa91M82za+PIISc1ygCYtWSBoZKSAqLk0rkZpHaekDP -WqeUSNGYV//RunTeuRDAf5OxehERb1srzBXhRZ3cZdzXbgR/`), - }, - } - - for _, test := range testCases { - test := test - t.Run(test.desc, func(t *testing.T) { - t.Parallel() - - require.Equal(t, test.expected, sanitize(test.toSanitize), "The sanitized certificates should be equal") - }) - } -} - -func TestTLSClientHeadersWithPEM(t *testing.T) { +func TestPassTLSClientCert_PEM(t *testing.T) { testCases := []struct { desc string certContents []string // set the request TLS attribute if defined @@ -417,70 +326,36 @@ func TestTLSClientHeadersWithPEM(t *testing.T) { t.Run(test.desc, func(t *testing.T) { t.Parallel() - require.Equal(t, http.StatusOK, res.Code, "Http Status should be OK") - require.Equal(t, "bar", res.Body.String(), "Should be the expected body") + assert.Equal(t, http.StatusOK, res.Code, "Http Status should be OK") + assert.Equal(t, "bar", res.Body.String(), "Should be the expected body") if test.expectedHeader != "" { - require.Equal(t, getCleanCertContents(test.certContents), req.Header.Get(xForwardedTLSClientCert), "The request header should contain the cleaned certificate") + expected := getCleanCertContents(test.certContents) + assert.Equal(t, expected, req.Header.Get(xForwardedTLSClientCert), "The request header should contain the cleaned certificate") } else { - require.Empty(t, req.Header.Get(xForwardedTLSClientCert)) + assert.Empty(t, req.Header.Get(xForwardedTLSClientCert)) } - require.Empty(t, res.Header().Get(xForwardedTLSClientCert), "The response header should be always empty") + + assert.Empty(t, res.Header().Get(xForwardedTLSClientCert), "The response header should be always empty") }) } } -func TestGetSans(t *testing.T) { - urlFoo, err := url.Parse("my.foo.com") - require.NoError(t, err) - urlBar, err := url.Parse("my.bar.com") - require.NoError(t, err) +func TestPassTLSClientCert_certInfo(t *testing.T) { + minimalCheeseCertAllInfo := strings.Join([]string{ + `Subject="C=FR,ST=Some-State,O=Cheese"`, + `Issuer="DC=org,DC=cheese,C=FR,C=US,ST=Signing State,ST=Signing State 2,L=TOULOUSE,L=LYON,O=Cheese,O=Cheese 2,CN=Simple Signing CA 2"`, + `NB="1544094636"`, + `NA="1632568236"`, + }, fieldSeparator) - testCases := []struct { - desc string - cert *x509.Certificate // set the request TLS attribute if defined - expected []string - }{ - { - desc: "With nil", - }, - { - desc: "Certificate without Sans", - cert: &x509.Certificate{}, - }, - { - desc: "Certificate with all Sans", - cert: &x509.Certificate{ - DNSNames: []string{"foo", "bar"}, - EmailAddresses: []string{"test@test.com", "test2@test.com"}, - IPAddresses: []net.IP{net.IPv4(10, 0, 0, 1), net.IPv4(10, 0, 0, 2)}, - URIs: []*url.URL{urlFoo, urlBar}, - }, - expected: []string{"foo", "bar", "test@test.com", "test2@test.com", "10.0.0.1", "10.0.0.2", urlFoo.String(), urlBar.String()}, - }, - } - - for _, test := range testCases { - sans := getSANs(test.cert) - - test := test - t.Run(test.desc, func(t *testing.T) { - t.Parallel() - - if len(test.expected) > 0 { - for i, expected := range test.expected { - require.Equal(t, expected, sans[i]) - } - } else { - require.Empty(t, sans) - } - }) - } -} - -func TestTLSClientHeadersWithCertInfo(t *testing.T) { - minimalCheeseCertAllInfo := `Subject="C=FR,ST=Some-State,O=Cheese",Issuer="DC=org,DC=cheese,C=FR,C=US,ST=Signing State,ST=Signing State 2,L=TOULOUSE,L=LYON,O=Cheese,O=Cheese 2,CN=Simple Signing CA 2",NB=1544094636,NA=1632568236,SAN=` - completeCertAllInfo := `Subject="DC=org,DC=cheese,C=FR,C=US,ST=Cheese org state,ST=Cheese com state,L=TOULOUSE,L=LYON,O=Cheese,O=Cheese 2,CN=*.cheese.com",Issuer="DC=org,DC=cheese,C=FR,C=US,ST=Signing State,ST=Signing State 2,L=TOULOUSE,L=LYON,O=Cheese,O=Cheese 2,CN=Simple Signing CA 2",NB=1544094616,NA=1607166616,SAN=*.cheese.org,*.cheese.net,*.cheese.com,test@cheese.org,test@cheese.net,10.0.1.0,10.0.1.2` + completeCertAllInfo := strings.Join([]string{ + `Subject="DC=org,DC=cheese,C=FR,C=US,ST=Cheese org state,ST=Cheese com state,L=TOULOUSE,L=LYON,O=Cheese,O=Cheese 2,CN=*.cheese.com"`, + `Issuer="DC=org,DC=cheese,C=FR,C=US,ST=Signing State,ST=Signing State 2,L=TOULOUSE,L=LYON,O=Cheese,O=Cheese 2,CN=Simple Signing CA 2"`, + `NB="1544094616"`, + `NA="1607166616"`, + `SAN="*.cheese.org,*.cheese.net,*.cheese.com,test@cheese.org,test@cheese.net,10.0.1.0,10.0.1.2"`, + }, fieldSeparator) testCases := []struct { desc string @@ -547,7 +422,7 @@ func TestTLSClientHeadersWithCertInfo(t *testing.T) { }, }, }, - expectedHeader: url.QueryEscape(minimalCheeseCertAllInfo), + expectedHeader: minimalCheeseCertAllInfo, }, { desc: "TLS with simple certificate, with some info", @@ -564,7 +439,7 @@ func TestTLSClientHeadersWithCertInfo(t *testing.T) { }, }, }, - expectedHeader: url.QueryEscape(`Subject="O=Cheese",Issuer="C=FR,C=US",NA=1632568236,SAN=`), + expectedHeader: `Subject="O=Cheese";Issuer="C=FR,C=US";NA="1632568236"`, }, { desc: "TLS with complete certificate, with all info", @@ -594,7 +469,7 @@ func TestTLSClientHeadersWithCertInfo(t *testing.T) { }, }, }, - expectedHeader: url.QueryEscape(completeCertAllInfo), + expectedHeader: completeCertAllInfo, }, { desc: "TLS with 2 certificates, with all info", @@ -624,7 +499,7 @@ func TestTLSClientHeadersWithCertInfo(t *testing.T) { }, }, }, - expectedHeader: url.QueryEscape(strings.Join([]string{minimalCheeseCertAllInfo, completeCertAllInfo}, ";")), + expectedHeader: strings.Join([]string{minimalCheeseCertAllInfo, completeCertAllInfo}, certSeparator), }, } @@ -645,15 +520,157 @@ func TestTLSClientHeadersWithCertInfo(t *testing.T) { t.Run(test.desc, func(t *testing.T) { t.Parallel() - require.Equal(t, http.StatusOK, res.Code, "Http Status should be OK") - require.Equal(t, "bar", res.Body.String(), "Should be the expected body") + assert.Equal(t, http.StatusOK, res.Code, "Http Status should be OK") + assert.Equal(t, "bar", res.Body.String(), "Should be the expected body") if test.expectedHeader != "" { - require.Equal(t, test.expectedHeader, req.Header.Get(xForwardedTLSClientCertInfo), "The request header should contain the cleaned certificate") + unescape, err := url.QueryUnescape(req.Header.Get(xForwardedTLSClientCertInfo)) + require.NoError(t, err) + assert.Equal(t, test.expectedHeader, unescape, "The request header should contain the cleaned certificate") } else { - require.Empty(t, req.Header.Get(xForwardedTLSClientCertInfo)) + assert.Empty(t, req.Header.Get(xForwardedTLSClientCertInfo)) } - require.Empty(t, res.Header().Get(xForwardedTLSClientCertInfo), "The response header should be always empty") + + assert.Empty(t, res.Header().Get(xForwardedTLSClientCertInfo), "The response header should be always empty") }) } } + +func Test_sanitize(t *testing.T) { + testCases := []struct { + desc string + toSanitize []byte + expected string + }{ + { + desc: "Empty", + }, + { + desc: "With a minimal cert", + toSanitize: []byte(minimalCheeseCrt), + expected: `MIIEQDCCAygCFFRY0OBk/L5Se0IZRj3CMljawL2UMA0GCSqGSIb3DQEBCwUAMIIB +hDETMBEGCgmSJomT8ixkARkWA29yZzEWMBQGCgmSJomT8ixkARkWBmNoZWVzZTEP +MA0GA1UECgwGQ2hlZXNlMREwDwYDVQQKDAhDaGVlc2UgMjEfMB0GA1UECwwWU2lt +cGxlIFNpZ25pbmcgU2VjdGlvbjEhMB8GA1UECwwYU2ltcGxlIFNpZ25pbmcgU2Vj +dGlvbiAyMRowGAYDVQQDDBFTaW1wbGUgU2lnbmluZyBDQTEcMBoGA1UEAwwTU2lt +cGxlIFNpZ25pbmcgQ0EgMjELMAkGA1UEBhMCRlIxCzAJBgNVBAYTAlVTMREwDwYD +VQQHDAhUT1VMT1VTRTENMAsGA1UEBwwETFlPTjEWMBQGA1UECAwNU2lnbmluZyBT +dGF0ZTEYMBYGA1UECAwPU2lnbmluZyBTdGF0ZSAyMSEwHwYJKoZIhvcNAQkBFhJz +aW1wbGVAc2lnbmluZy5jb20xIjAgBgkqhkiG9w0BCQEWE3NpbXBsZTJAc2lnbmlu +Zy5jb20wHhcNMTgxMjA2MTExMDM2WhcNMjEwOTI1MTExMDM2WjAzMQswCQYDVQQG +EwJGUjETMBEGA1UECAwKU29tZS1TdGF0ZTEPMA0GA1UECgwGQ2hlZXNlMIIBIjAN +BgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAskX/bUtwFo1gF2BTPNaNcTUMaRFu +FMZozK8IgLjccZ4kZ0R9oFO6Yp8Zl/IvPaf7tE26PI7XP7eHriUdhnQzX7iioDd0 +RZa68waIhAGc+xPzRFrP3b3yj3S2a9Rve3c0K+SCV+EtKAwsxMqQDhoo9PcBfo5B +RHfht07uD5MncUcGirwN+/pxHV5xzAGPcc7On0/5L7bq/G+63nhu78zw9XyuLaHC +PM5VbOUvpyIESJHbMMzTdFGL8ob9VKO+Kr1kVGdEA9i8FLGl3xz/GBKuW/JD0xyW +DrU29mri5vYWHmkuv7ZWHGXnpXjTtPHwveE9/0/ArnmpMyR9JtqFr1oEvQIDAQAB +MA0GCSqGSIb3DQEBCwUAA4IBAQBHta+NWXI08UHeOkGzOTGRiWXsOH2dqdX6gTe9 +xF1AIjyoQ0gvpoGVvlnChSzmlUj+vnx/nOYGIt1poE3hZA3ZHZD/awsvGyp3GwWD +IfXrEViSCIyF+8tNNKYyUcEO3xdAsAUGgfUwwF/mZ6MBV5+A/ZEEILlTq8zFt9dV +vdKzIt7fZYxYBBHFSarl1x8pDgWXlf3hAufevGJXip9xGYmznF0T5cq1RbWJ4be3 +/9K7yuWhuBYC3sbTbCneHBa91M82za+PIISc1ygCYtWSBoZKSAqLk0rkZpHaekDP +WqeUSNGYV//RunTeuRDAf5OxehERb1srzBXhRZ3cZdzXbgR/`, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + content := sanitize(test.toSanitize) + + expected := url.QueryEscape(strings.Replace(test.expected, "\n", "", -1)) + assert.Equal(t, expected, content, "The sanitized certificates should be equal") + }) + } +} + +func Test_getSANs(t *testing.T) { + urlFoo := testhelpers.MustParseURL("my.foo.com") + urlBar := testhelpers.MustParseURL("my.bar.com") + + testCases := []struct { + desc string + cert *x509.Certificate // set the request TLS attribute if defined + expected []string + }{ + { + desc: "With nil", + }, + { + desc: "Certificate without Sans", + cert: &x509.Certificate{}, + }, + { + desc: "Certificate with all Sans", + cert: &x509.Certificate{ + DNSNames: []string{"foo", "bar"}, + EmailAddresses: []string{"test@test.com", "test2@test.com"}, + IPAddresses: []net.IP{net.IPv4(10, 0, 0, 1), net.IPv4(10, 0, 0, 2)}, + URIs: []*url.URL{urlFoo, urlBar}, + }, + expected: []string{"foo", "bar", "test@test.com", "test2@test.com", "10.0.0.1", "10.0.0.2", urlFoo.String(), urlBar.String()}, + }, + } + + for _, test := range testCases { + sans := getSANs(test.cert) + + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + if len(test.expected) > 0 { + for i, expected := range test.expected { + assert.Equal(t, expected, sans[i]) + } + } else { + assert.Empty(t, sans) + } + }) + } +} + +func getCleanCertContents(certContents []string) string { + exp := regexp.MustCompile("-----BEGIN CERTIFICATE-----(?s)(.*)") + + var cleanedCertContent []string + for _, certContent := range certContents { + cert := sanitize([]byte(exp.FindString(certContent))) + cleanedCertContent = append(cleanedCertContent, cert) + } + + return strings.Join(cleanedCertContent, certSeparator) +} + +func buildTLSWith(certContents []string) *tls.ConnectionState { + var peerCertificates []*x509.Certificate + + for _, certContent := range certContents { + peerCertificates = append(peerCertificates, getCertificate(certContent)) + } + + return &tls.ConnectionState{PeerCertificates: peerCertificates} +} + +func getCertificate(certContent string) *x509.Certificate { + roots := x509.NewCertPool() + ok := roots.AppendCertsFromPEM([]byte(signingCA)) + if !ok { + panic("failed to parse root certificate") + } + + block, _ := pem.Decode([]byte(certContent)) + if block == nil { + panic("failed to parse certificate PEM") + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + panic("failed to parse certificate: " + err.Error()) + } + + return cert +}