diff --git a/pkg/cli/command.go b/pkg/cli/command.go index 58ec8673..be888dd8 100644 --- a/pkg/cli/command.go +++ b/pkg/cli/command.go @@ -1,9 +1,18 @@ package cli import ( + "time" + "github.com/spf13/cobra" ) +// ───────────────────────────────────────────────────────────────────────────── +// Cobra Re-exports +// ───────────────────────────────────────────────────────────────────────────── + +// PositionalArgs is the cobra positional args type. +type PositionalArgs = cobra.PositionalArgs + // ───────────────────────────────────────────────────────────────────────────── // Command Type Re-export // ───────────────────────────────────────────────────────────────────────────── @@ -69,23 +78,6 @@ func NewRun(use, short, long string, run func(cmd *Command, args []string)) *Com return cmd } -// NewPassthrough creates a command that passes all arguments (including flags) -// to the given function. Used for commands that do their own flag parsing -// (e.g. incremental migration from flag.FlagSet to cobra). -// -// cmd := cli.NewPassthrough("train", "Train a model", func(args []string) { -// // args includes all flags: ["--model", "gemma-3-1b", "--epochs", "10"] -// fs := flag.NewFlagSet("train", flag.ExitOnError) -// // ... -// }) -func NewPassthrough(use, short string, fn func(args []string)) *Command { - cmd := NewRun(use, short, "", func(_ *Command, args []string) { - fn(args) - }) - cmd.DisableFlagParsing = true - return cmd -} - // ───────────────────────────────────────────────────────────────────────────── // Flag Helpers // ───────────────────────────────────────────────────────────────────────────── @@ -129,6 +121,45 @@ func IntFlag(cmd *Command, ptr *int, name, short string, def int, usage string) } } +// Float64Flag adds a float64 flag to a command. +// The value will be stored in the provided pointer. +// +// var threshold float64 +// cli.Float64Flag(cmd, &threshold, "threshold", "t", 0.0, "Score threshold") +func Float64Flag(cmd *Command, ptr *float64, name, short string, def float64, usage string) { + if short != "" { + cmd.Flags().Float64VarP(ptr, name, short, def, usage) + } else { + cmd.Flags().Float64Var(ptr, name, def, usage) + } +} + +// Int64Flag adds an int64 flag to a command. +// The value will be stored in the provided pointer. +// +// var seed int64 +// cli.Int64Flag(cmd, &seed, "seed", "s", 0, "Random seed") +func Int64Flag(cmd *Command, ptr *int64, name, short string, def int64, usage string) { + if short != "" { + cmd.Flags().Int64VarP(ptr, name, short, def, usage) + } else { + cmd.Flags().Int64Var(ptr, name, def, usage) + } +} + +// DurationFlag adds a time.Duration flag to a command. +// The value will be stored in the provided pointer. +// +// var timeout time.Duration +// cli.DurationFlag(cmd, &timeout, "timeout", "t", 30*time.Second, "Request timeout") +func DurationFlag(cmd *Command, ptr *time.Duration, name, short string, def time.Duration, usage string) { + if short != "" { + cmd.Flags().DurationVarP(ptr, name, short, def, usage) + } else { + cmd.Flags().DurationVar(ptr, name, def, usage) + } +} + // StringSliceFlag adds a string slice flag to a command. // The value will be stored in the provided pointer. // diff --git a/pkg/cli/commands_test.go b/pkg/cli/commands_test.go index 08654e4b..f5229564 100644 --- a/pkg/cli/commands_test.go +++ b/pkg/cli/commands_test.go @@ -164,22 +164,3 @@ func TestWithAppName_Good(t *testing.T) { }) } -// TestNewPassthrough_Good tests the passthrough command builder. -func TestNewPassthrough_Good(t *testing.T) { - t.Run("passes all args including flags", func(t *testing.T) { - var received []string - cmd := NewPassthrough("train", "Train", func(args []string) { - received = args - }) - - cmd.SetArgs([]string{"--model", "gemma", "--epochs", "10"}) - err := cmd.Execute() - require.NoError(t, err) - assert.Equal(t, []string{"--model", "gemma", "--epochs", "10"}, received) - }) - - t.Run("flag parsing is disabled", func(t *testing.T) { - cmd := NewPassthrough("run", "Run", func(_ []string) {}) - assert.True(t, cmd.DisableFlagParsing) - }) -}