traefik/pkg/provider/ecs/ecs.go
2024-03-12 12:08:03 +01:00

570 lines
18 KiB
Go

package ecs
import (
"context"
"fmt"
"strings"
"text/template"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ecs"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/cenkalti/backoff/v4"
"github.com/patrickmn/go-cache"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/job"
"github.com/traefik/traefik/v3/pkg/logs"
"github.com/traefik/traefik/v3/pkg/provider"
"github.com/traefik/traefik/v3/pkg/safe"
)
// Provider holds configurations of the provider.
type Provider struct {
Constraints string `description:"Constraints is an expression that Traefik matches against the container's labels to determine whether to create any route for that container." json:"constraints,omitempty" toml:"constraints,omitempty" yaml:"constraints,omitempty" export:"true"`
ExposedByDefault bool `description:"Expose services by default." json:"exposedByDefault,omitempty" toml:"exposedByDefault,omitempty" yaml:"exposedByDefault,omitempty" export:"true"`
RefreshSeconds int `description:"Polling interval (in seconds)." json:"refreshSeconds,omitempty" toml:"refreshSeconds,omitempty" yaml:"refreshSeconds,omitempty" export:"true"`
DefaultRule string `description:"Default rule." json:"defaultRule,omitempty" toml:"defaultRule,omitempty" yaml:"defaultRule,omitempty"`
// Provider lookup parameters.
Clusters []string `description:"ECS Cluster names." json:"clusters,omitempty" toml:"clusters,omitempty" yaml:"clusters,omitempty" export:"true"`
AutoDiscoverClusters bool `description:"Auto discover cluster." json:"autoDiscoverClusters,omitempty" toml:"autoDiscoverClusters,omitempty" yaml:"autoDiscoverClusters,omitempty" export:"true"`
HealthyTasksOnly bool `description:"Determines whether to discover only healthy tasks." json:"healthyTasksOnly,omitempty" toml:"healthyTasksOnly,omitempty" yaml:"healthyTasksOnly,omitempty" export:"true"`
ECSAnywhere bool `description:"Enable ECS Anywhere support." json:"ecsAnywhere,omitempty" toml:"ecsAnywhere,omitempty" yaml:"ecsAnywhere,omitempty" export:"true"`
Region string `description:"AWS region to use for requests." json:"region,omitempty" toml:"region,omitempty" yaml:"region,omitempty" export:"true"`
AccessKeyID string `description:"AWS credentials access key ID to use for making requests." json:"accessKeyID,omitempty" toml:"accessKeyID,omitempty" yaml:"accessKeyID,omitempty" loggable:"false"`
SecretAccessKey string `description:"AWS credentials access key to use for making requests." json:"secretAccessKey,omitempty" toml:"secretAccessKey,omitempty" yaml:"secretAccessKey,omitempty" loggable:"false"`
defaultRuleTpl *template.Template
}
type ecsInstance struct {
Name string
ID string
containerDefinition *ecs.ContainerDefinition
machine *machine
Labels map[string]string
ExtraConf configuration
}
type portMapping struct {
containerPort int64
hostPort int64
protocol string
}
type machine struct {
state string
privateIP string
ports []portMapping
healthStatus string
}
type awsClient struct {
ecs *ecs.ECS
ec2 *ec2.EC2
ssm *ssm.SSM
}
// DefaultTemplateRule The default template for the default rule.
const DefaultTemplateRule = "Host(`{{ normalize .Name }}`)"
var (
_ provider.Provider = (*Provider)(nil)
existingTaskDefCache = cache.New(30*time.Minute, 5*time.Minute)
)
// SetDefaults sets the default values.
func (p *Provider) SetDefaults() {
p.Clusters = []string{"default"}
p.AutoDiscoverClusters = false
p.HealthyTasksOnly = false
p.ExposedByDefault = true
p.RefreshSeconds = 15
p.DefaultRule = DefaultTemplateRule
}
// Init the provider.
func (p *Provider) Init() error {
defaultRuleTpl, err := provider.MakeDefaultRuleTemplate(p.DefaultRule, nil)
if err != nil {
return fmt.Errorf("error while parsing default rule: %w", err)
}
p.defaultRuleTpl = defaultRuleTpl
return nil
}
func (p *Provider) createClient(logger zerolog.Logger) (*awsClient, error) {
sess, err := session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
})
if err != nil {
return nil, err
}
ec2meta := ec2metadata.New(sess)
if p.Region == "" && ec2meta.Available() {
logger.Info().Msg("No region provided, querying instance metadata endpoint...")
identity, err := ec2meta.GetInstanceIdentityDocument()
if err != nil {
return nil, err
}
p.Region = identity.Region
}
cfg := &aws.Config{
Credentials: credentials.NewChainCredentials(
[]credentials.Provider{
&credentials.StaticProvider{
Value: credentials.Value{
AccessKeyID: p.AccessKeyID,
SecretAccessKey: p.SecretAccessKey,
},
},
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{},
defaults.RemoteCredProvider(*(defaults.Config()), defaults.Handlers()),
}),
}
// Set the region if it is defined by the user or resolved from the EC2 metadata.
if p.Region != "" {
cfg.Region = &p.Region
}
cfg.WithLogger(logs.NewAWSWrapper(logger))
return &awsClient{
ecs.New(sess, cfg),
ec2.New(sess, cfg),
ssm.New(sess, cfg),
}, nil
}
// Provide configuration to traefik from ECS.
func (p *Provider) Provide(configurationChan chan<- dynamic.Message, pool *safe.Pool) error {
pool.GoCtx(func(routineCtx context.Context) {
logger := log.Ctx(routineCtx).With().Str(logs.ProviderName, "ecs").Logger()
ctxLog := logger.WithContext(routineCtx)
operation := func() error {
awsClient, err := p.createClient(logger)
if err != nil {
return fmt.Errorf("unable to create AWS client: %w", err)
}
err = p.loadConfiguration(ctxLog, awsClient, configurationChan)
if err != nil {
return fmt.Errorf("failed to get ECS configuration: %w", err)
}
ticker := time.NewTicker(time.Second * time.Duration(p.RefreshSeconds))
defer ticker.Stop()
for {
select {
case <-ticker.C:
err = p.loadConfiguration(ctxLog, awsClient, configurationChan)
if err != nil {
return fmt.Errorf("failed to refresh ECS configuration: %w", err)
}
case <-routineCtx.Done():
return nil
}
}
}
notify := func(err error, time time.Duration) {
logger.Error().Err(err).Msgf("Provider error, retrying in %s", time)
}
err := backoff.RetryNotify(safe.OperationWithRecover(operation), backoff.WithContext(job.NewBackOff(backoff.NewExponentialBackOff()), routineCtx), notify)
if err != nil {
logger.Error().Err(err).Msg("Cannot retrieve data")
}
})
return nil
}
func (p *Provider) loadConfiguration(ctx context.Context, client *awsClient, configurationChan chan<- dynamic.Message) error {
instances, err := p.listInstances(ctx, client)
if err != nil {
return err
}
configurationChan <- dynamic.Message{
ProviderName: "ecs",
Configuration: p.buildConfiguration(ctx, instances),
}
return nil
}
// Find all running Provider tasks in a cluster, also collect the task definitions (for docker labels)
// and the EC2 instance data.
func (p *Provider) listInstances(ctx context.Context, client *awsClient) ([]ecsInstance, error) {
logger := log.Ctx(ctx)
var clustersArn []*string
var clusters []string
if p.AutoDiscoverClusters {
input := &ecs.ListClustersInput{}
for {
result, err := client.ecs.ListClusters(input)
if err != nil {
return nil, err
}
if result != nil {
clustersArn = append(clustersArn, result.ClusterArns...)
input.NextToken = result.NextToken
if result.NextToken == nil {
break
}
} else {
break
}
}
for _, cArn := range clustersArn {
clusters = append(clusters, *cArn)
}
} else {
clusters = p.Clusters
}
var instances []ecsInstance
logger.Debug().Msgf("ECS Clusters: %s", clusters)
for _, c := range clusters {
input := &ecs.ListTasksInput{
Cluster: &c,
DesiredStatus: aws.String(ecs.DesiredStatusRunning),
}
tasks := make(map[string]*ecs.Task)
err := client.ecs.ListTasksPagesWithContext(ctx, input, func(page *ecs.ListTasksOutput, lastPage bool) bool {
if len(page.TaskArns) > 0 {
resp, err := client.ecs.DescribeTasksWithContext(ctx, &ecs.DescribeTasksInput{
Tasks: page.TaskArns,
Cluster: &c,
})
if err != nil {
logger.Error().Msgf("Unable to describe tasks for %v", page.TaskArns)
} else {
for _, t := range resp.Tasks {
if p.HealthyTasksOnly && aws.StringValue(t.HealthStatus) != ecs.HealthStatusHealthy {
logger.Debug().Msgf("Skipping unhealthy task %s", aws.StringValue(t.TaskArn))
continue
}
tasks[aws.StringValue(t.TaskArn)] = t
}
}
}
return !lastPage
})
if err != nil {
return nil, fmt.Errorf("listing tasks: %w", err)
}
// Skip to the next cluster if there are no tasks found on
// this cluster.
if len(tasks) == 0 {
continue
}
ec2Instances, err := p.lookupEc2Instances(ctx, client, &c, tasks)
if err != nil {
return nil, err
}
miInstances := make(map[string]*ssm.InstanceInformation)
if p.ECSAnywhere {
// Try looking up for instances on ECS Anywhere
miInstances, err = p.lookupMiInstances(ctx, client, &c, tasks)
if err != nil {
return nil, err
}
}
taskDefinitions, err := p.lookupTaskDefinitions(ctx, client, tasks)
if err != nil {
return nil, err
}
for key, task := range tasks {
containerInstance := ec2Instances[aws.StringValue(task.ContainerInstanceArn)]
taskDef := taskDefinitions[key]
for _, container := range task.Containers {
var containerDefinition *ecs.ContainerDefinition
for _, def := range taskDef.ContainerDefinitions {
if aws.StringValue(container.Name) == aws.StringValue(def.Name) {
containerDefinition = def
break
}
}
if containerDefinition == nil {
logger.Debug().Msgf("Unable to find container definition for %s", aws.StringValue(container.Name))
continue
}
var mach *machine
if len(task.Attachments) != 0 {
if len(container.NetworkInterfaces) == 0 {
logger.Error().Msgf("Skip container %s: no network interfaces", aws.StringValue(container.Name))
continue
}
var ports []portMapping
for _, mapping := range containerDefinition.PortMappings {
if mapping != nil {
protocol := "TCP"
if aws.StringValue(mapping.Protocol) == "udp" {
protocol = "UDP"
}
ports = append(ports, portMapping{
hostPort: aws.Int64Value(mapping.HostPort),
containerPort: aws.Int64Value(mapping.ContainerPort),
protocol: protocol,
})
}
}
mach = &machine{
privateIP: aws.StringValue(container.NetworkInterfaces[0].PrivateIpv4Address),
ports: ports,
state: aws.StringValue(task.LastStatus),
healthStatus: aws.StringValue(task.HealthStatus),
}
} else {
miContainerInstance := miInstances[aws.StringValue(task.ContainerInstanceArn)]
if containerInstance == nil && miContainerInstance == nil {
logger.Error().Msgf("Unable to find container instance information for %s", aws.StringValue(container.Name))
continue
}
var ports []portMapping
for _, mapping := range container.NetworkBindings {
if mapping != nil {
ports = append(ports, portMapping{
hostPort: aws.Int64Value(mapping.HostPort),
containerPort: aws.Int64Value(mapping.ContainerPort),
})
}
}
var privateIPAddress, stateName string
if containerInstance != nil {
privateIPAddress = aws.StringValue(containerInstance.PrivateIpAddress)
stateName = aws.StringValue(containerInstance.State.Name)
} else if miContainerInstance != nil {
privateIPAddress = aws.StringValue(miContainerInstance.IPAddress)
stateName = aws.StringValue(task.LastStatus)
}
mach = &machine{
privateIP: privateIPAddress,
ports: ports,
state: stateName,
}
}
instance := ecsInstance{
Name: fmt.Sprintf("%s-%s", strings.Replace(aws.StringValue(task.Group), ":", "-", 1), *container.Name),
ID: key[len(key)-12:],
containerDefinition: containerDefinition,
machine: mach,
Labels: aws.StringValueMap(containerDefinition.DockerLabels),
}
extraConf, err := p.getConfiguration(instance)
if err != nil {
logger.Error().Err(err).Msgf("Skip container %s", getServiceName(instance))
continue
}
instance.ExtraConf = extraConf
instances = append(instances, instance)
}
}
}
return instances, nil
}
func (p *Provider) lookupMiInstances(ctx context.Context, client *awsClient, clusterName *string, ecsDatas map[string]*ecs.Task) (map[string]*ssm.InstanceInformation, error) {
instanceIDs := make(map[string]string)
miInstances := make(map[string]*ssm.InstanceInformation)
var containerInstancesArns []*string
var instanceArns []*string
for _, task := range ecsDatas {
if task.ContainerInstanceArn != nil {
containerInstancesArns = append(containerInstancesArns, task.ContainerInstanceArn)
}
}
for _, arns := range p.chunkIDs(containerInstancesArns) {
resp, err := client.ecs.DescribeContainerInstancesWithContext(ctx, &ecs.DescribeContainerInstancesInput{
ContainerInstances: arns,
Cluster: clusterName,
})
if err != nil {
return nil, fmt.Errorf("describing container instances: %w", err)
}
for _, container := range resp.ContainerInstances {
instanceIDs[aws.StringValue(container.Ec2InstanceId)] = aws.StringValue(container.ContainerInstanceArn)
// Disallow EC2 Instance IDs
// This prevents considering EC2 instances in ECS
// and getting InvalidInstanceID.Malformed error when calling the describe-instances endpoint.
if !strings.HasPrefix(aws.StringValue(container.Ec2InstanceId), "mi-") {
continue
}
instanceArns = append(instanceArns, container.Ec2InstanceId)
}
}
if len(instanceArns) > 0 {
for _, ids := range p.chunkIDs(instanceArns) {
input := &ssm.DescribeInstanceInformationInput{
Filters: []*ssm.InstanceInformationStringFilter{
{
Key: aws.String("InstanceIds"),
Values: ids,
},
},
}
err := client.ssm.DescribeInstanceInformationPagesWithContext(ctx, input, func(page *ssm.DescribeInstanceInformationOutput, lastPage bool) bool {
if len(page.InstanceInformationList) > 0 {
for _, i := range page.InstanceInformationList {
if i.InstanceId != nil {
miInstances[instanceIDs[aws.StringValue(i.InstanceId)]] = i
}
}
}
return !lastPage
})
if err != nil {
return nil, fmt.Errorf("describing instances: %w", err)
}
}
}
return miInstances, nil
}
func (p *Provider) lookupEc2Instances(ctx context.Context, client *awsClient, clusterName *string, ecsDatas map[string]*ecs.Task) (map[string]*ec2.Instance, error) {
instanceIDs := make(map[string]string)
ec2Instances := make(map[string]*ec2.Instance)
var containerInstancesArns []*string
var instanceArns []*string
for _, task := range ecsDatas {
if task.ContainerInstanceArn != nil {
containerInstancesArns = append(containerInstancesArns, task.ContainerInstanceArn)
}
}
for _, arns := range p.chunkIDs(containerInstancesArns) {
resp, err := client.ecs.DescribeContainerInstancesWithContext(ctx, &ecs.DescribeContainerInstancesInput{
ContainerInstances: arns,
Cluster: clusterName,
})
if err != nil {
return nil, fmt.Errorf("describing container instances: %w", err)
}
for _, container := range resp.ContainerInstances {
instanceIDs[aws.StringValue(container.Ec2InstanceId)] = aws.StringValue(container.ContainerInstanceArn)
// Disallow Instance IDs of the form mi-*
// This prevents considering external instances in ECS Anywhere setups
// and getting InvalidInstanceID.Malformed error when calling the describe-instances endpoint.
if strings.HasPrefix(aws.StringValue(container.Ec2InstanceId), "mi-") {
continue
}
instanceArns = append(instanceArns, container.Ec2InstanceId)
}
}
if len(instanceArns) > 0 {
for _, ids := range p.chunkIDs(instanceArns) {
input := &ec2.DescribeInstancesInput{
InstanceIds: ids,
}
err := client.ec2.DescribeInstancesPagesWithContext(ctx, input, func(page *ec2.DescribeInstancesOutput, lastPage bool) bool {
if len(page.Reservations) > 0 {
for _, r := range page.Reservations {
for _, i := range r.Instances {
if i.InstanceId != nil {
ec2Instances[instanceIDs[aws.StringValue(i.InstanceId)]] = i
}
}
}
}
return !lastPage
})
if err != nil {
return nil, fmt.Errorf("describing instances: %w", err)
}
}
}
return ec2Instances, nil
}
func (p *Provider) lookupTaskDefinitions(ctx context.Context, client *awsClient, taskDefArns map[string]*ecs.Task) (map[string]*ecs.TaskDefinition, error) {
logger := log.Ctx(ctx)
taskDef := make(map[string]*ecs.TaskDefinition)
for arn, task := range taskDefArns {
if definition, ok := existingTaskDefCache.Get(arn); ok {
taskDef[arn] = definition.(*ecs.TaskDefinition)
logger.Debug().Msgf("Found cached task definition for %s. Skipping the call", arn)
} else {
resp, err := client.ecs.DescribeTaskDefinitionWithContext(ctx, &ecs.DescribeTaskDefinitionInput{
TaskDefinition: task.TaskDefinitionArn,
})
if err != nil {
return nil, fmt.Errorf("describing task definition: %w", err)
}
taskDef[arn] = resp.TaskDefinition
existingTaskDefCache.Set(arn, resp.TaskDefinition, cache.DefaultExpiration)
}
}
return taskDef, nil
}
// chunkIDs ECS expects no more than 100 parameters be passed to a API call;
// thus, pack each string into an array capped at 100 elements.
func (p *Provider) chunkIDs(ids []*string) [][]*string {
var chunked [][]*string
for i := 0; i < len(ids); i += 100 {
var sliceEnd int
if i+100 < len(ids) {
sliceEnd = i + 100
} else {
sliceEnd = len(ids)
}
chunked = append(chunked, ids[i:sliceEnd])
}
return chunked
}