feat(ml): add CoreGo service wrapper and CLI commands (Tasks 6-7)
Service registration with DI lifecycle, typed options, and backend management. Ten CLI subcommands under `core ml` for scoring, probing, export, expansion, status, GGUF/PEFT conversion, agent, and worker. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
fcd1758b7d
commit
3dbb5988a8
11 changed files with 839 additions and 0 deletions
67
internal/cmd/ml/cmd_agent.go
Normal file
67
internal/cmd/ml/cmd_agent.go
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/cli/pkg/ml"
|
||||
)
|
||||
|
||||
var (
|
||||
agentM3Host string
|
||||
agentM3User string
|
||||
agentM3SSHKey string
|
||||
agentM3AdapterBase string
|
||||
agentBaseModel string
|
||||
agentPollInterval int
|
||||
agentWorkDir string
|
||||
agentFilter string
|
||||
agentForce bool
|
||||
agentOneShot bool
|
||||
agentDryRun bool
|
||||
)
|
||||
|
||||
var agentCmd = &cli.Command{
|
||||
Use: "agent",
|
||||
Short: "Run the scoring agent daemon",
|
||||
Long: "Polls M3 for unscored LoRA checkpoints, converts, probes, and pushes results to InfluxDB.",
|
||||
RunE: runAgent,
|
||||
}
|
||||
|
||||
func init() {
|
||||
agentCmd.Flags().StringVar(&agentM3Host, "m3-host", ml.EnvOr("M3_HOST", "10.69.69.108"), "M3 host address")
|
||||
agentCmd.Flags().StringVar(&agentM3User, "m3-user", ml.EnvOr("M3_USER", "claude"), "M3 SSH user")
|
||||
agentCmd.Flags().StringVar(&agentM3SSHKey, "m3-ssh-key", ml.EnvOr("M3_SSH_KEY", ml.ExpandHome("~/.ssh/id_ed25519")), "SSH key for M3")
|
||||
agentCmd.Flags().StringVar(&agentM3AdapterBase, "m3-adapter-base", ml.EnvOr("M3_ADAPTER_BASE", "/Volumes/Data/lem"), "Adapter base dir on M3")
|
||||
agentCmd.Flags().StringVar(&agentBaseModel, "base-model", ml.EnvOr("BASE_MODEL", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"), "HuggingFace base model ID")
|
||||
agentCmd.Flags().IntVar(&agentPollInterval, "poll", ml.IntEnvOr("POLL_INTERVAL", 300), "Poll interval in seconds")
|
||||
agentCmd.Flags().StringVar(&agentWorkDir, "work-dir", ml.EnvOr("WORK_DIR", "/tmp/scoring-agent"), "Working directory for adapters")
|
||||
agentCmd.Flags().StringVar(&agentFilter, "filter", "", "Filter adapter dirs by prefix")
|
||||
agentCmd.Flags().BoolVar(&agentForce, "force", false, "Re-score already-scored checkpoints")
|
||||
agentCmd.Flags().BoolVar(&agentOneShot, "one-shot", false, "Process one checkpoint and exit")
|
||||
agentCmd.Flags().BoolVar(&agentDryRun, "dry-run", false, "Discover and plan but don't execute")
|
||||
}
|
||||
|
||||
func runAgent(cmd *cli.Command, args []string) error {
|
||||
cfg := &ml.AgentConfig{
|
||||
M3Host: agentM3Host,
|
||||
M3User: agentM3User,
|
||||
M3SSHKey: agentM3SSHKey,
|
||||
M3AdapterBase: agentM3AdapterBase,
|
||||
InfluxURL: influxURL,
|
||||
InfluxDB: influxDB,
|
||||
DBPath: dbPath,
|
||||
APIURL: apiURL,
|
||||
JudgeURL: judgeURL,
|
||||
JudgeModel: judgeModel,
|
||||
Model: modelName,
|
||||
BaseModel: agentBaseModel,
|
||||
PollInterval: agentPollInterval,
|
||||
WorkDir: agentWorkDir,
|
||||
Filter: agentFilter,
|
||||
Force: agentForce,
|
||||
OneShot: agentOneShot,
|
||||
DryRun: agentDryRun,
|
||||
}
|
||||
|
||||
ml.RunAgentLoop(cfg)
|
||||
return nil
|
||||
}
|
||||
40
internal/cmd/ml/cmd_convert.go
Normal file
40
internal/cmd/ml/cmd_convert.go
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/cli/pkg/ml"
|
||||
)
|
||||
|
||||
var (
|
||||
convertInput string
|
||||
convertConfig string
|
||||
convertOutputDir string
|
||||
convertBaseModel string
|
||||
)
|
||||
|
||||
var convertCmd = &cli.Command{
|
||||
Use: "convert",
|
||||
Short: "Convert MLX LoRA adapter to PEFT format",
|
||||
Long: "Converts an MLX safetensors LoRA adapter to HuggingFace PEFT format for Ollama.",
|
||||
RunE: runConvert,
|
||||
}
|
||||
|
||||
func init() {
|
||||
convertCmd.Flags().StringVar(&convertInput, "input", "", "Input safetensors file (required)")
|
||||
convertCmd.Flags().StringVar(&convertConfig, "config", "", "Adapter config JSON (required)")
|
||||
convertCmd.Flags().StringVar(&convertOutputDir, "output-dir", "", "Output directory (required)")
|
||||
convertCmd.Flags().StringVar(&convertBaseModel, "base-model", "", "Base model name for adapter_config.json")
|
||||
convertCmd.MarkFlagRequired("input")
|
||||
convertCmd.MarkFlagRequired("config")
|
||||
convertCmd.MarkFlagRequired("output-dir")
|
||||
}
|
||||
|
||||
func runConvert(cmd *cli.Command, args []string) error {
|
||||
if err := ml.ConvertMLXtoPEFT(convertInput, convertConfig, convertOutputDir, convertBaseModel); err != nil {
|
||||
return fmt.Errorf("convert to PEFT: %w", err)
|
||||
}
|
||||
fmt.Printf("PEFT adapter written to %s\n", convertOutputDir)
|
||||
return nil
|
||||
}
|
||||
81
internal/cmd/ml/cmd_expand.go
Normal file
81
internal/cmd/ml/cmd_expand.go
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/cli/pkg/ml"
|
||||
)
|
||||
|
||||
var (
|
||||
expandWorker string
|
||||
expandOutput string
|
||||
expandLimit int
|
||||
expandDryRun bool
|
||||
)
|
||||
|
||||
var expandCmd = &cli.Command{
|
||||
Use: "expand",
|
||||
Short: "Generate expansion responses from pending prompts",
|
||||
Long: "Reads pending expansion prompts from DuckDB and generates responses via an OpenAI-compatible API.",
|
||||
RunE: runExpand,
|
||||
}
|
||||
|
||||
func init() {
|
||||
expandCmd.Flags().StringVar(&expandWorker, "worker", "", "Worker hostname (defaults to os.Hostname())")
|
||||
expandCmd.Flags().StringVar(&expandOutput, "output", ".", "Output directory for JSONL files")
|
||||
expandCmd.Flags().IntVar(&expandLimit, "limit", 0, "Max prompts to process (0 = all)")
|
||||
expandCmd.Flags().BoolVar(&expandDryRun, "dry-run", false, "Print plan and exit without generating")
|
||||
}
|
||||
|
||||
func runExpand(cmd *cli.Command, args []string) error {
|
||||
if modelName == "" {
|
||||
return fmt.Errorf("--model is required")
|
||||
}
|
||||
|
||||
path := dbPath
|
||||
if path == "" {
|
||||
path = os.Getenv("LEM_DB")
|
||||
}
|
||||
if path == "" {
|
||||
return fmt.Errorf("--db or LEM_DB env is required")
|
||||
}
|
||||
|
||||
if expandWorker == "" {
|
||||
h, _ := os.Hostname()
|
||||
expandWorker = h
|
||||
}
|
||||
|
||||
db, err := ml.OpenDBReadWrite(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
rows, err := db.QueryExpansionPrompts("pending", expandLimit)
|
||||
if err != nil {
|
||||
return fmt.Errorf("query expansion_prompts: %w", err)
|
||||
}
|
||||
fmt.Printf("Loaded %d pending prompts from %s\n", len(rows), path)
|
||||
|
||||
var prompts []ml.Response
|
||||
for _, r := range rows {
|
||||
prompt := r.Prompt
|
||||
if prompt == "" && r.PromptEn != "" {
|
||||
prompt = r.PromptEn
|
||||
}
|
||||
prompts = append(prompts, ml.Response{
|
||||
ID: r.SeedID,
|
||||
Domain: r.Domain,
|
||||
Prompt: prompt,
|
||||
})
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
backend := ml.NewHTTPBackend(apiURL, modelName)
|
||||
influx := ml.NewInfluxClient(influxURL, influxDB)
|
||||
|
||||
return ml.ExpandPrompts(ctx, backend, influx, prompts, modelName, expandWorker, expandOutput, expandDryRun, expandLimit)
|
||||
}
|
||||
109
internal/cmd/ml/cmd_export.go
Normal file
109
internal/cmd/ml/cmd_export.go
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/cli/pkg/ml"
|
||||
)
|
||||
|
||||
var (
|
||||
exportOutputDir string
|
||||
exportMinChars int
|
||||
exportTrainPct int
|
||||
exportValidPct int
|
||||
exportTestPct int
|
||||
exportSeed int64
|
||||
exportParquet bool
|
||||
)
|
||||
|
||||
var exportCmd = &cli.Command{
|
||||
Use: "export",
|
||||
Short: "Export golden set to training JSONL and Parquet",
|
||||
Long: "Reads golden set from DuckDB, filters, splits, and exports to JSONL and optionally Parquet.",
|
||||
RunE: runExport,
|
||||
}
|
||||
|
||||
func init() {
|
||||
exportCmd.Flags().StringVar(&exportOutputDir, "output-dir", "", "Output directory for training files (required)")
|
||||
exportCmd.Flags().IntVar(&exportMinChars, "min-chars", 50, "Minimum response length in characters")
|
||||
exportCmd.Flags().IntVar(&exportTrainPct, "train", 80, "Training split percentage")
|
||||
exportCmd.Flags().IntVar(&exportValidPct, "valid", 10, "Validation split percentage")
|
||||
exportCmd.Flags().IntVar(&exportTestPct, "test", 10, "Test split percentage")
|
||||
exportCmd.Flags().Int64Var(&exportSeed, "seed", 42, "Random seed for shuffle")
|
||||
exportCmd.Flags().BoolVar(&exportParquet, "parquet", false, "Also export Parquet files")
|
||||
exportCmd.MarkFlagRequired("output-dir")
|
||||
}
|
||||
|
||||
func runExport(cmd *cli.Command, args []string) error {
|
||||
if err := ml.ValidatePercentages(exportTrainPct, exportValidPct, exportTestPct); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
path := dbPath
|
||||
if path == "" {
|
||||
path = os.Getenv("LEM_DB")
|
||||
}
|
||||
if path == "" {
|
||||
return fmt.Errorf("--db or LEM_DB env is required")
|
||||
}
|
||||
|
||||
db, err := ml.OpenDB(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
rows, err := db.QueryGoldenSet(exportMinChars)
|
||||
if err != nil {
|
||||
return fmt.Errorf("query golden set: %w", err)
|
||||
}
|
||||
fmt.Printf("Loaded %d golden set rows (min %d chars)\n", len(rows), exportMinChars)
|
||||
|
||||
// Convert to Response format.
|
||||
var responses []ml.Response
|
||||
for _, r := range rows {
|
||||
responses = append(responses, ml.Response{
|
||||
ID: r.SeedID,
|
||||
Domain: r.Domain,
|
||||
Prompt: r.Prompt,
|
||||
Response: r.Response,
|
||||
})
|
||||
}
|
||||
|
||||
filtered := ml.FilterResponses(responses)
|
||||
fmt.Printf("After filtering: %d responses\n", len(filtered))
|
||||
|
||||
train, valid, test := ml.SplitData(filtered, exportTrainPct, exportValidPct, exportTestPct, exportSeed)
|
||||
fmt.Printf("Split: train=%d, valid=%d, test=%d\n", len(train), len(valid), len(test))
|
||||
|
||||
if err := os.MkdirAll(exportOutputDir, 0755); err != nil {
|
||||
return fmt.Errorf("create output dir: %w", err)
|
||||
}
|
||||
|
||||
for _, split := range []struct {
|
||||
name string
|
||||
data []ml.Response
|
||||
}{
|
||||
{"train", train},
|
||||
{"valid", valid},
|
||||
{"test", test},
|
||||
} {
|
||||
path := fmt.Sprintf("%s/%s.jsonl", exportOutputDir, split.name)
|
||||
if err := ml.WriteTrainingJSONL(path, split.data); err != nil {
|
||||
return fmt.Errorf("write %s: %w", split.name, err)
|
||||
}
|
||||
fmt.Printf(" %s.jsonl: %d examples\n", split.name, len(split.data))
|
||||
}
|
||||
|
||||
if exportParquet {
|
||||
n, err := ml.ExportParquet(exportOutputDir, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("export parquet: %w", err)
|
||||
}
|
||||
fmt.Printf(" Parquet: %d total rows\n", n)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
40
internal/cmd/ml/cmd_gguf.go
Normal file
40
internal/cmd/ml/cmd_gguf.go
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/cli/pkg/ml"
|
||||
)
|
||||
|
||||
var (
|
||||
ggufInput string
|
||||
ggufConfig string
|
||||
ggufOutput string
|
||||
ggufArch string
|
||||
)
|
||||
|
||||
var ggufCmd = &cli.Command{
|
||||
Use: "gguf",
|
||||
Short: "Convert MLX LoRA adapter to GGUF format",
|
||||
Long: "Converts an MLX safetensors LoRA adapter to GGUF v3 format for use with llama.cpp.",
|
||||
RunE: runGGUF,
|
||||
}
|
||||
|
||||
func init() {
|
||||
ggufCmd.Flags().StringVar(&ggufInput, "input", "", "Input safetensors file (required)")
|
||||
ggufCmd.Flags().StringVar(&ggufConfig, "config", "", "Adapter config JSON (required)")
|
||||
ggufCmd.Flags().StringVar(&ggufOutput, "output", "", "Output GGUF file (required)")
|
||||
ggufCmd.Flags().StringVar(&ggufArch, "arch", "gemma3", "GGUF architecture name")
|
||||
ggufCmd.MarkFlagRequired("input")
|
||||
ggufCmd.MarkFlagRequired("config")
|
||||
ggufCmd.MarkFlagRequired("output")
|
||||
}
|
||||
|
||||
func runGGUF(cmd *cli.Command, args []string) error {
|
||||
if err := ml.ConvertMLXtoGGUFLoRA(ggufInput, ggufConfig, ggufOutput, ggufArch); err != nil {
|
||||
return fmt.Errorf("convert to GGUF: %w", err)
|
||||
}
|
||||
fmt.Printf("GGUF LoRA adapter written to %s\n", ggufOutput)
|
||||
return nil
|
||||
}
|
||||
63
internal/cmd/ml/cmd_ml.go
Normal file
63
internal/cmd/ml/cmd_ml.go
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
// Package ml provides ML inference, scoring, and training pipeline commands.
|
||||
//
|
||||
// Commands:
|
||||
// - core ml score: Score responses with heuristic and LLM judges
|
||||
// - core ml probe: Run capability and content probes against a model
|
||||
// - core ml export: Export golden set to training JSONL/Parquet
|
||||
// - core ml expand: Generate expansion responses
|
||||
// - core ml status: Show training and generation progress
|
||||
// - core ml gguf: Convert MLX LoRA adapter to GGUF format
|
||||
// - core ml convert: Convert MLX LoRA adapter to PEFT format
|
||||
// - core ml agent: Run the scoring agent daemon
|
||||
// - core ml worker: Run a distributed worker node
|
||||
package ml
|
||||
|
||||
import (
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
)
|
||||
|
||||
func init() {
|
||||
cli.RegisterCommands(AddMLCommands)
|
||||
}
|
||||
|
||||
var mlCmd = &cli.Command{
|
||||
Use: "ml",
|
||||
Short: "ML inference, scoring, and training pipeline",
|
||||
Long: "Commands for ML model scoring, probe evaluation, data export, and format conversion.",
|
||||
}
|
||||
|
||||
// AddMLCommands registers the 'ml' command and all subcommands.
|
||||
func AddMLCommands(root *cli.Command) {
|
||||
initFlags()
|
||||
mlCmd.AddCommand(scoreCmd)
|
||||
mlCmd.AddCommand(probeCmd)
|
||||
mlCmd.AddCommand(exportCmd)
|
||||
mlCmd.AddCommand(expandCmd)
|
||||
mlCmd.AddCommand(statusCmd)
|
||||
mlCmd.AddCommand(ggufCmd)
|
||||
mlCmd.AddCommand(convertCmd)
|
||||
mlCmd.AddCommand(agentCmd)
|
||||
mlCmd.AddCommand(workerCmd)
|
||||
root.AddCommand(mlCmd)
|
||||
}
|
||||
|
||||
// Shared persistent flags.
|
||||
var (
|
||||
apiURL string
|
||||
judgeURL string
|
||||
judgeModel string
|
||||
influxURL string
|
||||
influxDB string
|
||||
dbPath string
|
||||
modelName string
|
||||
)
|
||||
|
||||
func initFlags() {
|
||||
mlCmd.PersistentFlags().StringVar(&apiURL, "api-url", "http://10.69.69.108:8090", "OpenAI-compatible API URL")
|
||||
mlCmd.PersistentFlags().StringVar(&judgeURL, "judge-url", "http://10.69.69.108:11434", "Judge model API URL (Ollama)")
|
||||
mlCmd.PersistentFlags().StringVar(&judgeModel, "judge-model", "gemma3:27b", "Judge model name")
|
||||
mlCmd.PersistentFlags().StringVar(&influxURL, "influx", "", "InfluxDB URL (default http://10.69.69.165:8181)")
|
||||
mlCmd.PersistentFlags().StringVar(&influxDB, "influx-db", "", "InfluxDB database (default training)")
|
||||
mlCmd.PersistentFlags().StringVar(&dbPath, "db", "", "DuckDB database path (or set LEM_DB env)")
|
||||
mlCmd.PersistentFlags().StringVar(&modelName, "model", "", "Model name for API")
|
||||
}
|
||||
66
internal/cmd/ml/cmd_probe.go
Normal file
66
internal/cmd/ml/cmd_probe.go
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/cli/pkg/ml"
|
||||
)
|
||||
|
||||
var (
|
||||
probeOutput string
|
||||
)
|
||||
|
||||
var probeCmd = &cli.Command{
|
||||
Use: "probe",
|
||||
Short: "Run capability and content probes against a model",
|
||||
Long: "Runs 23 capability probes and 6 content probes against an OpenAI-compatible API.",
|
||||
RunE: runProbe,
|
||||
}
|
||||
|
||||
func init() {
|
||||
probeCmd.Flags().StringVar(&probeOutput, "output", "", "Output JSON file for probe results")
|
||||
}
|
||||
|
||||
func runProbe(cmd *cli.Command, args []string) error {
|
||||
if apiURL == "" {
|
||||
return fmt.Errorf("--api-url is required")
|
||||
}
|
||||
|
||||
model := modelName
|
||||
if model == "" {
|
||||
model = "default"
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
backend := ml.NewHTTPBackend(apiURL, model)
|
||||
|
||||
fmt.Printf("Running %d capability probes against %s...\n", len(ml.CapabilityProbes), apiURL)
|
||||
results := ml.RunCapabilityProbes(ctx, backend)
|
||||
|
||||
fmt.Printf("\nResults: %.1f%% (%d/%d)\n", results.Accuracy, results.Correct, results.Total)
|
||||
|
||||
for cat, data := range results.ByCategory {
|
||||
catAcc := 0.0
|
||||
if data.Total > 0 {
|
||||
catAcc = float64(data.Correct) / float64(data.Total) * 100
|
||||
}
|
||||
fmt.Printf(" %-20s %d/%d (%.0f%%)\n", cat, data.Correct, data.Total, catAcc)
|
||||
}
|
||||
|
||||
if probeOutput != "" {
|
||||
data, err := json.MarshalIndent(results, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal results: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(probeOutput, data, 0644); err != nil {
|
||||
return fmt.Errorf("write output: %w", err)
|
||||
}
|
||||
fmt.Printf("\nResults written to %s\n", probeOutput)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
77
internal/cmd/ml/cmd_score.go
Normal file
77
internal/cmd/ml/cmd_score.go
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/cli/pkg/ml"
|
||||
)
|
||||
|
||||
var (
|
||||
scoreInput string
|
||||
scoreSuites string
|
||||
scoreOutput string
|
||||
scoreConcur int
|
||||
)
|
||||
|
||||
var scoreCmd = &cli.Command{
|
||||
Use: "score",
|
||||
Short: "Score responses with heuristic and LLM judges",
|
||||
Long: "Reads a JSONL file of prompt/response pairs and scores them across configured suites.",
|
||||
RunE: runScore,
|
||||
}
|
||||
|
||||
func init() {
|
||||
scoreCmd.Flags().StringVar(&scoreInput, "input", "", "Input JSONL file with prompt/response pairs (required)")
|
||||
scoreCmd.Flags().StringVar(&scoreSuites, "suites", "all", "Comma-separated scoring suites (heuristic,semantic,content,exact,truthfulqa,donotanswer,toxigen)")
|
||||
scoreCmd.Flags().StringVar(&scoreOutput, "output", "", "Output JSON file for scores")
|
||||
scoreCmd.Flags().IntVar(&scoreConcur, "concurrency", 4, "Number of concurrent scoring workers")
|
||||
scoreCmd.MarkFlagRequired("input")
|
||||
}
|
||||
|
||||
func runScore(cmd *cli.Command, args []string) error {
|
||||
responses, err := ml.ReadResponses(scoreInput)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read input: %w", err)
|
||||
}
|
||||
|
||||
var judge *ml.Judge
|
||||
if judgeURL != "" {
|
||||
backend := ml.NewHTTPBackend(judgeURL, judgeModel)
|
||||
judge = ml.NewJudge(backend)
|
||||
}
|
||||
|
||||
engine := ml.NewEngine(judge, scoreConcur, scoreSuites)
|
||||
|
||||
ctx := context.Background()
|
||||
perPrompt := engine.ScoreAll(ctx, responses)
|
||||
averages := ml.ComputeAverages(perPrompt)
|
||||
|
||||
if scoreOutput != "" {
|
||||
output := &ml.ScorerOutput{
|
||||
Metadata: ml.Metadata{
|
||||
JudgeModel: judgeModel,
|
||||
JudgeURL: judgeURL,
|
||||
ScoredAt: time.Now(),
|
||||
Suites: ml.SplitComma(scoreSuites),
|
||||
},
|
||||
ModelAverages: averages,
|
||||
PerPrompt: perPrompt,
|
||||
}
|
||||
if err := ml.WriteScores(scoreOutput, output); err != nil {
|
||||
return fmt.Errorf("write output: %w", err)
|
||||
}
|
||||
fmt.Printf("Scores written to %s\n", scoreOutput)
|
||||
} else {
|
||||
for model, avgs := range averages {
|
||||
fmt.Printf("%s:\n", model)
|
||||
for field, val := range avgs {
|
||||
fmt.Printf(" %-25s %.3f\n", field, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
54
internal/cmd/ml/cmd_status.go
Normal file
54
internal/cmd/ml/cmd_status.go
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/cli/pkg/ml"
|
||||
)
|
||||
|
||||
var statusCmd = &cli.Command{
|
||||
Use: "status",
|
||||
Short: "Show training and generation progress",
|
||||
Long: "Queries InfluxDB for training status, loss, and generation progress. Optionally shows DuckDB table counts.",
|
||||
RunE: runStatus,
|
||||
}
|
||||
|
||||
func runStatus(cmd *cli.Command, args []string) error {
|
||||
influx := ml.NewInfluxClient(influxURL, influxDB)
|
||||
|
||||
if err := ml.PrintStatus(influx, os.Stdout); err != nil {
|
||||
return fmt.Errorf("status: %w", err)
|
||||
}
|
||||
|
||||
path := dbPath
|
||||
if path == "" {
|
||||
path = os.Getenv("LEM_DB")
|
||||
}
|
||||
|
||||
if path != "" {
|
||||
db, err := ml.OpenDB(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
counts, err := db.TableCounts()
|
||||
if err != nil {
|
||||
return fmt.Errorf("table counts: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("DuckDB:")
|
||||
order := []string{"golden_set", "expansion_prompts", "seeds", "training_examples",
|
||||
"prompts", "gemini_responses", "benchmark_questions", "benchmark_results", "validations"}
|
||||
for _, table := range order {
|
||||
if count, ok := counts[table]; ok {
|
||||
fmt.Fprintf(os.Stdout, " %-22s %6d rows\n", table, count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
80
internal/cmd/ml/cmd_worker.go
Normal file
80
internal/cmd/ml/cmd_worker.go
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/cli/pkg/ml"
|
||||
)
|
||||
|
||||
var (
|
||||
workerAPIBase string
|
||||
workerID string
|
||||
workerName string
|
||||
workerAPIKey string
|
||||
workerGPU string
|
||||
workerVRAM int
|
||||
workerLangs string
|
||||
workerModels string
|
||||
workerInferURL string
|
||||
workerTaskType string
|
||||
workerBatchSize int
|
||||
workerPoll time.Duration
|
||||
workerOneShot bool
|
||||
workerDryRun bool
|
||||
)
|
||||
|
||||
var workerCmd = &cli.Command{
|
||||
Use: "worker",
|
||||
Short: "Run a distributed worker node",
|
||||
Long: "Polls the LEM API for tasks, runs local inference, and submits results.",
|
||||
RunE: runWorker,
|
||||
}
|
||||
|
||||
func init() {
|
||||
workerCmd.Flags().StringVar(&workerAPIBase, "api", ml.EnvOr("LEM_API", "https://infer.lthn.ai"), "LEM API base URL")
|
||||
workerCmd.Flags().StringVar(&workerID, "id", ml.EnvOr("LEM_WORKER_ID", ml.MachineID()), "Worker ID")
|
||||
workerCmd.Flags().StringVar(&workerName, "name", ml.EnvOr("LEM_WORKER_NAME", ml.Hostname()), "Worker display name")
|
||||
workerCmd.Flags().StringVar(&workerAPIKey, "key", ml.EnvOr("LEM_API_KEY", ""), "API key")
|
||||
workerCmd.Flags().StringVar(&workerGPU, "gpu", ml.EnvOr("LEM_GPU", ""), "GPU type")
|
||||
workerCmd.Flags().IntVar(&workerVRAM, "vram", ml.IntEnvOr("LEM_VRAM_GB", 0), "GPU VRAM in GB")
|
||||
workerCmd.Flags().StringVar(&workerLangs, "languages", ml.EnvOr("LEM_LANGUAGES", ""), "Comma-separated language codes")
|
||||
workerCmd.Flags().StringVar(&workerModels, "models", ml.EnvOr("LEM_MODELS", ""), "Comma-separated model names")
|
||||
workerCmd.Flags().StringVar(&workerInferURL, "infer", ml.EnvOr("LEM_INFER_URL", "http://localhost:8090"), "Local inference endpoint")
|
||||
workerCmd.Flags().StringVar(&workerTaskType, "type", "", "Filter by task type")
|
||||
workerCmd.Flags().IntVar(&workerBatchSize, "batch", 5, "Tasks per poll")
|
||||
workerCmd.Flags().DurationVar(&workerPoll, "poll", 30*time.Second, "Poll interval")
|
||||
workerCmd.Flags().BoolVar(&workerOneShot, "one-shot", false, "Process one batch and exit")
|
||||
workerCmd.Flags().BoolVar(&workerDryRun, "dry-run", false, "Fetch tasks but don't run inference")
|
||||
}
|
||||
|
||||
func runWorker(cmd *cli.Command, args []string) error {
|
||||
if workerAPIKey == "" {
|
||||
workerAPIKey = ml.ReadKeyFile()
|
||||
}
|
||||
|
||||
cfg := &ml.WorkerConfig{
|
||||
APIBase: workerAPIBase,
|
||||
WorkerID: workerID,
|
||||
Name: workerName,
|
||||
APIKey: workerAPIKey,
|
||||
GPUType: workerGPU,
|
||||
VRAMGb: workerVRAM,
|
||||
InferURL: workerInferURL,
|
||||
TaskType: workerTaskType,
|
||||
BatchSize: workerBatchSize,
|
||||
PollInterval: workerPoll,
|
||||
OneShot: workerOneShot,
|
||||
DryRun: workerDryRun,
|
||||
}
|
||||
|
||||
if workerLangs != "" {
|
||||
cfg.Languages = ml.SplitComma(workerLangs)
|
||||
}
|
||||
if workerModels != "" {
|
||||
cfg.Models = ml.SplitComma(workerModels)
|
||||
}
|
||||
|
||||
ml.RunWorkerLoop(cfg)
|
||||
return nil
|
||||
}
|
||||
162
pkg/ml/service.go
Normal file
162
pkg/ml/service.go
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/framework"
|
||||
)
|
||||
|
||||
// Service manages ML inference backends and scoring with Core lifecycle.
|
||||
type Service struct {
|
||||
*framework.ServiceRuntime[Options]
|
||||
|
||||
backends map[string]Backend
|
||||
mu sync.RWMutex
|
||||
engine *Engine
|
||||
judge *Judge
|
||||
}
|
||||
|
||||
// Options configures the ML service.
|
||||
type Options struct {
|
||||
// DefaultBackend is the name of the default inference backend.
|
||||
DefaultBackend string
|
||||
|
||||
// LlamaPath is the path to the llama-server binary.
|
||||
LlamaPath string
|
||||
|
||||
// ModelDir is the directory containing model files.
|
||||
ModelDir string
|
||||
|
||||
// OllamaURL is the Ollama API base URL.
|
||||
OllamaURL string
|
||||
|
||||
// JudgeURL is the judge model API URL.
|
||||
JudgeURL string
|
||||
|
||||
// JudgeModel is the judge model name.
|
||||
JudgeModel string
|
||||
|
||||
// InfluxURL is the InfluxDB URL for metrics.
|
||||
InfluxURL string
|
||||
|
||||
// InfluxDB is the InfluxDB database name.
|
||||
InfluxDB string
|
||||
|
||||
// Concurrency is the number of concurrent scoring workers.
|
||||
Concurrency int
|
||||
|
||||
// Suites is a comma-separated list of scoring suites to enable.
|
||||
Suites string
|
||||
}
|
||||
|
||||
// NewService creates an ML service factory for Core registration.
|
||||
//
|
||||
// core, _ := framework.New(
|
||||
// framework.WithName("ml", ml.NewService(ml.Options{})),
|
||||
// )
|
||||
func NewService(opts Options) func(*framework.Core) (any, error) {
|
||||
return func(c *framework.Core) (any, error) {
|
||||
if opts.Concurrency == 0 {
|
||||
opts.Concurrency = 4
|
||||
}
|
||||
if opts.Suites == "" {
|
||||
opts.Suites = "all"
|
||||
}
|
||||
|
||||
svc := &Service{
|
||||
ServiceRuntime: framework.NewServiceRuntime(c, opts),
|
||||
backends: make(map[string]Backend),
|
||||
}
|
||||
return svc, nil
|
||||
}
|
||||
}
|
||||
|
||||
// OnStartup initializes backends and scoring engine.
|
||||
func (s *Service) OnStartup(ctx context.Context) error {
|
||||
opts := s.Opts()
|
||||
|
||||
// Register Ollama backend if URL provided.
|
||||
if opts.OllamaURL != "" {
|
||||
s.RegisterBackend("ollama", NewHTTPBackend(opts.OllamaURL, opts.JudgeModel))
|
||||
}
|
||||
|
||||
// Set up judge if judge URL is provided.
|
||||
if opts.JudgeURL != "" {
|
||||
judgeBackend := NewHTTPBackend(opts.JudgeURL, opts.JudgeModel)
|
||||
s.judge = NewJudge(judgeBackend)
|
||||
s.engine = NewEngine(s.judge, opts.Concurrency, opts.Suites)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnShutdown cleans up resources.
|
||||
func (s *Service) OnShutdown(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterBackend adds or replaces a named inference backend.
|
||||
func (s *Service) RegisterBackend(name string, backend Backend) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.backends[name] = backend
|
||||
}
|
||||
|
||||
// Backend returns a named backend, or nil if not found.
|
||||
func (s *Service) Backend(name string) Backend {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.backends[name]
|
||||
}
|
||||
|
||||
// DefaultBackend returns the configured default backend.
|
||||
func (s *Service) DefaultBackend() Backend {
|
||||
name := s.Opts().DefaultBackend
|
||||
if name == "" {
|
||||
name = "ollama"
|
||||
}
|
||||
return s.Backend(name)
|
||||
}
|
||||
|
||||
// Backends returns the names of all registered backends.
|
||||
func (s *Service) Backends() []string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
names := make([]string, 0, len(s.backends))
|
||||
for name := range s.backends {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// Judge returns the configured judge, or nil if not set up.
|
||||
func (s *Service) Judge() *Judge {
|
||||
return s.judge
|
||||
}
|
||||
|
||||
// Engine returns the scoring engine, or nil if not set up.
|
||||
func (s *Service) Engine() *Engine {
|
||||
return s.engine
|
||||
}
|
||||
|
||||
// Generate generates text using the named backend (or default).
|
||||
func (s *Service) Generate(ctx context.Context, backendName, prompt string, opts GenOpts) (string, error) {
|
||||
b := s.Backend(backendName)
|
||||
if b == nil {
|
||||
b = s.DefaultBackend()
|
||||
}
|
||||
if b == nil {
|
||||
return "", fmt.Errorf("no backend available (requested: %q)", backendName)
|
||||
}
|
||||
return b.Generate(ctx, prompt, opts)
|
||||
}
|
||||
|
||||
// ScoreResponses scores a batch of responses using the configured engine.
|
||||
func (s *Service) ScoreResponses(ctx context.Context, responses []Response) (map[string][]PromptScore, error) {
|
||||
if s.engine == nil {
|
||||
return nil, fmt.Errorf("scoring engine not configured (set JudgeURL and JudgeModel)")
|
||||
}
|
||||
return s.engine.ScoreAll(ctx, responses), nil
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue