diff --git a/pkg/cli/commands.go b/pkg/cli/commands.go index eaf8f73a7..6e24e2a4b 100644 --- a/pkg/cli/commands.go +++ b/pkg/cli/commands.go @@ -3,7 +3,6 @@ package cli import ( - "errors" "fmt" "os" "path/filepath" @@ -17,7 +16,9 @@ type Command struct { Resources []ResourceLoader Run func([]string) error Hidden bool - subCommands []*Command + // AllowArg if not set, disallows any argument that is not a known command or a sub-command. + AllowArg bool + subCommands []*Command } // AddCommand Adds a sub command. @@ -40,13 +41,17 @@ func Execute(cmd *Command) error { } func execute(cmd *Command, args []string, root bool) error { + // Calls command without args. if len(args) == 1 { - if err := run(cmd, args); err != nil { + if err := run(cmd, args[1:]); err != nil { return fmt.Errorf("command %s error: %v", args[0], err) } return nil } + // Special case: if the command is the top level one, + // and the first arg (`args[1]`) is not the command name or a known sub-command, + // then we run the top level command itself. if root && cmd.Name != args[1] && !contains(cmd.subCommands, args[1]) { if err := run(cmd, args[1:]); err != nil { return fmt.Errorf("command %s error: %v", filepath.Base(args[0]), err) @@ -54,6 +59,7 @@ func execute(cmd *Command, args []string, root bool) error { return nil } + // Calls command by its name. if len(args) >= 2 && cmd.Name == args[1] { if err := run(cmd, args[2:]); err != nil { return fmt.Errorf("command %s error: %v", cmd.Name, err) @@ -61,6 +67,7 @@ func execute(cmd *Command, args []string, root bool) error { return nil } + // No sub-command, calls the current command. if len(cmd.subCommands) == 0 { if err := run(cmd, args[1:]); err != nil { return fmt.Errorf("command %s error: %v", cmd.Name, err) @@ -68,6 +75,7 @@ func execute(cmd *Command, args []string, root bool) error { return nil } + // Trying to find the sub-command. for _, subCmd := range cmd.subCommands { if len(args) >= 2 && subCmd.Name == args[1] { return execute(subCmd, args[1:], false) @@ -84,7 +92,12 @@ func run(cmd *Command, args []string) error { if cmd.Run == nil { _ = PrintHelp(os.Stdout, cmd) - return errors.New("command not found") + return fmt.Errorf("command %s is not runnable", cmd.Name) + } + + if len(args) > 0 && !isFlag(args[0]) && !cmd.AllowArg { + _ = PrintHelp(os.Stdout, cmd) + return fmt.Errorf("command not found: %v", args) } if cmd.Configuration == nil { @@ -113,3 +126,7 @@ func contains(cmds []*Command, name string) bool { return false } + +func isFlag(arg string) bool { + return len(arg) > 0 && arg[1] == '-' +} diff --git a/pkg/cli/commands_test.go b/pkg/cli/commands_test.go index b635af3f2..77a6b8329 100644 --- a/pkg/cli/commands_test.go +++ b/pkg/cli/commands_test.go @@ -86,6 +86,23 @@ func Test_execute(t *testing.T) { }, expected: expected{result: "root"}, }, + { + desc: "root command, with argument, command not found", + args: []string{"", "echo"}, + command: func() *Command { + return &Command{ + Name: "root", + Description: "This is a test", + Configuration: nil, + Run: func(_ []string) error { + called = "root" + return nil + }, + } + + }, + expected: expected{error: true}, + }, { desc: "one sub command", args: []string{"", "sub1"}, @@ -114,6 +131,34 @@ func Test_execute(t *testing.T) { }, expected: expected{result: "sub1"}, }, + { + desc: "one sub command, with argument, command not found", + args: []string{"", "sub1", "echo"}, + command: func() *Command { + rootCmd := &Command{ + Name: "test", + Description: "This is a test", + Configuration: nil, + Run: func(_ []string) error { + called += "root" + return nil + }, + } + + _ = rootCmd.AddCommand(&Command{ + Name: "sub1", + Description: "sub1", + Configuration: nil, + Run: func(_ []string) error { + called += "sub1" + return nil + }, + }) + + return rootCmd + }, + expected: expected{error: true}, + }, { desc: "two sub commands", args: []string{"", "sub2"}, @@ -376,6 +421,7 @@ func Test_execute(t *testing.T) { Name: "sub1", Description: "sub1", Configuration: nil, + AllowArg: true, Run: func(args []string) error { called += "sub1-" + strings.Join(args, "-") return nil @@ -394,6 +440,7 @@ func Test_execute(t *testing.T) { Name: "root", Description: "This is a test", Configuration: nil, + AllowArg: true, Run: func(args []string) error { called += "root-" + strings.Join(args, "-") return nil