diff --git a/pkg/cli/prompt_test.go b/pkg/cli/prompt_test.go index c9684e3..e72ac5e 100644 --- a/pkg/cli/prompt_test.go +++ b/pkg/cli/prompt_test.go @@ -205,6 +205,14 @@ func TestChoose_Good_Filter(t *testing.T) { assert.Equal(t, "apricot", val) } +func TestChoose_Bad_FilteredDefaultDoesNotFallBackToFirstVisible(t *testing.T) { + SetStdin(strings.NewReader("ap\n\n2\n")) + defer SetStdin(nil) + + val := Choose("Pick", []string{"apple", "banana", "apricot"}, WithDefaultIndex[string](1), Filter[string]()) + assert.Equal(t, "apricot", val) +} + func TestChooseMulti_Good_Filter(t *testing.T) { SetStdin(strings.NewReader("ap\n1 2\n")) defer SetStdin(nil) @@ -213,6 +221,14 @@ func TestChooseMulti_Good_Filter(t *testing.T) { assert.Equal(t, []string{"apple", "apricot"}, vals) } +func TestChooseMulti_Bad_FilteredDefaultDoesNotFallBackToFirstVisible(t *testing.T) { + SetStdin(strings.NewReader("ap\n\n2\n")) + defer SetStdin(nil) + + vals := ChooseMulti("Pick", []string{"apple", "banana", "apricot"}, WithDefaultIndex[string](1), Filter[string]()) + assert.Equal(t, []string{"apricot"}, vals) +} + func TestChooseMulti_Good_Commas(t *testing.T) { SetStdin(strings.NewReader("1,3\n")) defer SetStdin(nil) diff --git a/pkg/cli/utils.go b/pkg/cli/utils.go index 1fca45e..a5cc3fb 100644 --- a/pkg/cli/utils.go +++ b/pkg/cli/utils.go @@ -363,16 +363,20 @@ func Choose[T any](prompt string, items []T, opts ...ChooseOption[T]) T { response = strings.TrimSpace(response) if err != nil && response == "" { - if cfg.defaultN >= 0 { - return items[defaultVisibleIndex(visible, cfg.defaultN)] + if idx, ok := defaultVisibleIndex(visible, cfg.defaultN); ok { + return items[idx] } var zero T return zero } if response == "" { + if idx, ok := defaultVisibleIndex(visible, cfg.defaultN); ok { + return items[idx] + } if cfg.defaultN >= 0 { - return items[defaultVisibleIndex(visible, cfg.defaultN)] + fmt.Printf("Default selection is not available in the current list\n") + continue } fmt.Printf("Please enter a number between 1 and %d\n", len(visible)) continue @@ -454,8 +458,12 @@ func ChooseMulti[T any](prompt string, items []T, opts ...ChooseOption[T]) []T { // Empty response returns no selections if response == "" { + if idx, ok := defaultVisibleIndex(visible, cfg.defaultN); ok { + return []T{items[idx]} + } if cfg.defaultN >= 0 { - return []T{items[defaultVisibleIndex(visible, cfg.defaultN)]} + fmt.Printf("Default selection is not available in the current list\n") + continue } return nil } @@ -499,18 +507,16 @@ func renderChoices[T any](prompt string, items []T, visible []int, displayFn fun } } -func defaultVisibleIndex(visible []int, defaultN int) int { - if defaultN >= 0 { - for _, idx := range visible { - if idx == defaultN { - return idx - } +func defaultVisibleIndex(visible []int, defaultN int) (int, bool) { + if defaultN < 0 { + return 0, false + } + for _, idx := range visible { + if idx == defaultN { + return idx, true } } - if len(visible) > 0 { - return visible[0] - } - return 0 + return 0, false } func filterVisible[T any](items []T, visible []int, query string, displayFn func(T) string) []int {