traefik/pkg/server/service/loadbalancer/mirror/mirror.go
2019-08-29 01:28:05 -07:00

112 lines
2.6 KiB
Go

package mirror
import (
"bufio"
"context"
"errors"
"net"
"net/http"
"sync"
"github.com/containous/traefik/v2/pkg/safe"
)
// Mirroring is an http.Handler that can mirror requests.
type Mirroring struct {
handler http.Handler
mirrorHandlers []*mirrorHandler
rw http.ResponseWriter
routinePool *safe.Pool
lock sync.RWMutex
total uint64
}
// New returns a new instance of *Mirroring.
func New(handler http.Handler, pool *safe.Pool) *Mirroring {
return &Mirroring{
routinePool: pool,
handler: handler,
rw: blackholeResponseWriter{},
}
}
func (m *Mirroring) inc() uint64 {
m.lock.Lock()
defer m.lock.Unlock()
m.total++
return m.total
}
type mirrorHandler struct {
http.Handler
percent int
lock sync.RWMutex
count uint64
}
func (m *Mirroring) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
m.handler.ServeHTTP(rw, req)
select {
case <-req.Context().Done():
// No mirroring if request has been canceled during main handler ServeHTTP
return
default:
}
m.routinePool.GoCtx(func(_ context.Context) {
total := m.inc()
for _, handler := range m.mirrorHandlers {
handler.lock.Lock()
if handler.count*100 < total*uint64(handler.percent) {
handler.count++
handler.lock.Unlock()
// When a request served by m.handler is successful, req.Context will be cancelled,
// which would trigger a cancellation of the ongoing mirrored requests.
// Therefore, we give a new, non-cancellable context to each of the mirrored calls,
// so they can terminate by themselves.
handler.ServeHTTP(m.rw, req.WithContext(contextStopPropagation{req.Context()}))
} else {
handler.lock.Unlock()
}
}
})
}
// AddMirror adds an httpHandler to mirror to.
func (m *Mirroring) AddMirror(handler http.Handler, percent int) error {
if percent < 0 || percent > 100 {
return errors.New("percent must be between 0 and 100")
}
m.mirrorHandlers = append(m.mirrorHandlers, &mirrorHandler{Handler: handler, percent: percent})
return nil
}
type blackholeResponseWriter struct{}
func (b blackholeResponseWriter) Flush() {}
func (b blackholeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, errors.New("you can hijack connection on blackholeResponseWriter")
}
func (b blackholeResponseWriter) Header() http.Header {
return http.Header{}
}
func (b blackholeResponseWriter) Write(bytes []byte) (int, error) {
return len(bytes), nil
}
func (b blackholeResponseWriter) WriteHeader(statusCode int) {}
type contextStopPropagation struct {
context.Context
}
func (c contextStopPropagation) Done() <-chan struct{} {
return make(chan struct{})
}