From 862957c30c237864c0ea8a8d3e1f99b266654a8f Mon Sep 17 00:00:00 2001 From: Ludovic Fernandez Date: Thu, 8 Mar 2018 10:08:03 +0100 Subject: [PATCH 1/2] Safe access to ECS API pointer values. --- provider/ecs/ecs.go | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/provider/ecs/ecs.go b/provider/ecs/ecs.go index 5f5e47174..4f8c958cb 100644 --- a/provider/ecs/ecs.go +++ b/provider/ecs/ecs.go @@ -308,12 +308,12 @@ func (p *Provider) listInstances(ctx context.Context, client *awsClient) ([]ecsI byTaskDefinition := make(map[string]int) for _, task := range tasks { - if _, found := byContainerInstance[*task.ContainerInstanceArn]; !found { - byContainerInstance[*task.ContainerInstanceArn] = len(containerInstanceArns) + if _, found := byContainerInstance[aws.StringValue(task.ContainerInstanceArn)]; !found { + byContainerInstance[aws.StringValue(task.ContainerInstanceArn)] = len(containerInstanceArns) containerInstanceArns = append(containerInstanceArns, task.ContainerInstanceArn) } - if _, found := byTaskDefinition[*task.TaskDefinitionArn]; !found { - byTaskDefinition[*task.TaskDefinitionArn] = len(taskDefinitionArns) + if _, found := byTaskDefinition[aws.StringValue(task.TaskDefinitionArn)]; !found { + byTaskDefinition[aws.StringValue(task.TaskDefinitionArn)] = len(taskDefinitionArns) taskDefinitionArns = append(taskDefinitionArns, task.TaskDefinitionArn) } } @@ -327,11 +327,10 @@ func (p *Provider) listInstances(ctx context.Context, client *awsClient) ([]ecsI if err != nil { return nil, err } - for _, task := range tasks { - machineIdx := byContainerInstance[*task.ContainerInstanceArn] - taskDefIdx := byTaskDefinition[*task.TaskDefinitionArn] + machineIdx := byContainerInstance[aws.StringValue(task.ContainerInstanceArn)] + taskDefIdx := byTaskDefinition[aws.StringValue(task.TaskDefinitionArn)] for _, container := range task.Containers { @@ -345,8 +344,8 @@ func (p *Provider) listInstances(ctx context.Context, client *awsClient) ([]ecsI } instances = append(instances, ecsInstance{ - fmt.Sprintf("%s-%s", strings.Replace(*task.Group, ":", "-", 1), *container.Name), - (*task.TaskArn)[len(*task.TaskArn)-12:], + fmt.Sprintf("%s-%s", strings.Replace(aws.StringValue(task.Group), ":", "-", 1), *container.Name), + (aws.StringValue(task.TaskArn))[len(aws.StringValue(task.TaskArn))-12:], task, taskDefinition, container, @@ -381,7 +380,7 @@ func (p *Provider) lookupEc2Instances(ctx context.Context, client *awsClient, cl containerResp := req.Data.(*ecs.DescribeContainerInstancesOutput) for i, container := range containerResp.ContainerInstances { - order[*container.Ec2InstanceId] = order[*container.ContainerInstanceArn] + order[aws.StringValue(container.Ec2InstanceId)] = order[aws.StringValue(container.ContainerInstanceArn)] instanceIds[i] = container.Ec2InstanceId } } @@ -399,7 +398,7 @@ func (p *Provider) lookupEc2Instances(ctx context.Context, client *awsClient, cl for _, r := range instancesResp.Reservations { for _, i := range r.Instances { if i.InstanceId != nil { - instances[order[*i.InstanceId]] = i + instances[order[aws.StringValue(i.InstanceId)]] = i } } } @@ -426,7 +425,7 @@ func (p *Provider) lookupTaskDefinitions(ctx context.Context, client *awsClient, func (p *Provider) label(i ecsInstance, k string) string { if v, found := i.containerDefinition.DockerLabels[k]; found { - return *v + return aws.StringValue(v) } return "" } @@ -565,14 +564,14 @@ func (p *Provider) getProtocol(i ecsInstance) string { } func (p *Provider) getHost(i ecsInstance) string { - return *i.machine.PrivateIpAddress + return aws.StringValue(i.machine.PrivateIpAddress) } func (p *Provider) getPort(i ecsInstance) string { if port := p.label(i, types.LabelPort); port != "" { return port } - return strconv.FormatInt(*i.container.NetworkBindings[0].HostPort, 10) + return strconv.FormatInt(aws.Int64Value(i.container.NetworkBindings[0].HostPort), 10) } func (p *Provider) getWeight(i ecsInstance) string { From 59f7b2ea98517c1b5168e996b3894add3065b36b Mon Sep 17 00:00:00 2001 From: Ludovic Fernandez Date: Thu, 8 Mar 2018 15:08:03 +0100 Subject: [PATCH 2/2] Propagate insecure in white list. --- whitelist/ip.go | 4 ++-- whitelist/ip_test.go | 57 ++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/whitelist/ip.go b/whitelist/ip.go index 322404fab..200e7ae7c 100644 --- a/whitelist/ip.go +++ b/whitelist/ip.go @@ -17,10 +17,10 @@ type IP struct { // NewIP builds a new IP given a list of CIDR-Strings to whitelist func NewIP(whitelistStrings []string, insecure bool) (*IP, error) { if len(whitelistStrings) == 0 && !insecure { - return nil, errors.New("no whiteListsNet provided") + return nil, errors.New("no white list provided") } - ip := IP{} + ip := IP{insecure: insecure} if !insecure { for _, whitelistString := range whitelistStrings { diff --git a/whitelist/ip_test.go b/whitelist/ip_test.go index abd65f297..a80fe98d0 100644 --- a/whitelist/ip_test.go +++ b/whitelist/ip_test.go @@ -19,12 +19,12 @@ func TestNew(t *testing.T) { desc: "nil whitelist", whitelistStrings: nil, expectedWhitelists: nil, - errMessage: "no whiteListsNet provided", + errMessage: "no white list provided", }, { desc: "empty whitelist", whitelistStrings: []string{}, expectedWhitelists: nil, - errMessage: "no whiteListsNet provided", + errMessage: "no white list provided", }, { desc: "whitelist containing empty string", whitelistStrings: []string{ @@ -90,7 +90,7 @@ func TestNew(t *testing.T) { } } -func TestIsAllowed(t *testing.T) { +func TestContainsIsAllowed(t *testing.T) { cases := []struct { desc string whitelistStrings []string @@ -275,6 +275,7 @@ func TestIsAllowed(t *testing.T) { test := test t.Run(test.desc, func(t *testing.T) { t.Parallel() + whiteLister, err := NewIP(test.whitelistStrings, false) require.NoError(t, err) @@ -297,7 +298,55 @@ func TestIsAllowed(t *testing.T) { } } -func TestBrokenIPs(t *testing.T) { +func TestContainsInsecure(t *testing.T) { + mustNewIP := func(whitelistStrings []string, insecure bool) *IP { + ip, err := NewIP(whitelistStrings, insecure) + if err != nil { + t.Fatal(err) + } + return ip + } + + testCases := []struct { + desc string + whiteLister *IP + ip string + expected bool + }{ + { + desc: "valid ip and insecure", + whiteLister: mustNewIP([]string{"1.2.3.4/24"}, true), + ip: "1.2.3.1", + expected: true, + }, + { + desc: "invalid ip and insecure", + whiteLister: mustNewIP([]string{"1.2.3.4/24"}, true), + ip: "10.2.3.1", + expected: true, + }, + { + desc: "invalid ip and secure", + whiteLister: mustNewIP([]string{"1.2.3.4/24"}, false), + ip: "10.2.3.1", + expected: false, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + ok, _, err := test.whiteLister.Contains(test.ip) + require.NoError(t, err) + + assert.Equal(t, test.expected, ok) + }) + } +} + +func TestContainsBrokenIPs(t *testing.T) { brokenIPs := []string{ "foo", "10.0.0.350",