feat(ml): add CoreGo service wrapper and CLI commands (Tasks 6-7)
Some checks are pending
Security Scan / Go Vulnerability Check (push) Waiting to run
Security Scan / Secret Detection (push) Waiting to run
Security Scan / Dependency & Config Scan (push) Waiting to run

Service registration with DI lifecycle, typed options, and backend
management. Ten CLI subcommands under `core ml` for scoring, probing,
export, expansion, status, GGUF/PEFT conversion, agent, and worker.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Claude 2026-02-15 23:52:46 +00:00
parent fcd1758b7d
commit 3dbb5988a8
No known key found for this signature in database
GPG key ID: AF404715446AEB41
11 changed files with 839 additions and 0 deletions

View file

@ -0,0 +1,67 @@
package ml
import (
"forge.lthn.ai/core/cli/pkg/cli"
"forge.lthn.ai/core/cli/pkg/ml"
)
var (
agentM3Host string
agentM3User string
agentM3SSHKey string
agentM3AdapterBase string
agentBaseModel string
agentPollInterval int
agentWorkDir string
agentFilter string
agentForce bool
agentOneShot bool
agentDryRun bool
)
var agentCmd = &cli.Command{
Use: "agent",
Short: "Run the scoring agent daemon",
Long: "Polls M3 for unscored LoRA checkpoints, converts, probes, and pushes results to InfluxDB.",
RunE: runAgent,
}
func init() {
agentCmd.Flags().StringVar(&agentM3Host, "m3-host", ml.EnvOr("M3_HOST", "10.69.69.108"), "M3 host address")
agentCmd.Flags().StringVar(&agentM3User, "m3-user", ml.EnvOr("M3_USER", "claude"), "M3 SSH user")
agentCmd.Flags().StringVar(&agentM3SSHKey, "m3-ssh-key", ml.EnvOr("M3_SSH_KEY", ml.ExpandHome("~/.ssh/id_ed25519")), "SSH key for M3")
agentCmd.Flags().StringVar(&agentM3AdapterBase, "m3-adapter-base", ml.EnvOr("M3_ADAPTER_BASE", "/Volumes/Data/lem"), "Adapter base dir on M3")
agentCmd.Flags().StringVar(&agentBaseModel, "base-model", ml.EnvOr("BASE_MODEL", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"), "HuggingFace base model ID")
agentCmd.Flags().IntVar(&agentPollInterval, "poll", ml.IntEnvOr("POLL_INTERVAL", 300), "Poll interval in seconds")
agentCmd.Flags().StringVar(&agentWorkDir, "work-dir", ml.EnvOr("WORK_DIR", "/tmp/scoring-agent"), "Working directory for adapters")
agentCmd.Flags().StringVar(&agentFilter, "filter", "", "Filter adapter dirs by prefix")
agentCmd.Flags().BoolVar(&agentForce, "force", false, "Re-score already-scored checkpoints")
agentCmd.Flags().BoolVar(&agentOneShot, "one-shot", false, "Process one checkpoint and exit")
agentCmd.Flags().BoolVar(&agentDryRun, "dry-run", false, "Discover and plan but don't execute")
}
func runAgent(cmd *cli.Command, args []string) error {
cfg := &ml.AgentConfig{
M3Host: agentM3Host,
M3User: agentM3User,
M3SSHKey: agentM3SSHKey,
M3AdapterBase: agentM3AdapterBase,
InfluxURL: influxURL,
InfluxDB: influxDB,
DBPath: dbPath,
APIURL: apiURL,
JudgeURL: judgeURL,
JudgeModel: judgeModel,
Model: modelName,
BaseModel: agentBaseModel,
PollInterval: agentPollInterval,
WorkDir: agentWorkDir,
Filter: agentFilter,
Force: agentForce,
OneShot: agentOneShot,
DryRun: agentDryRun,
}
ml.RunAgentLoop(cfg)
return nil
}

View file

@ -0,0 +1,40 @@
package ml
import (
"fmt"
"forge.lthn.ai/core/cli/pkg/cli"
"forge.lthn.ai/core/cli/pkg/ml"
)
var (
convertInput string
convertConfig string
convertOutputDir string
convertBaseModel string
)
var convertCmd = &cli.Command{
Use: "convert",
Short: "Convert MLX LoRA adapter to PEFT format",
Long: "Converts an MLX safetensors LoRA adapter to HuggingFace PEFT format for Ollama.",
RunE: runConvert,
}
func init() {
convertCmd.Flags().StringVar(&convertInput, "input", "", "Input safetensors file (required)")
convertCmd.Flags().StringVar(&convertConfig, "config", "", "Adapter config JSON (required)")
convertCmd.Flags().StringVar(&convertOutputDir, "output-dir", "", "Output directory (required)")
convertCmd.Flags().StringVar(&convertBaseModel, "base-model", "", "Base model name for adapter_config.json")
convertCmd.MarkFlagRequired("input")
convertCmd.MarkFlagRequired("config")
convertCmd.MarkFlagRequired("output-dir")
}
func runConvert(cmd *cli.Command, args []string) error {
if err := ml.ConvertMLXtoPEFT(convertInput, convertConfig, convertOutputDir, convertBaseModel); err != nil {
return fmt.Errorf("convert to PEFT: %w", err)
}
fmt.Printf("PEFT adapter written to %s\n", convertOutputDir)
return nil
}

View file

@ -0,0 +1,81 @@
package ml
import (
"context"
"fmt"
"os"
"forge.lthn.ai/core/cli/pkg/cli"
"forge.lthn.ai/core/cli/pkg/ml"
)
var (
expandWorker string
expandOutput string
expandLimit int
expandDryRun bool
)
var expandCmd = &cli.Command{
Use: "expand",
Short: "Generate expansion responses from pending prompts",
Long: "Reads pending expansion prompts from DuckDB and generates responses via an OpenAI-compatible API.",
RunE: runExpand,
}
func init() {
expandCmd.Flags().StringVar(&expandWorker, "worker", "", "Worker hostname (defaults to os.Hostname())")
expandCmd.Flags().StringVar(&expandOutput, "output", ".", "Output directory for JSONL files")
expandCmd.Flags().IntVar(&expandLimit, "limit", 0, "Max prompts to process (0 = all)")
expandCmd.Flags().BoolVar(&expandDryRun, "dry-run", false, "Print plan and exit without generating")
}
func runExpand(cmd *cli.Command, args []string) error {
if modelName == "" {
return fmt.Errorf("--model is required")
}
path := dbPath
if path == "" {
path = os.Getenv("LEM_DB")
}
if path == "" {
return fmt.Errorf("--db or LEM_DB env is required")
}
if expandWorker == "" {
h, _ := os.Hostname()
expandWorker = h
}
db, err := ml.OpenDBReadWrite(path)
if err != nil {
return fmt.Errorf("open db: %w", err)
}
defer db.Close()
rows, err := db.QueryExpansionPrompts("pending", expandLimit)
if err != nil {
return fmt.Errorf("query expansion_prompts: %w", err)
}
fmt.Printf("Loaded %d pending prompts from %s\n", len(rows), path)
var prompts []ml.Response
for _, r := range rows {
prompt := r.Prompt
if prompt == "" && r.PromptEn != "" {
prompt = r.PromptEn
}
prompts = append(prompts, ml.Response{
ID: r.SeedID,
Domain: r.Domain,
Prompt: prompt,
})
}
ctx := context.Background()
backend := ml.NewHTTPBackend(apiURL, modelName)
influx := ml.NewInfluxClient(influxURL, influxDB)
return ml.ExpandPrompts(ctx, backend, influx, prompts, modelName, expandWorker, expandOutput, expandDryRun, expandLimit)
}

View file

@ -0,0 +1,109 @@
package ml
import (
"fmt"
"os"
"forge.lthn.ai/core/cli/pkg/cli"
"forge.lthn.ai/core/cli/pkg/ml"
)
var (
exportOutputDir string
exportMinChars int
exportTrainPct int
exportValidPct int
exportTestPct int
exportSeed int64
exportParquet bool
)
var exportCmd = &cli.Command{
Use: "export",
Short: "Export golden set to training JSONL and Parquet",
Long: "Reads golden set from DuckDB, filters, splits, and exports to JSONL and optionally Parquet.",
RunE: runExport,
}
func init() {
exportCmd.Flags().StringVar(&exportOutputDir, "output-dir", "", "Output directory for training files (required)")
exportCmd.Flags().IntVar(&exportMinChars, "min-chars", 50, "Minimum response length in characters")
exportCmd.Flags().IntVar(&exportTrainPct, "train", 80, "Training split percentage")
exportCmd.Flags().IntVar(&exportValidPct, "valid", 10, "Validation split percentage")
exportCmd.Flags().IntVar(&exportTestPct, "test", 10, "Test split percentage")
exportCmd.Flags().Int64Var(&exportSeed, "seed", 42, "Random seed for shuffle")
exportCmd.Flags().BoolVar(&exportParquet, "parquet", false, "Also export Parquet files")
exportCmd.MarkFlagRequired("output-dir")
}
func runExport(cmd *cli.Command, args []string) error {
if err := ml.ValidatePercentages(exportTrainPct, exportValidPct, exportTestPct); err != nil {
return err
}
path := dbPath
if path == "" {
path = os.Getenv("LEM_DB")
}
if path == "" {
return fmt.Errorf("--db or LEM_DB env is required")
}
db, err := ml.OpenDB(path)
if err != nil {
return fmt.Errorf("open db: %w", err)
}
defer db.Close()
rows, err := db.QueryGoldenSet(exportMinChars)
if err != nil {
return fmt.Errorf("query golden set: %w", err)
}
fmt.Printf("Loaded %d golden set rows (min %d chars)\n", len(rows), exportMinChars)
// Convert to Response format.
var responses []ml.Response
for _, r := range rows {
responses = append(responses, ml.Response{
ID: r.SeedID,
Domain: r.Domain,
Prompt: r.Prompt,
Response: r.Response,
})
}
filtered := ml.FilterResponses(responses)
fmt.Printf("After filtering: %d responses\n", len(filtered))
train, valid, test := ml.SplitData(filtered, exportTrainPct, exportValidPct, exportTestPct, exportSeed)
fmt.Printf("Split: train=%d, valid=%d, test=%d\n", len(train), len(valid), len(test))
if err := os.MkdirAll(exportOutputDir, 0755); err != nil {
return fmt.Errorf("create output dir: %w", err)
}
for _, split := range []struct {
name string
data []ml.Response
}{
{"train", train},
{"valid", valid},
{"test", test},
} {
path := fmt.Sprintf("%s/%s.jsonl", exportOutputDir, split.name)
if err := ml.WriteTrainingJSONL(path, split.data); err != nil {
return fmt.Errorf("write %s: %w", split.name, err)
}
fmt.Printf(" %s.jsonl: %d examples\n", split.name, len(split.data))
}
if exportParquet {
n, err := ml.ExportParquet(exportOutputDir, "")
if err != nil {
return fmt.Errorf("export parquet: %w", err)
}
fmt.Printf(" Parquet: %d total rows\n", n)
}
return nil
}

View file

@ -0,0 +1,40 @@
package ml
import (
"fmt"
"forge.lthn.ai/core/cli/pkg/cli"
"forge.lthn.ai/core/cli/pkg/ml"
)
var (
ggufInput string
ggufConfig string
ggufOutput string
ggufArch string
)
var ggufCmd = &cli.Command{
Use: "gguf",
Short: "Convert MLX LoRA adapter to GGUF format",
Long: "Converts an MLX safetensors LoRA adapter to GGUF v3 format for use with llama.cpp.",
RunE: runGGUF,
}
func init() {
ggufCmd.Flags().StringVar(&ggufInput, "input", "", "Input safetensors file (required)")
ggufCmd.Flags().StringVar(&ggufConfig, "config", "", "Adapter config JSON (required)")
ggufCmd.Flags().StringVar(&ggufOutput, "output", "", "Output GGUF file (required)")
ggufCmd.Flags().StringVar(&ggufArch, "arch", "gemma3", "GGUF architecture name")
ggufCmd.MarkFlagRequired("input")
ggufCmd.MarkFlagRequired("config")
ggufCmd.MarkFlagRequired("output")
}
func runGGUF(cmd *cli.Command, args []string) error {
if err := ml.ConvertMLXtoGGUFLoRA(ggufInput, ggufConfig, ggufOutput, ggufArch); err != nil {
return fmt.Errorf("convert to GGUF: %w", err)
}
fmt.Printf("GGUF LoRA adapter written to %s\n", ggufOutput)
return nil
}

63
internal/cmd/ml/cmd_ml.go Normal file
View file

@ -0,0 +1,63 @@
// Package ml provides ML inference, scoring, and training pipeline commands.
//
// Commands:
// - core ml score: Score responses with heuristic and LLM judges
// - core ml probe: Run capability and content probes against a model
// - core ml export: Export golden set to training JSONL/Parquet
// - core ml expand: Generate expansion responses
// - core ml status: Show training and generation progress
// - core ml gguf: Convert MLX LoRA adapter to GGUF format
// - core ml convert: Convert MLX LoRA adapter to PEFT format
// - core ml agent: Run the scoring agent daemon
// - core ml worker: Run a distributed worker node
package ml
import (
"forge.lthn.ai/core/cli/pkg/cli"
)
func init() {
cli.RegisterCommands(AddMLCommands)
}
var mlCmd = &cli.Command{
Use: "ml",
Short: "ML inference, scoring, and training pipeline",
Long: "Commands for ML model scoring, probe evaluation, data export, and format conversion.",
}
// AddMLCommands registers the 'ml' command and all subcommands.
func AddMLCommands(root *cli.Command) {
initFlags()
mlCmd.AddCommand(scoreCmd)
mlCmd.AddCommand(probeCmd)
mlCmd.AddCommand(exportCmd)
mlCmd.AddCommand(expandCmd)
mlCmd.AddCommand(statusCmd)
mlCmd.AddCommand(ggufCmd)
mlCmd.AddCommand(convertCmd)
mlCmd.AddCommand(agentCmd)
mlCmd.AddCommand(workerCmd)
root.AddCommand(mlCmd)
}
// Shared persistent flags.
var (
apiURL string
judgeURL string
judgeModel string
influxURL string
influxDB string
dbPath string
modelName string
)
func initFlags() {
mlCmd.PersistentFlags().StringVar(&apiURL, "api-url", "http://10.69.69.108:8090", "OpenAI-compatible API URL")
mlCmd.PersistentFlags().StringVar(&judgeURL, "judge-url", "http://10.69.69.108:11434", "Judge model API URL (Ollama)")
mlCmd.PersistentFlags().StringVar(&judgeModel, "judge-model", "gemma3:27b", "Judge model name")
mlCmd.PersistentFlags().StringVar(&influxURL, "influx", "", "InfluxDB URL (default http://10.69.69.165:8181)")
mlCmd.PersistentFlags().StringVar(&influxDB, "influx-db", "", "InfluxDB database (default training)")
mlCmd.PersistentFlags().StringVar(&dbPath, "db", "", "DuckDB database path (or set LEM_DB env)")
mlCmd.PersistentFlags().StringVar(&modelName, "model", "", "Model name for API")
}

View file

@ -0,0 +1,66 @@
package ml
import (
"context"
"encoding/json"
"fmt"
"os"
"forge.lthn.ai/core/cli/pkg/cli"
"forge.lthn.ai/core/cli/pkg/ml"
)
var (
probeOutput string
)
var probeCmd = &cli.Command{
Use: "probe",
Short: "Run capability and content probes against a model",
Long: "Runs 23 capability probes and 6 content probes against an OpenAI-compatible API.",
RunE: runProbe,
}
func init() {
probeCmd.Flags().StringVar(&probeOutput, "output", "", "Output JSON file for probe results")
}
func runProbe(cmd *cli.Command, args []string) error {
if apiURL == "" {
return fmt.Errorf("--api-url is required")
}
model := modelName
if model == "" {
model = "default"
}
ctx := context.Background()
backend := ml.NewHTTPBackend(apiURL, model)
fmt.Printf("Running %d capability probes against %s...\n", len(ml.CapabilityProbes), apiURL)
results := ml.RunCapabilityProbes(ctx, backend)
fmt.Printf("\nResults: %.1f%% (%d/%d)\n", results.Accuracy, results.Correct, results.Total)
for cat, data := range results.ByCategory {
catAcc := 0.0
if data.Total > 0 {
catAcc = float64(data.Correct) / float64(data.Total) * 100
}
fmt.Printf(" %-20s %d/%d (%.0f%%)\n", cat, data.Correct, data.Total, catAcc)
}
if probeOutput != "" {
data, err := json.MarshalIndent(results, "", " ")
if err != nil {
return fmt.Errorf("marshal results: %w", err)
}
if err := os.WriteFile(probeOutput, data, 0644); err != nil {
return fmt.Errorf("write output: %w", err)
}
fmt.Printf("\nResults written to %s\n", probeOutput)
}
return nil
}

View file

@ -0,0 +1,77 @@
package ml
import (
"context"
"fmt"
"time"
"forge.lthn.ai/core/cli/pkg/cli"
"forge.lthn.ai/core/cli/pkg/ml"
)
var (
scoreInput string
scoreSuites string
scoreOutput string
scoreConcur int
)
var scoreCmd = &cli.Command{
Use: "score",
Short: "Score responses with heuristic and LLM judges",
Long: "Reads a JSONL file of prompt/response pairs and scores them across configured suites.",
RunE: runScore,
}
func init() {
scoreCmd.Flags().StringVar(&scoreInput, "input", "", "Input JSONL file with prompt/response pairs (required)")
scoreCmd.Flags().StringVar(&scoreSuites, "suites", "all", "Comma-separated scoring suites (heuristic,semantic,content,exact,truthfulqa,donotanswer,toxigen)")
scoreCmd.Flags().StringVar(&scoreOutput, "output", "", "Output JSON file for scores")
scoreCmd.Flags().IntVar(&scoreConcur, "concurrency", 4, "Number of concurrent scoring workers")
scoreCmd.MarkFlagRequired("input")
}
func runScore(cmd *cli.Command, args []string) error {
responses, err := ml.ReadResponses(scoreInput)
if err != nil {
return fmt.Errorf("read input: %w", err)
}
var judge *ml.Judge
if judgeURL != "" {
backend := ml.NewHTTPBackend(judgeURL, judgeModel)
judge = ml.NewJudge(backend)
}
engine := ml.NewEngine(judge, scoreConcur, scoreSuites)
ctx := context.Background()
perPrompt := engine.ScoreAll(ctx, responses)
averages := ml.ComputeAverages(perPrompt)
if scoreOutput != "" {
output := &ml.ScorerOutput{
Metadata: ml.Metadata{
JudgeModel: judgeModel,
JudgeURL: judgeURL,
ScoredAt: time.Now(),
Suites: ml.SplitComma(scoreSuites),
},
ModelAverages: averages,
PerPrompt: perPrompt,
}
if err := ml.WriteScores(scoreOutput, output); err != nil {
return fmt.Errorf("write output: %w", err)
}
fmt.Printf("Scores written to %s\n", scoreOutput)
} else {
for model, avgs := range averages {
fmt.Printf("%s:\n", model)
for field, val := range avgs {
fmt.Printf(" %-25s %.3f\n", field, val)
}
}
}
return nil
}

View file

@ -0,0 +1,54 @@
package ml
import (
"fmt"
"os"
"forge.lthn.ai/core/cli/pkg/cli"
"forge.lthn.ai/core/cli/pkg/ml"
)
var statusCmd = &cli.Command{
Use: "status",
Short: "Show training and generation progress",
Long: "Queries InfluxDB for training status, loss, and generation progress. Optionally shows DuckDB table counts.",
RunE: runStatus,
}
func runStatus(cmd *cli.Command, args []string) error {
influx := ml.NewInfluxClient(influxURL, influxDB)
if err := ml.PrintStatus(influx, os.Stdout); err != nil {
return fmt.Errorf("status: %w", err)
}
path := dbPath
if path == "" {
path = os.Getenv("LEM_DB")
}
if path != "" {
db, err := ml.OpenDB(path)
if err != nil {
return fmt.Errorf("open db: %w", err)
}
defer db.Close()
counts, err := db.TableCounts()
if err != nil {
return fmt.Errorf("table counts: %w", err)
}
fmt.Println()
fmt.Println("DuckDB:")
order := []string{"golden_set", "expansion_prompts", "seeds", "training_examples",
"prompts", "gemini_responses", "benchmark_questions", "benchmark_results", "validations"}
for _, table := range order {
if count, ok := counts[table]; ok {
fmt.Fprintf(os.Stdout, " %-22s %6d rows\n", table, count)
}
}
}
return nil
}

View file

@ -0,0 +1,80 @@
package ml
import (
"time"
"forge.lthn.ai/core/cli/pkg/cli"
"forge.lthn.ai/core/cli/pkg/ml"
)
var (
workerAPIBase string
workerID string
workerName string
workerAPIKey string
workerGPU string
workerVRAM int
workerLangs string
workerModels string
workerInferURL string
workerTaskType string
workerBatchSize int
workerPoll time.Duration
workerOneShot bool
workerDryRun bool
)
var workerCmd = &cli.Command{
Use: "worker",
Short: "Run a distributed worker node",
Long: "Polls the LEM API for tasks, runs local inference, and submits results.",
RunE: runWorker,
}
func init() {
workerCmd.Flags().StringVar(&workerAPIBase, "api", ml.EnvOr("LEM_API", "https://infer.lthn.ai"), "LEM API base URL")
workerCmd.Flags().StringVar(&workerID, "id", ml.EnvOr("LEM_WORKER_ID", ml.MachineID()), "Worker ID")
workerCmd.Flags().StringVar(&workerName, "name", ml.EnvOr("LEM_WORKER_NAME", ml.Hostname()), "Worker display name")
workerCmd.Flags().StringVar(&workerAPIKey, "key", ml.EnvOr("LEM_API_KEY", ""), "API key")
workerCmd.Flags().StringVar(&workerGPU, "gpu", ml.EnvOr("LEM_GPU", ""), "GPU type")
workerCmd.Flags().IntVar(&workerVRAM, "vram", ml.IntEnvOr("LEM_VRAM_GB", 0), "GPU VRAM in GB")
workerCmd.Flags().StringVar(&workerLangs, "languages", ml.EnvOr("LEM_LANGUAGES", ""), "Comma-separated language codes")
workerCmd.Flags().StringVar(&workerModels, "models", ml.EnvOr("LEM_MODELS", ""), "Comma-separated model names")
workerCmd.Flags().StringVar(&workerInferURL, "infer", ml.EnvOr("LEM_INFER_URL", "http://localhost:8090"), "Local inference endpoint")
workerCmd.Flags().StringVar(&workerTaskType, "type", "", "Filter by task type")
workerCmd.Flags().IntVar(&workerBatchSize, "batch", 5, "Tasks per poll")
workerCmd.Flags().DurationVar(&workerPoll, "poll", 30*time.Second, "Poll interval")
workerCmd.Flags().BoolVar(&workerOneShot, "one-shot", false, "Process one batch and exit")
workerCmd.Flags().BoolVar(&workerDryRun, "dry-run", false, "Fetch tasks but don't run inference")
}
func runWorker(cmd *cli.Command, args []string) error {
if workerAPIKey == "" {
workerAPIKey = ml.ReadKeyFile()
}
cfg := &ml.WorkerConfig{
APIBase: workerAPIBase,
WorkerID: workerID,
Name: workerName,
APIKey: workerAPIKey,
GPUType: workerGPU,
VRAMGb: workerVRAM,
InferURL: workerInferURL,
TaskType: workerTaskType,
BatchSize: workerBatchSize,
PollInterval: workerPoll,
OneShot: workerOneShot,
DryRun: workerDryRun,
}
if workerLangs != "" {
cfg.Languages = ml.SplitComma(workerLangs)
}
if workerModels != "" {
cfg.Models = ml.SplitComma(workerModels)
}
ml.RunWorkerLoop(cfg)
return nil
}

162
pkg/ml/service.go Normal file
View file

@ -0,0 +1,162 @@
package ml
import (
"context"
"fmt"
"sync"
"forge.lthn.ai/core/cli/pkg/framework"
)
// Service manages ML inference backends and scoring with Core lifecycle.
type Service struct {
*framework.ServiceRuntime[Options]
backends map[string]Backend
mu sync.RWMutex
engine *Engine
judge *Judge
}
// Options configures the ML service.
type Options struct {
// DefaultBackend is the name of the default inference backend.
DefaultBackend string
// LlamaPath is the path to the llama-server binary.
LlamaPath string
// ModelDir is the directory containing model files.
ModelDir string
// OllamaURL is the Ollama API base URL.
OllamaURL string
// JudgeURL is the judge model API URL.
JudgeURL string
// JudgeModel is the judge model name.
JudgeModel string
// InfluxURL is the InfluxDB URL for metrics.
InfluxURL string
// InfluxDB is the InfluxDB database name.
InfluxDB string
// Concurrency is the number of concurrent scoring workers.
Concurrency int
// Suites is a comma-separated list of scoring suites to enable.
Suites string
}
// NewService creates an ML service factory for Core registration.
//
// core, _ := framework.New(
// framework.WithName("ml", ml.NewService(ml.Options{})),
// )
func NewService(opts Options) func(*framework.Core) (any, error) {
return func(c *framework.Core) (any, error) {
if opts.Concurrency == 0 {
opts.Concurrency = 4
}
if opts.Suites == "" {
opts.Suites = "all"
}
svc := &Service{
ServiceRuntime: framework.NewServiceRuntime(c, opts),
backends: make(map[string]Backend),
}
return svc, nil
}
}
// OnStartup initializes backends and scoring engine.
func (s *Service) OnStartup(ctx context.Context) error {
opts := s.Opts()
// Register Ollama backend if URL provided.
if opts.OllamaURL != "" {
s.RegisterBackend("ollama", NewHTTPBackend(opts.OllamaURL, opts.JudgeModel))
}
// Set up judge if judge URL is provided.
if opts.JudgeURL != "" {
judgeBackend := NewHTTPBackend(opts.JudgeURL, opts.JudgeModel)
s.judge = NewJudge(judgeBackend)
s.engine = NewEngine(s.judge, opts.Concurrency, opts.Suites)
}
return nil
}
// OnShutdown cleans up resources.
func (s *Service) OnShutdown(ctx context.Context) error {
return nil
}
// RegisterBackend adds or replaces a named inference backend.
func (s *Service) RegisterBackend(name string, backend Backend) {
s.mu.Lock()
defer s.mu.Unlock()
s.backends[name] = backend
}
// Backend returns a named backend, or nil if not found.
func (s *Service) Backend(name string) Backend {
s.mu.RLock()
defer s.mu.RUnlock()
return s.backends[name]
}
// DefaultBackend returns the configured default backend.
func (s *Service) DefaultBackend() Backend {
name := s.Opts().DefaultBackend
if name == "" {
name = "ollama"
}
return s.Backend(name)
}
// Backends returns the names of all registered backends.
func (s *Service) Backends() []string {
s.mu.RLock()
defer s.mu.RUnlock()
names := make([]string, 0, len(s.backends))
for name := range s.backends {
names = append(names, name)
}
return names
}
// Judge returns the configured judge, or nil if not set up.
func (s *Service) Judge() *Judge {
return s.judge
}
// Engine returns the scoring engine, or nil if not set up.
func (s *Service) Engine() *Engine {
return s.engine
}
// Generate generates text using the named backend (or default).
func (s *Service) Generate(ctx context.Context, backendName, prompt string, opts GenOpts) (string, error) {
b := s.Backend(backendName)
if b == nil {
b = s.DefaultBackend()
}
if b == nil {
return "", fmt.Errorf("no backend available (requested: %q)", backendName)
}
return b.Generate(ctx, prompt, opts)
}
// ScoreResponses scores a batch of responses using the configured engine.
func (s *Service) ScoreResponses(ctx context.Context, responses []Response) (map[string][]PromptScore, error) {
if s.engine == nil {
return nil, fmt.Errorf("scoring engine not configured (set JudgeURL and JudgeModel)")
}
return s.engine.ScoreAll(ctx, responses), nil
}