diff --git a/pkg/server/service/loadbalancer/wrr/wrr.go b/pkg/server/service/loadbalancer/wrr/wrr.go index a805c66a9..0dca68ee5 100644 --- a/pkg/server/service/loadbalancer/wrr/wrr.go +++ b/pkg/server/service/loadbalancer/wrr/wrr.go @@ -1,6 +1,7 @@ package wrr import ( + "container/heap" "fmt" "net/http" "sync" @@ -11,8 +12,9 @@ import ( type namedHandler struct { http.Handler - name string - weight int + name string + weight float64 + deadline float64 } type stickyCookie struct { @@ -23,10 +25,7 @@ type stickyCookie struct { // New creates a new load balancer. func New(sticky *dynamic.Sticky) *Balancer { - balancer := &Balancer{ - mutex: &sync.Mutex{}, - index: -1, - } + balancer := &Balancer{} if sticky != nil && sticky.Cookie != nil { balancer.stickyCookie = &stickyCookie{ name: sticky.Cookie.Name, @@ -37,43 +36,48 @@ func New(sticky *dynamic.Sticky) *Balancer { return balancer } -// Balancer is a WeightedRoundRobin load balancer. +// Len implements heap.Interface/sort.Interface. +func (b *Balancer) Len() int { return len(b.handlers) } + +// Less implements heap.Interface/sort.Interface. +func (b *Balancer) Less(i, j int) bool { + return b.handlers[i].deadline < b.handlers[j].deadline +} + +// Swap implements heap.Interface/sort.Interface. +func (b *Balancer) Swap(i, j int) { + b.handlers[i], b.handlers[j] = b.handlers[j], b.handlers[i] +} + +// Push implements heap.Interface for pushing an item into the heap. +func (b *Balancer) Push(x interface{}) { + h, ok := x.(*namedHandler) + if !ok { + return + } + + b.handlers = append(b.handlers, h) +} + +// Pop implements heap.Interface for poping an item from the heap. +// It panics if b.Len() < 1. +func (b *Balancer) Pop() interface{} { + h := b.handlers[len(b.handlers)-1] + b.handlers = b.handlers[0 : len(b.handlers)-1] + return h +} + +// Balancer is a WeightedRoundRobin load balancer based on Earliest Deadline First (EDF). +// (https://en.wikipedia.org/wiki/Earliest_deadline_first_scheduling) +// Each pick from the schedule has the earliest deadline entry selected. +// Entries have deadlines set at currentDeadline + 1 / weight, +// providing weighted round robin behavior with floating point weights and an O(log n) pick time. type Balancer struct { - handlers []*namedHandler - mutex *sync.Mutex - // Current index (starts from -1) - index int - currentWeight int - stickyCookie *stickyCookie -} + stickyCookie *stickyCookie -func (b *Balancer) maxWeight() int { - max := -1 - for _, s := range b.handlers { - if s.weight > max { - max = s.weight - } - } - return max -} - -func (b *Balancer) weightGcd() int { - divisor := -1 - for _, s := range b.handlers { - if divisor == -1 { - divisor = s.weight - } else { - divisor = gcd(divisor, s.weight) - } - } - return divisor -} - -func gcd(a, b int) int { - for b != 0 { - a, b = b, a%b - } - return a + mutex sync.RWMutex + handlers []*namedHandler + curDeadline float64 } func (b *Balancer) nextServer() (*namedHandler, error) { @@ -84,32 +88,17 @@ func (b *Balancer) nextServer() (*namedHandler, error) { return nil, fmt.Errorf("no servers in the pool") } - // The algo below may look messy, but is actually very simple - // it calculates the GCD and subtracts it on every iteration, what interleaves servers - // and allows us not to build an iterator every time we readjust weights + // Pick handler with closest deadline. + handler := heap.Pop(b).(*namedHandler) - // GCD across all enabled servers - gcd := b.weightGcd() - // Maximum weight across all enabled servers - max := b.maxWeight() + // curDeadline should be handler's deadline so that new added entry would have a fair competition environment with the old ones. + b.curDeadline = handler.deadline + handler.deadline += 1 / handler.weight - for { - b.index = (b.index + 1) % len(b.handlers) - if b.index == 0 { - b.currentWeight -= gcd - if b.currentWeight <= 0 { - b.currentWeight = max - if b.currentWeight == 0 { - return nil, fmt.Errorf("all servers have 0 weight") - } - } - } - srv := b.handlers[b.index] - if srv.weight >= b.currentWeight { - log.WithoutContext().Debugf("Service Select: %s", srv.name) - return srv, nil - } - } + heap.Push(b, handler) + + log.WithoutContext().Debugf("Service selected by WRR: %s", handler.name) + return handler, nil } func (b *Balancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -146,10 +135,22 @@ func (b *Balancer) ServeHTTP(w http.ResponseWriter, req *http.Request) { // AddService adds a handler. // It is not thread safe with ServeHTTP. +// A handler with a non-positive weight is ignored. func (b *Balancer) AddService(name string, handler http.Handler, weight *int) { w := 1 if weight != nil { w = *weight } - b.handlers = append(b.handlers, &namedHandler{Handler: handler, name: name, weight: w}) + if w <= 0 { // non-positive weight is meaningless + return + } + + h := &namedHandler{Handler: handler, name: name, weight: float64(w)} + + // use RWLock to protect b.curDeadline + b.mutex.RLock() + h.deadline = b.curDeadline + 1/h.weight + b.mutex.RUnlock() + + heap.Push(b, h) } diff --git a/pkg/server/service/loadbalancer/wrr/wrr_test.go b/pkg/server/service/loadbalancer/wrr/wrr_test.go index b6e2761fa..72771d8bf 100644 --- a/pkg/server/service/loadbalancer/wrr/wrr_test.go +++ b/pkg/server/service/loadbalancer/wrr/wrr_test.go @@ -13,11 +13,13 @@ func Int(v int) *int { return &v } type responseRecorder struct { *httptest.ResponseRecorder - save map[string]int + save map[string]int + sequence []string } func (r *responseRecorder) WriteHeader(statusCode int) { r.save[r.Header().Get("server")]++ + r.sequence = append(r.sequence, r.Header().Get("server")) r.ResponseRecorder.WriteHeader(statusCode) } @@ -112,3 +114,29 @@ func TestSticky(t *testing.T) { assert.Equal(t, 0, recorder.save["first"]) assert.Equal(t, 3, recorder.save["second"]) } + +// TestBalancerBias makes sure that the WRR algorithm spreads elements evenly right from the start, +// and that it does not "over-favor" the high-weighted ones with a biased start-up regime. +func TestBalancerBias(t *testing.T) { + balancer := New(nil) + + balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("server", "A") + rw.WriteHeader(http.StatusOK) + }), Int(11)) + + balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("server", "B") + rw.WriteHeader(http.StatusOK) + }), Int(3)) + + recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}} + + for i := 0; i < 14; i++ { + balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) + } + + wantSequence := []string{"A", "A", "A", "B", "A", "A", "A", "A", "B", "A", "A", "A", "B", "A"} + + assert.Equal(t, wantSequence, recorder.sequence) +}