package ipwhitelist import ( "context" "errors" "fmt" "net/http" "github.com/rs/zerolog/log" "github.com/traefik/traefik/v3/pkg/config/dynamic" "github.com/traefik/traefik/v3/pkg/ip" "github.com/traefik/traefik/v3/pkg/middlewares" "github.com/traefik/traefik/v3/pkg/tracing" "go.opentelemetry.io/otel/trace" ) const ( typeName = "IPWhiteLister" ) // ipWhiteLister is a middleware that provides Checks of the Requesting IP against a set of Whitelists. type ipWhiteLister struct { next http.Handler whiteLister *ip.Checker strategy ip.Strategy name string } // New builds a new IPWhiteLister given a list of CIDR-Strings to whitelist. func New(ctx context.Context, next http.Handler, config dynamic.IPWhiteList, name string) (http.Handler, error) { logger := middlewares.GetLogger(ctx, name, typeName) logger.Debug().Msg("Creating middleware") if len(config.SourceRange) == 0 { return nil, errors.New("sourceRange is empty, IPWhiteLister not created") } checker, err := ip.NewChecker(config.SourceRange) if err != nil { return nil, fmt.Errorf("cannot parse CIDR whitelist %s: %w", config.SourceRange, err) } strategy, err := config.IPStrategy.Get() if err != nil { return nil, err } logger.Debug().Msgf("Setting up IPWhiteLister with sourceRange: %s", config.SourceRange) return &ipWhiteLister{ strategy: strategy, whiteLister: checker, next: next, name: name, }, nil } func (wl *ipWhiteLister) GetTracingInformation() (string, string, trace.SpanKind) { return wl.name, typeName, trace.SpanKindInternal } func (wl *ipWhiteLister) ServeHTTP(rw http.ResponseWriter, req *http.Request) { logger := middlewares.GetLogger(req.Context(), wl.name, typeName) ctx := logger.WithContext(req.Context()) clientIP := wl.strategy.GetIP(req) err := wl.whiteLister.IsAuthorized(clientIP) if err != nil { msg := fmt.Sprintf("Rejecting IP %s: %v", clientIP, err) logger.Debug().Msg(msg) tracing.SetStatusErrorf(req.Context(), msg) reject(ctx, rw) return } logger.Debug().Msgf("Accepting IP %s", clientIP) wl.next.ServeHTTP(rw, req) } func reject(ctx context.Context, rw http.ResponseWriter) { statusCode := http.StatusForbidden rw.WriteHeader(statusCode) _, err := rw.Write([]byte(http.StatusText(statusCode))) if err != nil { log.Ctx(ctx).Error().Err(err).Send() } }