diff --git a/pkg/provider/acme/provider.go b/pkg/provider/acme/provider.go index 5a4278320..e0fddd06f 100644 --- a/pkg/provider/acme/provider.go +++ b/pkg/provider/acme/provider.go @@ -179,12 +179,12 @@ func (p *Provider) Provide(configurationChan chan<- dynamic.Message, pool *safe. p.renewCertificates(ctx) ticker := time.NewTicker(24 * time.Hour) - pool.Go(func(stop chan bool) { + pool.GoCtx(func(ctxPool context.Context) { for { select { case <-ticker.C: p.renewCertificates(ctx) - case <-stop: + case <-ctxPool.Done(): ticker.Stop() return } @@ -341,7 +341,7 @@ func (p *Provider) resolveDomains(ctx context.Context, domains []string, tlsStor } func (p *Provider) watchNewDomains(ctx context.Context) { - p.pool.Go(func(stop chan bool) { + p.pool.GoCtx(func(ctxPool context.Context) { for { select { case config := <-p.configFromListenerChan: @@ -415,7 +415,7 @@ func (p *Provider) watchNewDomains(ctx context.Context) { p.resolveDomains(ctxRouter, domains, tlsStore) } } - case <-stop: + case <-ctxPool.Done(): return } } @@ -556,7 +556,7 @@ func deleteUnnecessaryDomains(ctx context.Context, domains []types.Domain) []typ func (p *Provider) watchCertificate(ctx context.Context) { p.certsChan = make(chan *CertAndStore) - p.pool.Go(func(stop chan bool) { + p.pool.GoCtx(func(ctxPool context.Context) { for { select { case cert := <-p.certsChan: @@ -576,7 +576,7 @@ func (p *Provider) watchCertificate(ctx context.Context) { if err != nil { log.FromContext(ctx).Error(err) } - case <-stop: + case <-ctxPool.Done(): return } } diff --git a/pkg/provider/file/file.go b/pkg/provider/file/file.go index 862b9e1a2..3902ad12c 100644 --- a/pkg/provider/file/file.go +++ b/pkg/provider/file/file.go @@ -103,11 +103,11 @@ func (p *Provider) addWatcher(pool *safe.Pool, directory string, configurationCh } // Process events - pool.Go(func(stop chan bool) { + pool.GoCtx(func(ctx context.Context) { defer watcher.Close() for { select { - case <-stop: + case <-ctx.Done(): return case evt := <-watcher.Events: if p.Directory == "" { diff --git a/pkg/provider/kubernetes/crd/kubernetes.go b/pkg/provider/kubernetes/crd/kubernetes.go index 6b9c6f332..c4ff55019 100644 --- a/pkg/provider/kubernetes/crd/kubernetes.go +++ b/pkg/provider/kubernetes/crd/kubernetes.go @@ -98,11 +98,9 @@ func (p *Provider) Provide(configurationChan chan<- dynamic.Message, pool *safe. return err } - pool.Go(func(stop chan bool) { + pool.GoCtx(func(ctxPool context.Context) { operation := func() error { - stopWatch := make(chan struct{}, 1) - defer close(stopWatch) - eventsChan, err := k8sClient.WatchAll(p.Namespaces, stopWatch) + eventsChan, err := k8sClient.WatchAll(p.Namespaces, ctxPool.Done()) if err != nil { logger.Errorf("Error watching kubernetes events: %v", err) @@ -110,20 +108,20 @@ func (p *Provider) Provide(configurationChan chan<- dynamic.Message, pool *safe. select { case <-timer.C: return err - case <-stop: + case <-ctxPool.Done(): return nil } } throttleDuration := time.Duration(p.ThrottleDuration) - throttledChan := throttleEvents(ctxLog, throttleDuration, stop, eventsChan) + throttledChan := throttleEvents(ctxLog, throttleDuration, pool, eventsChan) if throttledChan != nil { eventsChan = throttledChan } for { select { - case <-stop: + case <-ctxPool.Done(): return nil case event := <-eventsChan: // Note that event is the *first* event that came in during this throttling interval -- if we're hitting our throttle, we may have dropped events. @@ -156,7 +154,7 @@ func (p *Provider) Provide(configurationChan chan<- dynamic.Message, pool *safe. notify := func(err error, time time.Duration) { logger.Errorf("Provider connection error: %v; retrying in %s", err, time) } - err := backoff.RetryNotify(safe.OperationWithRecover(operation), job.NewBackOff(backoff.NewExponentialBackOff()), notify) + err := backoff.RetryNotify(safe.OperationWithRecover(operation), backoff.WithContext(job.NewBackOff(backoff.NewExponentialBackOff()), ctxPool), notify) if err != nil { logger.Errorf("Cannot connect to Provider: %v", err) } @@ -625,7 +623,7 @@ func getCABlocks(secret *corev1.Secret, namespace, secretName string) (string, e return cert, nil } -func throttleEvents(ctx context.Context, throttleDuration time.Duration, stop chan bool, eventsChan <-chan interface{}) chan interface{} { +func throttleEvents(ctx context.Context, throttleDuration time.Duration, pool *safe.Pool, eventsChan <-chan interface{}) chan interface{} { if throttleDuration == 0 { return nil } @@ -635,10 +633,10 @@ func throttleEvents(ctx context.Context, throttleDuration time.Duration, stop ch // Run a goroutine that reads events from eventChan and does a non-blocking write to pendingEvent. // This guarantees that writing to eventChan will never block, // and that pendingEvent will have something in it if there's been an event since we read from that channel. - go func() { + pool.GoCtx(func(ctxPool context.Context) { for { select { - case <-stop: + case <-ctxPool.Done(): return case nextEvent := <-eventsChan: select { @@ -650,7 +648,7 @@ func throttleEvents(ctx context.Context, throttleDuration time.Duration, stop ch } } } - }() + }) return eventsChanBuffered } diff --git a/pkg/provider/kubernetes/ingress/kubernetes.go b/pkg/provider/kubernetes/ingress/kubernetes.go index 03902a62f..1f1d5184f 100644 --- a/pkg/provider/kubernetes/ingress/kubernetes.go +++ b/pkg/provider/kubernetes/ingress/kubernetes.go @@ -104,32 +104,29 @@ func (p *Provider) Provide(configurationChan chan<- dynamic.Message, pool *safe. return err } - pool.Go(func(stop chan bool) { + pool.GoCtx(func(ctxPool context.Context) { operation := func() error { - stopWatch := make(chan struct{}, 1) - defer close(stopWatch) - - eventsChan, err := k8sClient.WatchAll(p.Namespaces, stopWatch) + eventsChan, err := k8sClient.WatchAll(p.Namespaces, ctxPool.Done()) if err != nil { logger.Errorf("Error watching kubernetes events: %v", err) timer := time.NewTimer(1 * time.Second) select { case <-timer.C: return err - case <-stop: + case <-ctxPool.Done(): return nil } } throttleDuration := time.Duration(p.ThrottleDuration) - throttledChan := throttleEvents(ctxLog, throttleDuration, stop, eventsChan) + throttledChan := throttleEvents(ctxLog, throttleDuration, pool, eventsChan) if throttledChan != nil { eventsChan = throttledChan } for { select { - case <-stop: + case <-ctxPool.Done(): return nil case event := <-eventsChan: // Note that event is the *first* event that came in during this @@ -164,7 +161,8 @@ func (p *Provider) Provide(configurationChan chan<- dynamic.Message, pool *safe. notify := func(err error, time time.Duration) { logger.Errorf("Provider connection error: %s; retrying in %s", err, time) } - err := backoff.RetryNotify(safe.OperationWithRecover(operation), job.NewBackOff(backoff.NewExponentialBackOff()), notify) + + err := backoff.RetryNotify(safe.OperationWithRecover(operation), backoff.WithContext(job.NewBackOff(backoff.NewExponentialBackOff()), ctxPool), notify) if err != nil { logger.Errorf("Cannot connect to Provider: %s", err) } @@ -517,7 +515,7 @@ func (p *Provider) updateIngressStatus(i *v1beta1.Ingress, k8sClient Client) err return k8sClient.UpdateIngressStatus(i.Namespace, i.Name, service.Status.LoadBalancer.Ingress[0].IP, service.Status.LoadBalancer.Ingress[0].Hostname) } -func throttleEvents(ctx context.Context, throttleDuration time.Duration, stop chan bool, eventsChan <-chan interface{}) chan interface{} { +func throttleEvents(ctx context.Context, throttleDuration time.Duration, pool *safe.Pool, eventsChan <-chan interface{}) chan interface{} { if throttleDuration == 0 { return nil } @@ -528,10 +526,10 @@ func throttleEvents(ctx context.Context, throttleDuration time.Duration, stop ch // non-blocking write to pendingEvent. This guarantees that writing to // eventChan will never block, and that pendingEvent will have // something in it if there's been an event since we read from that channel. - go func() { + pool.GoCtx(func(ctxPool context.Context) { for { select { - case <-stop: + case <-ctxPool.Done(): return case nextEvent := <-eventsChan: select { @@ -545,7 +543,7 @@ func throttleEvents(ctx context.Context, throttleDuration time.Duration, stop ch } } } - }() + }) return eventsChanBuffered } diff --git a/pkg/provider/marathon/marathon.go b/pkg/provider/marathon/marathon.go index 78213f8e7..e1aabf0eb 100644 --- a/pkg/provider/marathon/marathon.go +++ b/pkg/provider/marathon/marathon.go @@ -159,11 +159,11 @@ func (p *Provider) Provide(configurationChan chan<- dynamic.Message, pool *safe. logger.Errorf("Failed to register for events, %s", err) return err } - pool.Go(func(stop chan bool) { + pool.GoCtx(func(ctxPool context.Context) { defer close(update) for { select { - case <-stop: + case <-ctxPool.Done(): return case event := <-update: logger.Debugf("Received provider event %s", event) diff --git a/pkg/safe/routine.go b/pkg/safe/routine.go index c1f81e5c3..a8de1e872 100644 --- a/pkg/safe/routine.go +++ b/pkg/safe/routine.go @@ -10,88 +10,37 @@ import ( "github.com/containous/traefik/v2/pkg/log" ) -type routine struct { - goroutine func(chan bool) - stop chan bool -} - type routineCtx func(ctx context.Context) // Pool is a pool of go routines type Pool struct { - routines []routine - waitGroup sync.WaitGroup - lock sync.Mutex - baseCtx context.Context - baseCancel context.CancelFunc - ctx context.Context - cancel context.CancelFunc + waitGroup sync.WaitGroup + ctx context.Context + cancel context.CancelFunc } // NewPool creates a Pool func NewPool(parentCtx context.Context) *Pool { - baseCtx, baseCancel := context.WithCancel(parentCtx) - ctx, cancel := context.WithCancel(baseCtx) + ctx, cancel := context.WithCancel(parentCtx) return &Pool{ - baseCtx: baseCtx, - baseCancel: baseCancel, - ctx: ctx, - cancel: cancel, + ctx: ctx, + cancel: cancel, } } -// Ctx returns main context -func (p *Pool) Ctx() context.Context { - return p.baseCtx -} - // GoCtx starts a recoverable goroutine with a context func (p *Pool) GoCtx(goroutine routineCtx) { - p.lock.Lock() p.waitGroup.Add(1) Go(func() { defer p.waitGroup.Done() goroutine(p.ctx) }) - p.lock.Unlock() -} - -// Go starts a recoverable goroutine, and can be stopped with stop chan -func (p *Pool) Go(goroutine func(stop chan bool)) { - p.lock.Lock() - newRoutine := routine{ - goroutine: goroutine, - stop: make(chan bool, 1), - } - p.routines = append(p.routines, newRoutine) - p.waitGroup.Add(1) - Go(func() { - defer p.waitGroup.Done() - goroutine(newRoutine.stop) - }) - p.lock.Unlock() } // Stop stops all started routines, waiting for their termination func (p *Pool) Stop() { - p.lock.Lock() - defer p.lock.Unlock() p.cancel() - for _, routine := range p.routines { - routine.stop <- true - } p.waitGroup.Wait() - for _, routine := range p.routines { - close(routine.stop) - } -} - -// Cleanup releases resources used by the pool, and should be called when the pool will no longer be used -func (p *Pool) Cleanup() { - p.Stop() - p.lock.Lock() - defer p.lock.Unlock() - p.baseCancel() } // Go starts a recoverable goroutine diff --git a/pkg/safe/routine_test.go b/pkg/safe/routine_test.go index caeef93ca..85d87bf10 100644 --- a/pkg/safe/routine_test.go +++ b/pkg/safe/routine_test.go @@ -18,12 +18,13 @@ func TestNewPoolContext(t *testing.T) { ctx := context.WithValue(context.Background(), testKey, "test") p := NewPool(ctx) - retCtx := p.Ctx() - - retCtxVal, ok := retCtx.Value(testKey).(string) - if !ok || retCtxVal != "test" { - t.Errorf("Pool.Ctx() did not return a derived context, got %#v, expected context with test value", retCtx) - } + p.GoCtx(func(ctx context.Context) { + retCtxVal, ok := ctx.Value(testKey).(string) + if !ok || retCtxVal != "test" { + t.Errorf("Pool.Ctx() did not return a derived context, got %#v, expected context with test value", ctx) + } + }) + p.Stop() } type fakeRoutine struct { @@ -46,14 +47,6 @@ func (tr *fakeRoutine) routineCtx(ctx context.Context) { <-ctx.Done() } -func (tr *fakeRoutine) routine(stop chan bool) { - tr.Lock() - tr.started = true - tr.Unlock() - tr.startSig <- true - <-stop -} - func TestPoolWithCtx(t *testing.T) { testRoutine := newFakeRoutine() @@ -79,12 +72,12 @@ func TestPoolWithCtx(t *testing.T) { defer timer.Stop() test.fn(p) - defer p.Cleanup() + defer p.Stop() testDone := make(chan bool, 1) go func() { <-testRoutine.startSig - p.Cleanup() + p.Stop() testDone <- true }() @@ -100,89 +93,30 @@ func TestPoolWithCtx(t *testing.T) { } } -func TestPoolWithStopChan(t *testing.T) { - testRoutine := newFakeRoutine() - +func TestPoolCleanupWithGoPanicking(t *testing.T) { p := NewPool(context.Background()) timer := time.NewTimer(500 * time.Millisecond) defer timer.Stop() - p.Go(testRoutine.routine) - if len(p.routines) != 1 { - t.Fatalf("After Pool.Go(func), Pool did have %d goroutines, expected 1", len(p.routines)) - } + p.GoCtx(func(ctx context.Context) { + panic("BOOM") + }) testDone := make(chan bool, 1) go func() { - <-testRoutine.startSig - p.Cleanup() + p.Stop() testDone <- true }() select { case <-timer.C: - testRoutine.Lock() - defer testRoutine.Unlock() - t.Fatalf("Pool test did not complete in time, goroutine started equals '%t'", testRoutine.started) + t.Fatalf("Pool.Cleanup() did not complete in time with a panicking goroutine") case <-testDone: return } } -func TestPoolCleanupWithGoPanicking(t *testing.T) { - testRoutine := func(stop chan bool) { - panic("BOOM") - } - - testCtxRoutine := func(ctx context.Context) { - panic("BOOM") - } - - testCases := []struct { - desc string - fn func(*Pool) - }{ - { - desc: "Go()", - fn: func(p *Pool) { - p.Go(testRoutine) - }, - }, - { - desc: "GoCtx()", - fn: func(p *Pool) { - p.GoCtx(testCtxRoutine) - }, - }, - } - - for _, test := range testCases { - test := test - t.Run(test.desc, func(t *testing.T) { - p := NewPool(context.Background()) - - timer := time.NewTimer(500 * time.Millisecond) - defer timer.Stop() - - test.fn(p) - - testDone := make(chan bool, 1) - go func() { - p.Cleanup() - testDone <- true - }() - - select { - case <-timer.C: - t.Fatalf("Pool.Cleanup() did not complete in time with a panicking goroutine") - case <-testDone: - return - } - }) - } -} - func TestGoroutineRecover(t *testing.T) { // if recover fails the test will panic Go(func() { diff --git a/pkg/server/configurationwatcher.go b/pkg/server/configurationwatcher.go index 06b07aedd..714069c5d 100644 --- a/pkg/server/configurationwatcher.go +++ b/pkg/server/configurationwatcher.go @@ -1,6 +1,7 @@ package server import ( + "context" "encoding/json" "reflect" "time" @@ -49,8 +50,8 @@ func NewConfigurationWatcher(routinesPool *safe.Pool, pvd provider.Provider, pro // Start the configuration watcher. func (c *ConfigurationWatcher) Start() { - c.routinesPool.Go(c.listenProviders) - c.routinesPool.Go(c.listenConfigurations) + c.routinesPool.GoCtx(c.listenProviders) + c.routinesPool.GoCtx(c.listenConfigurations) c.startProvider() } @@ -90,10 +91,10 @@ func (c *ConfigurationWatcher) startProvider() { // listenProviders receives configuration changes from the providers. // The configuration message then gets passed along a series of check // to finally end up in a throttler that sends it to listenConfigurations (through c. configurationValidatedChan). -func (c *ConfigurationWatcher) listenProviders(stop chan bool) { +func (c *ConfigurationWatcher) listenProviders(ctx context.Context) { for { select { - case <-stop: + case <-ctx.Done(): return case configMsg, ok := <-c.configurationChan: if !ok { @@ -111,10 +112,10 @@ func (c *ConfigurationWatcher) listenProviders(stop chan bool) { } } -func (c *ConfigurationWatcher) listenConfigurations(stop chan bool) { +func (c *ConfigurationWatcher) listenConfigurations(ctx context.Context) { for { select { - case <-stop: + case <-ctx.Done(): return case configMsg, ok := <-c.configurationValidatedChan: if !ok || configMsg.Configuration == nil { @@ -178,8 +179,8 @@ func (c *ConfigurationWatcher) preLoadConfiguration(configMsg dynamic.Message) { if !ok { providerConfigUpdateCh = make(chan dynamic.Message) c.providerConfigUpdateMap[configMsg.ProviderName] = providerConfigUpdateCh - c.routinesPool.Go(func(stop chan bool) { - c.throttleProviderConfigReload(c.providersThrottleDuration, c.configurationValidatedChan, providerConfigUpdateCh, stop) + c.routinesPool.GoCtx(func(ctxPool context.Context) { + c.throttleProviderConfigReload(ctxPool, c.providersThrottleDuration, c.configurationValidatedChan, providerConfigUpdateCh) }) } @@ -190,14 +191,14 @@ func (c *ConfigurationWatcher) preLoadConfiguration(configMsg dynamic.Message) { // It will immediately publish a new configuration and then only publish the next configuration after the throttle duration. // Note that in the case it receives N new configs in the timeframe of the throttle duration after publishing, // it will publish the last of the newly received configurations. -func (c *ConfigurationWatcher) throttleProviderConfigReload(throttle time.Duration, publish chan<- dynamic.Message, in <-chan dynamic.Message, stop chan bool) { +func (c *ConfigurationWatcher) throttleProviderConfigReload(ctx context.Context, throttle time.Duration, publish chan<- dynamic.Message, in <-chan dynamic.Message) { ring := channels.NewRingChannel(1) defer ring.Close() - c.routinesPool.Go(func(stop chan bool) { + c.routinesPool.GoCtx(func(ctxPool context.Context) { for { select { - case <-stop: + case <-ctxPool.Done(): return case nextConfig := <-ring.Out(): if config, ok := nextConfig.(dynamic.Message); ok { @@ -210,7 +211,7 @@ func (c *ConfigurationWatcher) throttleProviderConfigReload(throttle time.Durati for { select { - case <-stop: + case <-ctx.Done(): return case nextConfig := <-in: ring.In() <- nextConfig diff --git a/pkg/server/server.go b/pkg/server/server.go index 7e395a759..99950a62d 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -58,7 +58,7 @@ func (s *Server) Start(ctx context.Context) { s.tcpEntryPoints.Start() s.watcher.Start() - s.routinesPool.Go(s.listenSignals) + s.routinesPool.GoCtx(s.listenSignals) } // Wait blocks until the server shutdown. @@ -90,7 +90,7 @@ func (s *Server) Close() { stopMetricsClients() - s.routinesPool.Cleanup() + s.routinesPool.Stop() signal.Stop(s.signals) close(s.signals) diff --git a/pkg/server/server_signals.go b/pkg/server/server_signals.go index a1411dc65..e02702652 100644 --- a/pkg/server/server_signals.go +++ b/pkg/server/server_signals.go @@ -3,6 +3,7 @@ package server import ( + "context" "os/signal" "syscall" @@ -13,10 +14,10 @@ func (s *Server) configureSignals() { signal.Notify(s.signals, syscall.SIGUSR1) } -func (s *Server) listenSignals(stop chan bool) { +func (s *Server) listenSignals(ctx context.Context) { for { select { - case <-stop: + case <-ctx.Done(): return case sig := <-s.signals: if sig == syscall.SIGUSR1 { diff --git a/pkg/server/server_signals_windows.go b/pkg/server/server_signals_windows.go index 05cf4eace..91c14979d 100644 --- a/pkg/server/server_signals_windows.go +++ b/pkg/server/server_signals_windows.go @@ -2,6 +2,8 @@ package server +import "context" + func (s *Server) configureSignals() {} -func (s *Server) listenSignals(stop chan bool) {} +func (s *Server) listenSignals(ctx context.Context) {}