cli/pkg/mcp/tools_ml.go

280 lines
8.2 KiB
Go
Raw Normal View History

package mcp
import (
"context"
"fmt"
"strings"
"forge.lthn.ai/core/cli/pkg/log"
"forge.lthn.ai/core/cli/pkg/ml"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// MLSubsystem exposes ML inference and scoring tools via MCP.
type MLSubsystem struct {
service *ml.Service
logger *log.Logger
}
// NewMLSubsystem creates an MCP subsystem for ML tools.
func NewMLSubsystem(svc *ml.Service) *MLSubsystem {
return &MLSubsystem{
service: svc,
logger: log.Default(),
}
}
func (m *MLSubsystem) Name() string { return "ml" }
// RegisterTools adds ML tools to the MCP server.
func (m *MLSubsystem) RegisterTools(server *mcp.Server) {
mcp.AddTool(server, &mcp.Tool{
Name: "ml_generate",
Description: "Generate text via a configured ML inference backend.",
}, m.mlGenerate)
mcp.AddTool(server, &mcp.Tool{
Name: "ml_score",
Description: "Score a prompt/response pair using heuristic and LLM judge suites.",
}, m.mlScore)
mcp.AddTool(server, &mcp.Tool{
Name: "ml_probe",
Description: "Run capability probes against an inference backend.",
}, m.mlProbe)
mcp.AddTool(server, &mcp.Tool{
Name: "ml_status",
Description: "Show training and generation progress from InfluxDB.",
}, m.mlStatus)
mcp.AddTool(server, &mcp.Tool{
Name: "ml_backends",
Description: "List available inference backends and their status.",
}, m.mlBackends)
}
// --- Input/Output types ---
// MLGenerateInput contains parameters for text generation.
type MLGenerateInput struct {
Prompt string `json:"prompt"` // The prompt to generate from
Backend string `json:"backend,omitempty"` // Backend name (default: service default)
Model string `json:"model,omitempty"` // Model override
Temperature float64 `json:"temperature,omitempty"` // Sampling temperature
MaxTokens int `json:"max_tokens,omitempty"` // Maximum tokens to generate
}
// MLGenerateOutput contains the generation result.
type MLGenerateOutput struct {
Response string `json:"response"`
Backend string `json:"backend"`
Model string `json:"model,omitempty"`
}
// MLScoreInput contains parameters for scoring a response.
type MLScoreInput struct {
Prompt string `json:"prompt"` // The original prompt
Response string `json:"response"` // The model response to score
Suites string `json:"suites,omitempty"` // Comma-separated suites (default: heuristic)
}
// MLScoreOutput contains the scoring result.
type MLScoreOutput struct {
Heuristic *ml.HeuristicScores `json:"heuristic,omitempty"`
Semantic *ml.SemanticScores `json:"semantic,omitempty"`
Content *ml.ContentScores `json:"content,omitempty"`
}
// MLProbeInput contains parameters for running probes.
type MLProbeInput struct {
Backend string `json:"backend,omitempty"` // Backend name
Categories string `json:"categories,omitempty"` // Comma-separated categories to run
}
// MLProbeOutput contains probe results.
type MLProbeOutput struct {
Total int `json:"total"`
Results []MLProbeResultItem `json:"results"`
}
// MLProbeResultItem is a single probe result.
type MLProbeResultItem struct {
ID string `json:"id"`
Category string `json:"category"`
Response string `json:"response"`
}
// MLStatusInput contains parameters for the status query.
type MLStatusInput struct {
InfluxURL string `json:"influx_url,omitempty"` // InfluxDB URL override
InfluxDB string `json:"influx_db,omitempty"` // InfluxDB database override
}
// MLStatusOutput contains pipeline status.
type MLStatusOutput struct {
Status string `json:"status"`
}
// MLBackendsInput is empty — lists all backends.
type MLBackendsInput struct{}
// MLBackendsOutput lists available backends.
type MLBackendsOutput struct {
Backends []MLBackendInfo `json:"backends"`
Default string `json:"default"`
}
// MLBackendInfo describes a single backend.
type MLBackendInfo struct {
Name string `json:"name"`
Available bool `json:"available"`
}
// --- Tool handlers ---
func (m *MLSubsystem) mlGenerate(ctx context.Context, req *mcp.CallToolRequest, input MLGenerateInput) (*mcp.CallToolResult, MLGenerateOutput, error) {
m.logger.Info("MCP tool execution", "tool", "ml_generate", "backend", input.Backend, "user", log.Username())
if input.Prompt == "" {
return nil, MLGenerateOutput{}, fmt.Errorf("prompt cannot be empty")
}
opts := ml.GenOpts{
Temperature: input.Temperature,
MaxTokens: input.MaxTokens,
Model: input.Model,
}
response, err := m.service.Generate(ctx, input.Backend, input.Prompt, opts)
if err != nil {
return nil, MLGenerateOutput{}, fmt.Errorf("generate: %w", err)
}
return nil, MLGenerateOutput{
Response: response,
Backend: input.Backend,
Model: input.Model,
}, nil
}
func (m *MLSubsystem) mlScore(ctx context.Context, req *mcp.CallToolRequest, input MLScoreInput) (*mcp.CallToolResult, MLScoreOutput, error) {
m.logger.Info("MCP tool execution", "tool", "ml_score", "suites", input.Suites, "user", log.Username())
if input.Prompt == "" || input.Response == "" {
return nil, MLScoreOutput{}, fmt.Errorf("prompt and response cannot be empty")
}
suites := input.Suites
if suites == "" {
suites = "heuristic"
}
output := MLScoreOutput{}
for _, suite := range strings.Split(suites, ",") {
suite = strings.TrimSpace(suite)
switch suite {
case "heuristic":
output.Heuristic = ml.ScoreHeuristic(input.Response)
case "semantic":
judge := m.service.Judge()
if judge == nil {
return nil, MLScoreOutput{}, fmt.Errorf("semantic scoring requires a judge backend")
}
s, err := judge.ScoreSemantic(ctx, input.Prompt, input.Response)
if err != nil {
return nil, MLScoreOutput{}, fmt.Errorf("semantic score: %w", err)
}
output.Semantic = s
case "content":
return nil, MLScoreOutput{}, fmt.Errorf("content scoring requires a ContentProbe — use ml_probe instead")
}
}
return nil, output, nil
}
func (m *MLSubsystem) mlProbe(ctx context.Context, req *mcp.CallToolRequest, input MLProbeInput) (*mcp.CallToolResult, MLProbeOutput, error) {
m.logger.Info("MCP tool execution", "tool", "ml_probe", "backend", input.Backend, "user", log.Username())
// Filter probes by category if specified.
probes := ml.CapabilityProbes
if input.Categories != "" {
cats := make(map[string]bool)
for _, c := range strings.Split(input.Categories, ",") {
cats[strings.TrimSpace(c)] = true
}
var filtered []ml.Probe
for _, p := range probes {
if cats[p.Category] {
filtered = append(filtered, p)
}
}
probes = filtered
}
var results []MLProbeResultItem
for _, probe := range probes {
resp, err := m.service.Generate(ctx, input.Backend, probe.Prompt, ml.GenOpts{Temperature: 0.7, MaxTokens: 2048})
if err != nil {
resp = fmt.Sprintf("error: %v", err)
}
results = append(results, MLProbeResultItem{
ID: probe.ID,
Category: probe.Category,
Response: resp,
})
}
return nil, MLProbeOutput{
Total: len(results),
Results: results,
}, nil
}
func (m *MLSubsystem) mlStatus(ctx context.Context, req *mcp.CallToolRequest, input MLStatusInput) (*mcp.CallToolResult, MLStatusOutput, error) {
m.logger.Info("MCP tool execution", "tool", "ml_status", "user", log.Username())
url := input.InfluxURL
db := input.InfluxDB
if url == "" {
url = "http://localhost:8086"
}
if db == "" {
db = "lem"
}
influx := ml.NewInfluxClient(url, db)
var buf strings.Builder
if err := ml.PrintStatus(influx, &buf); err != nil {
return nil, MLStatusOutput{}, fmt.Errorf("status: %w", err)
}
return nil, MLStatusOutput{Status: buf.String()}, nil
}
func (m *MLSubsystem) mlBackends(ctx context.Context, req *mcp.CallToolRequest, input MLBackendsInput) (*mcp.CallToolResult, MLBackendsOutput, error) {
m.logger.Info("MCP tool execution", "tool", "ml_backends", "user", log.Username())
names := m.service.Backends()
backends := make([]MLBackendInfo, len(names))
defaultName := ""
for i, name := range names {
b := m.service.Backend(name)
backends[i] = MLBackendInfo{
Name: name,
Available: b != nil && b.Available(),
}
}
if db := m.service.DefaultBackend(); db != nil {
defaultName = db.Name()
}
return nil, MLBackendsOutput{
Backends: backends,
Default: defaultName,
}, nil
}