diff --git a/acme/acme.go b/acme/acme.go index 07e881270..7f75b8ad8 100644 --- a/acme/acme.go +++ b/acme/acme.go @@ -166,9 +166,12 @@ type ACME struct { Domains []Domain `description:"SANs (alternative domains) to each main domain using format: --acme.domains='main.com,san1.com,san2.com' --acme.domains='main.net,san1.net,san2.net'"` StorageFile string `description:"File used for certificates storage."` OnDemand bool `description:"Enable on demand certificate. This will request a certificate from Let's Encrypt during the first TLS handshake for a hostname that does not yet have a certificate."` + OnHostRule bool `description:"Enable certificate generation on frontends Host rules."` CAServer string `description:"CA server to use."` EntryPoint string `description:"Entrypoint to proxy acme challenge to."` storageLock sync.RWMutex + client *acme.Client + account *Account } //Domains parse []Domain @@ -229,14 +232,14 @@ func (a *ACME) CreateConfig(tlsConfig *tls.Config, CheckOnDemandDomain func(doma } tlsConfig.Certificates = append(tlsConfig.Certificates, *cert) } - var account *Account var needRegister bool + var err error // if certificates in storage, load them if fileInfo, err := os.Stat(a.StorageFile); err == nil && fileInfo.Size() != 0 { log.Infof("Loading ACME certificates...") // load account - account, err = a.loadAccount(a) + a.account, err = a.loadAccount(a) if err != nil { return err } @@ -247,42 +250,42 @@ func (a *ACME) CreateConfig(tlsConfig *tls.Config, CheckOnDemandDomain func(doma if err != nil { return err } - account = &Account{ + a.account = &Account{ Email: a.Email, PrivateKey: x509.MarshalPKCS1PrivateKey(privateKey), } - account.DomainsCertificate = DomainsCertificates{Certs: []*DomainsCertificate{}, lock: &sync.RWMutex{}} + a.account.DomainsCertificate = DomainsCertificates{Certs: []*DomainsCertificate{}, lock: &sync.RWMutex{}} needRegister = true } - client, err := a.buildACMEClient(account) + a.client, err = a.buildACMEClient() if err != nil { return err } - client.ExcludeChallenges([]acme.Challenge{acme.HTTP01, acme.DNS01}) + a.client.ExcludeChallenges([]acme.Challenge{acme.HTTP01, acme.DNS01}) wrapperChallengeProvider := newWrapperChallengeProvider() - client.SetChallengeProvider(acme.TLSSNI01, wrapperChallengeProvider) + a.client.SetChallengeProvider(acme.TLSSNI01, wrapperChallengeProvider) if needRegister { // New users will need to register; be sure to save it - reg, err := client.Register() + reg, err := a.client.Register() if err != nil { return err } - account.Registration = reg + a.account.Registration = reg } // The client has a URL to the current Let's Encrypt Subscriber // Agreement. The user will need to agree to it. - err = client.AgreeToTOS() + err = a.client.AgreeToTOS() if err != nil { return err } safe.Go(func() { - a.retrieveCertificates(client, account) - if err := a.renewCertificates(client, account); err != nil { - log.Errorf("Error renewing ACME certificate %+v: %s", account, err.Error()) + a.retrieveCertificates(a.client) + if err := a.renewCertificates(a.client); err != nil { + log.Errorf("Error renewing ACME certificate %+v: %s", a.account, err.Error()) } }) @@ -290,14 +293,14 @@ func (a *ACME) CreateConfig(tlsConfig *tls.Config, CheckOnDemandDomain func(doma if challengeCert, ok := wrapperChallengeProvider.getCertificate(clientHello.ServerName); ok { return challengeCert, nil } - if domainCert, ok := account.DomainsCertificate.getCertificateForDomain(clientHello.ServerName); ok { + if domainCert, ok := a.account.DomainsCertificate.getCertificateForDomain(clientHello.ServerName); ok { return domainCert.tlsCert, nil } if a.OnDemand { if CheckOnDemandDomain != nil && !CheckOnDemandDomain(clientHello.ServerName) { return nil, nil } - return a.loadCertificateOnDemand(client, account, clientHello) + return a.loadCertificateOnDemand(clientHello) } return nil, nil } @@ -307,8 +310,8 @@ func (a *ACME) CreateConfig(tlsConfig *tls.Config, CheckOnDemandDomain func(doma for { select { case <-ticker.C: - if err := a.renewCertificates(client, account); err != nil { - log.Errorf("Error renewing ACME certificate %+v: %s", account, err.Error()) + if err := a.renewCertificates(a.client); err != nil { + log.Errorf("Error renewing ACME certificate %+v: %s", a.account, err.Error()) } } } @@ -317,11 +320,11 @@ func (a *ACME) CreateConfig(tlsConfig *tls.Config, CheckOnDemandDomain func(doma return nil } -func (a *ACME) retrieveCertificates(client *acme.Client, account *Account) { +func (a *ACME) retrieveCertificates(client *acme.Client) { log.Infof("Retrieving ACME certificates...") for _, domain := range a.Domains { // check if cert isn't already loaded - if _, exists := account.DomainsCertificate.exists(domain); !exists { + if _, exists := a.account.DomainsCertificate.exists(domain); !exists { domains := []string{} domains = append(domains, domain.Main) domains = append(domains, domain.SANs...) @@ -330,13 +333,13 @@ func (a *ACME) retrieveCertificates(client *acme.Client, account *Account) { log.Errorf("Error getting ACME certificate for domain %s: %s", domains, err.Error()) continue } - _, err = account.DomainsCertificate.addCertificateForDomains(certificateResource, domain) + _, err = a.account.DomainsCertificate.addCertificateForDomains(certificateResource, domain) if err != nil { log.Errorf("Error adding ACME certificate for domain %s: %s", domains, err.Error()) continue } - if err = a.saveAccount(account); err != nil { - log.Errorf("Error Saving ACME account %+v: %s", account, err.Error()) + if err = a.saveAccount(); err != nil { + log.Errorf("Error Saving ACME account %+v: %s", a.account, err.Error()) continue } } @@ -344,9 +347,9 @@ func (a *ACME) retrieveCertificates(client *acme.Client, account *Account) { log.Infof("Retrieved ACME certificates") } -func (a *ACME) renewCertificates(client *acme.Client, account *Account) error { +func (a *ACME) renewCertificates(client *acme.Client) error { log.Debugf("Testing certificate renew...") - for _, certificateResource := range account.DomainsCertificate.Certs { + for _, certificateResource := range a.account.DomainsCertificate.Certs { if certificateResource.needRenew() { log.Debugf("Renewing certificate %+v", certificateResource.Domains) renewedCert, err := client.RenewCertificate(acme.CertificateResource{ @@ -368,12 +371,12 @@ func (a *ACME) renewCertificates(client *acme.Client, account *Account) error { PrivateKey: renewedCert.PrivateKey, Certificate: renewedCert.Certificate, } - err = account.DomainsCertificate.renewCertificates(renewedACMECert, certificateResource.Domains) + err = a.account.DomainsCertificate.renewCertificates(renewedACMECert, certificateResource.Domains) if err != nil { log.Errorf("Error renewing certificate: %v", err) continue } - if err = a.saveAccount(account); err != nil { + if err = a.saveAccount(); err != nil { log.Errorf("Error saving ACME account: %v", err) continue } @@ -382,12 +385,12 @@ func (a *ACME) renewCertificates(client *acme.Client, account *Account) error { return nil } -func (a *ACME) buildACMEClient(Account *Account) (*acme.Client, error) { +func (a *ACME) buildACMEClient() (*acme.Client, error) { caServer := "https://acme-v01.api.letsencrypt.org/directory" if len(a.CAServer) > 0 { caServer = a.CAServer } - client, err := acme.NewClient(caServer, Account, acme.RSA4096) + client, err := acme.NewClient(caServer, a.account, acme.RSA4096) if err != nil { return nil, err } @@ -395,25 +398,60 @@ func (a *ACME) buildACMEClient(Account *Account) (*acme.Client, error) { return client, nil } -func (a *ACME) loadCertificateOnDemand(client *acme.Client, Account *Account, clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - if certificateResource, ok := Account.DomainsCertificate.getCertificateForDomain(clientHello.ServerName); ok { +func (a *ACME) loadCertificateOnDemand(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + if certificateResource, ok := a.account.DomainsCertificate.getCertificateForDomain(clientHello.ServerName); ok { return certificateResource.tlsCert, nil } - Certificate, err := a.getDomainsCertificates(client, []string{clientHello.ServerName}) + certificate, err := a.getDomainsCertificates(a.client, []string{clientHello.ServerName}) if err != nil { return nil, err } log.Debugf("Got certificate on demand for domain %s", clientHello.ServerName) - cert, err := Account.DomainsCertificate.addCertificateForDomains(Certificate, Domain{Main: clientHello.ServerName}) + cert, err := a.account.DomainsCertificate.addCertificateForDomains(certificate, Domain{Main: clientHello.ServerName}) if err != nil { return nil, err } - if err = a.saveAccount(Account); err != nil { + if err = a.saveAccount(); err != nil { return nil, err } return cert.tlsCert, nil } +// LoadCertificateForDomains loads certificates from ACME for given domains +func (a *ACME) LoadCertificateForDomains(domains []string) { + safe.Go(func() { + var domain Domain + if len(domains) == 0 { + // no domain + return + + } else if len(domains) > 1 { + domain = Domain{Main: domains[0], SANs: domains[1:]} + } else { + domain = Domain{Main: domains[0]} + } + if _, exists := a.account.DomainsCertificate.exists(domain); exists { + // domain already exists + return + } + certificate, err := a.getDomainsCertificates(a.client, domains) + if err != nil { + log.Errorf("Error getting ACME certificates %+v : %v", domains, err) + return + } + log.Debugf("Got certificate for domains %+v", domains) + _, err = a.account.DomainsCertificate.addCertificateForDomains(certificate, domain) + if err != nil { + log.Errorf("Error adding ACME certificates %+v : %v", domains, err) + return + } + if err = a.saveAccount(); err != nil { + log.Errorf("Error Saving ACME account %+v: %v", a.account, err) + return + } + }) +} + func (a *ACME) loadAccount(acmeConfig *ACME) (*Account, error) { a.storageLock.RLock() defer a.storageLock.RUnlock() @@ -435,11 +473,11 @@ func (a *ACME) loadAccount(acmeConfig *ACME) (*Account, error) { return &Account, nil } -func (a *ACME) saveAccount(Account *Account) error { +func (a *ACME) saveAccount() error { a.storageLock.Lock() defer a.storageLock.Unlock() // write account to file - data, err := json.MarshalIndent(Account, "", " ") + data, err := json.MarshalIndent(a.account, "", " ") if err != nil { return err } diff --git a/docs/toml.md b/docs/toml.md index b184d5b46..b417a9e42 100644 --- a/docs/toml.md +++ b/docs/toml.md @@ -206,6 +206,13 @@ entryPoint = "https" # # onDemand = true +# Enable certificate generation on frontends Host rules. This will request a certificate from Let's Encrypt for each frontend with a Host rule. +# For example, a rule Host:test1.traefik.io,test2.traefik.io will request a certificate with main domain test1.traefik.io and SAN test2.traefik.io. +# +# Optional +# +# OnHostRule = true + # CA server to use # Uncomment the line to run on the staging let's encrypt server # Leave comment to go to prod diff --git a/rules.go b/rules.go index f0c03409d..6527c7c0e 100644 --- a/rules.go +++ b/rules.go @@ -2,6 +2,7 @@ package main import ( "errors" + "fmt" "github.com/containous/mux" "net" "net/http" @@ -93,8 +94,7 @@ func (r *Rules) headersRegexp(headers ...string) *mux.Route { return r.route.route.HeadersRegexp(headers...) } -// Parse parses rules expressions -func (r *Rules) Parse(expression string) (*mux.Route, error) { +func (r *Rules) parseRules(expression string, onRule func(functionName string, function interface{}, arguments []string) error) error { functions := map[string]interface{}{ "Host": r.host, "HostRegexp": r.hostRegexp, @@ -108,7 +108,7 @@ func (r *Rules) Parse(expression string) (*mux.Route, error) { } if len(expression) == 0 { - return nil, errors.New("Empty rule") + return errors.New("Empty rule") } f := func(c rune) bool { @@ -122,17 +122,16 @@ func (r *Rules) Parse(expression string) (*mux.Route, error) { parsedRules := strings.FieldsFunc(expression, splitRule) - var resultRoute *mux.Route - for _, rule := range parsedRules { // get function parsedFunctions := strings.FieldsFunc(rule, f) if len(parsedFunctions) == 0 { - return nil, errors.New("Error parsing rule: '" + rule + "'") + return errors.New("Error parsing rule: '" + rule + "'") } - parsedFunction, ok := functions[strings.TrimSpace(parsedFunctions[0])] + functionName := strings.TrimSpace(parsedFunctions[0]) + parsedFunction, ok := functions[functionName] if !ok { - return nil, errors.New("Error parsing rule: '" + rule + "'. Unknown function: '" + parsedFunctions[0] + "'") + return errors.New("Error parsing rule: '" + rule + "'. Unknown function: '" + parsedFunctions[0] + "'") } parsedFunctions = append(parsedFunctions[:0], parsedFunctions[1:]...) fargs := func(c rune) bool { @@ -141,26 +140,62 @@ func (r *Rules) Parse(expression string) (*mux.Route, error) { // get function parsedArgs := strings.FieldsFunc(strings.Join(parsedFunctions, ":"), fargs) if len(parsedArgs) == 0 { - return nil, errors.New("Error parsing args from rule: '" + rule + "'") + return errors.New("Error parsing args from rule: '" + rule + "'") } - inputs := make([]reflect.Value, len(parsedArgs)) for i := range parsedArgs { - inputs[i] = reflect.ValueOf(strings.TrimSpace(parsedArgs[i])) + parsedArgs[i] = strings.TrimSpace(parsedArgs[i]) } - method := reflect.ValueOf(parsedFunction) + + err := onRule(functionName, parsedFunction, parsedArgs) + if err != nil { + return fmt.Errorf("Parsing error on rule:", err) + } + } + return nil + +} + +// Parse parses rules expressions +func (r *Rules) Parse(expression string) (*mux.Route, error) { + var resultRoute *mux.Route + err := r.parseRules(expression, func(functionName string, function interface{}, arguments []string) error { + inputs := make([]reflect.Value, len(arguments)) + for i := range arguments { + inputs[i] = reflect.ValueOf(arguments[i]) + } + method := reflect.ValueOf(function) if method.IsValid() { resultRoute = method.Call(inputs)[0].Interface().(*mux.Route) if r.err != nil { - return nil, r.err + return r.err } if resultRoute.GetError() != nil { - return nil, resultRoute.GetError() + return resultRoute.GetError() } } else { - return nil, errors.New("Method not found: '" + parsedFunctions[0] + "'") + return errors.New("Method not found: '" + functionName + "'") } + return nil + }) + if err != nil { + return nil, fmt.Errorf("Error parsing rule:", err) } return resultRoute, nil } + +// ParseDomains parses rules expressions and returns domains +func (r *Rules) ParseDomains(expression string) ([]string, error) { + domains := []string{} + err := r.parseRules(expression, func(functionName string, function interface{}, arguments []string) error { + if functionName == "Host" { + domains = append(domains, arguments...) + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("Error parsing domains:", err) + } + return domains, nil +} diff --git a/rules_test.go b/rules_test.go index 96c41b475..2bb89a347 100644 --- a/rules_test.go +++ b/rules_test.go @@ -4,11 +4,11 @@ import ( "github.com/containous/mux" "net/http" "net/url" + "reflect" "testing" ) func TestParseOneRule(t *testing.T) { - router := mux.NewRouter() route := router.NewRoute() serverRoute := &serverRoute{route: route} @@ -31,7 +31,6 @@ func TestParseOneRule(t *testing.T) { } func TestParseTwoRules(t *testing.T) { - router := mux.NewRouter() route := router.NewRoute() serverRoute := &serverRoute{route: route} @@ -53,6 +52,29 @@ func TestParseTwoRules(t *testing.T) { } } +func TestParseDomains(t *testing.T) { + rules := &Rules{} + expressionsSlice := []string{ + "Host:foo.bar,test.bar", + "Path:/test", + "Host:foo.bar;Path:/test", + } + domainsSlice := [][]string{ + {"foo.bar", "test.bar"}, + {}, + {"foo.bar"}, + } + for i, expression := range expressionsSlice { + domains, err := rules.ParseDomains(expression) + if err != nil { + t.Fatalf("Error while parsing domains: %v", err) + } + if !reflect.DeepEqual(domains, domainsSlice[i]) { + t.Fatalf("Error parsing domains: expected %+v, got %+v", domainsSlice[i], domains) + } + } +} + func TestPriorites(t *testing.T) { router := mux.NewRouter() router.StrictSlash(true) diff --git a/server.go b/server.go index 6e93fa2cf..3b2ad3e8b 100644 --- a/server.go +++ b/server.go @@ -244,6 +244,7 @@ func (server *Server) listenConfigurations(stop chan bool) { log.Infof("Server configuration reloaded on %s", server.serverEntryPoints[newServerEntryPointName].httpServer.Addr) } server.currentConfigurations.Set(newConfigurations) + server.postLoadConfig() } else { log.Error("Error loading new configuration, aborted ", err) } @@ -251,6 +252,26 @@ func (server *Server) listenConfigurations(stop chan bool) { } } +func (server *Server) postLoadConfig() { + if server.globalConfiguration.ACME != nil && server.globalConfiguration.ACME.OnHostRule { + currentConfigurations := server.currentConfigurations.Get().(configs) + for _, configuration := range currentConfigurations { + for _, frontend := range configuration.Frontends { + for _, route := range frontend.Routes { + rules := Rules{} + domains, err := rules.ParseDomains(route.Rule) + if err != nil { + log.Errorf("Error parsing domains: %v", err) + } else { + server.globalConfiguration.ACME.LoadCertificateForDomains(domains) + } + } + + } + } + } +} + func (server *Server) configureProviders() { // configure providers if server.globalConfiguration.Docker != nil { diff --git a/traefik.sample.toml b/traefik.sample.toml index 05a396874..54aa2721c 100644 --- a/traefik.sample.toml +++ b/traefik.sample.toml @@ -96,6 +96,13 @@ # # onDemand = true +# Enable certificate generation on frontends Host rules. This will request a certificate from Let's Encrypt for each frontend with a Host rule. +# For example, a rule Host:test1.traefik.io,test2.traefik.io will request a certificate with main domain test1.traefik.io and SAN test2.traefik.io. +# +# Optional +# +# OnHostRule = true + # CA server to use # Uncomment the line to run on the staging let's encrypt server # Leave comment to go to prod