feat(cli): add Int64Flag, DurationFlag helpers; remove NewPassthrough

Add Int64Flag and DurationFlag to the flag helper set for commands
needing int64 seeds and time.Duration intervals. Remove NewPassthrough
which enabled the anti-pattern of bypassing cobra flag parsing with
stdlib flag.FlagSet.

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-23 03:32:39 +00:00
parent 0006650a10
commit 38765962f8
2 changed files with 48 additions and 36 deletions

View file

@ -1,9 +1,18 @@
package cli
import (
"time"
"github.com/spf13/cobra"
)
// ─────────────────────────────────────────────────────────────────────────────
// Cobra Re-exports
// ─────────────────────────────────────────────────────────────────────────────
// PositionalArgs is the cobra positional args type.
type PositionalArgs = cobra.PositionalArgs
// ─────────────────────────────────────────────────────────────────────────────
// Command Type Re-export
// ─────────────────────────────────────────────────────────────────────────────
@ -69,23 +78,6 @@ func NewRun(use, short, long string, run func(cmd *Command, args []string)) *Com
return cmd
}
// NewPassthrough creates a command that passes all arguments (including flags)
// to the given function. Used for commands that do their own flag parsing
// (e.g. incremental migration from flag.FlagSet to cobra).
//
// cmd := cli.NewPassthrough("train", "Train a model", func(args []string) {
// // args includes all flags: ["--model", "gemma-3-1b", "--epochs", "10"]
// fs := flag.NewFlagSet("train", flag.ExitOnError)
// // ...
// })
func NewPassthrough(use, short string, fn func(args []string)) *Command {
cmd := NewRun(use, short, "", func(_ *Command, args []string) {
fn(args)
})
cmd.DisableFlagParsing = true
return cmd
}
// ─────────────────────────────────────────────────────────────────────────────
// Flag Helpers
// ─────────────────────────────────────────────────────────────────────────────
@ -129,6 +121,45 @@ func IntFlag(cmd *Command, ptr *int, name, short string, def int, usage string)
}
}
// Float64Flag adds a float64 flag to a command.
// The value will be stored in the provided pointer.
//
// var threshold float64
// cli.Float64Flag(cmd, &threshold, "threshold", "t", 0.0, "Score threshold")
func Float64Flag(cmd *Command, ptr *float64, name, short string, def float64, usage string) {
if short != "" {
cmd.Flags().Float64VarP(ptr, name, short, def, usage)
} else {
cmd.Flags().Float64Var(ptr, name, def, usage)
}
}
// Int64Flag adds an int64 flag to a command.
// The value will be stored in the provided pointer.
//
// var seed int64
// cli.Int64Flag(cmd, &seed, "seed", "s", 0, "Random seed")
func Int64Flag(cmd *Command, ptr *int64, name, short string, def int64, usage string) {
if short != "" {
cmd.Flags().Int64VarP(ptr, name, short, def, usage)
} else {
cmd.Flags().Int64Var(ptr, name, def, usage)
}
}
// DurationFlag adds a time.Duration flag to a command.
// The value will be stored in the provided pointer.
//
// var timeout time.Duration
// cli.DurationFlag(cmd, &timeout, "timeout", "t", 30*time.Second, "Request timeout")
func DurationFlag(cmd *Command, ptr *time.Duration, name, short string, def time.Duration, usage string) {
if short != "" {
cmd.Flags().DurationVarP(ptr, name, short, def, usage)
} else {
cmd.Flags().DurationVar(ptr, name, def, usage)
}
}
// StringSliceFlag adds a string slice flag to a command.
// The value will be stored in the provided pointer.
//

View file

@ -164,22 +164,3 @@ func TestWithAppName_Good(t *testing.T) {
})
}
// TestNewPassthrough_Good tests the passthrough command builder.
func TestNewPassthrough_Good(t *testing.T) {
t.Run("passes all args including flags", func(t *testing.T) {
var received []string
cmd := NewPassthrough("train", "Train", func(args []string) {
received = args
})
cmd.SetArgs([]string{"--model", "gemma", "--epochs", "10"})
err := cmd.Execute()
require.NoError(t, err)
assert.Equal(t, []string{"--model", "gemma", "--epochs", "10"}, received)
})
t.Run("flag parsing is disabled", func(t *testing.T) {
cmd := NewPassthrough("run", "Run", func(_ []string) {})
assert.True(t, cmd.DisableFlagParsing)
})
}