diff --git a/server/server.go b/server/server.go index 5630ffc2a..ad7b64876 100644 --- a/server/server.go +++ b/server/server.go @@ -254,19 +254,8 @@ func (server *Server) defaultConfigurationValues(configuration *types.Configurat if configuration == nil || configuration.Frontends == nil { return } - for _, frontend := range configuration.Frontends { - // default endpoints if not defined in frontends - if len(frontend.EntryPoints) == 0 { - frontend.EntryPoints = server.globalConfiguration.DefaultEntryPoints - } - } - for backendName, backend := range configuration.Backends { - _, err := types.NewLoadBalancerMethod(backend.LoadBalancer) - if err != nil { - log.Debugf("Load balancer method '%+v' for backend %s: %v. Using default wrr.", backend.LoadBalancer, backendName, err) - backend.LoadBalancer = &types.LoadBalancer{Method: "wrr"} - } - } + server.configureFrontends(configuration.Frontends) + server.configureBackends(configuration.Backends) } func (server *Server) listenConfigurations(stop chan bool) { @@ -890,3 +879,29 @@ func sortedFrontendNamesForConfig(configuration *types.Configuration) []string { sort.Strings(keys) return keys } + +func (server *Server) configureFrontends(frontends map[string]*types.Frontend) { + for _, frontend := range frontends { + // default endpoints if not defined in frontends + if len(frontend.EntryPoints) == 0 { + frontend.EntryPoints = server.globalConfiguration.DefaultEntryPoints + } + } +} + +func (*Server) configureBackends(backends map[string]*types.Backend) { + for backendName, backend := range backends { + _, err := types.NewLoadBalancerMethod(backend.LoadBalancer) + if err != nil { + log.Debugf("Validation of load balancer method for backend %s failed: %s. Using default method wrr.", backendName, err) + var sticky bool + if backend.LoadBalancer != nil { + sticky = backend.LoadBalancer.Sticky + } + backend.LoadBalancer = &types.LoadBalancer{ + Method: "wrr", + Sticky: sticky, + } + } + } +} diff --git a/server/server_test.go b/server/server_test.go index 96c4daae6..cb4c2d6d9 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -13,6 +13,7 @@ import ( "github.com/containous/traefik/healthcheck" "github.com/containous/traefik/testhelpers" "github.com/containous/traefik/types" + "github.com/davecgh/go-spew/spew" "github.com/vulcand/oxy/roundrobin" ) @@ -277,3 +278,81 @@ func TestServerLoadConfigEmptyBasicAuth(t *testing.T) { t.Fatalf("got error: %s", err) } } + +func TestConfigureBackends(t *testing.T) { + validMethod := "Drr" + defaultMethod := "wrr" + + tests := []struct { + desc string + lb *types.LoadBalancer + wantMethod string + wantSticky bool + }{ + { + desc: "valid load balancer method with sticky enabled", + lb: &types.LoadBalancer{ + Method: validMethod, + Sticky: true, + }, + wantMethod: validMethod, + wantSticky: true, + }, + { + desc: "valid load balancer method with sticky disabled", + lb: &types.LoadBalancer{ + Method: validMethod, + Sticky: false, + }, + wantMethod: validMethod, + wantSticky: false, + }, + { + desc: "invalid load balancer method with sticky enabled", + lb: &types.LoadBalancer{ + Method: "Invalid", + Sticky: true, + }, + wantMethod: defaultMethod, + wantSticky: true, + }, + { + desc: "invalid load balancer method with sticky disabled", + lb: &types.LoadBalancer{ + Method: "Invalid", + Sticky: false, + }, + wantMethod: defaultMethod, + wantSticky: false, + }, + { + desc: "missing load balancer", + lb: nil, + wantMethod: defaultMethod, + wantSticky: false, + }, + } + + for _, test := range tests { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + backend := &types.Backend{ + LoadBalancer: test.lb, + } + + srv := Server{} + srv.configureBackends(map[string]*types.Backend{ + "backend": backend, + }) + + wantLB := types.LoadBalancer{ + Method: test.wantMethod, + Sticky: test.wantSticky, + } + if !reflect.DeepEqual(*backend.LoadBalancer, wantLB) { + t.Errorf("got backend load-balancer\n%v\nwant\n%v\n", spew.Sdump(backend.LoadBalancer), spew.Sdump(wantLB)) + } + }) + } +} diff --git a/types/types.go b/types/types.go index 05e704adc..3831847e0 100644 --- a/types/types.go +++ b/types/types.go @@ -81,19 +81,18 @@ var loadBalancerMethodNames = []string{ // NewLoadBalancerMethod create a new LoadBalancerMethod from a given LoadBalancer. func NewLoadBalancerMethod(loadBalancer *LoadBalancer) (LoadBalancerMethod, error) { + var method string if loadBalancer != nil { + method = loadBalancer.Method for i, name := range loadBalancerMethodNames { - if strings.EqualFold(name, loadBalancer.Method) { + if strings.EqualFold(name, method) { return LoadBalancerMethod(i), nil } } } - return Wrr, ErrInvalidLoadBalancerMethod + return Wrr, fmt.Errorf("invalid load-balancing method '%s'", method) } -// ErrInvalidLoadBalancerMethod is thrown when the specified load balancing method is invalid. -var ErrInvalidLoadBalancerMethod = errors.New("Invalid method, using default") - // Configuration of a provider. type Configuration struct { Backends map[string]*Backend `json:"backends,omitempty"`