diff --git a/internal/cmd/ml/cmd_agent.go b/internal/cmd/ml/cmd_agent.go new file mode 100644 index 0000000..841ddc4 --- /dev/null +++ b/internal/cmd/ml/cmd_agent.go @@ -0,0 +1,67 @@ +package ml + +import ( + "forge.lthn.ai/core/cli/pkg/cli" + "forge.lthn.ai/core/cli/pkg/ml" +) + +var ( + agentM3Host string + agentM3User string + agentM3SSHKey string + agentM3AdapterBase string + agentBaseModel string + agentPollInterval int + agentWorkDir string + agentFilter string + agentForce bool + agentOneShot bool + agentDryRun bool +) + +var agentCmd = &cli.Command{ + Use: "agent", + Short: "Run the scoring agent daemon", + Long: "Polls M3 for unscored LoRA checkpoints, converts, probes, and pushes results to InfluxDB.", + RunE: runAgent, +} + +func init() { + agentCmd.Flags().StringVar(&agentM3Host, "m3-host", ml.EnvOr("M3_HOST", "10.69.69.108"), "M3 host address") + agentCmd.Flags().StringVar(&agentM3User, "m3-user", ml.EnvOr("M3_USER", "claude"), "M3 SSH user") + agentCmd.Flags().StringVar(&agentM3SSHKey, "m3-ssh-key", ml.EnvOr("M3_SSH_KEY", ml.ExpandHome("~/.ssh/id_ed25519")), "SSH key for M3") + agentCmd.Flags().StringVar(&agentM3AdapterBase, "m3-adapter-base", ml.EnvOr("M3_ADAPTER_BASE", "/Volumes/Data/lem"), "Adapter base dir on M3") + agentCmd.Flags().StringVar(&agentBaseModel, "base-model", ml.EnvOr("BASE_MODEL", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"), "HuggingFace base model ID") + agentCmd.Flags().IntVar(&agentPollInterval, "poll", ml.IntEnvOr("POLL_INTERVAL", 300), "Poll interval in seconds") + agentCmd.Flags().StringVar(&agentWorkDir, "work-dir", ml.EnvOr("WORK_DIR", "/tmp/scoring-agent"), "Working directory for adapters") + agentCmd.Flags().StringVar(&agentFilter, "filter", "", "Filter adapter dirs by prefix") + agentCmd.Flags().BoolVar(&agentForce, "force", false, "Re-score already-scored checkpoints") + agentCmd.Flags().BoolVar(&agentOneShot, "one-shot", false, "Process one checkpoint and exit") + agentCmd.Flags().BoolVar(&agentDryRun, "dry-run", false, "Discover and plan but don't execute") +} + +func runAgent(cmd *cli.Command, args []string) error { + cfg := &ml.AgentConfig{ + M3Host: agentM3Host, + M3User: agentM3User, + M3SSHKey: agentM3SSHKey, + M3AdapterBase: agentM3AdapterBase, + InfluxURL: influxURL, + InfluxDB: influxDB, + DBPath: dbPath, + APIURL: apiURL, + JudgeURL: judgeURL, + JudgeModel: judgeModel, + Model: modelName, + BaseModel: agentBaseModel, + PollInterval: agentPollInterval, + WorkDir: agentWorkDir, + Filter: agentFilter, + Force: agentForce, + OneShot: agentOneShot, + DryRun: agentDryRun, + } + + ml.RunAgentLoop(cfg) + return nil +} diff --git a/internal/cmd/ml/cmd_convert.go b/internal/cmd/ml/cmd_convert.go new file mode 100644 index 0000000..11b544e --- /dev/null +++ b/internal/cmd/ml/cmd_convert.go @@ -0,0 +1,40 @@ +package ml + +import ( + "fmt" + + "forge.lthn.ai/core/cli/pkg/cli" + "forge.lthn.ai/core/cli/pkg/ml" +) + +var ( + convertInput string + convertConfig string + convertOutputDir string + convertBaseModel string +) + +var convertCmd = &cli.Command{ + Use: "convert", + Short: "Convert MLX LoRA adapter to PEFT format", + Long: "Converts an MLX safetensors LoRA adapter to HuggingFace PEFT format for Ollama.", + RunE: runConvert, +} + +func init() { + convertCmd.Flags().StringVar(&convertInput, "input", "", "Input safetensors file (required)") + convertCmd.Flags().StringVar(&convertConfig, "config", "", "Adapter config JSON (required)") + convertCmd.Flags().StringVar(&convertOutputDir, "output-dir", "", "Output directory (required)") + convertCmd.Flags().StringVar(&convertBaseModel, "base-model", "", "Base model name for adapter_config.json") + convertCmd.MarkFlagRequired("input") + convertCmd.MarkFlagRequired("config") + convertCmd.MarkFlagRequired("output-dir") +} + +func runConvert(cmd *cli.Command, args []string) error { + if err := ml.ConvertMLXtoPEFT(convertInput, convertConfig, convertOutputDir, convertBaseModel); err != nil { + return fmt.Errorf("convert to PEFT: %w", err) + } + fmt.Printf("PEFT adapter written to %s\n", convertOutputDir) + return nil +} diff --git a/internal/cmd/ml/cmd_expand.go b/internal/cmd/ml/cmd_expand.go new file mode 100644 index 0000000..1dd3b97 --- /dev/null +++ b/internal/cmd/ml/cmd_expand.go @@ -0,0 +1,81 @@ +package ml + +import ( + "context" + "fmt" + "os" + + "forge.lthn.ai/core/cli/pkg/cli" + "forge.lthn.ai/core/cli/pkg/ml" +) + +var ( + expandWorker string + expandOutput string + expandLimit int + expandDryRun bool +) + +var expandCmd = &cli.Command{ + Use: "expand", + Short: "Generate expansion responses from pending prompts", + Long: "Reads pending expansion prompts from DuckDB and generates responses via an OpenAI-compatible API.", + RunE: runExpand, +} + +func init() { + expandCmd.Flags().StringVar(&expandWorker, "worker", "", "Worker hostname (defaults to os.Hostname())") + expandCmd.Flags().StringVar(&expandOutput, "output", ".", "Output directory for JSONL files") + expandCmd.Flags().IntVar(&expandLimit, "limit", 0, "Max prompts to process (0 = all)") + expandCmd.Flags().BoolVar(&expandDryRun, "dry-run", false, "Print plan and exit without generating") +} + +func runExpand(cmd *cli.Command, args []string) error { + if modelName == "" { + return fmt.Errorf("--model is required") + } + + path := dbPath + if path == "" { + path = os.Getenv("LEM_DB") + } + if path == "" { + return fmt.Errorf("--db or LEM_DB env is required") + } + + if expandWorker == "" { + h, _ := os.Hostname() + expandWorker = h + } + + db, err := ml.OpenDBReadWrite(path) + if err != nil { + return fmt.Errorf("open db: %w", err) + } + defer db.Close() + + rows, err := db.QueryExpansionPrompts("pending", expandLimit) + if err != nil { + return fmt.Errorf("query expansion_prompts: %w", err) + } + fmt.Printf("Loaded %d pending prompts from %s\n", len(rows), path) + + var prompts []ml.Response + for _, r := range rows { + prompt := r.Prompt + if prompt == "" && r.PromptEn != "" { + prompt = r.PromptEn + } + prompts = append(prompts, ml.Response{ + ID: r.SeedID, + Domain: r.Domain, + Prompt: prompt, + }) + } + + ctx := context.Background() + backend := ml.NewHTTPBackend(apiURL, modelName) + influx := ml.NewInfluxClient(influxURL, influxDB) + + return ml.ExpandPrompts(ctx, backend, influx, prompts, modelName, expandWorker, expandOutput, expandDryRun, expandLimit) +} diff --git a/internal/cmd/ml/cmd_export.go b/internal/cmd/ml/cmd_export.go new file mode 100644 index 0000000..2e6dba4 --- /dev/null +++ b/internal/cmd/ml/cmd_export.go @@ -0,0 +1,109 @@ +package ml + +import ( + "fmt" + "os" + + "forge.lthn.ai/core/cli/pkg/cli" + "forge.lthn.ai/core/cli/pkg/ml" +) + +var ( + exportOutputDir string + exportMinChars int + exportTrainPct int + exportValidPct int + exportTestPct int + exportSeed int64 + exportParquet bool +) + +var exportCmd = &cli.Command{ + Use: "export", + Short: "Export golden set to training JSONL and Parquet", + Long: "Reads golden set from DuckDB, filters, splits, and exports to JSONL and optionally Parquet.", + RunE: runExport, +} + +func init() { + exportCmd.Flags().StringVar(&exportOutputDir, "output-dir", "", "Output directory for training files (required)") + exportCmd.Flags().IntVar(&exportMinChars, "min-chars", 50, "Minimum response length in characters") + exportCmd.Flags().IntVar(&exportTrainPct, "train", 80, "Training split percentage") + exportCmd.Flags().IntVar(&exportValidPct, "valid", 10, "Validation split percentage") + exportCmd.Flags().IntVar(&exportTestPct, "test", 10, "Test split percentage") + exportCmd.Flags().Int64Var(&exportSeed, "seed", 42, "Random seed for shuffle") + exportCmd.Flags().BoolVar(&exportParquet, "parquet", false, "Also export Parquet files") + exportCmd.MarkFlagRequired("output-dir") +} + +func runExport(cmd *cli.Command, args []string) error { + if err := ml.ValidatePercentages(exportTrainPct, exportValidPct, exportTestPct); err != nil { + return err + } + + path := dbPath + if path == "" { + path = os.Getenv("LEM_DB") + } + if path == "" { + return fmt.Errorf("--db or LEM_DB env is required") + } + + db, err := ml.OpenDB(path) + if err != nil { + return fmt.Errorf("open db: %w", err) + } + defer db.Close() + + rows, err := db.QueryGoldenSet(exportMinChars) + if err != nil { + return fmt.Errorf("query golden set: %w", err) + } + fmt.Printf("Loaded %d golden set rows (min %d chars)\n", len(rows), exportMinChars) + + // Convert to Response format. + var responses []ml.Response + for _, r := range rows { + responses = append(responses, ml.Response{ + ID: r.SeedID, + Domain: r.Domain, + Prompt: r.Prompt, + Response: r.Response, + }) + } + + filtered := ml.FilterResponses(responses) + fmt.Printf("After filtering: %d responses\n", len(filtered)) + + train, valid, test := ml.SplitData(filtered, exportTrainPct, exportValidPct, exportTestPct, exportSeed) + fmt.Printf("Split: train=%d, valid=%d, test=%d\n", len(train), len(valid), len(test)) + + if err := os.MkdirAll(exportOutputDir, 0755); err != nil { + return fmt.Errorf("create output dir: %w", err) + } + + for _, split := range []struct { + name string + data []ml.Response + }{ + {"train", train}, + {"valid", valid}, + {"test", test}, + } { + path := fmt.Sprintf("%s/%s.jsonl", exportOutputDir, split.name) + if err := ml.WriteTrainingJSONL(path, split.data); err != nil { + return fmt.Errorf("write %s: %w", split.name, err) + } + fmt.Printf(" %s.jsonl: %d examples\n", split.name, len(split.data)) + } + + if exportParquet { + n, err := ml.ExportParquet(exportOutputDir, "") + if err != nil { + return fmt.Errorf("export parquet: %w", err) + } + fmt.Printf(" Parquet: %d total rows\n", n) + } + + return nil +} diff --git a/internal/cmd/ml/cmd_gguf.go b/internal/cmd/ml/cmd_gguf.go new file mode 100644 index 0000000..6545554 --- /dev/null +++ b/internal/cmd/ml/cmd_gguf.go @@ -0,0 +1,40 @@ +package ml + +import ( + "fmt" + + "forge.lthn.ai/core/cli/pkg/cli" + "forge.lthn.ai/core/cli/pkg/ml" +) + +var ( + ggufInput string + ggufConfig string + ggufOutput string + ggufArch string +) + +var ggufCmd = &cli.Command{ + Use: "gguf", + Short: "Convert MLX LoRA adapter to GGUF format", + Long: "Converts an MLX safetensors LoRA adapter to GGUF v3 format for use with llama.cpp.", + RunE: runGGUF, +} + +func init() { + ggufCmd.Flags().StringVar(&ggufInput, "input", "", "Input safetensors file (required)") + ggufCmd.Flags().StringVar(&ggufConfig, "config", "", "Adapter config JSON (required)") + ggufCmd.Flags().StringVar(&ggufOutput, "output", "", "Output GGUF file (required)") + ggufCmd.Flags().StringVar(&ggufArch, "arch", "gemma3", "GGUF architecture name") + ggufCmd.MarkFlagRequired("input") + ggufCmd.MarkFlagRequired("config") + ggufCmd.MarkFlagRequired("output") +} + +func runGGUF(cmd *cli.Command, args []string) error { + if err := ml.ConvertMLXtoGGUFLoRA(ggufInput, ggufConfig, ggufOutput, ggufArch); err != nil { + return fmt.Errorf("convert to GGUF: %w", err) + } + fmt.Printf("GGUF LoRA adapter written to %s\n", ggufOutput) + return nil +} diff --git a/internal/cmd/ml/cmd_ml.go b/internal/cmd/ml/cmd_ml.go new file mode 100644 index 0000000..07a908c --- /dev/null +++ b/internal/cmd/ml/cmd_ml.go @@ -0,0 +1,63 @@ +// Package ml provides ML inference, scoring, and training pipeline commands. +// +// Commands: +// - core ml score: Score responses with heuristic and LLM judges +// - core ml probe: Run capability and content probes against a model +// - core ml export: Export golden set to training JSONL/Parquet +// - core ml expand: Generate expansion responses +// - core ml status: Show training and generation progress +// - core ml gguf: Convert MLX LoRA adapter to GGUF format +// - core ml convert: Convert MLX LoRA adapter to PEFT format +// - core ml agent: Run the scoring agent daemon +// - core ml worker: Run a distributed worker node +package ml + +import ( + "forge.lthn.ai/core/cli/pkg/cli" +) + +func init() { + cli.RegisterCommands(AddMLCommands) +} + +var mlCmd = &cli.Command{ + Use: "ml", + Short: "ML inference, scoring, and training pipeline", + Long: "Commands for ML model scoring, probe evaluation, data export, and format conversion.", +} + +// AddMLCommands registers the 'ml' command and all subcommands. +func AddMLCommands(root *cli.Command) { + initFlags() + mlCmd.AddCommand(scoreCmd) + mlCmd.AddCommand(probeCmd) + mlCmd.AddCommand(exportCmd) + mlCmd.AddCommand(expandCmd) + mlCmd.AddCommand(statusCmd) + mlCmd.AddCommand(ggufCmd) + mlCmd.AddCommand(convertCmd) + mlCmd.AddCommand(agentCmd) + mlCmd.AddCommand(workerCmd) + root.AddCommand(mlCmd) +} + +// Shared persistent flags. +var ( + apiURL string + judgeURL string + judgeModel string + influxURL string + influxDB string + dbPath string + modelName string +) + +func initFlags() { + mlCmd.PersistentFlags().StringVar(&apiURL, "api-url", "http://10.69.69.108:8090", "OpenAI-compatible API URL") + mlCmd.PersistentFlags().StringVar(&judgeURL, "judge-url", "http://10.69.69.108:11434", "Judge model API URL (Ollama)") + mlCmd.PersistentFlags().StringVar(&judgeModel, "judge-model", "gemma3:27b", "Judge model name") + mlCmd.PersistentFlags().StringVar(&influxURL, "influx", "", "InfluxDB URL (default http://10.69.69.165:8181)") + mlCmd.PersistentFlags().StringVar(&influxDB, "influx-db", "", "InfluxDB database (default training)") + mlCmd.PersistentFlags().StringVar(&dbPath, "db", "", "DuckDB database path (or set LEM_DB env)") + mlCmd.PersistentFlags().StringVar(&modelName, "model", "", "Model name for API") +} diff --git a/internal/cmd/ml/cmd_probe.go b/internal/cmd/ml/cmd_probe.go new file mode 100644 index 0000000..72594f8 --- /dev/null +++ b/internal/cmd/ml/cmd_probe.go @@ -0,0 +1,66 @@ +package ml + +import ( + "context" + "encoding/json" + "fmt" + "os" + + "forge.lthn.ai/core/cli/pkg/cli" + "forge.lthn.ai/core/cli/pkg/ml" +) + +var ( + probeOutput string +) + +var probeCmd = &cli.Command{ + Use: "probe", + Short: "Run capability and content probes against a model", + Long: "Runs 23 capability probes and 6 content probes against an OpenAI-compatible API.", + RunE: runProbe, +} + +func init() { + probeCmd.Flags().StringVar(&probeOutput, "output", "", "Output JSON file for probe results") +} + +func runProbe(cmd *cli.Command, args []string) error { + if apiURL == "" { + return fmt.Errorf("--api-url is required") + } + + model := modelName + if model == "" { + model = "default" + } + + ctx := context.Background() + backend := ml.NewHTTPBackend(apiURL, model) + + fmt.Printf("Running %d capability probes against %s...\n", len(ml.CapabilityProbes), apiURL) + results := ml.RunCapabilityProbes(ctx, backend) + + fmt.Printf("\nResults: %.1f%% (%d/%d)\n", results.Accuracy, results.Correct, results.Total) + + for cat, data := range results.ByCategory { + catAcc := 0.0 + if data.Total > 0 { + catAcc = float64(data.Correct) / float64(data.Total) * 100 + } + fmt.Printf(" %-20s %d/%d (%.0f%%)\n", cat, data.Correct, data.Total, catAcc) + } + + if probeOutput != "" { + data, err := json.MarshalIndent(results, "", " ") + if err != nil { + return fmt.Errorf("marshal results: %w", err) + } + if err := os.WriteFile(probeOutput, data, 0644); err != nil { + return fmt.Errorf("write output: %w", err) + } + fmt.Printf("\nResults written to %s\n", probeOutput) + } + + return nil +} diff --git a/internal/cmd/ml/cmd_score.go b/internal/cmd/ml/cmd_score.go new file mode 100644 index 0000000..cb28a18 --- /dev/null +++ b/internal/cmd/ml/cmd_score.go @@ -0,0 +1,77 @@ +package ml + +import ( + "context" + "fmt" + "time" + + "forge.lthn.ai/core/cli/pkg/cli" + "forge.lthn.ai/core/cli/pkg/ml" +) + +var ( + scoreInput string + scoreSuites string + scoreOutput string + scoreConcur int +) + +var scoreCmd = &cli.Command{ + Use: "score", + Short: "Score responses with heuristic and LLM judges", + Long: "Reads a JSONL file of prompt/response pairs and scores them across configured suites.", + RunE: runScore, +} + +func init() { + scoreCmd.Flags().StringVar(&scoreInput, "input", "", "Input JSONL file with prompt/response pairs (required)") + scoreCmd.Flags().StringVar(&scoreSuites, "suites", "all", "Comma-separated scoring suites (heuristic,semantic,content,exact,truthfulqa,donotanswer,toxigen)") + scoreCmd.Flags().StringVar(&scoreOutput, "output", "", "Output JSON file for scores") + scoreCmd.Flags().IntVar(&scoreConcur, "concurrency", 4, "Number of concurrent scoring workers") + scoreCmd.MarkFlagRequired("input") +} + +func runScore(cmd *cli.Command, args []string) error { + responses, err := ml.ReadResponses(scoreInput) + if err != nil { + return fmt.Errorf("read input: %w", err) + } + + var judge *ml.Judge + if judgeURL != "" { + backend := ml.NewHTTPBackend(judgeURL, judgeModel) + judge = ml.NewJudge(backend) + } + + engine := ml.NewEngine(judge, scoreConcur, scoreSuites) + + ctx := context.Background() + perPrompt := engine.ScoreAll(ctx, responses) + averages := ml.ComputeAverages(perPrompt) + + if scoreOutput != "" { + output := &ml.ScorerOutput{ + Metadata: ml.Metadata{ + JudgeModel: judgeModel, + JudgeURL: judgeURL, + ScoredAt: time.Now(), + Suites: ml.SplitComma(scoreSuites), + }, + ModelAverages: averages, + PerPrompt: perPrompt, + } + if err := ml.WriteScores(scoreOutput, output); err != nil { + return fmt.Errorf("write output: %w", err) + } + fmt.Printf("Scores written to %s\n", scoreOutput) + } else { + for model, avgs := range averages { + fmt.Printf("%s:\n", model) + for field, val := range avgs { + fmt.Printf(" %-25s %.3f\n", field, val) + } + } + } + + return nil +} diff --git a/internal/cmd/ml/cmd_status.go b/internal/cmd/ml/cmd_status.go new file mode 100644 index 0000000..35a9020 --- /dev/null +++ b/internal/cmd/ml/cmd_status.go @@ -0,0 +1,54 @@ +package ml + +import ( + "fmt" + "os" + + "forge.lthn.ai/core/cli/pkg/cli" + "forge.lthn.ai/core/cli/pkg/ml" +) + +var statusCmd = &cli.Command{ + Use: "status", + Short: "Show training and generation progress", + Long: "Queries InfluxDB for training status, loss, and generation progress. Optionally shows DuckDB table counts.", + RunE: runStatus, +} + +func runStatus(cmd *cli.Command, args []string) error { + influx := ml.NewInfluxClient(influxURL, influxDB) + + if err := ml.PrintStatus(influx, os.Stdout); err != nil { + return fmt.Errorf("status: %w", err) + } + + path := dbPath + if path == "" { + path = os.Getenv("LEM_DB") + } + + if path != "" { + db, err := ml.OpenDB(path) + if err != nil { + return fmt.Errorf("open db: %w", err) + } + defer db.Close() + + counts, err := db.TableCounts() + if err != nil { + return fmt.Errorf("table counts: %w", err) + } + + fmt.Println() + fmt.Println("DuckDB:") + order := []string{"golden_set", "expansion_prompts", "seeds", "training_examples", + "prompts", "gemini_responses", "benchmark_questions", "benchmark_results", "validations"} + for _, table := range order { + if count, ok := counts[table]; ok { + fmt.Fprintf(os.Stdout, " %-22s %6d rows\n", table, count) + } + } + } + + return nil +} diff --git a/internal/cmd/ml/cmd_worker.go b/internal/cmd/ml/cmd_worker.go new file mode 100644 index 0000000..41ddbfa --- /dev/null +++ b/internal/cmd/ml/cmd_worker.go @@ -0,0 +1,80 @@ +package ml + +import ( + "time" + + "forge.lthn.ai/core/cli/pkg/cli" + "forge.lthn.ai/core/cli/pkg/ml" +) + +var ( + workerAPIBase string + workerID string + workerName string + workerAPIKey string + workerGPU string + workerVRAM int + workerLangs string + workerModels string + workerInferURL string + workerTaskType string + workerBatchSize int + workerPoll time.Duration + workerOneShot bool + workerDryRun bool +) + +var workerCmd = &cli.Command{ + Use: "worker", + Short: "Run a distributed worker node", + Long: "Polls the LEM API for tasks, runs local inference, and submits results.", + RunE: runWorker, +} + +func init() { + workerCmd.Flags().StringVar(&workerAPIBase, "api", ml.EnvOr("LEM_API", "https://infer.lthn.ai"), "LEM API base URL") + workerCmd.Flags().StringVar(&workerID, "id", ml.EnvOr("LEM_WORKER_ID", ml.MachineID()), "Worker ID") + workerCmd.Flags().StringVar(&workerName, "name", ml.EnvOr("LEM_WORKER_NAME", ml.Hostname()), "Worker display name") + workerCmd.Flags().StringVar(&workerAPIKey, "key", ml.EnvOr("LEM_API_KEY", ""), "API key") + workerCmd.Flags().StringVar(&workerGPU, "gpu", ml.EnvOr("LEM_GPU", ""), "GPU type") + workerCmd.Flags().IntVar(&workerVRAM, "vram", ml.IntEnvOr("LEM_VRAM_GB", 0), "GPU VRAM in GB") + workerCmd.Flags().StringVar(&workerLangs, "languages", ml.EnvOr("LEM_LANGUAGES", ""), "Comma-separated language codes") + workerCmd.Flags().StringVar(&workerModels, "models", ml.EnvOr("LEM_MODELS", ""), "Comma-separated model names") + workerCmd.Flags().StringVar(&workerInferURL, "infer", ml.EnvOr("LEM_INFER_URL", "http://localhost:8090"), "Local inference endpoint") + workerCmd.Flags().StringVar(&workerTaskType, "type", "", "Filter by task type") + workerCmd.Flags().IntVar(&workerBatchSize, "batch", 5, "Tasks per poll") + workerCmd.Flags().DurationVar(&workerPoll, "poll", 30*time.Second, "Poll interval") + workerCmd.Flags().BoolVar(&workerOneShot, "one-shot", false, "Process one batch and exit") + workerCmd.Flags().BoolVar(&workerDryRun, "dry-run", false, "Fetch tasks but don't run inference") +} + +func runWorker(cmd *cli.Command, args []string) error { + if workerAPIKey == "" { + workerAPIKey = ml.ReadKeyFile() + } + + cfg := &ml.WorkerConfig{ + APIBase: workerAPIBase, + WorkerID: workerID, + Name: workerName, + APIKey: workerAPIKey, + GPUType: workerGPU, + VRAMGb: workerVRAM, + InferURL: workerInferURL, + TaskType: workerTaskType, + BatchSize: workerBatchSize, + PollInterval: workerPoll, + OneShot: workerOneShot, + DryRun: workerDryRun, + } + + if workerLangs != "" { + cfg.Languages = ml.SplitComma(workerLangs) + } + if workerModels != "" { + cfg.Models = ml.SplitComma(workerModels) + } + + ml.RunWorkerLoop(cfg) + return nil +} diff --git a/pkg/ml/service.go b/pkg/ml/service.go new file mode 100644 index 0000000..0cfff4b --- /dev/null +++ b/pkg/ml/service.go @@ -0,0 +1,162 @@ +package ml + +import ( + "context" + "fmt" + "sync" + + "forge.lthn.ai/core/cli/pkg/framework" +) + +// Service manages ML inference backends and scoring with Core lifecycle. +type Service struct { + *framework.ServiceRuntime[Options] + + backends map[string]Backend + mu sync.RWMutex + engine *Engine + judge *Judge +} + +// Options configures the ML service. +type Options struct { + // DefaultBackend is the name of the default inference backend. + DefaultBackend string + + // LlamaPath is the path to the llama-server binary. + LlamaPath string + + // ModelDir is the directory containing model files. + ModelDir string + + // OllamaURL is the Ollama API base URL. + OllamaURL string + + // JudgeURL is the judge model API URL. + JudgeURL string + + // JudgeModel is the judge model name. + JudgeModel string + + // InfluxURL is the InfluxDB URL for metrics. + InfluxURL string + + // InfluxDB is the InfluxDB database name. + InfluxDB string + + // Concurrency is the number of concurrent scoring workers. + Concurrency int + + // Suites is a comma-separated list of scoring suites to enable. + Suites string +} + +// NewService creates an ML service factory for Core registration. +// +// core, _ := framework.New( +// framework.WithName("ml", ml.NewService(ml.Options{})), +// ) +func NewService(opts Options) func(*framework.Core) (any, error) { + return func(c *framework.Core) (any, error) { + if opts.Concurrency == 0 { + opts.Concurrency = 4 + } + if opts.Suites == "" { + opts.Suites = "all" + } + + svc := &Service{ + ServiceRuntime: framework.NewServiceRuntime(c, opts), + backends: make(map[string]Backend), + } + return svc, nil + } +} + +// OnStartup initializes backends and scoring engine. +func (s *Service) OnStartup(ctx context.Context) error { + opts := s.Opts() + + // Register Ollama backend if URL provided. + if opts.OllamaURL != "" { + s.RegisterBackend("ollama", NewHTTPBackend(opts.OllamaURL, opts.JudgeModel)) + } + + // Set up judge if judge URL is provided. + if opts.JudgeURL != "" { + judgeBackend := NewHTTPBackend(opts.JudgeURL, opts.JudgeModel) + s.judge = NewJudge(judgeBackend) + s.engine = NewEngine(s.judge, opts.Concurrency, opts.Suites) + } + + return nil +} + +// OnShutdown cleans up resources. +func (s *Service) OnShutdown(ctx context.Context) error { + return nil +} + +// RegisterBackend adds or replaces a named inference backend. +func (s *Service) RegisterBackend(name string, backend Backend) { + s.mu.Lock() + defer s.mu.Unlock() + s.backends[name] = backend +} + +// Backend returns a named backend, or nil if not found. +func (s *Service) Backend(name string) Backend { + s.mu.RLock() + defer s.mu.RUnlock() + return s.backends[name] +} + +// DefaultBackend returns the configured default backend. +func (s *Service) DefaultBackend() Backend { + name := s.Opts().DefaultBackend + if name == "" { + name = "ollama" + } + return s.Backend(name) +} + +// Backends returns the names of all registered backends. +func (s *Service) Backends() []string { + s.mu.RLock() + defer s.mu.RUnlock() + names := make([]string, 0, len(s.backends)) + for name := range s.backends { + names = append(names, name) + } + return names +} + +// Judge returns the configured judge, or nil if not set up. +func (s *Service) Judge() *Judge { + return s.judge +} + +// Engine returns the scoring engine, or nil if not set up. +func (s *Service) Engine() *Engine { + return s.engine +} + +// Generate generates text using the named backend (or default). +func (s *Service) Generate(ctx context.Context, backendName, prompt string, opts GenOpts) (string, error) { + b := s.Backend(backendName) + if b == nil { + b = s.DefaultBackend() + } + if b == nil { + return "", fmt.Errorf("no backend available (requested: %q)", backendName) + } + return b.Generate(ctx, prompt, opts) +} + +// ScoreResponses scores a batch of responses using the configured engine. +func (s *Service) ScoreResponses(ctx context.Context, responses []Response) (map[string][]PromptScore, error) { + if s.engine == nil { + return nil, fmt.Errorf("scoring engine not configured (set JudgeURL and JudgeModel)") + } + return s.engine.ScoreAll(ctx, responses), nil +}