Refactor getDefaultCertificate to use CertificateData

Signed-off-by: baalajimaestro <me@baalajimaestro.me>
This commit is contained in:
baalajimaestro 2022-10-23 12:55:43 +05:30
parent 2685d94660
commit b77524eadc
Signed by: baalajimaestro
GPG key ID: F93C394FE9BBAFD5
2 changed files with 12 additions and 10 deletions

View file

@ -23,7 +23,7 @@ type CertificateStore struct {
// NewCertificateStore create a store for dynamic certificates. // NewCertificateStore create a store for dynamic certificates.
func NewCertificateStore() *CertificateStore { func NewCertificateStore() *CertificateStore {
s := &safe.Safe{} s := &safe.Safe{}
s.Set(make(map[string]*tls.Certificate)) s.Set(make(map[string]*CertificateData))
return &CertificateStore{ return &CertificateStore{
DynamicCerts: s, DynamicCerts: s,
@ -118,7 +118,7 @@ func (c *CertificateStore) GetBestCertificate(clientHello *tls.ClientHelloInfo)
} }
// GetCertificate returns the first certificate matching all the given domains. // GetCertificate returns the first certificate matching all the given domains.
func (c *CertificateStore) GetCertificate(domains []string) *tls.Certificate { func (c *CertificateStore) GetCertificate(domains []string) *CertificateData {
if c == nil { if c == nil {
return nil return nil
} }
@ -127,11 +127,11 @@ func (c *CertificateStore) GetCertificate(domains []string) *tls.Certificate {
domainsKey := strings.Join(domains, ",") domainsKey := strings.Join(domains, ",")
if cert, ok := c.CertCache.Get(domainsKey); ok { if cert, ok := c.CertCache.Get(domainsKey); ok {
return cert.(*tls.Certificate) return cert.(*CertificateData)
} }
if c.DynamicCerts != nil && c.DynamicCerts.Get() != nil { if c.DynamicCerts != nil && c.DynamicCerts.Get() != nil {
for certDomains, cert := range c.DynamicCerts.Get().(map[string]*tls.Certificate) { for certDomains, cert := range c.DynamicCerts.Get().(map[string]*CertificateData) {
if domainsKey == certDomains { if domainsKey == certDomains {
c.CertCache.SetDefault(domainsKey, cert) c.CertCache.SetDefault(domainsKey, cert)
return cert return cert

View file

@ -130,7 +130,7 @@ func (m *Manager) UpdateConfigs(ctx context.Context, stores map[string]Store, co
log.FromContext(ctxStore).Errorf("Error while creating certificate store: %v", err) log.FromContext(ctxStore).Errorf("Error while creating certificate store: %v", err)
} }
st.DefaultCertificate = certificate st.DefaultCertificate = certificate.Certificate
} }
} }
@ -273,17 +273,19 @@ func (m *Manager) GetStore(storeName string) *CertificateStore {
return m.getStore(storeName) return m.getStore(storeName)
} }
func getDefaultCertificate(ctx context.Context, tlsStore Store, st *CertificateStore) (*tls.Certificate, error) { func getDefaultCertificate(ctx context.Context, tlsStore Store, st *CertificateStore) (*CertificateData, error) {
if tlsStore.DefaultCertificate != nil { if tlsStore.DefaultCertificate != nil {
cert, err := buildDefaultCertificate(tlsStore.DefaultCertificate) cert, err := buildDefaultCertificate(tlsStore.DefaultCertificate)
certificate := CertificateData{Certificate: cert}
if err != nil { if err != nil {
return nil, err return nil, err
} }
return cert, nil return &certificate, nil
} }
defaultCert, err := generate.DefaultCertificate() defaultCert, err := generate.DefaultCertificate()
defaultCertificate := CertificateData{Certificate: defaultCert}
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -291,19 +293,19 @@ func getDefaultCertificate(ctx context.Context, tlsStore Store, st *CertificateS
if tlsStore.DefaultGeneratedCert != nil && tlsStore.DefaultGeneratedCert.Domain != nil && tlsStore.DefaultGeneratedCert.Resolver != "" { if tlsStore.DefaultGeneratedCert != nil && tlsStore.DefaultGeneratedCert.Domain != nil && tlsStore.DefaultGeneratedCert.Resolver != "" {
domains, err := sanitizeDomains(*tlsStore.DefaultGeneratedCert.Domain) domains, err := sanitizeDomains(*tlsStore.DefaultGeneratedCert.Domain)
if err != nil { if err != nil {
return defaultCert, fmt.Errorf("falling back to the internal generated certificate because invalid domains: %w", err) return &defaultCertificate, fmt.Errorf("falling back to the internal generated certificate because invalid domains: %w", err)
} }
defaultACMECert := st.GetCertificate(domains) defaultACMECert := st.GetCertificate(domains)
if defaultACMECert == nil { if defaultACMECert == nil {
return defaultCert, fmt.Errorf("unable to find certificate for domains %q: falling back to the internal generated certificate", strings.Join(domains, ",")) return &defaultCertificate, fmt.Errorf("unable to find certificate for domains %q: falling back to the internal generated certificate", strings.Join(domains, ","))
} }
return defaultACMECert, nil return defaultACMECert, nil
} }
log.FromContext(ctx).Debug("No default certificate, fallback to the internal generated certificate") log.FromContext(ctx).Debug("No default certificate, fallback to the internal generated certificate")
return defaultCert, nil return &defaultCertificate, nil
} }
// creates a TLS config that allows terminating HTTPS for multiple domains using SNI. // creates a TLS config that allows terminating HTTPS for multiple domains using SNI.