Merge branch 'v1.5' into master

This commit is contained in:
Fernandez Ludovic 2018-03-09 12:02:29 +01:00
commit 0a41cd43a5
4 changed files with 70 additions and 21 deletions

View file

@ -372,7 +372,7 @@ func getBoolValue(i ecsInstance, labelName string, defaultValue bool) bool {
rawValue, ok := i.containerDefinition.DockerLabels[labelName]
if ok {
if rawValue != nil {
v, err := strconv.ParseBool(*rawValue)
v, err := strconv.ParseBool(aws.StringValue(rawValue))
if err == nil {
return v
}

View file

@ -280,12 +280,12 @@ func (p *Provider) listInstances(ctx context.Context, client *awsClient) ([]ecsI
byTaskDefinition := make(map[string]int)
for _, task := range tasks {
if _, found := byContainerInstance[*task.ContainerInstanceArn]; !found {
byContainerInstance[*task.ContainerInstanceArn] = len(containerInstanceArns)
if _, found := byContainerInstance[aws.StringValue(task.ContainerInstanceArn)]; !found {
byContainerInstance[aws.StringValue(task.ContainerInstanceArn)] = len(containerInstanceArns)
containerInstanceArns = append(containerInstanceArns, task.ContainerInstanceArn)
}
if _, found := byTaskDefinition[*task.TaskDefinitionArn]; !found {
byTaskDefinition[*task.TaskDefinitionArn] = len(taskDefinitionArns)
if _, found := byTaskDefinition[aws.StringValue(task.TaskDefinitionArn)]; !found {
byTaskDefinition[aws.StringValue(task.TaskDefinitionArn)] = len(taskDefinitionArns)
taskDefinitionArns = append(taskDefinitionArns, task.TaskDefinitionArn)
}
}
@ -302,23 +302,23 @@ func (p *Provider) listInstances(ctx context.Context, client *awsClient) ([]ecsI
for _, task := range tasks {
machineIdx := byContainerInstance[*task.ContainerInstanceArn]
taskDefIdx := byTaskDefinition[*task.TaskDefinitionArn]
machineIdx := byContainerInstance[aws.StringValue(task.ContainerInstanceArn)]
taskDefIdx := byTaskDefinition[aws.StringValue(task.TaskDefinitionArn)]
for _, container := range task.Containers {
taskDefinition := taskDefinitions[taskDefIdx]
var containerDefinition *ecs.ContainerDefinition
for _, def := range taskDefinition.ContainerDefinitions {
if *container.Name == *def.Name {
if aws.StringValue(container.Name) == aws.StringValue(def.Name) {
containerDefinition = def
break
}
}
instances = append(instances, ecsInstance{
fmt.Sprintf("%s-%s", strings.Replace(*task.Group, ":", "-", 1), *container.Name),
(*task.TaskArn)[len(*task.TaskArn)-12:],
fmt.Sprintf("%s-%s", strings.Replace(aws.StringValue(task.Group), ":", "-", 1), *container.Name),
(aws.StringValue(task.TaskArn))[len(aws.StringValue(task.TaskArn))-12:],
task,
taskDefinition,
container,
@ -338,7 +338,7 @@ func (p *Provider) lookupEc2Instances(ctx context.Context, client *awsClient, cl
instanceIds := make([]*string, len(containerArns))
instances := make([]*ec2.Instance, len(containerArns))
for i, arn := range containerArns {
order[*arn] = i
order[aws.StringValue(arn)] = i
}
req, _ := client.ecs.DescribeContainerInstancesRequest(&ecs.DescribeContainerInstancesInput{
@ -353,7 +353,7 @@ func (p *Provider) lookupEc2Instances(ctx context.Context, client *awsClient, cl
containerResp := req.Data.(*ecs.DescribeContainerInstancesOutput)
for i, container := range containerResp.ContainerInstances {
order[*container.Ec2InstanceId] = order[*container.ContainerInstanceArn]
order[aws.StringValue(container.Ec2InstanceId)] = order[aws.StringValue(container.ContainerInstanceArn)]
instanceIds[i] = container.Ec2InstanceId
}
}
@ -371,7 +371,7 @@ func (p *Provider) lookupEc2Instances(ctx context.Context, client *awsClient, cl
for _, r := range instancesResp.Reservations {
for _, i := range r.Instances {
if i.InstanceId != nil {
instances[order[*i.InstanceId]] = i
instances[order[aws.StringValue(i.InstanceId)]] = i
}
}
}
@ -408,8 +408,8 @@ func (p *Provider) filterInstance(i ecsInstance) bool {
return false
}
if *i.machine.State.Name != ec2.InstanceStateNameRunning {
log.Debugf("Filtering ecs instance in an incorrect state %s (%s) (state = %s)", i.Name, i.ID, *i.machine.State.Name)
if aws.StringValue(i.machine.State.Name) != ec2.InstanceStateNameRunning {
log.Debugf("Filtering ecs instance in an incorrect state %s (%s) (state = %s)", i.Name, i.ID, aws.StringValue(i.machine.State.Name))
return false
}

View file

@ -17,10 +17,10 @@ type IP struct {
// NewIP builds a new IP given a list of CIDR-Strings to whitelist
func NewIP(whitelistStrings []string, insecure bool) (*IP, error) {
if len(whitelistStrings) == 0 && !insecure {
return nil, errors.New("no whiteListsNet provided")
return nil, errors.New("no white list provided")
}
ip := IP{}
ip := IP{insecure: insecure}
if !insecure {
for _, whitelistString := range whitelistStrings {

View file

@ -19,12 +19,12 @@ func TestNew(t *testing.T) {
desc: "nil whitelist",
whitelistStrings: nil,
expectedWhitelists: nil,
errMessage: "no whiteListsNet provided",
errMessage: "no white list provided",
}, {
desc: "empty whitelist",
whitelistStrings: []string{},
expectedWhitelists: nil,
errMessage: "no whiteListsNet provided",
errMessage: "no white list provided",
}, {
desc: "whitelist containing empty string",
whitelistStrings: []string{
@ -90,7 +90,7 @@ func TestNew(t *testing.T) {
}
}
func TestIsAllowed(t *testing.T) {
func TestContainsIsAllowed(t *testing.T) {
cases := []struct {
desc string
whitelistStrings []string
@ -275,6 +275,7 @@ func TestIsAllowed(t *testing.T) {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
whiteLister, err := NewIP(test.whitelistStrings, false)
require.NoError(t, err)
@ -297,7 +298,55 @@ func TestIsAllowed(t *testing.T) {
}
}
func TestBrokenIPs(t *testing.T) {
func TestContainsInsecure(t *testing.T) {
mustNewIP := func(whitelistStrings []string, insecure bool) *IP {
ip, err := NewIP(whitelistStrings, insecure)
if err != nil {
t.Fatal(err)
}
return ip
}
testCases := []struct {
desc string
whiteLister *IP
ip string
expected bool
}{
{
desc: "valid ip and insecure",
whiteLister: mustNewIP([]string{"1.2.3.4/24"}, true),
ip: "1.2.3.1",
expected: true,
},
{
desc: "invalid ip and insecure",
whiteLister: mustNewIP([]string{"1.2.3.4/24"}, true),
ip: "10.2.3.1",
expected: true,
},
{
desc: "invalid ip and secure",
whiteLister: mustNewIP([]string{"1.2.3.4/24"}, false),
ip: "10.2.3.1",
expected: false,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
ok, _, err := test.whiteLister.Contains(test.ip)
require.NoError(t, err)
assert.Equal(t, test.expected, ok)
})
}
}
func TestContainsBrokenIPs(t *testing.T) {
brokenIPs := []string{
"foo",
"10.0.0.350",