traefik/pkg/server/middleware/tcp/middlewares.go
2024-03-12 12:08:03 +01:00

122 lines
3.5 KiB
Go

package tcpmiddleware
import (
"context"
"fmt"
"slices"
"strings"
"github.com/rs/zerolog/log"
"github.com/traefik/traefik/v3/pkg/config/runtime"
"github.com/traefik/traefik/v3/pkg/middlewares/tcp/inflightconn"
"github.com/traefik/traefik/v3/pkg/middlewares/tcp/ipallowlist"
"github.com/traefik/traefik/v3/pkg/middlewares/tcp/ipwhitelist"
"github.com/traefik/traefik/v3/pkg/server/provider"
"github.com/traefik/traefik/v3/pkg/tcp"
)
type middlewareStackType int
const (
middlewareStackKey middlewareStackType = iota
)
// Builder the middleware builder.
type Builder struct {
configs map[string]*runtime.TCPMiddlewareInfo
}
// NewBuilder creates a new Builder.
func NewBuilder(configs map[string]*runtime.TCPMiddlewareInfo) *Builder {
return &Builder{configs: configs}
}
// BuildChain creates a middleware chain.
func (b *Builder) BuildChain(ctx context.Context, middlewares []string) *tcp.Chain {
chain := tcp.NewChain()
for _, name := range middlewares {
middlewareName := provider.GetQualifiedName(ctx, name)
chain = chain.Append(func(next tcp.Handler) (tcp.Handler, error) {
constructorContext := provider.AddInContext(ctx, middlewareName)
if midInf, ok := b.configs[middlewareName]; !ok || midInf.TCPMiddleware == nil {
return nil, fmt.Errorf("middleware %q does not exist", middlewareName)
}
var err error
if constructorContext, err = checkRecursion(constructorContext, middlewareName); err != nil {
b.configs[middlewareName].AddError(err, true)
return nil, err
}
constructor, err := b.buildConstructor(constructorContext, middlewareName)
if err != nil {
b.configs[middlewareName].AddError(err, true)
return nil, err
}
handler, err := constructor(next)
if err != nil {
b.configs[middlewareName].AddError(err, true)
return nil, err
}
return handler, nil
})
}
return &chain
}
func checkRecursion(ctx context.Context, middlewareName string) (context.Context, error) {
currentStack, ok := ctx.Value(middlewareStackKey).([]string)
if !ok {
currentStack = []string{}
}
if slices.Contains(currentStack, middlewareName) {
return ctx, fmt.Errorf("could not instantiate middleware %s: recursion detected in %s", middlewareName, strings.Join(append(currentStack, middlewareName), "->"))
}
return context.WithValue(ctx, middlewareStackKey, append(currentStack, middlewareName)), nil
}
func (b *Builder) buildConstructor(ctx context.Context, middlewareName string) (tcp.Constructor, error) {
config := b.configs[middlewareName]
if config == nil || config.TCPMiddleware == nil {
return nil, fmt.Errorf("invalid middleware %q configuration", middlewareName)
}
var middleware tcp.Constructor
// InFlightConn
if config.InFlightConn != nil {
middleware = func(next tcp.Handler) (tcp.Handler, error) {
return inflightconn.New(ctx, next, *config.InFlightConn, middlewareName)
}
}
// IPWhiteList
if config.IPWhiteList != nil {
log.Warn().Msg("IPWhiteList is deprecated, please use IPAllowList instead.")
middleware = func(next tcp.Handler) (tcp.Handler, error) {
return ipwhitelist.New(ctx, next, *config.IPWhiteList, middlewareName)
}
}
// IPAllowList
if config.IPAllowList != nil {
middleware = func(next tcp.Handler) (tcp.Handler, error) {
return ipallowlist.New(ctx, next, *config.IPAllowList, middlewareName)
}
}
if middleware == nil {
return nil, fmt.Errorf("invalid middleware %q configuration: invalid middleware type or middleware does not exist", middlewareName)
}
return middleware, nil
}