traefik/pkg/muxer/tcp/mux.go
2024-02-19 15:44:03 +01:00

486 lines
12 KiB
Go

package tcp
import (
"bytes"
"errors"
"fmt"
"net"
"regexp"
"slices"
"sort"
"strconv"
"strings"
"github.com/go-acme/lego/v4/challenge/tlsalpn01"
"github.com/traefik/traefik/v2/pkg/ip"
"github.com/traefik/traefik/v2/pkg/log"
"github.com/traefik/traefik/v2/pkg/rules"
"github.com/traefik/traefik/v2/pkg/tcp"
"github.com/traefik/traefik/v2/pkg/types"
"github.com/vulcand/predicate"
)
var tcpFuncs = map[string]func(*matchersTree, ...string) error{
"HostSNI": hostSNI,
"HostSNIRegexp": hostSNIRegexp,
"ClientIP": clientIP,
"ALPN": alpn,
}
// ParseHostSNI extracts the HostSNIs declared in a rule.
// This is a first naive implementation used in TCP routing.
func ParseHostSNI(rule string) ([]string, error) {
var matchers []string
for matcher := range tcpFuncs {
matchers = append(matchers, matcher)
}
parser, err := rules.NewParser(matchers)
if err != nil {
return nil, err
}
parse, err := parser.Parse(rule)
if err != nil {
return nil, err
}
buildTree, ok := parse.(rules.TreeBuilder)
if !ok {
return nil, fmt.Errorf("error while parsing rule %s", rule)
}
return buildTree().ParseMatchers([]string{"HostSNI"}), nil
}
// ConnData contains TCP connection metadata.
type ConnData struct {
serverName string
remoteIP string
alpnProtos []string
}
// NewConnData builds a connData struct from the given parameters.
func NewConnData(serverName string, conn tcp.WriteCloser, alpnProtos []string) (ConnData, error) {
remoteIP, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err != nil {
return ConnData{}, fmt.Errorf("error while parsing remote address %q: %w", conn.RemoteAddr().String(), err)
}
// as per https://datatracker.ietf.org/doc/html/rfc6066:
// > The hostname is represented as a byte string using ASCII encoding without a trailing dot.
// so there is no need to trim a potential trailing dot
serverName = types.CanonicalDomain(serverName)
return ConnData{
serverName: types.CanonicalDomain(serverName),
remoteIP: remoteIP,
alpnProtos: alpnProtos,
}, nil
}
// Muxer defines a muxer that handles TCP routing with rules.
type Muxer struct {
routes []*route
parser predicate.Parser
}
// NewMuxer returns a TCP muxer.
func NewMuxer() (*Muxer, error) {
var matcherNames []string
for matcherName := range tcpFuncs {
matcherNames = append(matcherNames, matcherName)
}
parser, err := rules.NewParser(matcherNames)
if err != nil {
return nil, fmt.Errorf("error while creating rules parser: %w", err)
}
return &Muxer{parser: parser}, nil
}
// Match returns the handler of the first route matching the connection metadata,
// and whether the match is exactly from the rule HostSNI(*).
func (m Muxer) Match(meta ConnData) (tcp.Handler, bool) {
for _, route := range m.routes {
if route.matchers.match(meta) {
return route.handler, route.catchAll
}
}
return nil, false
}
// AddRoute adds a new route, associated to the given handler, at the given
// priority, to the muxer.
func (m *Muxer) AddRoute(rule string, priority int, handler tcp.Handler) error {
parse, err := m.parser.Parse(rule)
if err != nil {
return fmt.Errorf("error while parsing rule %s: %w", rule, err)
}
buildTree, ok := parse.(rules.TreeBuilder)
if !ok {
return fmt.Errorf("error while parsing rule %s", rule)
}
ruleTree := buildTree()
var matchers matchersTree
err = addRule(&matchers, ruleTree)
if err != nil {
return err
}
var catchAll bool
if ruleTree.RuleLeft == nil && ruleTree.RuleRight == nil && len(ruleTree.Value) == 1 {
catchAll = ruleTree.Value[0] == "*" && strings.EqualFold(ruleTree.Matcher, "HostSNI")
}
// Special case for when the catchAll fallback is present.
// When no user-defined priority is found, the lowest computable priority minus one is used,
// in order to make the fallback the last to be evaluated.
if priority == 0 && catchAll {
priority = -1
}
// Default value, which means the user has not set it, so we'll compute it.
if priority == 0 {
priority = len(rule)
}
newRoute := &route{
handler: handler,
matchers: matchers,
catchAll: catchAll,
priority: priority,
}
m.routes = append(m.routes, newRoute)
sort.Sort(routes(m.routes))
return nil
}
func addRule(tree *matchersTree, rule *rules.Tree) error {
switch rule.Matcher {
case "and", "or":
tree.operator = rule.Matcher
tree.left = &matchersTree{}
err := addRule(tree.left, rule.RuleLeft)
if err != nil {
return err
}
tree.right = &matchersTree{}
return addRule(tree.right, rule.RuleRight)
default:
err := rules.CheckRule(rule)
if err != nil {
return err
}
err = tcpFuncs[rule.Matcher](tree, rule.Value...)
if err != nil {
return err
}
if rule.Not {
matcherFunc := tree.matcher
tree.matcher = func(meta ConnData) bool {
return !matcherFunc(meta)
}
}
}
return nil
}
// HasRoutes returns whether the muxer has routes.
func (m *Muxer) HasRoutes() bool {
return len(m.routes) > 0
}
// routes implements sort.Interface.
type routes []*route
// Len implements sort.Interface.
func (r routes) Len() int { return len(r) }
// Swap implements sort.Interface.
func (r routes) Swap(i, j int) { r[i], r[j] = r[j], r[i] }
// Less implements sort.Interface.
func (r routes) Less(i, j int) bool { return r[i].priority > r[j].priority }
// route holds the matchers to match TCP route,
// and the handler that will serve the connection.
type route struct {
// matchers tree structure reflecting the rule.
matchers matchersTree
// handler responsible for handling the route.
handler tcp.Handler
// catchAll indicates whether the route rule has exactly the catchAll value (HostSNI(`*`)).
catchAll bool
// priority is used to disambiguate between two (or more) rules that would
// all match for a given request.
// Computed from the matching rule length, if not user-set.
priority int
}
// matcher is a matcher func used to match connection properties.
type matcher func(meta ConnData) bool
// matchersTree represents the matchers tree structure.
type matchersTree struct {
// If matcher is not nil, it means that this matcherTree is a leaf of the tree.
// It is therefore mutually exclusive with left and right.
matcher matcher
// operator to combine the evaluation of left and right leaves.
operator string
// Mutually exclusive with matcher.
left *matchersTree
right *matchersTree
}
func (m *matchersTree) match(meta ConnData) bool {
if m == nil {
// This should never happen as it should have been detected during parsing.
log.WithoutContext().Warnf("Rule matcher is nil")
return false
}
if m.matcher != nil {
return m.matcher(meta)
}
switch m.operator {
case "or":
return m.left.match(meta) || m.right.match(meta)
case "and":
return m.left.match(meta) && m.right.match(meta)
default:
// This should never happen as it should have been detected during parsing.
log.WithoutContext().Warnf("Invalid rule operator %s", m.operator)
return false
}
}
func clientIP(tree *matchersTree, clientIPs ...string) error {
checker, err := ip.NewChecker(clientIPs)
if err != nil {
return fmt.Errorf("could not initialize IP Checker for \"ClientIP\" matcher: %w", err)
}
tree.matcher = func(meta ConnData) bool {
if meta.remoteIP == "" {
return false
}
ok, err := checker.Contains(meta.remoteIP)
if err != nil {
log.WithoutContext().Warnf("\"ClientIP\" matcher: could not match remote address: %v", err)
return false
}
return ok
}
return nil
}
// alpn checks if any of the connection ALPN protocols matches one of the matcher protocols.
func alpn(tree *matchersTree, protos ...string) error {
if len(protos) == 0 {
return errors.New("empty value for \"ALPN\" matcher is not allowed")
}
for _, proto := range protos {
if proto == tlsalpn01.ACMETLS1Protocol {
return fmt.Errorf("invalid protocol value for \"ALPN\" matcher, %q is not allowed", proto)
}
}
tree.matcher = func(meta ConnData) bool {
return slices.ContainsFunc(meta.alpnProtos, func(proto string) bool {
return slices.Contains(protos, proto)
})
}
return nil
}
var hostOrIP = regexp.MustCompile(`^[[:alnum:]\.\-\:]+$`)
// hostSNI checks if the SNI Host of the connection match the matcher host.
func hostSNI(tree *matchersTree, hosts ...string) error {
if len(hosts) == 0 {
return errors.New("empty value for \"HostSNI\" matcher is not allowed")
}
for i, host := range hosts {
// Special case to allow global wildcard
if host == "*" {
continue
}
if !hostOrIP.MatchString(host) {
return fmt.Errorf("invalid value for \"HostSNI\" matcher, %q is not a valid hostname or IP", host)
}
hosts[i] = strings.ToLower(host)
}
tree.matcher = func(meta ConnData) bool {
// Since a HostSNI(`*`) rule has been provided as catchAll for non-TLS TCP,
// it allows matching with an empty serverName.
// Which is why we make sure to take that case into account before
// checking meta.serverName.
if hosts[0] == "*" {
return true
}
if meta.serverName == "" {
return false
}
for _, host := range hosts {
if host == "*" {
return true
}
if host == meta.serverName {
return true
}
// trim trailing period in case of FQDN
host = strings.TrimSuffix(host, ".")
if host == meta.serverName {
return true
}
}
return false
}
return nil
}
// hostSNIRegexp checks if the SNI Host of the connection matches the matcher host regexp.
func hostSNIRegexp(tree *matchersTree, templates ...string) error {
if len(templates) == 0 {
return errors.New("empty value for \"HostSNIRegexp\" matcher is not allowed")
}
var regexps []*regexp.Regexp
for _, template := range templates {
preparedPattern, err := preparePattern(template)
if err != nil {
return fmt.Errorf("invalid pattern value for \"HostSNIRegexp\" matcher, %q is not a valid pattern: %w", template, err)
}
regexp, err := regexp.Compile(preparedPattern)
if err != nil {
return err
}
regexps = append(regexps, regexp)
}
tree.matcher = func(meta ConnData) bool {
for _, regexp := range regexps {
if regexp.MatchString(meta.serverName) {
return true
}
}
return false
}
return nil
}
// TODO: expose more of containous/mux fork to get rid of the following copied code (https://github.com/containous/mux/blob/8ffa4f6d063c/regexp.go).
// preparePattern builds a regexp pattern from the initial user defined expression.
// This function reuses the code dedicated to host matching of the newRouteRegexp func from the gorilla/mux library.
// https://github.com/containous/mux/tree/8ffa4f6d063c1e2b834a73be6a1515cca3992618.
func preparePattern(template string) (string, error) {
// Check if it is well-formed.
idxs, errBraces := braceIndices(template)
if errBraces != nil {
return "", errBraces
}
defaultPattern := "[^.]+"
pattern := bytes.NewBufferString("")
// Host SNI matching is case-insensitive
_, _ = fmt.Fprint(pattern, "(?i)")
pattern.WriteByte('^')
var end int
for i := 0; i < len(idxs); i += 2 {
// Set all values we are interested in.
raw := template[end:idxs[i]]
end = idxs[i+1]
parts := strings.SplitN(template[idxs[i]+1:end-1], ":", 2)
name := parts[0]
patt := defaultPattern
if len(parts) == 2 {
patt = parts[1]
}
// Name or pattern can't be empty.
if name == "" || patt == "" {
return "", fmt.Errorf("mux: missing name or pattern in %q",
template[idxs[i]:end])
}
// Build the regexp pattern.
_, _ = fmt.Fprintf(pattern, "%s(?P<%s>%s)", regexp.QuoteMeta(raw), varGroupName(i/2), patt)
}
// Add the remaining.
raw := template[end:]
pattern.WriteString(regexp.QuoteMeta(raw))
pattern.WriteByte('$')
return pattern.String(), nil
}
// varGroupName builds a capturing group name for the indexed variable.
// This function is a copy of varGroupName func from the gorilla/mux library.
// https://github.com/containous/mux/tree/8ffa4f6d063c1e2b834a73be6a1515cca3992618.
func varGroupName(idx int) string {
return "v" + strconv.Itoa(idx)
}
// braceIndices returns the first level curly brace indices from a string.
// This function is a copy of braceIndices func from the gorilla/mux library.
// https://github.com/containous/mux/tree/8ffa4f6d063c1e2b834a73be6a1515cca3992618.
func braceIndices(s string) ([]int, error) {
var level, idx int
var idxs []int
for i := range len(s) {
switch s[i] {
case '{':
if level++; level == 1 {
idx = i
}
case '}':
if level--; level == 0 {
idxs = append(idxs, idx, i+1)
} else if level < 0 {
return nil, fmt.Errorf("mux: unbalanced braces in %q", s)
}
}
}
if level != 0 {
return nil, fmt.Errorf("mux: unbalanced braces in %q", s)
}
return idxs, nil
}