From 0202bec84a336465fdfac9894771eb19404c6980 Mon Sep 17 00:00:00 2001 From: Snider Date: Mon, 9 Mar 2026 18:40:50 +0000 Subject: [PATCH] refactor: extract MCP server to core/mcp Move mcp/, cmd/mcpcmd/, cmd/brain-seed/ to the new core/mcp repo. Update daemon import to use forge.lthn.ai/core/mcp/pkg/mcp. Co-Authored-By: Claude Opus 4.6 --- cmd/brain-seed/main.go | 502 ---------------------- cmd/daemon/cmd.go | 2 +- cmd/mcpcmd/cmd_mcp.go | 92 ----- mcp/brain/brain.go | 42 -- mcp/brain/brain_test.go | 229 ---------- mcp/brain/tools.go | 220 ---------- mcp/bridge.go | 64 --- mcp/bridge_test.go | 250 ----------- mcp/ide/bridge.go | 191 --------- mcp/ide/bridge_test.go | 442 -------------------- mcp/ide/config.go | 57 --- mcp/ide/ide.go | 62 --- mcp/ide/tools_build.go | 114 ----- mcp/ide/tools_chat.go | 201 --------- mcp/ide/tools_dashboard.go | 132 ------ mcp/ide/tools_test.go | 781 ----------------------------------- mcp/integration_test.go | 121 ------ mcp/iter_test.go | 40 -- mcp/mcp.go | 580 -------------------------- mcp/mcp_test.go | 180 -------- mcp/registry.go | 149 ------- mcp/registry_test.go | 150 ------- mcp/subsystem.go | 32 -- mcp/subsystem_test.go | 114 ----- mcp/tools_metrics.go | 213 ---------- mcp/tools_metrics_test.go | 207 ---------- mcp/tools_ml.go | 290 ------------- mcp/tools_ml_test.go | 479 --------------------- mcp/tools_process.go | 305 -------------- mcp/tools_process_ci_test.go | 515 ----------------------- mcp/tools_process_test.go | 290 ------------- mcp/tools_rag.go | 233 ----------- mcp/tools_rag_ci_test.go | 181 -------- mcp/tools_rag_test.go | 173 -------- mcp/tools_webview.go | 497 ---------------------- mcp/tools_webview_test.go | 452 -------------------- mcp/tools_ws.go | 142 ------- mcp/tools_ws_test.go | 174 -------- mcp/transport_e2e_test.go | 742 --------------------------------- mcp/transport_stdio.go | 15 - mcp/transport_tcp.go | 177 -------- mcp/transport_tcp_test.go | 184 --------- mcp/transport_unix.go | 52 --- 43 files changed, 1 insertion(+), 10067 deletions(-) delete mode 100644 cmd/brain-seed/main.go delete mode 100644 cmd/mcpcmd/cmd_mcp.go delete mode 100644 mcp/brain/brain.go delete mode 100644 mcp/brain/brain_test.go delete mode 100644 mcp/brain/tools.go delete mode 100644 mcp/bridge.go delete mode 100644 mcp/bridge_test.go delete mode 100644 mcp/ide/bridge.go delete mode 100644 mcp/ide/bridge_test.go delete mode 100644 mcp/ide/config.go delete mode 100644 mcp/ide/ide.go delete mode 100644 mcp/ide/tools_build.go delete mode 100644 mcp/ide/tools_chat.go delete mode 100644 mcp/ide/tools_dashboard.go delete mode 100644 mcp/ide/tools_test.go delete mode 100644 mcp/integration_test.go delete mode 100644 mcp/iter_test.go delete mode 100644 mcp/mcp.go delete mode 100644 mcp/mcp_test.go delete mode 100644 mcp/registry.go delete mode 100644 mcp/registry_test.go delete mode 100644 mcp/subsystem.go delete mode 100644 mcp/subsystem_test.go delete mode 100644 mcp/tools_metrics.go delete mode 100644 mcp/tools_metrics_test.go delete mode 100644 mcp/tools_ml.go delete mode 100644 mcp/tools_ml_test.go delete mode 100644 mcp/tools_process.go delete mode 100644 mcp/tools_process_ci_test.go delete mode 100644 mcp/tools_process_test.go delete mode 100644 mcp/tools_rag.go delete mode 100644 mcp/tools_rag_ci_test.go delete mode 100644 mcp/tools_rag_test.go delete mode 100644 mcp/tools_webview.go delete mode 100644 mcp/tools_webview_test.go delete mode 100644 mcp/tools_ws.go delete mode 100644 mcp/tools_ws_test.go delete mode 100644 mcp/transport_e2e_test.go delete mode 100644 mcp/transport_stdio.go delete mode 100644 mcp/transport_tcp.go delete mode 100644 mcp/transport_tcp_test.go delete mode 100644 mcp/transport_unix.go diff --git a/cmd/brain-seed/main.go b/cmd/brain-seed/main.go deleted file mode 100644 index 692e1ae..0000000 --- a/cmd/brain-seed/main.go +++ /dev/null @@ -1,502 +0,0 @@ -// SPDX-License-Identifier: EUPL-1.2 - -// brain-seed imports Claude Code MEMORY.md files into the OpenBrain knowledge -// store via the MCP HTTP API (brain_remember tool). The Laravel app handles -// embedding, Qdrant storage, and MariaDB dual-write internally. -// -// Usage: -// -// go run ./cmd/brain-seed -api-key YOUR_KEY -// go run ./cmd/brain-seed -api-key YOUR_KEY -api https://lthn.sh/api/v1/mcp -// go run ./cmd/brain-seed -api-key YOUR_KEY -dry-run -// go run ./cmd/brain-seed -api-key YOUR_KEY -plans -// go run ./cmd/brain-seed -api-key YOUR_KEY -claude-md # Also import CLAUDE.md files -package main - -import ( - "bytes" - "crypto/tls" - "encoding/json" - "flag" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "regexp" - "strings" - "time" -) - -var ( - apiURL = flag.String("api", "https://lthn.sh/api/v1/mcp", "MCP API base URL") - apiKey = flag.String("api-key", "", "MCP API key (Bearer token)") - server = flag.String("server", "hosthub-agent", "MCP server ID") - agent = flag.String("agent", "charon", "Agent ID for attribution") - dryRun = flag.Bool("dry-run", false, "Preview without storing") - plans = flag.Bool("plans", false, "Also import plan documents") - claudeMd = flag.Bool("claude-md", false, "Also import CLAUDE.md files") - memoryPath = flag.String("memory-path", "", "Override memory scan path (default: ~/.claude/projects/*/memory/)") - planPath = flag.String("plan-path", "", "Override plan scan path (default: ~/Code/*/docs/plans/)") - codePath = flag.String("code-path", "", "Override code root for CLAUDE.md scan (default: ~/Code)") - maxChars = flag.Int("max-chars", 3800, "Max chars per section (embeddinggemma limit ~4000)") -) - -// httpClient with TLS skip for non-public TLDs (.lthn.sh has real certs, but -// allow .lan/.local if someone has legacy config). -var httpClient = &http.Client{ - Timeout: 30 * time.Second, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: false}, - }, -} - -func main() { - flag.Parse() - - fmt.Println("OpenBrain Seed — MCP API Client") - fmt.Println(strings.Repeat("=", 55)) - - if *apiKey == "" && !*dryRun { - fmt.Println("ERROR: -api-key is required (or use -dry-run)") - fmt.Println(" Generate one at: https://lthn.sh/admin/mcp/api-keys") - os.Exit(1) - } - - if *dryRun { - fmt.Println("[DRY RUN] — no data will be stored") - } - - fmt.Printf("API: %s\n", *apiURL) - fmt.Printf("Server: %s | Agent: %s\n", *server, *agent) - - // Discover memory files - memPath := *memoryPath - if memPath == "" { - home, _ := os.UserHomeDir() - memPath = filepath.Join(home, ".claude", "projects", "*", "memory") - } - memFiles, _ := filepath.Glob(filepath.Join(memPath, "*.md")) - fmt.Printf("\nFound %d memory files\n", len(memFiles)) - - // Discover plan files - var planFiles []string - if *plans { - pPath := *planPath - if pPath == "" { - home, _ := os.UserHomeDir() - pPath = filepath.Join(home, "Code", "*", "docs", "plans") - } - planFiles, _ = filepath.Glob(filepath.Join(pPath, "*.md")) - // Also check nested dirs (completed/, etc.) - nested, _ := filepath.Glob(filepath.Join(pPath, "*", "*.md")) - planFiles = append(planFiles, nested...) - - // Also check host-uk nested repos - home, _ := os.UserHomeDir() - hostUkPath := filepath.Join(home, "Code", "host-uk", "*", "docs", "plans") - hostUkFiles, _ := filepath.Glob(filepath.Join(hostUkPath, "*.md")) - planFiles = append(planFiles, hostUkFiles...) - hostUkNested, _ := filepath.Glob(filepath.Join(hostUkPath, "*", "*.md")) - planFiles = append(planFiles, hostUkNested...) - - fmt.Printf("Found %d plan files\n", len(planFiles)) - } - - // Discover CLAUDE.md files - var claudeFiles []string - if *claudeMd { - cPath := *codePath - if cPath == "" { - home, _ := os.UserHomeDir() - cPath = filepath.Join(home, "Code") - } - claudeFiles = discoverClaudeMdFiles(cPath) - fmt.Printf("Found %d CLAUDE.md files\n", len(claudeFiles)) - } - - imported := 0 - skipped := 0 - errors := 0 - - // Process memory files - fmt.Println("\n--- Memory Files ---") - for _, f := range memFiles { - project := extractProject(f) - sections := parseMarkdownSections(f) - filename := strings.TrimSuffix(filepath.Base(f), ".md") - - if len(sections) == 0 { - fmt.Printf(" skip %s/%s (no sections)\n", project, filename) - skipped++ - continue - } - - for _, sec := range sections { - content := sec.heading + "\n\n" + sec.content - if strings.TrimSpace(sec.content) == "" { - skipped++ - continue - } - - memType := inferType(sec.heading, sec.content, "memory") - tags := buildTags(filename, "memory", project) - confidence := confidenceForSource("memory") - - // Truncate to embedding model limit - content = truncate(content, *maxChars) - - if *dryRun { - fmt.Printf(" [DRY] %s/%s :: %s (%s) — %d chars\n", - project, filename, sec.heading, memType, len(content)) - imported++ - continue - } - - if err := callBrainRemember(content, memType, tags, project, confidence); err != nil { - fmt.Printf(" FAIL %s/%s :: %s — %v\n", project, filename, sec.heading, err) - errors++ - continue - } - fmt.Printf(" ok %s/%s :: %s (%s)\n", project, filename, sec.heading, memType) - imported++ - } - } - - // Process plan files - if *plans && len(planFiles) > 0 { - fmt.Println("\n--- Plan Documents ---") - for _, f := range planFiles { - project := extractProjectFromPlan(f) - sections := parseMarkdownSections(f) - filename := strings.TrimSuffix(filepath.Base(f), ".md") - - if len(sections) == 0 { - skipped++ - continue - } - - for _, sec := range sections { - content := sec.heading + "\n\n" + sec.content - if strings.TrimSpace(sec.content) == "" { - skipped++ - continue - } - - tags := buildTags(filename, "plans", project) - content = truncate(content, *maxChars) - - if *dryRun { - fmt.Printf(" [DRY] %s :: %s / %s (plan) — %d chars\n", - project, filename, sec.heading, len(content)) - imported++ - continue - } - - if err := callBrainRemember(content, "plan", tags, project, 0.6); err != nil { - fmt.Printf(" FAIL %s :: %s / %s — %v\n", project, filename, sec.heading, err) - errors++ - continue - } - fmt.Printf(" ok %s :: %s / %s (plan)\n", project, filename, sec.heading) - imported++ - } - } - } - - // Process CLAUDE.md files - if *claudeMd && len(claudeFiles) > 0 { - fmt.Println("\n--- CLAUDE.md Files ---") - for _, f := range claudeFiles { - project := extractProjectFromClaudeMd(f) - sections := parseMarkdownSections(f) - - if len(sections) == 0 { - skipped++ - continue - } - - for _, sec := range sections { - content := sec.heading + "\n\n" + sec.content - if strings.TrimSpace(sec.content) == "" { - skipped++ - continue - } - - tags := buildTags("CLAUDE", "claude-md", project) - content = truncate(content, *maxChars) - - if *dryRun { - fmt.Printf(" [DRY] %s :: CLAUDE.md / %s (convention) — %d chars\n", - project, sec.heading, len(content)) - imported++ - continue - } - - if err := callBrainRemember(content, "convention", tags, project, 0.9); err != nil { - fmt.Printf(" FAIL %s :: CLAUDE.md / %s — %v\n", project, sec.heading, err) - errors++ - continue - } - fmt.Printf(" ok %s :: CLAUDE.md / %s (convention)\n", project, sec.heading) - imported++ - } - } - } - - fmt.Printf("\n%s\n", strings.Repeat("=", 55)) - prefix := "" - if *dryRun { - prefix = "[DRY RUN] " - } - fmt.Printf("%sImported: %d | Skipped: %d | Errors: %d\n", prefix, imported, skipped, errors) -} - -// callBrainRemember sends a memory to the MCP API via brain_remember tool. -func callBrainRemember(content, memType string, tags []string, project string, confidence float64) error { - args := map[string]any{ - "content": content, - "type": memType, - "tags": tags, - "confidence": confidence, - } - if project != "" && project != "unknown" { - args["project"] = project - } - - payload := map[string]any{ - "server": *server, - "tool": "brain_remember", - "arguments": args, - } - - body, err := json.Marshal(payload) - if err != nil { - return fmt.Errorf("marshal: %w", err) - } - - req, err := http.NewRequest("POST", *apiURL+"/tools/call", bytes.NewReader(body)) - if err != nil { - return fmt.Errorf("request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+*apiKey) - - resp, err := httpClient.Do(req) - if err != nil { - return fmt.Errorf("http: %w", err) - } - defer resp.Body.Close() - - respBody, _ := io.ReadAll(resp.Body) - - if resp.StatusCode != 200 { - return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody)) - } - - var result struct { - Success bool `json:"success"` - Error string `json:"error"` - } - if err := json.Unmarshal(respBody, &result); err != nil { - return fmt.Errorf("decode: %w", err) - } - if !result.Success { - return fmt.Errorf("API: %s", result.Error) - } - - return nil -} - -// truncate caps content to maxLen chars, appending an ellipsis if truncated. -func truncate(s string, maxLen int) string { - if len(s) <= maxLen { - return s - } - // Find last space before limit to avoid splitting mid-word - cut := maxLen - if idx := strings.LastIndex(s[:maxLen], " "); idx > maxLen-200 { - cut = idx - } - return s[:cut] + "…" -} - -// discoverClaudeMdFiles finds CLAUDE.md files across a code directory. -func discoverClaudeMdFiles(codePath string) []string { - var files []string - - // Walk up to 4 levels deep, skip node_modules/vendor/.claude - _ = filepath.WalkDir(codePath, func(path string, d os.DirEntry, err error) error { - if err != nil { - return nil - } - if d.IsDir() { - name := d.Name() - if name == "node_modules" || name == "vendor" || name == ".claude" { - return filepath.SkipDir - } - // Limit depth - rel, _ := filepath.Rel(codePath, path) - if strings.Count(rel, string(os.PathSeparator)) > 3 { - return filepath.SkipDir - } - return nil - } - if d.Name() == "CLAUDE.md" { - files = append(files, path) - } - return nil - }) - - return files -} - -// section is a parsed markdown section. -type section struct { - heading string - content string -} - -var headingRe = regexp.MustCompile(`^#{1,3}\s+(.+)$`) - -// parseMarkdownSections splits a markdown file by headings. -func parseMarkdownSections(path string) []section { - data, err := os.ReadFile(path) - if err != nil || len(data) == 0 { - return nil - } - - var sections []section - lines := strings.Split(string(data), "\n") - var curHeading string - var curContent []string - - for _, line := range lines { - if m := headingRe.FindStringSubmatch(line); m != nil { - if curHeading != "" && len(curContent) > 0 { - text := strings.TrimSpace(strings.Join(curContent, "\n")) - if text != "" { - sections = append(sections, section{curHeading, text}) - } - } - curHeading = strings.TrimSpace(m[1]) - curContent = nil - } else { - curContent = append(curContent, line) - } - } - - // Flush last section - if curHeading != "" && len(curContent) > 0 { - text := strings.TrimSpace(strings.Join(curContent, "\n")) - if text != "" { - sections = append(sections, section{curHeading, text}) - } - } - - // If no headings found, treat entire file as one section - if len(sections) == 0 && strings.TrimSpace(string(data)) != "" { - sections = append(sections, section{ - heading: strings.TrimSuffix(filepath.Base(path), ".md"), - content: strings.TrimSpace(string(data)), - }) - } - - return sections -} - -// extractProject derives a project name from a Claude memory path. -// ~/.claude/projects/-Users-snider-Code-eaas/memory/MEMORY.md → "eaas" -func extractProject(path string) string { - re := regexp.MustCompile(`projects/[^/]*-([^-/]+)/memory/`) - if m := re.FindStringSubmatch(path); m != nil { - return m[1] - } - return "unknown" -} - -// extractProjectFromPlan derives a project name from a plan path. -// ~/Code/eaas/docs/plans/foo.md → "eaas" -// ~/Code/host-uk/core/docs/plans/foo.md → "core" -func extractProjectFromPlan(path string) string { - // Check host-uk nested repos first - re := regexp.MustCompile(`Code/host-uk/([^/]+)/docs/plans/`) - if m := re.FindStringSubmatch(path); m != nil { - return m[1] - } - re = regexp.MustCompile(`Code/([^/]+)/docs/plans/`) - if m := re.FindStringSubmatch(path); m != nil { - return m[1] - } - return "unknown" -} - -// extractProjectFromClaudeMd derives a project name from a CLAUDE.md path. -// ~/Code/host-uk/core/CLAUDE.md → "core" -// ~/Code/eaas/CLAUDE.md → "eaas" -func extractProjectFromClaudeMd(path string) string { - re := regexp.MustCompile(`Code/host-uk/([^/]+)/`) - if m := re.FindStringSubmatch(path); m != nil { - return m[1] - } - re = regexp.MustCompile(`Code/([^/]+)/`) - if m := re.FindStringSubmatch(path); m != nil { - return m[1] - } - return "unknown" -} - -// inferType guesses the memory type from heading + content keywords. -func inferType(heading, content, source string) string { - // Source-specific defaults (match PHP BrainIngestCommand behaviour) - if source == "plans" { - return "plan" - } - if source == "claude-md" { - return "convention" - } - - lower := strings.ToLower(heading + " " + content) - patterns := map[string][]string{ - "architecture": {"architecture", "stack", "infrastructure", "layer", "service mesh"}, - "convention": {"convention", "standard", "naming", "pattern", "rule", "coding"}, - "decision": {"decision", "chose", "strategy", "approach", "domain"}, - "bug": {"bug", "fix", "broken", "error", "issue", "lesson"}, - "plan": {"plan", "todo", "roadmap", "milestone", "phase", "task"}, - "research": {"research", "finding", "discovery", "analysis", "rfc"}, - } - for t, keywords := range patterns { - for _, kw := range keywords { - if strings.Contains(lower, kw) { - return t - } - } - } - return "observation" -} - -// buildTags creates the tag list for a memory. -func buildTags(filename, source, project string) []string { - tags := []string{"source:" + source} - if project != "" && project != "unknown" { - tags = append(tags, "project:"+project) - } - if filename != "MEMORY" && filename != "CLAUDE" { - tags = append(tags, strings.ReplaceAll(strings.ReplaceAll(filename, "-", " "), "_", " ")) - } - return tags -} - -// confidenceForSource returns a default confidence level matching the PHP ingest command. -func confidenceForSource(source string) float64 { - switch source { - case "claude-md": - return 0.9 - case "memory": - return 0.8 - case "plans": - return 0.6 - default: - return 0.5 - } -} diff --git a/cmd/daemon/cmd.go b/cmd/daemon/cmd.go index bf00d12..31db594 100644 --- a/cmd/daemon/cmd.go +++ b/cmd/daemon/cmd.go @@ -8,7 +8,7 @@ import ( "path/filepath" "forge.lthn.ai/core/cli/pkg/cli" - "forge.lthn.ai/core/go-ai/mcp" + "forge.lthn.ai/core/mcp/pkg/mcp" "forge.lthn.ai/core/go-log" "forge.lthn.ai/core/go-process" ) diff --git a/cmd/mcpcmd/cmd_mcp.go b/cmd/mcpcmd/cmd_mcp.go deleted file mode 100644 index 13ded33..0000000 --- a/cmd/mcpcmd/cmd_mcp.go +++ /dev/null @@ -1,92 +0,0 @@ -// Package mcpcmd provides the MCP server command. -// -// Commands: -// - mcp serve: Start the MCP server for AI tool integration -package mcpcmd - -import ( - "context" - "os" - "os/signal" - "syscall" - - "forge.lthn.ai/core/cli/pkg/cli" - "forge.lthn.ai/core/go-ai/mcp" -) - -var workspaceFlag string - -var mcpCmd = &cli.Command{ - Use: "mcp", - Short: "MCP server for AI tool integration", - Long: "Model Context Protocol (MCP) server providing file operations, RAG, and metrics tools.", -} - -var serveCmd = &cli.Command{ - Use: "serve", - Short: "Start the MCP server", - Long: `Start the MCP server on stdio (default) or TCP. - -The server provides file operations, RAG tools, and metrics tools for AI assistants. - -Environment variables: - MCP_ADDR TCP address to listen on (e.g., "localhost:9999") - If not set, uses stdio transport. - -Examples: - # Start with stdio transport (for Claude Code integration) - core mcp serve - - # Start with workspace restriction - core mcp serve --workspace /path/to/project - - # Start TCP server - MCP_ADDR=localhost:9999 core mcp serve`, - RunE: func(cmd *cli.Command, args []string) error { - return runServe() - }, -} - -func initFlags() { - cli.StringFlag(serveCmd, &workspaceFlag, "workspace", "w", "", "Restrict file operations to this directory (empty = unrestricted)") -} - -// AddMCPCommands registers the 'mcp' command and all subcommands. -func AddMCPCommands(root *cli.Command) { - initFlags() - mcpCmd.AddCommand(serveCmd) - root.AddCommand(mcpCmd) -} - -func runServe() error { - // Build MCP service options - var opts []mcp.Option - - if workspaceFlag != "" { - opts = append(opts, mcp.WithWorkspaceRoot(workspaceFlag)) - } else { - // Explicitly unrestricted when no workspace specified - opts = append(opts, mcp.WithWorkspaceRoot("")) - } - - // Create the MCP service - svc, err := mcp.New(opts...) - if err != nil { - return cli.Wrap(err, "create MCP service") - } - - // Set up signal handling for clean shutdown - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - - go func() { - <-sigCh - cancel() - }() - - // Run the server (blocks until context cancelled or error) - return svc.Run(ctx) -} diff --git a/mcp/brain/brain.go b/mcp/brain/brain.go deleted file mode 100644 index dfa95e7..0000000 --- a/mcp/brain/brain.go +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: EUPL-1.2 - -// Package brain provides an MCP subsystem that proxies OpenBrain knowledge -// store operations to the Laravel php-agentic backend via the IDE bridge. -package brain - -import ( - "context" - "errors" - - "forge.lthn.ai/core/go-ai/mcp/ide" - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// errBridgeNotAvailable is returned when a tool requires the Laravel bridge -// but it has not been initialised (headless mode). -var errBridgeNotAvailable = errors.New("brain: bridge not available") - -// Subsystem implements mcp.Subsystem for OpenBrain knowledge store operations. -// It proxies brain_* tool calls to the Laravel backend via the shared IDE bridge. -type Subsystem struct { - bridge *ide.Bridge -} - -// New creates a brain subsystem that uses the given IDE bridge for Laravel communication. -// Pass nil if headless (tools will return errBridgeNotAvailable). -func New(bridge *ide.Bridge) *Subsystem { - return &Subsystem{bridge: bridge} -} - -// Name implements mcp.Subsystem. -func (s *Subsystem) Name() string { return "brain" } - -// RegisterTools implements mcp.Subsystem. -func (s *Subsystem) RegisterTools(server *mcp.Server) { - s.registerBrainTools(server) -} - -// Shutdown implements mcp.SubsystemWithShutdown. -func (s *Subsystem) Shutdown(_ context.Context) error { - return nil -} diff --git a/mcp/brain/brain_test.go b/mcp/brain/brain_test.go deleted file mode 100644 index bf71cc5..0000000 --- a/mcp/brain/brain_test.go +++ /dev/null @@ -1,229 +0,0 @@ -// SPDX-License-Identifier: EUPL-1.2 - -package brain - -import ( - "context" - "encoding/json" - "testing" - "time" -) - -// --- Nil bridge tests (headless mode) --- - -func TestBrainRemember_Bad_NilBridge(t *testing.T) { - sub := New(nil) - _, _, err := sub.brainRemember(context.Background(), nil, RememberInput{ - Content: "test memory", - Type: "observation", - }) - if err == nil { - t.Error("expected error when bridge is nil") - } -} - -func TestBrainRecall_Bad_NilBridge(t *testing.T) { - sub := New(nil) - _, _, err := sub.brainRecall(context.Background(), nil, RecallInput{ - Query: "how does scoring work?", - }) - if err == nil { - t.Error("expected error when bridge is nil") - } -} - -func TestBrainForget_Bad_NilBridge(t *testing.T) { - sub := New(nil) - _, _, err := sub.brainForget(context.Background(), nil, ForgetInput{ - ID: "550e8400-e29b-41d4-a716-446655440000", - }) - if err == nil { - t.Error("expected error when bridge is nil") - } -} - -func TestBrainList_Bad_NilBridge(t *testing.T) { - sub := New(nil) - _, _, err := sub.brainList(context.Background(), nil, ListInput{ - Project: "eaas", - }) - if err == nil { - t.Error("expected error when bridge is nil") - } -} - -// --- Subsystem interface tests --- - -func TestSubsystem_Good_Name(t *testing.T) { - sub := New(nil) - if sub.Name() != "brain" { - t.Errorf("expected Name() = 'brain', got %q", sub.Name()) - } -} - -func TestSubsystem_Good_ShutdownNoop(t *testing.T) { - sub := New(nil) - if err := sub.Shutdown(context.Background()); err != nil { - t.Errorf("Shutdown failed: %v", err) - } -} - -// --- Struct round-trip tests --- - -func TestRememberInput_Good_RoundTrip(t *testing.T) { - in := RememberInput{ - Content: "LEM scoring was blind to negative emotions", - Type: "bug", - Tags: []string{"scoring", "lem"}, - Project: "eaas", - Confidence: 0.95, - Supersedes: "550e8400-e29b-41d4-a716-446655440000", - ExpiresIn: 24, - } - data, err := json.Marshal(in) - if err != nil { - t.Fatalf("marshal failed: %v", err) - } - var out RememberInput - if err := json.Unmarshal(data, &out); err != nil { - t.Fatalf("unmarshal failed: %v", err) - } - if out.Content != in.Content || out.Type != in.Type { - t.Errorf("round-trip mismatch: content or type") - } - if len(out.Tags) != 2 || out.Tags[0] != "scoring" { - t.Errorf("round-trip mismatch: tags") - } - if out.Confidence != 0.95 { - t.Errorf("round-trip mismatch: confidence %f != 0.95", out.Confidence) - } -} - -func TestRememberOutput_Good_RoundTrip(t *testing.T) { - in := RememberOutput{ - Success: true, - MemoryID: "550e8400-e29b-41d4-a716-446655440000", - Timestamp: time.Now().Truncate(time.Second), - } - data, err := json.Marshal(in) - if err != nil { - t.Fatalf("marshal failed: %v", err) - } - var out RememberOutput - if err := json.Unmarshal(data, &out); err != nil { - t.Fatalf("unmarshal failed: %v", err) - } - if !out.Success || out.MemoryID != in.MemoryID { - t.Errorf("round-trip mismatch: %+v != %+v", out, in) - } -} - -func TestRecallInput_Good_RoundTrip(t *testing.T) { - in := RecallInput{ - Query: "how does verdict classification work?", - TopK: 5, - Filter: RecallFilter{ - Project: "eaas", - MinConfidence: 0.5, - }, - } - data, err := json.Marshal(in) - if err != nil { - t.Fatalf("marshal failed: %v", err) - } - var out RecallInput - if err := json.Unmarshal(data, &out); err != nil { - t.Fatalf("unmarshal failed: %v", err) - } - if out.Query != in.Query || out.TopK != 5 { - t.Errorf("round-trip mismatch: query or topK") - } - if out.Filter.Project != "eaas" || out.Filter.MinConfidence != 0.5 { - t.Errorf("round-trip mismatch: filter") - } -} - -func TestMemory_Good_RoundTrip(t *testing.T) { - in := Memory{ - ID: "550e8400-e29b-41d4-a716-446655440000", - AgentID: "virgil", - Type: "decision", - Content: "Use Qdrant for vector search", - Tags: []string{"architecture", "openbrain"}, - Project: "php-agentic", - Confidence: 0.9, - CreatedAt: "2026-03-03T12:00:00+00:00", - UpdatedAt: "2026-03-03T12:00:00+00:00", - } - data, err := json.Marshal(in) - if err != nil { - t.Fatalf("marshal failed: %v", err) - } - var out Memory - if err := json.Unmarshal(data, &out); err != nil { - t.Fatalf("unmarshal failed: %v", err) - } - if out.ID != in.ID || out.AgentID != "virgil" || out.Type != "decision" { - t.Errorf("round-trip mismatch: %+v", out) - } -} - -func TestForgetInput_Good_RoundTrip(t *testing.T) { - in := ForgetInput{ - ID: "550e8400-e29b-41d4-a716-446655440000", - Reason: "Superseded by new approach", - } - data, err := json.Marshal(in) - if err != nil { - t.Fatalf("marshal failed: %v", err) - } - var out ForgetInput - if err := json.Unmarshal(data, &out); err != nil { - t.Fatalf("unmarshal failed: %v", err) - } - if out.ID != in.ID || out.Reason != in.Reason { - t.Errorf("round-trip mismatch: %+v != %+v", out, in) - } -} - -func TestListInput_Good_RoundTrip(t *testing.T) { - in := ListInput{ - Project: "eaas", - Type: "decision", - AgentID: "charon", - Limit: 20, - } - data, err := json.Marshal(in) - if err != nil { - t.Fatalf("marshal failed: %v", err) - } - var out ListInput - if err := json.Unmarshal(data, &out); err != nil { - t.Fatalf("unmarshal failed: %v", err) - } - if out.Project != "eaas" || out.Type != "decision" || out.AgentID != "charon" || out.Limit != 20 { - t.Errorf("round-trip mismatch: %+v", out) - } -} - -func TestListOutput_Good_RoundTrip(t *testing.T) { - in := ListOutput{ - Success: true, - Count: 2, - Memories: []Memory{ - {ID: "id-1", AgentID: "virgil", Type: "decision", Content: "memory 1", Confidence: 0.9, CreatedAt: "2026-03-03T12:00:00+00:00", UpdatedAt: "2026-03-03T12:00:00+00:00"}, - {ID: "id-2", AgentID: "charon", Type: "bug", Content: "memory 2", Confidence: 0.8, CreatedAt: "2026-03-03T13:00:00+00:00", UpdatedAt: "2026-03-03T13:00:00+00:00"}, - }, - } - data, err := json.Marshal(in) - if err != nil { - t.Fatalf("marshal failed: %v", err) - } - var out ListOutput - if err := json.Unmarshal(data, &out); err != nil { - t.Fatalf("unmarshal failed: %v", err) - } - if !out.Success || out.Count != 2 || len(out.Memories) != 2 { - t.Errorf("round-trip mismatch: %+v", out) - } -} diff --git a/mcp/brain/tools.go b/mcp/brain/tools.go deleted file mode 100644 index e724713..0000000 --- a/mcp/brain/tools.go +++ /dev/null @@ -1,220 +0,0 @@ -// SPDX-License-Identifier: EUPL-1.2 - -package brain - -import ( - "context" - "fmt" - "time" - - "forge.lthn.ai/core/go-ai/mcp/ide" - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// -- Input/Output types ------------------------------------------------------- - -// RememberInput is the input for brain_remember. -type RememberInput struct { - Content string `json:"content"` - Type string `json:"type"` - Tags []string `json:"tags,omitempty"` - Project string `json:"project,omitempty"` - Confidence float64 `json:"confidence,omitempty"` - Supersedes string `json:"supersedes,omitempty"` - ExpiresIn int `json:"expires_in,omitempty"` -} - -// RememberOutput is the output for brain_remember. -type RememberOutput struct { - Success bool `json:"success"` - MemoryID string `json:"memoryId,omitempty"` - Timestamp time.Time `json:"timestamp"` -} - -// RecallInput is the input for brain_recall. -type RecallInput struct { - Query string `json:"query"` - TopK int `json:"top_k,omitempty"` - Filter RecallFilter `json:"filter,omitempty"` -} - -// RecallFilter holds optional filter criteria for brain_recall. -type RecallFilter struct { - Project string `json:"project,omitempty"` - Type any `json:"type,omitempty"` - AgentID string `json:"agent_id,omitempty"` - MinConfidence float64 `json:"min_confidence,omitempty"` -} - -// RecallOutput is the output for brain_recall. -type RecallOutput struct { - Success bool `json:"success"` - Count int `json:"count"` - Memories []Memory `json:"memories"` -} - -// Memory is a single memory entry returned by recall or list. -type Memory struct { - ID string `json:"id"` - AgentID string `json:"agent_id"` - Type string `json:"type"` - Content string `json:"content"` - Tags []string `json:"tags,omitempty"` - Project string `json:"project,omitempty"` - Confidence float64 `json:"confidence"` - SupersedesID string `json:"supersedes_id,omitempty"` - ExpiresAt string `json:"expires_at,omitempty"` - CreatedAt string `json:"created_at"` - UpdatedAt string `json:"updated_at"` -} - -// ForgetInput is the input for brain_forget. -type ForgetInput struct { - ID string `json:"id"` - Reason string `json:"reason,omitempty"` -} - -// ForgetOutput is the output for brain_forget. -type ForgetOutput struct { - Success bool `json:"success"` - Forgotten string `json:"forgotten"` - Timestamp time.Time `json:"timestamp"` -} - -// ListInput is the input for brain_list. -type ListInput struct { - Project string `json:"project,omitempty"` - Type string `json:"type,omitempty"` - AgentID string `json:"agent_id,omitempty"` - Limit int `json:"limit,omitempty"` -} - -// ListOutput is the output for brain_list. -type ListOutput struct { - Success bool `json:"success"` - Count int `json:"count"` - Memories []Memory `json:"memories"` -} - -// -- Tool registration -------------------------------------------------------- - -func (s *Subsystem) registerBrainTools(server *mcp.Server) { - mcp.AddTool(server, &mcp.Tool{ - Name: "brain_remember", - Description: "Store a memory in the shared OpenBrain knowledge store. Persists decisions, observations, conventions, research, plans, bugs, or architecture knowledge for other agents.", - }, s.brainRemember) - - mcp.AddTool(server, &mcp.Tool{ - Name: "brain_recall", - Description: "Semantic search across the shared OpenBrain knowledge store. Returns memories ranked by similarity to your query, with optional filtering.", - }, s.brainRecall) - - mcp.AddTool(server, &mcp.Tool{ - Name: "brain_forget", - Description: "Remove a memory from the shared OpenBrain knowledge store. Permanently deletes from both database and vector index.", - }, s.brainForget) - - mcp.AddTool(server, &mcp.Tool{ - Name: "brain_list", - Description: "List memories in the shared OpenBrain knowledge store. Supports filtering by project, type, and agent. No vector search -- use brain_recall for semantic queries.", - }, s.brainList) -} - -// -- Tool handlers ------------------------------------------------------------ - -func (s *Subsystem) brainRemember(_ context.Context, _ *mcp.CallToolRequest, input RememberInput) (*mcp.CallToolResult, RememberOutput, error) { - if s.bridge == nil { - return nil, RememberOutput{}, errBridgeNotAvailable - } - - err := s.bridge.Send(ide.BridgeMessage{ - Type: "brain_remember", - Data: map[string]any{ - "content": input.Content, - "type": input.Type, - "tags": input.Tags, - "project": input.Project, - "confidence": input.Confidence, - "supersedes": input.Supersedes, - "expires_in": input.ExpiresIn, - }, - }) - if err != nil { - return nil, RememberOutput{}, fmt.Errorf("failed to send brain_remember: %w", err) - } - - return nil, RememberOutput{ - Success: true, - Timestamp: time.Now(), - }, nil -} - -func (s *Subsystem) brainRecall(_ context.Context, _ *mcp.CallToolRequest, input RecallInput) (*mcp.CallToolResult, RecallOutput, error) { - if s.bridge == nil { - return nil, RecallOutput{}, errBridgeNotAvailable - } - - err := s.bridge.Send(ide.BridgeMessage{ - Type: "brain_recall", - Data: map[string]any{ - "query": input.Query, - "top_k": input.TopK, - "filter": input.Filter, - }, - }) - if err != nil { - return nil, RecallOutput{}, fmt.Errorf("failed to send brain_recall: %w", err) - } - - return nil, RecallOutput{ - Success: true, - Memories: []Memory{}, - }, nil -} - -func (s *Subsystem) brainForget(_ context.Context, _ *mcp.CallToolRequest, input ForgetInput) (*mcp.CallToolResult, ForgetOutput, error) { - if s.bridge == nil { - return nil, ForgetOutput{}, errBridgeNotAvailable - } - - err := s.bridge.Send(ide.BridgeMessage{ - Type: "brain_forget", - Data: map[string]any{ - "id": input.ID, - "reason": input.Reason, - }, - }) - if err != nil { - return nil, ForgetOutput{}, fmt.Errorf("failed to send brain_forget: %w", err) - } - - return nil, ForgetOutput{ - Success: true, - Forgotten: input.ID, - Timestamp: time.Now(), - }, nil -} - -func (s *Subsystem) brainList(_ context.Context, _ *mcp.CallToolRequest, input ListInput) (*mcp.CallToolResult, ListOutput, error) { - if s.bridge == nil { - return nil, ListOutput{}, errBridgeNotAvailable - } - - err := s.bridge.Send(ide.BridgeMessage{ - Type: "brain_list", - Data: map[string]any{ - "project": input.Project, - "type": input.Type, - "agent_id": input.AgentID, - "limit": input.Limit, - }, - }) - if err != nil { - return nil, ListOutput{}, fmt.Errorf("failed to send brain_list: %w", err) - } - - return nil, ListOutput{ - Success: true, - Memories: []Memory{}, - }, nil -} diff --git a/mcp/bridge.go b/mcp/bridge.go deleted file mode 100644 index eb5689f..0000000 --- a/mcp/bridge.go +++ /dev/null @@ -1,64 +0,0 @@ -// SPDX-License-Identifier: EUPL-1.2 - -package mcp - -import ( - "encoding/json" - "errors" - "io" - "net/http" - - "github.com/gin-gonic/gin" - - api "forge.lthn.ai/core/go-api" -) - -// maxBodySize is the maximum request body size accepted by bridged tool endpoints. -const maxBodySize = 10 << 20 // 10 MB - -// BridgeToAPI populates a go-api ToolBridge from recorded MCP tools. -// Each tool becomes a POST endpoint that reads a JSON body, dispatches -// to the tool's RESTHandler (which knows the concrete input type), and -// wraps the result in the standard api.Response envelope. -func BridgeToAPI(svc *Service, bridge *api.ToolBridge) { - for rec := range svc.ToolsSeq() { - desc := api.ToolDescriptor{ - Name: rec.Name, - Description: rec.Description, - Group: rec.Group, - InputSchema: rec.InputSchema, - OutputSchema: rec.OutputSchema, - } - - // Capture the handler for the closure. - handler := rec.RESTHandler - - bridge.Add(desc, func(c *gin.Context) { - var body []byte - if c.Request.Body != nil { - var err error - body, err = io.ReadAll(io.LimitReader(c.Request.Body, maxBodySize)) - if err != nil { - c.JSON(http.StatusBadRequest, api.Fail("invalid_request", "Failed to read request body")) - return - } - } - - result, err := handler(c.Request.Context(), body) - if err != nil { - // Classify JSON parse errors as client errors (400), - // everything else as server errors (500). - var syntaxErr *json.SyntaxError - var typeErr *json.UnmarshalTypeError - if errors.As(err, &syntaxErr) || errors.As(err, &typeErr) { - c.JSON(http.StatusBadRequest, api.Fail("invalid_input", "Malformed JSON in request body")) - return - } - c.JSON(http.StatusInternalServerError, api.Fail("tool_error", "Tool execution failed")) - return - } - - c.JSON(http.StatusOK, api.OK(result)) - }) - } -} diff --git a/mcp/bridge_test.go b/mcp/bridge_test.go deleted file mode 100644 index bc8df91..0000000 --- a/mcp/bridge_test.go +++ /dev/null @@ -1,250 +0,0 @@ -// SPDX-License-Identifier: EUPL-1.2 - -package mcp - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "os" - "path/filepath" - "strings" - "testing" - - "github.com/gin-gonic/gin" - - api "forge.lthn.ai/core/go-api" -) - -func init() { - gin.SetMode(gin.TestMode) -} - -func TestBridgeToAPI_Good_AllTools(t *testing.T) { - svc, err := New(WithWorkspaceRoot(t.TempDir())) - if err != nil { - t.Fatal(err) - } - - bridge := api.NewToolBridge("/tools") - BridgeToAPI(svc, bridge) - - svcCount := len(svc.Tools()) - bridgeCount := len(bridge.Tools()) - - if svcCount == 0 { - t.Fatal("expected non-zero tool count from service") - } - if bridgeCount != svcCount { - t.Fatalf("bridge tool count %d != service tool count %d", bridgeCount, svcCount) - } - - // Verify names match. - svcNames := make(map[string]bool) - for _, tr := range svc.Tools() { - svcNames[tr.Name] = true - } - for _, td := range bridge.Tools() { - if !svcNames[td.Name] { - t.Errorf("bridge has tool %q not found in service", td.Name) - } - } -} - -func TestBridgeToAPI_Good_DescribableGroup(t *testing.T) { - svc, err := New(WithWorkspaceRoot(t.TempDir())) - if err != nil { - t.Fatal(err) - } - - bridge := api.NewToolBridge("/tools") - BridgeToAPI(svc, bridge) - - // ToolBridge implements DescribableGroup. - var dg api.DescribableGroup = bridge - descs := dg.Describe() - - if len(descs) != len(svc.Tools()) { - t.Fatalf("expected %d descriptions, got %d", len(svc.Tools()), len(descs)) - } - - for _, d := range descs { - if d.Method != "POST" { - t.Errorf("expected Method=POST for %s, got %q", d.Path, d.Method) - } - if d.Summary == "" { - t.Errorf("expected non-empty Summary for %s", d.Path) - } - if len(d.Tags) == 0 { - t.Errorf("expected non-empty Tags for %s", d.Path) - } - } -} - -func TestBridgeToAPI_Good_FileRead(t *testing.T) { - tmpDir := t.TempDir() - - // Create a test file in the workspace. - testContent := "hello from bridge test" - if err := os.WriteFile(filepath.Join(tmpDir, "test.txt"), []byte(testContent), 0644); err != nil { - t.Fatal(err) - } - - svc, err := New(WithWorkspaceRoot(tmpDir)) - if err != nil { - t.Fatal(err) - } - - bridge := api.NewToolBridge("/tools") - BridgeToAPI(svc, bridge) - - // Register with a Gin engine and make a request. - engine := gin.New() - rg := engine.Group(bridge.BasePath()) - bridge.RegisterRoutes(rg) - - body := `{"path":"test.txt"}` - w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/tools/file_read", strings.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - engine.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) - } - - // Parse the response envelope. - var resp api.Response[ReadFileOutput] - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("unmarshal error: %v", err) - } - if !resp.Success { - t.Fatalf("expected Success=true, got error: %+v", resp.Error) - } - if resp.Data.Content != testContent { - t.Fatalf("expected content %q, got %q", testContent, resp.Data.Content) - } - if resp.Data.Path != "test.txt" { - t.Fatalf("expected path %q, got %q", "test.txt", resp.Data.Path) - } -} - -func TestBridgeToAPI_Bad_InvalidJSON(t *testing.T) { - svc, err := New(WithWorkspaceRoot(t.TempDir())) - if err != nil { - t.Fatal(err) - } - - bridge := api.NewToolBridge("/tools") - BridgeToAPI(svc, bridge) - - engine := gin.New() - rg := engine.Group(bridge.BasePath()) - bridge.RegisterRoutes(rg) - - // Send malformed JSON. - w := httptest.NewRecorder() - req, _ := http.NewRequest(http.MethodPost, "/tools/file_read", strings.NewReader("{bad json")) - req.Header.Set("Content-Type", "application/json") - engine.ServeHTTP(w, req) - - if w.Code != http.StatusInternalServerError { - // The handler unmarshals via RESTHandler which returns an error, - // but since it's a JSON parse error it ends up as tool_error. - // Check we get a non-200 with an error envelope. - if w.Code == http.StatusOK { - t.Fatalf("expected non-200 for invalid JSON, got 200") - } - } - - var resp api.Response[any] - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("unmarshal error: %v", err) - } - if resp.Success { - t.Fatal("expected Success=false for invalid JSON") - } - if resp.Error == nil { - t.Fatal("expected error in response") - } -} - -func TestBridgeToAPI_Good_EndToEnd(t *testing.T) { - svc, err := New(WithWorkspaceRoot(t.TempDir())) - if err != nil { - t.Fatal(err) - } - - bridge := api.NewToolBridge("/tools") - BridgeToAPI(svc, bridge) - - // Create an api.Engine with the bridge registered and Swagger enabled. - e, err := api.New( - api.WithSwagger("MCP Bridge Test", "Testing MCP-to-REST bridge", "0.1.0"), - ) - if err != nil { - t.Fatal(err) - } - e.Register(bridge) - - // Use a real test server because gin-swagger reads RequestURI - // which is not populated by httptest.NewRecorder. - srv := httptest.NewServer(e.Handler()) - defer srv.Close() - - // Verify the health endpoint still works. - resp, err := http.Get(srv.URL + "/health") - if err != nil { - t.Fatalf("health request failed: %v", err) - } - resp.Body.Close() - if resp.StatusCode != http.StatusOK { - t.Fatalf("expected 200 for /health, got %d", resp.StatusCode) - } - - // Verify a tool endpoint is reachable through the engine. - resp2, err := http.Post(srv.URL+"/tools/lang_list", "application/json", nil) - if err != nil { - t.Fatalf("lang_list request failed: %v", err) - } - defer resp2.Body.Close() - if resp2.StatusCode != http.StatusOK { - t.Fatalf("expected 200 for /tools/lang_list, got %d", resp2.StatusCode) - } - - var langResp api.Response[GetSupportedLanguagesOutput] - if err := json.NewDecoder(resp2.Body).Decode(&langResp); err != nil { - t.Fatalf("unmarshal error: %v", err) - } - if !langResp.Success { - t.Fatalf("expected Success=true, got error: %+v", langResp.Error) - } - if len(langResp.Data.Languages) == 0 { - t.Fatal("expected non-empty languages list") - } - - // Verify Swagger endpoint contains tool paths. - resp3, err := http.Get(srv.URL + "/swagger/doc.json") - if err != nil { - t.Fatalf("swagger request failed: %v", err) - } - defer resp3.Body.Close() - if resp3.StatusCode != http.StatusOK { - t.Fatalf("expected 200 for /swagger/doc.json, got %d", resp3.StatusCode) - } - - var specDoc map[string]any - if err := json.NewDecoder(resp3.Body).Decode(&specDoc); err != nil { - t.Fatalf("swagger unmarshal error: %v", err) - } - paths, ok := specDoc["paths"].(map[string]any) - if !ok { - t.Fatal("expected 'paths' in swagger spec") - } - if _, ok := paths["/tools/file_read"]; !ok { - t.Error("expected /tools/file_read in swagger paths") - } - if _, ok := paths["/tools/lang_list"]; !ok { - t.Error("expected /tools/lang_list in swagger paths") - } -} diff --git a/mcp/ide/bridge.go b/mcp/ide/bridge.go deleted file mode 100644 index 56ce884..0000000 --- a/mcp/ide/bridge.go +++ /dev/null @@ -1,191 +0,0 @@ -package ide - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "log" - "net/http" - "sync" - "time" - - "forge.lthn.ai/core/go-ws" - "github.com/gorilla/websocket" -) - -// BridgeMessage is the wire format between the IDE and Laravel. -type BridgeMessage struct { - Type string `json:"type"` - Channel string `json:"channel,omitempty"` - SessionID string `json:"sessionId,omitempty"` - Data any `json:"data,omitempty"` - Timestamp time.Time `json:"timestamp"` -} - -// Bridge maintains a WebSocket connection to the Laravel core-agentic -// backend and forwards responses to a local ws.Hub. -type Bridge struct { - cfg Config - hub *ws.Hub - conn *websocket.Conn - - mu sync.Mutex - connected bool - cancel context.CancelFunc -} - -// NewBridge creates a bridge that will connect to the Laravel backend and -// forward incoming messages to the provided ws.Hub channels. -func NewBridge(hub *ws.Hub, cfg Config) *Bridge { - return &Bridge{cfg: cfg, hub: hub} -} - -// Start begins the connection loop in a background goroutine. -// Call Shutdown to stop it. -func (b *Bridge) Start(ctx context.Context) { - ctx, b.cancel = context.WithCancel(ctx) - go b.connectLoop(ctx) -} - -// Shutdown cleanly closes the bridge. -func (b *Bridge) Shutdown() { - if b.cancel != nil { - b.cancel() - } - b.mu.Lock() - defer b.mu.Unlock() - if b.conn != nil { - b.conn.Close() - b.conn = nil - } - b.connected = false -} - -// Connected reports whether the bridge has an active connection. -func (b *Bridge) Connected() bool { - b.mu.Lock() - defer b.mu.Unlock() - return b.connected -} - -// Send sends a message to the Laravel backend. -func (b *Bridge) Send(msg BridgeMessage) error { - b.mu.Lock() - defer b.mu.Unlock() - if b.conn == nil { - return errors.New("bridge: not connected") - } - msg.Timestamp = time.Now() - data, err := json.Marshal(msg) - if err != nil { - return fmt.Errorf("bridge: marshal failed: %w", err) - } - return b.conn.WriteMessage(websocket.TextMessage, data) -} - -// connectLoop reconnects to Laravel with exponential backoff. -func (b *Bridge) connectLoop(ctx context.Context) { - delay := b.cfg.ReconnectInterval - for { - select { - case <-ctx.Done(): - return - default: - } - - if err := b.dial(ctx); err != nil { - log.Printf("ide bridge: connect failed: %v", err) - select { - case <-ctx.Done(): - return - case <-time.After(delay): - } - delay = min(delay*2, b.cfg.MaxReconnectInterval) - continue - } - - // Reset backoff on successful connection - delay = b.cfg.ReconnectInterval - b.readLoop(ctx) - } -} - -func (b *Bridge) dial(ctx context.Context) error { - dialer := websocket.Dialer{ - HandshakeTimeout: 10 * time.Second, - } - - var header http.Header - if b.cfg.Token != "" { - header = http.Header{} - header.Set("Authorization", "Bearer "+b.cfg.Token) - } - - conn, _, err := dialer.DialContext(ctx, b.cfg.LaravelWSURL, header) - if err != nil { - return err - } - - b.mu.Lock() - b.conn = conn - b.connected = true - b.mu.Unlock() - - log.Printf("ide bridge: connected to %s", b.cfg.LaravelWSURL) - return nil -} - -func (b *Bridge) readLoop(ctx context.Context) { - defer func() { - b.mu.Lock() - if b.conn != nil { - b.conn.Close() - } - b.connected = false - b.mu.Unlock() - }() - - for { - select { - case <-ctx.Done(): - return - default: - } - - _, data, err := b.conn.ReadMessage() - if err != nil { - log.Printf("ide bridge: read error: %v", err) - return - } - - var msg BridgeMessage - if err := json.Unmarshal(data, &msg); err != nil { - log.Printf("ide bridge: unmarshal error: %v", err) - continue - } - - b.dispatch(msg) - } -} - -// dispatch routes an incoming message to the appropriate ws.Hub channel. -func (b *Bridge) dispatch(msg BridgeMessage) { - if b.hub == nil { - return - } - - wsMsg := ws.Message{ - Type: ws.TypeEvent, - Data: msg.Data, - } - - channel := msg.Channel - if channel == "" { - channel = "ide:" + msg.Type - } - - if err := b.hub.SendToChannel(channel, wsMsg); err != nil { - log.Printf("ide bridge: dispatch to %s failed: %v", channel, err) - } -} diff --git a/mcp/ide/bridge_test.go b/mcp/ide/bridge_test.go deleted file mode 100644 index f1e3881..0000000 --- a/mcp/ide/bridge_test.go +++ /dev/null @@ -1,442 +0,0 @@ -package ide - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "strings" - "sync" - "sync/atomic" - "testing" - "time" - - "forge.lthn.ai/core/go-ws" - "github.com/gorilla/websocket" -) - -var testUpgrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { return true }, -} - -// echoServer creates a test WebSocket server that echoes messages back. -func echoServer(t *testing.T) *httptest.Server { - t.Helper() - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := testUpgrader.Upgrade(w, r, nil) - if err != nil { - t.Logf("upgrade error: %v", err) - return - } - defer conn.Close() - for { - mt, data, err := conn.ReadMessage() - if err != nil { - break - } - if err := conn.WriteMessage(mt, data); err != nil { - break - } - } - })) -} - -func wsURL(ts *httptest.Server) string { - return "ws" + strings.TrimPrefix(ts.URL, "http") -} - -// waitConnected polls bridge.Connected() until true or timeout. -func waitConnected(t *testing.T, bridge *Bridge, timeout time.Duration) { - t.Helper() - deadline := time.Now().Add(timeout) - for !bridge.Connected() && time.Now().Before(deadline) { - time.Sleep(50 * time.Millisecond) - } - if !bridge.Connected() { - t.Fatal("bridge did not connect within timeout") - } -} - -func TestBridge_Good_ConnectAndSend(t *testing.T) { - ts := echoServer(t) - defer ts.Close() - - hub := ws.NewHub() - ctx := t.Context() - go hub.Run(ctx) - - cfg := DefaultConfig() - cfg.LaravelWSURL = wsURL(ts) - cfg.ReconnectInterval = 100 * time.Millisecond - - bridge := NewBridge(hub, cfg) - bridge.Start(ctx) - - waitConnected(t, bridge, 2*time.Second) - - err := bridge.Send(BridgeMessage{ - Type: "test", - Data: "hello", - }) - if err != nil { - t.Fatalf("Send() failed: %v", err) - } -} - -func TestBridge_Good_Shutdown(t *testing.T) { - ts := echoServer(t) - defer ts.Close() - - hub := ws.NewHub() - ctx := t.Context() - go hub.Run(ctx) - - cfg := DefaultConfig() - cfg.LaravelWSURL = wsURL(ts) - cfg.ReconnectInterval = 100 * time.Millisecond - - bridge := NewBridge(hub, cfg) - bridge.Start(ctx) - - waitConnected(t, bridge, 2*time.Second) - - bridge.Shutdown() - if bridge.Connected() { - t.Error("bridge should be disconnected after Shutdown") - } -} - -func TestBridge_Bad_SendWithoutConnection(t *testing.T) { - hub := ws.NewHub() - cfg := DefaultConfig() - bridge := NewBridge(hub, cfg) - - err := bridge.Send(BridgeMessage{Type: "test"}) - if err == nil { - t.Error("expected error when sending without connection") - } -} - -func TestBridge_Good_MessageDispatch(t *testing.T) { - // Server that sends a message to the bridge on connect. - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := testUpgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - - msg := BridgeMessage{ - Type: "chat_response", - Channel: "chat:session-1", - Data: "hello from laravel", - } - data, _ := json.Marshal(msg) - conn.WriteMessage(websocket.TextMessage, data) - - // Keep connection open - for { - _, _, err := conn.ReadMessage() - if err != nil { - break - } - } - })) - defer ts.Close() - - hub := ws.NewHub() - ctx := t.Context() - go hub.Run(ctx) - - cfg := DefaultConfig() - cfg.LaravelWSURL = wsURL(ts) - cfg.ReconnectInterval = 100 * time.Millisecond - - bridge := NewBridge(hub, cfg) - bridge.Start(ctx) - - waitConnected(t, bridge, 2*time.Second) - - // Give time for the dispatched message to be processed. - time.Sleep(200 * time.Millisecond) - - // Verify hub stats — the message was dispatched (even without subscribers). - // This confirms the dispatch path ran without error. -} - -func TestBridge_Good_Reconnect(t *testing.T) { - // Use atomic counter to avoid data race between HTTP handler goroutine - // and the test goroutine. - var callCount atomic.Int32 - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - n := callCount.Add(1) - conn, err := testUpgrader.Upgrade(w, r, nil) - if err != nil { - return - } - // Close immediately on first connection to force reconnect - if n == 1 { - conn.Close() - return - } - defer conn.Close() - for { - _, _, err := conn.ReadMessage() - if err != nil { - break - } - } - })) - defer ts.Close() - - hub := ws.NewHub() - ctx := t.Context() - go hub.Run(ctx) - - cfg := DefaultConfig() - cfg.LaravelWSURL = wsURL(ts) - cfg.ReconnectInterval = 100 * time.Millisecond - cfg.MaxReconnectInterval = 200 * time.Millisecond - - bridge := NewBridge(hub, cfg) - bridge.Start(ctx) - - waitConnected(t, bridge, 3*time.Second) - - if callCount.Load() < 2 { - t.Errorf("expected at least 2 connection attempts, got %d", callCount.Load()) - } -} - -func TestBridge_Good_ExponentialBackoff(t *testing.T) { - // Track timestamps of dial attempts to verify backoff behaviour. - // The server rejects the WebSocket upgrade with HTTP 403, so dial() - // returns an error and the exponential backoff path fires. - var attempts []time.Time - var mu sync.Mutex - var attemptCount atomic.Int32 - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - mu.Lock() - attempts = append(attempts, time.Now()) - mu.Unlock() - attemptCount.Add(1) - - // Reject the upgrade — this makes dial() fail, triggering backoff. - http.Error(w, "forbidden", http.StatusForbidden) - })) - defer ts.Close() - - hub := ws.NewHub() - ctx := t.Context() - go hub.Run(ctx) - - cfg := DefaultConfig() - cfg.LaravelWSURL = wsURL(ts) - cfg.ReconnectInterval = 100 * time.Millisecond - cfg.MaxReconnectInterval = 400 * time.Millisecond - - bridge := NewBridge(hub, cfg) - bridge.Start(ctx) - - // Wait for at least 4 dial attempts. - deadline := time.Now().Add(5 * time.Second) - for attemptCount.Load() < 4 && time.Now().Before(deadline) { - time.Sleep(50 * time.Millisecond) - } - bridge.Shutdown() - - mu.Lock() - defer mu.Unlock() - - if len(attempts) < 4 { - t.Fatalf("expected at least 4 connection attempts, got %d", len(attempts)) - } - - // Verify exponential backoff: gap between attempts should increase. - // Expected delays: ~100ms, ~200ms, ~400ms (capped). - // Allow generous tolerance since timing is non-deterministic. - for i := 1; i < len(attempts) && i <= 3; i++ { - gap := attempts[i].Sub(attempts[i-1]) - // Minimum expected delay doubles each time: 100, 200, 400. - // We check a lower bound (50% of expected) to be resilient. - expectedMin := time.Duration(50*(1<<(i-1))) * time.Millisecond - if gap < expectedMin { - t.Errorf("attempt %d->%d gap %v < expected minimum %v", i-1, i, gap, expectedMin) - } - } - - // Verify the backoff caps at MaxReconnectInterval. - if len(attempts) >= 5 { - gap := attempts[4].Sub(attempts[3]) - // After cap is hit, delay should not exceed MaxReconnectInterval + tolerance. - maxExpected := cfg.MaxReconnectInterval + 200*time.Millisecond - if gap > maxExpected { - t.Errorf("attempt 3->4 gap %v exceeded max backoff %v", gap, maxExpected) - } - } -} - -func TestBridge_Good_ReconnectDetectsServerShutdown(t *testing.T) { - // Start a server that closes the WS connection on demand, then close - // the server entirely so the bridge cannot reconnect. - closeConn := make(chan struct{}, 1) - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := testUpgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - - // Wait for signal to close - <-closeConn - conn.WriteMessage(websocket.CloseMessage, - websocket.FormatCloseMessage(websocket.CloseGoingAway, "shutdown")) - })) - - hub := ws.NewHub() - ctx := t.Context() - go hub.Run(ctx) - - cfg := DefaultConfig() - cfg.LaravelWSURL = wsURL(ts) - // Use long reconnect so bridge stays disconnected after server dies. - cfg.ReconnectInterval = 5 * time.Second - cfg.MaxReconnectInterval = 5 * time.Second - - bridge := NewBridge(hub, cfg) - bridge.Start(ctx) - - waitConnected(t, bridge, 2*time.Second) - - // Signal server handler to close the WS connection, then shut down - // the server so the reconnect dial() also fails. - closeConn <- struct{}{} - ts.Close() - - // Wait for disconnection. - deadline := time.Now().Add(3 * time.Second) - for bridge.Connected() && time.Now().Before(deadline) { - time.Sleep(50 * time.Millisecond) - } - - if bridge.Connected() { - t.Error("expected bridge to detect server-side connection close") - } -} - -func TestBridge_Good_AuthHeader(t *testing.T) { - // Server that checks for the Authorization header on upgrade. - var receivedAuth atomic.Value - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedAuth.Store(r.Header.Get("Authorization")) - conn, err := testUpgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - for { - _, _, err := conn.ReadMessage() - if err != nil { - break - } - } - })) - defer ts.Close() - - hub := ws.NewHub() - ctx := t.Context() - go hub.Run(ctx) - - cfg := DefaultConfig() - cfg.LaravelWSURL = wsURL(ts) - cfg.ReconnectInterval = 100 * time.Millisecond - cfg.Token = "test-secret-token-42" - - bridge := NewBridge(hub, cfg) - bridge.Start(ctx) - - waitConnected(t, bridge, 2*time.Second) - - auth, ok := receivedAuth.Load().(string) - if !ok || auth == "" { - t.Fatal("server did not receive Authorization header") - } - - expected := "Bearer test-secret-token-42" - if auth != expected { - t.Errorf("expected auth header %q, got %q", expected, auth) - } -} - -func TestBridge_Good_NoAuthHeaderWhenTokenEmpty(t *testing.T) { - // Verify that no Authorization header is sent when Token is empty. - var receivedAuth atomic.Value - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedAuth.Store(r.Header.Get("Authorization")) - conn, err := testUpgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - for { - _, _, err := conn.ReadMessage() - if err != nil { - break - } - } - })) - defer ts.Close() - - hub := ws.NewHub() - ctx := t.Context() - go hub.Run(ctx) - - cfg := DefaultConfig() - cfg.LaravelWSURL = wsURL(ts) - cfg.ReconnectInterval = 100 * time.Millisecond - // Token intentionally left empty - - bridge := NewBridge(hub, cfg) - bridge.Start(ctx) - - waitConnected(t, bridge, 2*time.Second) - - auth, _ := receivedAuth.Load().(string) - if auth != "" { - t.Errorf("expected no Authorization header when token is empty, got %q", auth) - } -} - -func TestBridge_Good_WithTokenOption(t *testing.T) { - // Verify the WithToken option function works. - cfg := DefaultConfig() - opt := WithToken("my-token") - opt(&cfg) - - if cfg.Token != "my-token" { - t.Errorf("expected token 'my-token', got %q", cfg.Token) - } -} - -func TestSubsystem_Good_Name(t *testing.T) { - sub := New(nil) - if sub.Name() != "ide" { - t.Errorf("expected name 'ide', got %q", sub.Name()) - } -} - -func TestSubsystem_Good_NilHub(t *testing.T) { - sub := New(nil) - if sub.Bridge() != nil { - t.Error("expected nil bridge when hub is nil") - } - // Shutdown should not panic - if err := sub.Shutdown(context.Background()); err != nil { - t.Errorf("Shutdown with nil bridge failed: %v", err) - } -} diff --git a/mcp/ide/config.go b/mcp/ide/config.go deleted file mode 100644 index ff64419..0000000 --- a/mcp/ide/config.go +++ /dev/null @@ -1,57 +0,0 @@ -// Package ide provides an MCP subsystem that bridges the desktop IDE to -// a Laravel core-agentic backend over WebSocket. -package ide - -import "time" - -// Config holds connection and workspace settings for the IDE subsystem. -type Config struct { - // LaravelWSURL is the WebSocket endpoint for the Laravel core-agentic backend. - LaravelWSURL string - - // WorkspaceRoot is the local path used as the default workspace context. - WorkspaceRoot string - - // Token is the Bearer token sent in the Authorization header during - // WebSocket upgrade. When empty, no auth header is sent. - Token string - - // ReconnectInterval controls how long to wait between reconnect attempts. - ReconnectInterval time.Duration - - // MaxReconnectInterval caps exponential backoff for reconnection. - MaxReconnectInterval time.Duration -} - -// DefaultConfig returns sensible defaults for local development. -func DefaultConfig() Config { - return Config{ - LaravelWSURL: "ws://localhost:9876/ws", - WorkspaceRoot: ".", - ReconnectInterval: 2 * time.Second, - MaxReconnectInterval: 30 * time.Second, - } -} - -// Option configures the IDE subsystem. -type Option func(*Config) - -// WithLaravelURL sets the Laravel WebSocket endpoint. -func WithLaravelURL(url string) Option { - return func(c *Config) { c.LaravelWSURL = url } -} - -// WithWorkspaceRoot sets the workspace root directory. -func WithWorkspaceRoot(root string) Option { - return func(c *Config) { c.WorkspaceRoot = root } -} - -// WithReconnectInterval sets the base reconnect interval. -func WithReconnectInterval(d time.Duration) Option { - return func(c *Config) { c.ReconnectInterval = d } -} - -// WithToken sets the Bearer token for WebSocket authentication. -func WithToken(token string) Option { - return func(c *Config) { c.Token = token } -} diff --git a/mcp/ide/ide.go b/mcp/ide/ide.go deleted file mode 100644 index ba3a833..0000000 --- a/mcp/ide/ide.go +++ /dev/null @@ -1,62 +0,0 @@ -package ide - -import ( - "context" - "errors" - - "forge.lthn.ai/core/go-ws" - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// errBridgeNotAvailable is returned when a tool requires the Laravel bridge -// but it has not been initialised (headless mode). -var errBridgeNotAvailable = errors.New("bridge not available") - -// Subsystem implements mcp.Subsystem and mcp.SubsystemWithShutdown for the IDE. -type Subsystem struct { - cfg Config - bridge *Bridge - hub *ws.Hub -} - -// New creates an IDE subsystem. The ws.Hub is used for real-time forwarding; -// pass nil if headless (tools still work but real-time streaming is disabled). -func New(hub *ws.Hub, opts ...Option) *Subsystem { - cfg := DefaultConfig() - for _, opt := range opts { - opt(&cfg) - } - var bridge *Bridge - if hub != nil { - bridge = NewBridge(hub, cfg) - } - return &Subsystem{cfg: cfg, bridge: bridge, hub: hub} -} - -// Name implements mcp.Subsystem. -func (s *Subsystem) Name() string { return "ide" } - -// RegisterTools implements mcp.Subsystem. -func (s *Subsystem) RegisterTools(server *mcp.Server) { - s.registerChatTools(server) - s.registerBuildTools(server) - s.registerDashboardTools(server) -} - -// Shutdown implements mcp.SubsystemWithShutdown. -func (s *Subsystem) Shutdown(_ context.Context) error { - if s.bridge != nil { - s.bridge.Shutdown() - } - return nil -} - -// Bridge returns the Laravel WebSocket bridge (may be nil in headless mode). -func (s *Subsystem) Bridge() *Bridge { return s.bridge } - -// StartBridge begins the background connection to the Laravel backend. -func (s *Subsystem) StartBridge(ctx context.Context) { - if s.bridge != nil { - s.bridge.Start(ctx) - } -} diff --git a/mcp/ide/tools_build.go b/mcp/ide/tools_build.go deleted file mode 100644 index 57a6a86..0000000 --- a/mcp/ide/tools_build.go +++ /dev/null @@ -1,114 +0,0 @@ -package ide - -import ( - "context" - "time" - - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// Build tool input/output types. - -// BuildStatusInput is the input for ide_build_status. -type BuildStatusInput struct { - BuildID string `json:"buildId"` -} - -// BuildInfo represents a single build. -type BuildInfo struct { - ID string `json:"id"` - Repo string `json:"repo"` - Branch string `json:"branch"` - Status string `json:"status"` - Duration string `json:"duration,omitempty"` - StartedAt time.Time `json:"startedAt"` -} - -// BuildStatusOutput is the output for ide_build_status. -type BuildStatusOutput struct { - Build BuildInfo `json:"build"` -} - -// BuildListInput is the input for ide_build_list. -type BuildListInput struct { - Repo string `json:"repo,omitempty"` - Limit int `json:"limit,omitempty"` -} - -// BuildListOutput is the output for ide_build_list. -type BuildListOutput struct { - Builds []BuildInfo `json:"builds"` -} - -// BuildLogsInput is the input for ide_build_logs. -type BuildLogsInput struct { - BuildID string `json:"buildId"` - Tail int `json:"tail,omitempty"` -} - -// BuildLogsOutput is the output for ide_build_logs. -type BuildLogsOutput struct { - BuildID string `json:"buildId"` - Lines []string `json:"lines"` -} - -func (s *Subsystem) registerBuildTools(server *mcp.Server) { - mcp.AddTool(server, &mcp.Tool{ - Name: "ide_build_status", - Description: "Get the status of a specific build", - }, s.buildStatus) - - mcp.AddTool(server, &mcp.Tool{ - Name: "ide_build_list", - Description: "List recent builds, optionally filtered by repository", - }, s.buildList) - - mcp.AddTool(server, &mcp.Tool{ - Name: "ide_build_logs", - Description: "Retrieve log output for a build", - }, s.buildLogs) -} - -// buildStatus requests build status from the Laravel backend. -// Stub implementation: sends request via bridge, returns "unknown" status. Awaiting Laravel backend. -func (s *Subsystem) buildStatus(_ context.Context, _ *mcp.CallToolRequest, input BuildStatusInput) (*mcp.CallToolResult, BuildStatusOutput, error) { - if s.bridge == nil { - return nil, BuildStatusOutput{}, errBridgeNotAvailable - } - _ = s.bridge.Send(BridgeMessage{ - Type: "build_status", - Data: map[string]any{"buildId": input.BuildID}, - }) - return nil, BuildStatusOutput{ - Build: BuildInfo{ID: input.BuildID, Status: "unknown"}, - }, nil -} - -// buildList requests a list of builds from the Laravel backend. -// Stub implementation: sends request via bridge, returns empty list. Awaiting Laravel backend. -func (s *Subsystem) buildList(_ context.Context, _ *mcp.CallToolRequest, input BuildListInput) (*mcp.CallToolResult, BuildListOutput, error) { - if s.bridge == nil { - return nil, BuildListOutput{}, errBridgeNotAvailable - } - _ = s.bridge.Send(BridgeMessage{ - Type: "build_list", - Data: map[string]any{"repo": input.Repo, "limit": input.Limit}, - }) - return nil, BuildListOutput{Builds: []BuildInfo{}}, nil -} - -// buildLogs requests build log output from the Laravel backend. -// Stub implementation: sends request via bridge, returns empty lines. Awaiting Laravel backend. -func (s *Subsystem) buildLogs(_ context.Context, _ *mcp.CallToolRequest, input BuildLogsInput) (*mcp.CallToolResult, BuildLogsOutput, error) { - if s.bridge == nil { - return nil, BuildLogsOutput{}, errBridgeNotAvailable - } - _ = s.bridge.Send(BridgeMessage{ - Type: "build_logs", - Data: map[string]any{"buildId": input.BuildID, "tail": input.Tail}, - }) - return nil, BuildLogsOutput{ - BuildID: input.BuildID, - Lines: []string{}, - }, nil -} diff --git a/mcp/ide/tools_chat.go b/mcp/ide/tools_chat.go deleted file mode 100644 index bbdc6b0..0000000 --- a/mcp/ide/tools_chat.go +++ /dev/null @@ -1,201 +0,0 @@ -package ide - -import ( - "context" - "fmt" - "time" - - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// Chat tool input/output types. - -// ChatSendInput is the input for ide_chat_send. -type ChatSendInput struct { - SessionID string `json:"sessionId"` - Message string `json:"message"` -} - -// ChatSendOutput is the output for ide_chat_send. -type ChatSendOutput struct { - Sent bool `json:"sent"` - SessionID string `json:"sessionId"` - Timestamp time.Time `json:"timestamp"` -} - -// ChatHistoryInput is the input for ide_chat_history. -type ChatHistoryInput struct { - SessionID string `json:"sessionId"` - Limit int `json:"limit,omitempty"` -} - -// ChatMessage represents a single message in history. -type ChatMessage struct { - Role string `json:"role"` - Content string `json:"content"` - Timestamp time.Time `json:"timestamp"` -} - -// ChatHistoryOutput is the output for ide_chat_history. -type ChatHistoryOutput struct { - SessionID string `json:"sessionId"` - Messages []ChatMessage `json:"messages"` -} - -// SessionListInput is the input for ide_session_list. -type SessionListInput struct{} - -// Session represents an agent session. -type Session struct { - ID string `json:"id"` - Name string `json:"name"` - Status string `json:"status"` - CreatedAt time.Time `json:"createdAt"` -} - -// SessionListOutput is the output for ide_session_list. -type SessionListOutput struct { - Sessions []Session `json:"sessions"` -} - -// SessionCreateInput is the input for ide_session_create. -type SessionCreateInput struct { - Name string `json:"name"` -} - -// SessionCreateOutput is the output for ide_session_create. -type SessionCreateOutput struct { - Session Session `json:"session"` -} - -// PlanStatusInput is the input for ide_plan_status. -type PlanStatusInput struct { - SessionID string `json:"sessionId"` -} - -// PlanStep is a single step in an agent plan. -type PlanStep struct { - Name string `json:"name"` - Status string `json:"status"` -} - -// PlanStatusOutput is the output for ide_plan_status. -type PlanStatusOutput struct { - SessionID string `json:"sessionId"` - Status string `json:"status"` - Steps []PlanStep `json:"steps"` -} - -func (s *Subsystem) registerChatTools(server *mcp.Server) { - mcp.AddTool(server, &mcp.Tool{ - Name: "ide_chat_send", - Description: "Send a message to an agent chat session", - }, s.chatSend) - - mcp.AddTool(server, &mcp.Tool{ - Name: "ide_chat_history", - Description: "Retrieve message history for a chat session", - }, s.chatHistory) - - mcp.AddTool(server, &mcp.Tool{ - Name: "ide_session_list", - Description: "List active agent sessions", - }, s.sessionList) - - mcp.AddTool(server, &mcp.Tool{ - Name: "ide_session_create", - Description: "Create a new agent session", - }, s.sessionCreate) - - mcp.AddTool(server, &mcp.Tool{ - Name: "ide_plan_status", - Description: "Get the current plan status for a session", - }, s.planStatus) -} - -// chatSend forwards a chat message to the Laravel backend via bridge. -// Stub implementation: delegates to bridge, real response arrives via WebSocket subscription. -func (s *Subsystem) chatSend(_ context.Context, _ *mcp.CallToolRequest, input ChatSendInput) (*mcp.CallToolResult, ChatSendOutput, error) { - if s.bridge == nil { - return nil, ChatSendOutput{}, errBridgeNotAvailable - } - err := s.bridge.Send(BridgeMessage{ - Type: "chat_send", - Channel: "chat:" + input.SessionID, - SessionID: input.SessionID, - Data: input.Message, - }) - if err != nil { - return nil, ChatSendOutput{}, fmt.Errorf("failed to send message: %w", err) - } - return nil, ChatSendOutput{ - Sent: true, - SessionID: input.SessionID, - Timestamp: time.Now(), - }, nil -} - -// chatHistory requests message history from the Laravel backend. -// Stub implementation: sends request via bridge, returns empty messages. Real data arrives via WebSocket. -func (s *Subsystem) chatHistory(_ context.Context, _ *mcp.CallToolRequest, input ChatHistoryInput) (*mcp.CallToolResult, ChatHistoryOutput, error) { - if s.bridge == nil { - return nil, ChatHistoryOutput{}, errBridgeNotAvailable - } - // Request history via bridge; for now return placeholder indicating the - // request was forwarded. Real data arrives via WebSocket subscription. - _ = s.bridge.Send(BridgeMessage{ - Type: "chat_history", - SessionID: input.SessionID, - Data: map[string]any{"limit": input.Limit}, - }) - return nil, ChatHistoryOutput{ - SessionID: input.SessionID, - Messages: []ChatMessage{}, - }, nil -} - -// sessionList requests the session list from the Laravel backend. -// Stub implementation: sends request via bridge, returns empty sessions. Awaiting Laravel backend. -func (s *Subsystem) sessionList(_ context.Context, _ *mcp.CallToolRequest, _ SessionListInput) (*mcp.CallToolResult, SessionListOutput, error) { - if s.bridge == nil { - return nil, SessionListOutput{}, errBridgeNotAvailable - } - _ = s.bridge.Send(BridgeMessage{Type: "session_list"}) - return nil, SessionListOutput{Sessions: []Session{}}, nil -} - -// sessionCreate requests a new session from the Laravel backend. -// Stub implementation: sends request via bridge, returns placeholder session. Awaiting Laravel backend. -func (s *Subsystem) sessionCreate(_ context.Context, _ *mcp.CallToolRequest, input SessionCreateInput) (*mcp.CallToolResult, SessionCreateOutput, error) { - if s.bridge == nil { - return nil, SessionCreateOutput{}, errBridgeNotAvailable - } - _ = s.bridge.Send(BridgeMessage{ - Type: "session_create", - Data: map[string]any{"name": input.Name}, - }) - return nil, SessionCreateOutput{ - Session: Session{ - Name: input.Name, - Status: "creating", - CreatedAt: time.Now(), - }, - }, nil -} - -// planStatus requests plan status from the Laravel backend. -// Stub implementation: sends request via bridge, returns "unknown" status. Awaiting Laravel backend. -func (s *Subsystem) planStatus(_ context.Context, _ *mcp.CallToolRequest, input PlanStatusInput) (*mcp.CallToolResult, PlanStatusOutput, error) { - if s.bridge == nil { - return nil, PlanStatusOutput{}, errBridgeNotAvailable - } - _ = s.bridge.Send(BridgeMessage{ - Type: "plan_status", - SessionID: input.SessionID, - }) - return nil, PlanStatusOutput{ - SessionID: input.SessionID, - Status: "unknown", - Steps: []PlanStep{}, - }, nil -} diff --git a/mcp/ide/tools_dashboard.go b/mcp/ide/tools_dashboard.go deleted file mode 100644 index 6b660bf..0000000 --- a/mcp/ide/tools_dashboard.go +++ /dev/null @@ -1,132 +0,0 @@ -package ide - -import ( - "context" - "time" - - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// Dashboard tool input/output types. - -// DashboardOverviewInput is the input for ide_dashboard_overview. -type DashboardOverviewInput struct{} - -// DashboardOverview contains high-level platform stats. -type DashboardOverview struct { - Repos int `json:"repos"` - Services int `json:"services"` - ActiveSessions int `json:"activeSessions"` - RecentBuilds int `json:"recentBuilds"` - BridgeOnline bool `json:"bridgeOnline"` -} - -// DashboardOverviewOutput is the output for ide_dashboard_overview. -type DashboardOverviewOutput struct { - Overview DashboardOverview `json:"overview"` -} - -// DashboardActivityInput is the input for ide_dashboard_activity. -type DashboardActivityInput struct { - Limit int `json:"limit,omitempty"` -} - -// ActivityEvent represents a single activity feed item. -type ActivityEvent struct { - Type string `json:"type"` - Message string `json:"message"` - Timestamp time.Time `json:"timestamp"` -} - -// DashboardActivityOutput is the output for ide_dashboard_activity. -type DashboardActivityOutput struct { - Events []ActivityEvent `json:"events"` -} - -// DashboardMetricsInput is the input for ide_dashboard_metrics. -type DashboardMetricsInput struct { - Period string `json:"period,omitempty"` // "1h", "24h", "7d" -} - -// DashboardMetrics contains aggregate metrics. -type DashboardMetrics struct { - BuildsTotal int `json:"buildsTotal"` - BuildsSuccess int `json:"buildsSuccess"` - BuildsFailed int `json:"buildsFailed"` - AvgBuildTime string `json:"avgBuildTime"` - AgentSessions int `json:"agentSessions"` - MessagesTotal int `json:"messagesTotal"` - SuccessRate float64 `json:"successRate"` -} - -// DashboardMetricsOutput is the output for ide_dashboard_metrics. -type DashboardMetricsOutput struct { - Period string `json:"period"` - Metrics DashboardMetrics `json:"metrics"` -} - -func (s *Subsystem) registerDashboardTools(server *mcp.Server) { - mcp.AddTool(server, &mcp.Tool{ - Name: "ide_dashboard_overview", - Description: "Get a high-level overview of the platform (repos, services, sessions, builds)", - }, s.dashboardOverview) - - mcp.AddTool(server, &mcp.Tool{ - Name: "ide_dashboard_activity", - Description: "Get the recent activity feed", - }, s.dashboardActivity) - - mcp.AddTool(server, &mcp.Tool{ - Name: "ide_dashboard_metrics", - Description: "Get aggregate build and agent metrics for a time period", - }, s.dashboardMetrics) -} - -// dashboardOverview returns a platform overview with bridge status. -// Stub implementation: only BridgeOnline is live; other fields return zero values. Awaiting Laravel backend. -func (s *Subsystem) dashboardOverview(_ context.Context, _ *mcp.CallToolRequest, _ DashboardOverviewInput) (*mcp.CallToolResult, DashboardOverviewOutput, error) { - connected := s.bridge != nil && s.bridge.Connected() - - if s.bridge != nil { - _ = s.bridge.Send(BridgeMessage{Type: "dashboard_overview"}) - } - - return nil, DashboardOverviewOutput{ - Overview: DashboardOverview{ - BridgeOnline: connected, - }, - }, nil -} - -// dashboardActivity requests the activity feed from the Laravel backend. -// Stub implementation: sends request via bridge, returns empty events. Awaiting Laravel backend. -func (s *Subsystem) dashboardActivity(_ context.Context, _ *mcp.CallToolRequest, input DashboardActivityInput) (*mcp.CallToolResult, DashboardActivityOutput, error) { - if s.bridge == nil { - return nil, DashboardActivityOutput{}, errBridgeNotAvailable - } - _ = s.bridge.Send(BridgeMessage{ - Type: "dashboard_activity", - Data: map[string]any{"limit": input.Limit}, - }) - return nil, DashboardActivityOutput{Events: []ActivityEvent{}}, nil -} - -// dashboardMetrics requests aggregate metrics from the Laravel backend. -// Stub implementation: sends request via bridge, returns zero metrics. Awaiting Laravel backend. -func (s *Subsystem) dashboardMetrics(_ context.Context, _ *mcp.CallToolRequest, input DashboardMetricsInput) (*mcp.CallToolResult, DashboardMetricsOutput, error) { - if s.bridge == nil { - return nil, DashboardMetricsOutput{}, errBridgeNotAvailable - } - period := input.Period - if period == "" { - period = "24h" - } - _ = s.bridge.Send(BridgeMessage{ - Type: "dashboard_metrics", - Data: map[string]any{"period": period}, - }) - return nil, DashboardMetricsOutput{ - Period: period, - Metrics: DashboardMetrics{}, - }, nil -} diff --git a/mcp/ide/tools_test.go b/mcp/ide/tools_test.go deleted file mode 100644 index 21a01fa..0000000 --- a/mcp/ide/tools_test.go +++ /dev/null @@ -1,781 +0,0 @@ -package ide - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - "time" - - "forge.lthn.ai/core/go-ws" -) - -// --- Helpers --- - -// newNilBridgeSubsystem returns a Subsystem with no hub/bridge (headless mode). -func newNilBridgeSubsystem() *Subsystem { - return New(nil) -} - -// newConnectedSubsystem returns a Subsystem with a connected bridge and a -// running echo WS server. Caller must cancel ctx and close server when done. -func newConnectedSubsystem(t *testing.T) (*Subsystem, context.CancelFunc, *httptest.Server) { - t.Helper() - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := testUpgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - for { - mt, data, err := conn.ReadMessage() - if err != nil { - break - } - _ = conn.WriteMessage(mt, data) - } - })) - - hub := ws.NewHub() - ctx, cancel := context.WithCancel(context.Background()) - go hub.Run(ctx) - - sub := New(hub, - WithLaravelURL(wsURL(ts)), - WithReconnectInterval(50*time.Millisecond), - ) - sub.StartBridge(ctx) - - waitConnected(t, sub.Bridge(), 2*time.Second) - return sub, cancel, ts -} - -// --- 4.3: Chat tool tests --- - -// TestChatSend_Bad_NilBridge verifies chatSend returns error without a bridge. -func TestChatSend_Bad_NilBridge(t *testing.T) { - sub := newNilBridgeSubsystem() - _, _, err := sub.chatSend(context.Background(), nil, ChatSendInput{ - SessionID: "s1", - Message: "hello", - }) - if err == nil { - t.Error("expected error when bridge is nil") - } -} - -// TestChatSend_Good_Connected verifies chatSend succeeds with a connected bridge. -func TestChatSend_Good_Connected(t *testing.T) { - sub, cancel, ts := newConnectedSubsystem(t) - defer cancel() - defer ts.Close() - - _, out, err := sub.chatSend(context.Background(), nil, ChatSendInput{ - SessionID: "sess-42", - Message: "hello", - }) - if err != nil { - t.Fatalf("chatSend failed: %v", err) - } - if !out.Sent { - t.Error("expected Sent=true") - } - if out.SessionID != "sess-42" { - t.Errorf("expected sessionId 'sess-42', got %q", out.SessionID) - } - if out.Timestamp.IsZero() { - t.Error("expected non-zero timestamp") - } -} - -// TestChatHistory_Bad_NilBridge verifies chatHistory returns error without a bridge. -func TestChatHistory_Bad_NilBridge(t *testing.T) { - sub := newNilBridgeSubsystem() - _, _, err := sub.chatHistory(context.Background(), nil, ChatHistoryInput{ - SessionID: "s1", - }) - if err == nil { - t.Error("expected error when bridge is nil") - } -} - -// TestChatHistory_Good_Connected verifies chatHistory succeeds and returns empty messages. -func TestChatHistory_Good_Connected(t *testing.T) { - sub, cancel, ts := newConnectedSubsystem(t) - defer cancel() - defer ts.Close() - - _, out, err := sub.chatHistory(context.Background(), nil, ChatHistoryInput{ - SessionID: "sess-1", - Limit: 50, - }) - if err != nil { - t.Fatalf("chatHistory failed: %v", err) - } - if out.SessionID != "sess-1" { - t.Errorf("expected sessionId 'sess-1', got %q", out.SessionID) - } - if out.Messages == nil { - t.Error("expected non-nil messages slice") - } - if len(out.Messages) != 0 { - t.Errorf("expected 0 messages (stub), got %d", len(out.Messages)) - } -} - -// TestSessionList_Bad_NilBridge verifies sessionList returns error without a bridge. -func TestSessionList_Bad_NilBridge(t *testing.T) { - sub := newNilBridgeSubsystem() - _, _, err := sub.sessionList(context.Background(), nil, SessionListInput{}) - if err == nil { - t.Error("expected error when bridge is nil") - } -} - -// TestSessionList_Good_Connected verifies sessionList returns empty sessions. -func TestSessionList_Good_Connected(t *testing.T) { - sub, cancel, ts := newConnectedSubsystem(t) - defer cancel() - defer ts.Close() - - _, out, err := sub.sessionList(context.Background(), nil, SessionListInput{}) - if err != nil { - t.Fatalf("sessionList failed: %v", err) - } - if out.Sessions == nil { - t.Error("expected non-nil sessions slice") - } - if len(out.Sessions) != 0 { - t.Errorf("expected 0 sessions (stub), got %d", len(out.Sessions)) - } -} - -// TestSessionCreate_Bad_NilBridge verifies sessionCreate returns error without a bridge. -func TestSessionCreate_Bad_NilBridge(t *testing.T) { - sub := newNilBridgeSubsystem() - _, _, err := sub.sessionCreate(context.Background(), nil, SessionCreateInput{ - Name: "test", - }) - if err == nil { - t.Error("expected error when bridge is nil") - } -} - -// TestSessionCreate_Good_Connected verifies sessionCreate returns a session stub. -func TestSessionCreate_Good_Connected(t *testing.T) { - sub, cancel, ts := newConnectedSubsystem(t) - defer cancel() - defer ts.Close() - - _, out, err := sub.sessionCreate(context.Background(), nil, SessionCreateInput{ - Name: "my-session", - }) - if err != nil { - t.Fatalf("sessionCreate failed: %v", err) - } - if out.Session.Name != "my-session" { - t.Errorf("expected name 'my-session', got %q", out.Session.Name) - } - if out.Session.Status != "creating" { - t.Errorf("expected status 'creating', got %q", out.Session.Status) - } - if out.Session.CreatedAt.IsZero() { - t.Error("expected non-zero CreatedAt") - } -} - -// TestPlanStatus_Bad_NilBridge verifies planStatus returns error without a bridge. -func TestPlanStatus_Bad_NilBridge(t *testing.T) { - sub := newNilBridgeSubsystem() - _, _, err := sub.planStatus(context.Background(), nil, PlanStatusInput{ - SessionID: "s1", - }) - if err == nil { - t.Error("expected error when bridge is nil") - } -} - -// TestPlanStatus_Good_Connected verifies planStatus returns a stub status. -func TestPlanStatus_Good_Connected(t *testing.T) { - sub, cancel, ts := newConnectedSubsystem(t) - defer cancel() - defer ts.Close() - - _, out, err := sub.planStatus(context.Background(), nil, PlanStatusInput{ - SessionID: "sess-7", - }) - if err != nil { - t.Fatalf("planStatus failed: %v", err) - } - if out.SessionID != "sess-7" { - t.Errorf("expected sessionId 'sess-7', got %q", out.SessionID) - } - if out.Status != "unknown" { - t.Errorf("expected status 'unknown', got %q", out.Status) - } - if out.Steps == nil { - t.Error("expected non-nil steps slice") - } -} - -// --- 4.3: Build tool tests --- - -// TestBuildStatus_Bad_NilBridge verifies buildStatus returns error without a bridge. -func TestBuildStatus_Bad_NilBridge(t *testing.T) { - sub := newNilBridgeSubsystem() - _, _, err := sub.buildStatus(context.Background(), nil, BuildStatusInput{ - BuildID: "b1", - }) - if err == nil { - t.Error("expected error when bridge is nil") - } -} - -// TestBuildStatus_Good_Connected verifies buildStatus returns a stub. -func TestBuildStatus_Good_Connected(t *testing.T) { - sub, cancel, ts := newConnectedSubsystem(t) - defer cancel() - defer ts.Close() - - _, out, err := sub.buildStatus(context.Background(), nil, BuildStatusInput{ - BuildID: "build-99", - }) - if err != nil { - t.Fatalf("buildStatus failed: %v", err) - } - if out.Build.ID != "build-99" { - t.Errorf("expected build ID 'build-99', got %q", out.Build.ID) - } - if out.Build.Status != "unknown" { - t.Errorf("expected status 'unknown', got %q", out.Build.Status) - } -} - -// TestBuildList_Bad_NilBridge verifies buildList returns error without a bridge. -func TestBuildList_Bad_NilBridge(t *testing.T) { - sub := newNilBridgeSubsystem() - _, _, err := sub.buildList(context.Background(), nil, BuildListInput{ - Repo: "core-php", - Limit: 10, - }) - if err == nil { - t.Error("expected error when bridge is nil") - } -} - -// TestBuildList_Good_Connected verifies buildList returns an empty list. -func TestBuildList_Good_Connected(t *testing.T) { - sub, cancel, ts := newConnectedSubsystem(t) - defer cancel() - defer ts.Close() - - _, out, err := sub.buildList(context.Background(), nil, BuildListInput{ - Repo: "core-php", - Limit: 10, - }) - if err != nil { - t.Fatalf("buildList failed: %v", err) - } - if out.Builds == nil { - t.Error("expected non-nil builds slice") - } - if len(out.Builds) != 0 { - t.Errorf("expected 0 builds (stub), got %d", len(out.Builds)) - } -} - -// TestBuildLogs_Bad_NilBridge verifies buildLogs returns error without a bridge. -func TestBuildLogs_Bad_NilBridge(t *testing.T) { - sub := newNilBridgeSubsystem() - _, _, err := sub.buildLogs(context.Background(), nil, BuildLogsInput{ - BuildID: "b1", - Tail: 100, - }) - if err == nil { - t.Error("expected error when bridge is nil") - } -} - -// TestBuildLogs_Good_Connected verifies buildLogs returns empty lines. -func TestBuildLogs_Good_Connected(t *testing.T) { - sub, cancel, ts := newConnectedSubsystem(t) - defer cancel() - defer ts.Close() - - _, out, err := sub.buildLogs(context.Background(), nil, BuildLogsInput{ - BuildID: "build-55", - Tail: 50, - }) - if err != nil { - t.Fatalf("buildLogs failed: %v", err) - } - if out.BuildID != "build-55" { - t.Errorf("expected buildId 'build-55', got %q", out.BuildID) - } - if out.Lines == nil { - t.Error("expected non-nil lines slice") - } - if len(out.Lines) != 0 { - t.Errorf("expected 0 lines (stub), got %d", len(out.Lines)) - } -} - -// --- 4.3: Dashboard tool tests --- - -// TestDashboardOverview_Good_NilBridge verifies dashboardOverview works without bridge -// (it does not return error — it reports BridgeOnline=false). -func TestDashboardOverview_Good_NilBridge(t *testing.T) { - sub := newNilBridgeSubsystem() - _, out, err := sub.dashboardOverview(context.Background(), nil, DashboardOverviewInput{}) - if err != nil { - t.Fatalf("dashboardOverview failed: %v", err) - } - if out.Overview.BridgeOnline { - t.Error("expected BridgeOnline=false when bridge is nil") - } -} - -// TestDashboardOverview_Good_Connected verifies dashboardOverview reports bridge online. -func TestDashboardOverview_Good_Connected(t *testing.T) { - sub, cancel, ts := newConnectedSubsystem(t) - defer cancel() - defer ts.Close() - - _, out, err := sub.dashboardOverview(context.Background(), nil, DashboardOverviewInput{}) - if err != nil { - t.Fatalf("dashboardOverview failed: %v", err) - } - if !out.Overview.BridgeOnline { - t.Error("expected BridgeOnline=true when bridge is connected") - } -} - -// TestDashboardActivity_Bad_NilBridge verifies dashboardActivity returns error without bridge. -func TestDashboardActivity_Bad_NilBridge(t *testing.T) { - sub := newNilBridgeSubsystem() - _, _, err := sub.dashboardActivity(context.Background(), nil, DashboardActivityInput{ - Limit: 10, - }) - if err == nil { - t.Error("expected error when bridge is nil") - } -} - -// TestDashboardActivity_Good_Connected verifies dashboardActivity returns empty events. -func TestDashboardActivity_Good_Connected(t *testing.T) { - sub, cancel, ts := newConnectedSubsystem(t) - defer cancel() - defer ts.Close() - - _, out, err := sub.dashboardActivity(context.Background(), nil, DashboardActivityInput{ - Limit: 20, - }) - if err != nil { - t.Fatalf("dashboardActivity failed: %v", err) - } - if out.Events == nil { - t.Error("expected non-nil events slice") - } - if len(out.Events) != 0 { - t.Errorf("expected 0 events (stub), got %d", len(out.Events)) - } -} - -// TestDashboardMetrics_Bad_NilBridge verifies dashboardMetrics returns error without bridge. -func TestDashboardMetrics_Bad_NilBridge(t *testing.T) { - sub := newNilBridgeSubsystem() - _, _, err := sub.dashboardMetrics(context.Background(), nil, DashboardMetricsInput{ - Period: "1h", - }) - if err == nil { - t.Error("expected error when bridge is nil") - } -} - -// TestDashboardMetrics_Good_Connected verifies dashboardMetrics returns empty metrics. -func TestDashboardMetrics_Good_Connected(t *testing.T) { - sub, cancel, ts := newConnectedSubsystem(t) - defer cancel() - defer ts.Close() - - _, out, err := sub.dashboardMetrics(context.Background(), nil, DashboardMetricsInput{ - Period: "7d", - }) - if err != nil { - t.Fatalf("dashboardMetrics failed: %v", err) - } - if out.Period != "7d" { - t.Errorf("expected period '7d', got %q", out.Period) - } -} - -// TestDashboardMetrics_Good_DefaultPeriod verifies the default period is "24h". -func TestDashboardMetrics_Good_DefaultPeriod(t *testing.T) { - sub, cancel, ts := newConnectedSubsystem(t) - defer cancel() - defer ts.Close() - - _, out, err := sub.dashboardMetrics(context.Background(), nil, DashboardMetricsInput{}) - if err != nil { - t.Fatalf("dashboardMetrics failed: %v", err) - } - if out.Period != "24h" { - t.Errorf("expected default period '24h', got %q", out.Period) - } -} - -// --- Struct serialisation round-trip tests --- - -// TestChatSendInput_Good_RoundTrip verifies JSON serialisation of ChatSendInput. -func TestChatSendInput_Good_RoundTrip(t *testing.T) { - in := ChatSendInput{SessionID: "s1", Message: "hello"} - data, err := json.Marshal(in) - if err != nil { - t.Fatalf("marshal failed: %v", err) - } - var out ChatSendInput - if err := json.Unmarshal(data, &out); err != nil { - t.Fatalf("unmarshal failed: %v", err) - } - if out != in { - t.Errorf("round-trip mismatch: %+v != %+v", out, in) - } -} - -// TestChatSendOutput_Good_RoundTrip verifies JSON serialisation of ChatSendOutput. -func TestChatSendOutput_Good_RoundTrip(t *testing.T) { - in := ChatSendOutput{Sent: true, SessionID: "s1", Timestamp: time.Now().Truncate(time.Second)} - data, err := json.Marshal(in) - if err != nil { - t.Fatalf("marshal failed: %v", err) - } - var out ChatSendOutput - if err := json.Unmarshal(data, &out); err != nil { - t.Fatalf("unmarshal failed: %v", err) - } - if out.Sent != in.Sent || out.SessionID != in.SessionID { - t.Errorf("round-trip mismatch: %+v != %+v", out, in) - } -} - -// TestChatHistoryOutput_Good_RoundTrip verifies ChatHistoryOutput JSON round-trip. -func TestChatHistoryOutput_Good_RoundTrip(t *testing.T) { - in := ChatHistoryOutput{ - SessionID: "s1", - Messages: []ChatMessage{ - {Role: "user", Content: "hello", Timestamp: time.Now().Truncate(time.Second)}, - {Role: "assistant", Content: "hi", Timestamp: time.Now().Truncate(time.Second)}, - }, - } - data, err := json.Marshal(in) - if err != nil { - t.Fatalf("marshal failed: %v", err) - } - var out ChatHistoryOutput - if err := json.Unmarshal(data, &out); err != nil { - t.Fatalf("unmarshal failed: %v", err) - } - if out.SessionID != in.SessionID { - t.Errorf("sessionId mismatch: %q != %q", out.SessionID, in.SessionID) - } - if len(out.Messages) != 2 { - t.Errorf("expected 2 messages, got %d", len(out.Messages)) - } -} - -// TestSessionListOutput_Good_RoundTrip verifies SessionListOutput JSON round-trip. -func TestSessionListOutput_Good_RoundTrip(t *testing.T) { - in := SessionListOutput{ - Sessions: []Session{ - {ID: "s1", Name: "test", Status: "active", CreatedAt: time.Now().Truncate(time.Second)}, - }, - } - data, err := json.Marshal(in) - if err != nil { - t.Fatalf("marshal failed: %v", err) - } - var out SessionListOutput - if err := json.Unmarshal(data, &out); err != nil { - t.Fatalf("unmarshal failed: %v", err) - } - if len(out.Sessions) != 1 || out.Sessions[0].ID != "s1" { - t.Errorf("round-trip mismatch: %+v", out) - } -} - -// TestPlanStatusOutput_Good_RoundTrip verifies PlanStatusOutput JSON round-trip. -func TestPlanStatusOutput_Good_RoundTrip(t *testing.T) { - in := PlanStatusOutput{ - SessionID: "s1", - Status: "running", - Steps: []PlanStep{{Name: "step1", Status: "done"}, {Name: "step2", Status: "pending"}}, - } - data, err := json.Marshal(in) - if err != nil { - t.Fatalf("marshal failed: %v", err) - } - var out PlanStatusOutput - if err := json.Unmarshal(data, &out); err != nil { - t.Fatalf("unmarshal failed: %v", err) - } - if out.SessionID != "s1" || len(out.Steps) != 2 { - t.Errorf("round-trip mismatch: %+v", out) - } -} - -// TestBuildStatusOutput_Good_RoundTrip verifies BuildStatusOutput JSON round-trip. -func TestBuildStatusOutput_Good_RoundTrip(t *testing.T) { - in := BuildStatusOutput{ - Build: BuildInfo{ - ID: "b1", - Repo: "core-php", - Branch: "main", - Status: "success", - Duration: "2m30s", - StartedAt: time.Now().Truncate(time.Second), - }, - } - data, err := json.Marshal(in) - if err != nil { - t.Fatalf("marshal failed: %v", err) - } - var out BuildStatusOutput - if err := json.Unmarshal(data, &out); err != nil { - t.Fatalf("unmarshal failed: %v", err) - } - if out.Build.ID != "b1" || out.Build.Status != "success" { - t.Errorf("round-trip mismatch: %+v", out) - } -} - -// TestBuildListOutput_Good_RoundTrip verifies BuildListOutput JSON round-trip. -func TestBuildListOutput_Good_RoundTrip(t *testing.T) { - in := BuildListOutput{ - Builds: []BuildInfo{ - {ID: "b1", Repo: "core-php", Branch: "main", Status: "success"}, - {ID: "b2", Repo: "core-admin", Branch: "dev", Status: "failed"}, - }, - } - data, err := json.Marshal(in) - if err != nil { - t.Fatalf("marshal failed: %v", err) - } - var out BuildListOutput - if err := json.Unmarshal(data, &out); err != nil { - t.Fatalf("unmarshal failed: %v", err) - } - if len(out.Builds) != 2 { - t.Errorf("expected 2 builds, got %d", len(out.Builds)) - } -} - -// TestBuildLogsOutput_Good_RoundTrip verifies BuildLogsOutput JSON round-trip. -func TestBuildLogsOutput_Good_RoundTrip(t *testing.T) { - in := BuildLogsOutput{ - BuildID: "b1", - Lines: []string{"line1", "line2", "line3"}, - } - data, err := json.Marshal(in) - if err != nil { - t.Fatalf("marshal failed: %v", err) - } - var out BuildLogsOutput - if err := json.Unmarshal(data, &out); err != nil { - t.Fatalf("unmarshal failed: %v", err) - } - if out.BuildID != "b1" || len(out.Lines) != 3 { - t.Errorf("round-trip mismatch: %+v", out) - } -} - -// TestDashboardOverviewOutput_Good_RoundTrip verifies DashboardOverviewOutput JSON round-trip. -func TestDashboardOverviewOutput_Good_RoundTrip(t *testing.T) { - in := DashboardOverviewOutput{ - Overview: DashboardOverview{ - Repos: 18, - Services: 5, - ActiveSessions: 3, - RecentBuilds: 12, - BridgeOnline: true, - }, - } - data, err := json.Marshal(in) - if err != nil { - t.Fatalf("marshal failed: %v", err) - } - var out DashboardOverviewOutput - if err := json.Unmarshal(data, &out); err != nil { - t.Fatalf("unmarshal failed: %v", err) - } - if out.Overview.Repos != 18 || !out.Overview.BridgeOnline { - t.Errorf("round-trip mismatch: %+v", out) - } -} - -// TestDashboardActivityOutput_Good_RoundTrip verifies DashboardActivityOutput JSON round-trip. -func TestDashboardActivityOutput_Good_RoundTrip(t *testing.T) { - in := DashboardActivityOutput{ - Events: []ActivityEvent{ - {Type: "deploy", Message: "deployed v1.2", Timestamp: time.Now().Truncate(time.Second)}, - }, - } - data, err := json.Marshal(in) - if err != nil { - t.Fatalf("marshal failed: %v", err) - } - var out DashboardActivityOutput - if err := json.Unmarshal(data, &out); err != nil { - t.Fatalf("unmarshal failed: %v", err) - } - if len(out.Events) != 1 || out.Events[0].Type != "deploy" { - t.Errorf("round-trip mismatch: %+v", out) - } -} - -// TestDashboardMetricsOutput_Good_RoundTrip verifies DashboardMetricsOutput JSON round-trip. -func TestDashboardMetricsOutput_Good_RoundTrip(t *testing.T) { - in := DashboardMetricsOutput{ - Period: "24h", - Metrics: DashboardMetrics{ - BuildsTotal: 100, - BuildsSuccess: 90, - BuildsFailed: 10, - AvgBuildTime: "3m", - AgentSessions: 5, - MessagesTotal: 500, - SuccessRate: 0.9, - }, - } - data, err := json.Marshal(in) - if err != nil { - t.Fatalf("marshal failed: %v", err) - } - var out DashboardMetricsOutput - if err := json.Unmarshal(data, &out); err != nil { - t.Fatalf("unmarshal failed: %v", err) - } - if out.Period != "24h" || out.Metrics.BuildsTotal != 100 || out.Metrics.SuccessRate != 0.9 { - t.Errorf("round-trip mismatch: %+v", out) - } -} - -// TestBridgeMessage_Good_RoundTrip verifies BridgeMessage JSON round-trip. -func TestBridgeMessage_Good_RoundTrip(t *testing.T) { - in := BridgeMessage{ - Type: "test", - Channel: "ch1", - SessionID: "s1", - Data: "payload", - Timestamp: time.Now().Truncate(time.Second), - } - data, err := json.Marshal(in) - if err != nil { - t.Fatalf("marshal failed: %v", err) - } - var out BridgeMessage - if err := json.Unmarshal(data, &out); err != nil { - t.Fatalf("unmarshal failed: %v", err) - } - if out.Type != "test" || out.Channel != "ch1" || out.SessionID != "s1" { - t.Errorf("round-trip mismatch: %+v", out) - } -} - -// --- Subsystem integration tests --- - -// TestSubsystem_Good_RegisterTools verifies RegisterTools does not panic. -func TestSubsystem_Good_RegisterTools(t *testing.T) { - // RegisterTools requires a real mcp.Server which is complex to construct - // in isolation. This test verifies the Subsystem can be created and - // the Bridge/Shutdown path works end-to-end. - sub := New(nil) - if sub.Bridge() != nil { - t.Error("expected nil bridge with nil hub") - } - if err := sub.Shutdown(context.Background()); err != nil { - t.Errorf("Shutdown failed: %v", err) - } -} - -// TestSubsystem_Good_StartBridgeNilHub verifies StartBridge is a no-op with nil hub. -func TestSubsystem_Good_StartBridgeNilHub(t *testing.T) { - sub := New(nil) - // Should not panic - sub.StartBridge(context.Background()) -} - -// TestSubsystem_Good_WithOptions verifies all config options apply correctly. -func TestSubsystem_Good_WithOptions(t *testing.T) { - hub := ws.NewHub() - sub := New(hub, - WithLaravelURL("ws://custom:1234/ws"), - WithWorkspaceRoot("/tmp/test"), - WithReconnectInterval(5*time.Second), - WithToken("secret-123"), - ) - - if sub.cfg.LaravelWSURL != "ws://custom:1234/ws" { - t.Errorf("expected custom URL, got %q", sub.cfg.LaravelWSURL) - } - if sub.cfg.WorkspaceRoot != "/tmp/test" { - t.Errorf("expected workspace '/tmp/test', got %q", sub.cfg.WorkspaceRoot) - } - if sub.cfg.ReconnectInterval != 5*time.Second { - t.Errorf("expected 5s reconnect interval, got %v", sub.cfg.ReconnectInterval) - } - if sub.cfg.Token != "secret-123" { - t.Errorf("expected token 'secret-123', got %q", sub.cfg.Token) - } -} - -// --- Tool sends correct bridge message type --- - -// TestChatSend_Good_BridgeMessageType verifies the bridge receives the correct message type. -func TestChatSend_Good_BridgeMessageType(t *testing.T) { - msgCh := make(chan BridgeMessage, 1) - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := testUpgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - _, data, err := conn.ReadMessage() - if err != nil { - return - } - var msg BridgeMessage - json.Unmarshal(data, &msg) - msgCh <- msg - // Keep alive - for { - if _, _, err := conn.ReadMessage(); err != nil { - break - } - } - })) - defer ts.Close() - - hub := ws.NewHub() - ctx := t.Context() - go hub.Run(ctx) - - sub := New(hub, WithLaravelURL(wsURL(ts)), WithReconnectInterval(50*time.Millisecond)) - sub.StartBridge(ctx) - waitConnected(t, sub.Bridge(), 2*time.Second) - - sub.chatSend(ctx, nil, ChatSendInput{SessionID: "s1", Message: "test"}) - - select { - case received := <-msgCh: - if received.Type != "chat_send" { - t.Errorf("expected bridge message type 'chat_send', got %q", received.Type) - } - if received.Channel != "chat:s1" { - t.Errorf("expected channel 'chat:s1', got %q", received.Channel) - } - case <-time.After(2 * time.Second): - t.Fatal("timed out waiting for bridge message") - } -} diff --git a/mcp/integration_test.go b/mcp/integration_test.go deleted file mode 100644 index de35e66..0000000 --- a/mcp/integration_test.go +++ /dev/null @@ -1,121 +0,0 @@ -package mcp - -import ( - "context" - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestIntegration_FileTools(t *testing.T) { - tmpDir := t.TempDir() - s, err := New(WithWorkspaceRoot(tmpDir)) - assert.NoError(t, err) - - ctx := context.Background() - - // 1. Test file_write - writeInput := WriteFileInput{ - Path: "test.txt", - Content: "hello world", - } - _, writeOutput, err := s.writeFile(ctx, nil, writeInput) - assert.NoError(t, err) - assert.True(t, writeOutput.Success) - assert.Equal(t, "test.txt", writeOutput.Path) - - // Verify on disk - content, _ := os.ReadFile(filepath.Join(tmpDir, "test.txt")) - assert.Equal(t, "hello world", string(content)) - - // 2. Test file_read - readInput := ReadFileInput{ - Path: "test.txt", - } - _, readOutput, err := s.readFile(ctx, nil, readInput) - assert.NoError(t, err) - assert.Equal(t, "hello world", readOutput.Content) - assert.Equal(t, "plaintext", readOutput.Language) - - // 3. Test file_edit (replace_all=false) - editInput := EditDiffInput{ - Path: "test.txt", - OldString: "world", - NewString: "mcp", - } - _, editOutput, err := s.editDiff(ctx, nil, editInput) - assert.NoError(t, err) - assert.True(t, editOutput.Success) - assert.Equal(t, 1, editOutput.Replacements) - - // Verify change - _, readOutput, _ = s.readFile(ctx, nil, readInput) - assert.Equal(t, "hello mcp", readOutput.Content) - - // 4. Test file_edit (replace_all=true) - _ = s.medium.Write("multi.txt", "abc abc abc") - editInputMulti := EditDiffInput{ - Path: "multi.txt", - OldString: "abc", - NewString: "xyz", - ReplaceAll: true, - } - _, editOutput, err = s.editDiff(ctx, nil, editInputMulti) - assert.NoError(t, err) - assert.Equal(t, 3, editOutput.Replacements) - - content, _ = os.ReadFile(filepath.Join(tmpDir, "multi.txt")) - assert.Equal(t, "xyz xyz xyz", string(content)) - - // 5. Test dir_list - _ = s.medium.EnsureDir("subdir") - _ = s.medium.Write("subdir/file1.txt", "content1") - - listInput := ListDirectoryInput{ - Path: "subdir", - } - _, listOutput, err := s.listDirectory(ctx, nil, listInput) - assert.NoError(t, err) - assert.Len(t, listOutput.Entries, 1) - assert.Equal(t, "file1.txt", listOutput.Entries[0].Name) - assert.False(t, listOutput.Entries[0].IsDir) -} - -func TestIntegration_ErrorPaths(t *testing.T) { - tmpDir := t.TempDir() - s, err := New(WithWorkspaceRoot(tmpDir)) - assert.NoError(t, err) - - ctx := context.Background() - - // Read nonexistent file - _, _, err = s.readFile(ctx, nil, ReadFileInput{Path: "nonexistent.txt"}) - assert.Error(t, err) - - // Edit nonexistent file - _, _, err = s.editDiff(ctx, nil, EditDiffInput{ - Path: "nonexistent.txt", - OldString: "foo", - NewString: "bar", - }) - assert.Error(t, err) - - // Edit with empty old_string - _, _, err = s.editDiff(ctx, nil, EditDiffInput{ - Path: "test.txt", - OldString: "", - NewString: "bar", - }) - assert.Error(t, err) - - // Edit with old_string not found - _ = s.medium.Write("test.txt", "hello") - _, _, err = s.editDiff(ctx, nil, EditDiffInput{ - Path: "test.txt", - OldString: "missing", - NewString: "bar", - }) - assert.Error(t, err) -} diff --git a/mcp/iter_test.go b/mcp/iter_test.go deleted file mode 100644 index 5c9b274..0000000 --- a/mcp/iter_test.go +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-License-Identifier: EUPL-1.2 - -package mcp - -import ( - "slices" - "testing" -) - -func TestService_Iterators(t *testing.T) { - svc, err := New(WithWorkspaceRoot(t.TempDir())) - if err != nil { - t.Fatal(err) - } - - // Test ToolsSeq - tools := slices.Collect(svc.ToolsSeq()) - if len(tools) == 0 { - t.Error("expected non-empty ToolsSeq") - } - if len(tools) != len(svc.Tools()) { - t.Errorf("ToolsSeq length %d != Tools() length %d", len(tools), len(svc.Tools())) - } - - // Test SubsystemsSeq - subsystems := slices.Collect(svc.SubsystemsSeq()) - if len(subsystems) != len(svc.Subsystems()) { - t.Errorf("SubsystemsSeq length %d != Subsystems() length %d", len(subsystems), len(svc.Subsystems())) - } -} - -func TestRegistry_SplitTagSeq(t *testing.T) { - tag := "name,omitempty,json" - parts := slices.Collect(splitTagSeq(tag)) - expected := []string{"name", "omitempty", "json"} - - if !slices.Equal(parts, expected) { - t.Errorf("expected %v, got %v", expected, parts) - } -} diff --git a/mcp/mcp.go b/mcp/mcp.go deleted file mode 100644 index 7854cf3..0000000 --- a/mcp/mcp.go +++ /dev/null @@ -1,580 +0,0 @@ -// SPDX-License-Identifier: EUPL-1.2 - -// Package mcp provides a lightweight MCP (Model Context Protocol) server for CLI use. -// For full GUI integration (display, webview, process management), see core-gui/pkg/mcp. -package mcp - -import ( - "context" - "errors" - "fmt" - "iter" - "net/http" - "os" - "path/filepath" - "slices" - "strings" - - "forge.lthn.ai/core/go-io" - "forge.lthn.ai/core/go-log" - "forge.lthn.ai/core/go-process" - "forge.lthn.ai/core/go-ws" - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// Service provides a lightweight MCP server with file operations only. -// For full GUI features, use the core-gui package. -type Service struct { - server *mcp.Server - workspaceRoot string // Root directory for file operations (empty = unrestricted) - medium io.Medium // Filesystem medium for sandboxed operations - subsystems []Subsystem // Additional subsystems registered via WithSubsystem - logger *log.Logger // Logger for tool execution auditing - processService *process.Service // Process management service (optional) - wsHub *ws.Hub // WebSocket hub for real-time streaming (optional) - wsServer *http.Server // WebSocket HTTP server (optional) - wsAddr string // WebSocket server address - tools []ToolRecord // Parallel tool registry for REST bridge -} - -// Option configures a Service. -type Option func(*Service) error - -// WithWorkspaceRoot restricts file operations to the given directory. -// All paths are validated to be within this directory. -// An empty string disables the restriction (not recommended). -func WithWorkspaceRoot(root string) Option { - return func(s *Service) error { - if root == "" { - // Explicitly disable restriction - use unsandboxed global - s.workspaceRoot = "" - s.medium = io.Local - return nil - } - // Create sandboxed medium for this workspace - abs, err := filepath.Abs(root) - if err != nil { - return fmt.Errorf("invalid workspace root: %w", err) - } - m, err := io.NewSandboxed(abs) - if err != nil { - return fmt.Errorf("failed to create workspace medium: %w", err) - } - s.workspaceRoot = abs - s.medium = m - return nil - } -} - -// New creates a new MCP service with file operations. -// By default, restricts file access to the current working directory. -// Use WithWorkspaceRoot("") to disable restrictions (not recommended). -// Returns an error if initialization fails. -func New(opts ...Option) (*Service, error) { - impl := &mcp.Implementation{ - Name: "core-cli", - Version: "0.1.0", - } - - server := mcp.NewServer(impl, nil) - s := &Service{ - server: server, - logger: log.Default(), - } - - // Default to current working directory with sandboxed medium - cwd, err := os.Getwd() - if err != nil { - return nil, fmt.Errorf("failed to get working directory: %w", err) - } - s.workspaceRoot = cwd - m, err := io.NewSandboxed(cwd) - if err != nil { - return nil, fmt.Errorf("failed to create sandboxed medium: %w", err) - } - s.medium = m - - // Apply options - for _, opt := range opts { - if err := opt(s); err != nil { - return nil, fmt.Errorf("failed to apply option: %w", err) - } - } - - s.registerTools(s.server) - - // Register subsystem tools. - for _, sub := range s.subsystems { - sub.RegisterTools(s.server) - } - - return s, nil -} - -// Subsystems returns the registered subsystems. -func (s *Service) Subsystems() []Subsystem { - return s.subsystems -} - -// SubsystemsSeq returns an iterator over the registered subsystems. -func (s *Service) SubsystemsSeq() iter.Seq[Subsystem] { - return slices.Values(s.subsystems) -} - -// Tools returns all recorded tool metadata. -func (s *Service) Tools() []ToolRecord { - return s.tools -} - -// ToolsSeq returns an iterator over all recorded tool metadata. -func (s *Service) ToolsSeq() iter.Seq[ToolRecord] { - return slices.Values(s.tools) -} - -// Shutdown gracefully shuts down all subsystems that support it. -func (s *Service) Shutdown(ctx context.Context) error { - for _, sub := range s.subsystems { - if sh, ok := sub.(SubsystemWithShutdown); ok { - if err := sh.Shutdown(ctx); err != nil { - return fmt.Errorf("shutdown %s: %w", sub.Name(), err) - } - } - } - return nil -} - -// WithProcessService configures the process management service. -func WithProcessService(ps *process.Service) Option { - return func(s *Service) error { - s.processService = ps - return nil - } -} - -// WithWSHub configures the WebSocket hub for real-time streaming. -func WithWSHub(hub *ws.Hub) Option { - return func(s *Service) error { - s.wsHub = hub - return nil - } -} - -// WSHub returns the WebSocket hub. -func (s *Service) WSHub() *ws.Hub { - return s.wsHub -} - -// ProcessService returns the process service. -func (s *Service) ProcessService() *process.Service { - return s.processService -} - -// registerTools adds file operation tools to the MCP server. -func (s *Service) registerTools(server *mcp.Server) { - // File operations - addToolRecorded(s, server, "files", &mcp.Tool{ - Name: "file_read", - Description: "Read the contents of a file", - }, s.readFile) - - addToolRecorded(s, server, "files", &mcp.Tool{ - Name: "file_write", - Description: "Write content to a file", - }, s.writeFile) - - addToolRecorded(s, server, "files", &mcp.Tool{ - Name: "file_delete", - Description: "Delete a file or empty directory", - }, s.deleteFile) - - addToolRecorded(s, server, "files", &mcp.Tool{ - Name: "file_rename", - Description: "Rename or move a file", - }, s.renameFile) - - addToolRecorded(s, server, "files", &mcp.Tool{ - Name: "file_exists", - Description: "Check if a file or directory exists", - }, s.fileExists) - - addToolRecorded(s, server, "files", &mcp.Tool{ - Name: "file_edit", - Description: "Edit a file by replacing old_string with new_string. Use replace_all=true to replace all occurrences.", - }, s.editDiff) - - // Directory operations - addToolRecorded(s, server, "files", &mcp.Tool{ - Name: "dir_list", - Description: "List contents of a directory", - }, s.listDirectory) - - addToolRecorded(s, server, "files", &mcp.Tool{ - Name: "dir_create", - Description: "Create a new directory", - }, s.createDirectory) - - // Language detection - addToolRecorded(s, server, "language", &mcp.Tool{ - Name: "lang_detect", - Description: "Detect the programming language of a file", - }, s.detectLanguage) - - addToolRecorded(s, server, "language", &mcp.Tool{ - Name: "lang_list", - Description: "Get list of supported programming languages", - }, s.getSupportedLanguages) -} - -// Tool input/output types for MCP file operations. - -// ReadFileInput contains parameters for reading a file. -type ReadFileInput struct { - Path string `json:"path"` -} - -// ReadFileOutput contains the result of reading a file. -type ReadFileOutput struct { - Content string `json:"content"` - Language string `json:"language"` - Path string `json:"path"` -} - -// WriteFileInput contains parameters for writing a file. -type WriteFileInput struct { - Path string `json:"path"` - Content string `json:"content"` -} - -// WriteFileOutput contains the result of writing a file. -type WriteFileOutput struct { - Success bool `json:"success"` - Path string `json:"path"` -} - -// ListDirectoryInput contains parameters for listing a directory. -type ListDirectoryInput struct { - Path string `json:"path"` -} - -// ListDirectoryOutput contains the result of listing a directory. -type ListDirectoryOutput struct { - Entries []DirectoryEntry `json:"entries"` - Path string `json:"path"` -} - -// DirectoryEntry represents a single entry in a directory listing. -type DirectoryEntry struct { - Name string `json:"name"` - Path string `json:"path"` - IsDir bool `json:"isDir"` - Size int64 `json:"size"` -} - -// CreateDirectoryInput contains parameters for creating a directory. -type CreateDirectoryInput struct { - Path string `json:"path"` -} - -// CreateDirectoryOutput contains the result of creating a directory. -type CreateDirectoryOutput struct { - Success bool `json:"success"` - Path string `json:"path"` -} - -// DeleteFileInput contains parameters for deleting a file. -type DeleteFileInput struct { - Path string `json:"path"` -} - -// DeleteFileOutput contains the result of deleting a file. -type DeleteFileOutput struct { - Success bool `json:"success"` - Path string `json:"path"` -} - -// RenameFileInput contains parameters for renaming a file. -type RenameFileInput struct { - OldPath string `json:"oldPath"` - NewPath string `json:"newPath"` -} - -// RenameFileOutput contains the result of renaming a file. -type RenameFileOutput struct { - Success bool `json:"success"` - OldPath string `json:"oldPath"` - NewPath string `json:"newPath"` -} - -// FileExistsInput contains parameters for checking file existence. -type FileExistsInput struct { - Path string `json:"path"` -} - -// FileExistsOutput contains the result of checking file existence. -type FileExistsOutput struct { - Exists bool `json:"exists"` - IsDir bool `json:"isDir"` - Path string `json:"path"` -} - -// DetectLanguageInput contains parameters for detecting file language. -type DetectLanguageInput struct { - Path string `json:"path"` -} - -// DetectLanguageOutput contains the detected programming language. -type DetectLanguageOutput struct { - Language string `json:"language"` - Path string `json:"path"` -} - -// GetSupportedLanguagesInput is an empty struct for the languages query. -type GetSupportedLanguagesInput struct{} - -// GetSupportedLanguagesOutput contains the list of supported languages. -type GetSupportedLanguagesOutput struct { - Languages []LanguageInfo `json:"languages"` -} - -// LanguageInfo describes a supported programming language. -type LanguageInfo struct { - ID string `json:"id"` - Name string `json:"name"` - Extensions []string `json:"extensions"` -} - -// EditDiffInput contains parameters for editing a file via diff. -type EditDiffInput struct { - Path string `json:"path"` - OldString string `json:"old_string"` - NewString string `json:"new_string"` - ReplaceAll bool `json:"replace_all,omitempty"` -} - -// EditDiffOutput contains the result of a diff-based edit operation. -type EditDiffOutput struct { - Path string `json:"path"` - Success bool `json:"success"` - Replacements int `json:"replacements"` -} - -// Tool handlers - -func (s *Service) readFile(ctx context.Context, req *mcp.CallToolRequest, input ReadFileInput) (*mcp.CallToolResult, ReadFileOutput, error) { - content, err := s.medium.Read(input.Path) - if err != nil { - return nil, ReadFileOutput{}, fmt.Errorf("failed to read file: %w", err) - } - return nil, ReadFileOutput{ - Content: content, - Language: detectLanguageFromPath(input.Path), - Path: input.Path, - }, nil -} - -func (s *Service) writeFile(ctx context.Context, req *mcp.CallToolRequest, input WriteFileInput) (*mcp.CallToolResult, WriteFileOutput, error) { - // Medium.Write creates parent directories automatically - if err := s.medium.Write(input.Path, input.Content); err != nil { - return nil, WriteFileOutput{}, fmt.Errorf("failed to write file: %w", err) - } - return nil, WriteFileOutput{Success: true, Path: input.Path}, nil -} - -func (s *Service) listDirectory(ctx context.Context, req *mcp.CallToolRequest, input ListDirectoryInput) (*mcp.CallToolResult, ListDirectoryOutput, error) { - entries, err := s.medium.List(input.Path) - if err != nil { - return nil, ListDirectoryOutput{}, fmt.Errorf("failed to list directory: %w", err) - } - result := make([]DirectoryEntry, 0, len(entries)) - for _, e := range entries { - info, _ := e.Info() - var size int64 - if info != nil { - size = info.Size() - } - result = append(result, DirectoryEntry{ - Name: e.Name(), - Path: filepath.Join(input.Path, e.Name()), // Note: This might be relative path, client might expect absolute? - // Issue 103 says "Replace ... with local.Medium sandboxing". - // Previous code returned `filepath.Join(input.Path, e.Name())`. - // If input.Path is relative, this preserves it. - IsDir: e.IsDir(), - Size: size, - }) - } - return nil, ListDirectoryOutput{Entries: result, Path: input.Path}, nil -} - -func (s *Service) createDirectory(ctx context.Context, req *mcp.CallToolRequest, input CreateDirectoryInput) (*mcp.CallToolResult, CreateDirectoryOutput, error) { - if err := s.medium.EnsureDir(input.Path); err != nil { - return nil, CreateDirectoryOutput{}, fmt.Errorf("failed to create directory: %w", err) - } - return nil, CreateDirectoryOutput{Success: true, Path: input.Path}, nil -} - -func (s *Service) deleteFile(ctx context.Context, req *mcp.CallToolRequest, input DeleteFileInput) (*mcp.CallToolResult, DeleteFileOutput, error) { - if err := s.medium.Delete(input.Path); err != nil { - return nil, DeleteFileOutput{}, fmt.Errorf("failed to delete file: %w", err) - } - return nil, DeleteFileOutput{Success: true, Path: input.Path}, nil -} - -func (s *Service) renameFile(ctx context.Context, req *mcp.CallToolRequest, input RenameFileInput) (*mcp.CallToolResult, RenameFileOutput, error) { - if err := s.medium.Rename(input.OldPath, input.NewPath); err != nil { - return nil, RenameFileOutput{}, fmt.Errorf("failed to rename file: %w", err) - } - return nil, RenameFileOutput{Success: true, OldPath: input.OldPath, NewPath: input.NewPath}, nil -} - -func (s *Service) fileExists(ctx context.Context, req *mcp.CallToolRequest, input FileExistsInput) (*mcp.CallToolResult, FileExistsOutput, error) { - exists := s.medium.IsFile(input.Path) - if exists { - return nil, FileExistsOutput{Exists: true, IsDir: false, Path: input.Path}, nil - } - // Check if it's a directory by attempting to list it - // List might fail if it's a file too (but we checked IsFile) or if doesn't exist. - _, err := s.medium.List(input.Path) - isDir := err == nil - - // If List failed, it might mean it doesn't exist OR it's a special file or permissions. - // Assuming if List works, it's a directory. - - // Refinement: If it doesn't exist, List returns error. - - return nil, FileExistsOutput{Exists: isDir, IsDir: isDir, Path: input.Path}, nil -} - -func (s *Service) detectLanguage(ctx context.Context, req *mcp.CallToolRequest, input DetectLanguageInput) (*mcp.CallToolResult, DetectLanguageOutput, error) { - lang := detectLanguageFromPath(input.Path) - return nil, DetectLanguageOutput{Language: lang, Path: input.Path}, nil -} - -func (s *Service) getSupportedLanguages(ctx context.Context, req *mcp.CallToolRequest, input GetSupportedLanguagesInput) (*mcp.CallToolResult, GetSupportedLanguagesOutput, error) { - languages := []LanguageInfo{ - {ID: "typescript", Name: "TypeScript", Extensions: []string{".ts", ".tsx"}}, - {ID: "javascript", Name: "JavaScript", Extensions: []string{".js", ".jsx"}}, - {ID: "go", Name: "Go", Extensions: []string{".go"}}, - {ID: "python", Name: "Python", Extensions: []string{".py"}}, - {ID: "rust", Name: "Rust", Extensions: []string{".rs"}}, - {ID: "java", Name: "Java", Extensions: []string{".java"}}, - {ID: "php", Name: "PHP", Extensions: []string{".php"}}, - {ID: "ruby", Name: "Ruby", Extensions: []string{".rb"}}, - {ID: "html", Name: "HTML", Extensions: []string{".html", ".htm"}}, - {ID: "css", Name: "CSS", Extensions: []string{".css"}}, - {ID: "json", Name: "JSON", Extensions: []string{".json"}}, - {ID: "yaml", Name: "YAML", Extensions: []string{".yaml", ".yml"}}, - {ID: "markdown", Name: "Markdown", Extensions: []string{".md", ".markdown"}}, - {ID: "sql", Name: "SQL", Extensions: []string{".sql"}}, - {ID: "shell", Name: "Shell", Extensions: []string{".sh", ".bash"}}, - } - return nil, GetSupportedLanguagesOutput{Languages: languages}, nil -} - -func (s *Service) editDiff(ctx context.Context, req *mcp.CallToolRequest, input EditDiffInput) (*mcp.CallToolResult, EditDiffOutput, error) { - if input.OldString == "" { - return nil, EditDiffOutput{}, errors.New("old_string cannot be empty") - } - - content, err := s.medium.Read(input.Path) - if err != nil { - return nil, EditDiffOutput{}, fmt.Errorf("failed to read file: %w", err) - } - - count := 0 - - if input.ReplaceAll { - count = strings.Count(content, input.OldString) - if count == 0 { - return nil, EditDiffOutput{}, errors.New("old_string not found in file") - } - content = strings.ReplaceAll(content, input.OldString, input.NewString) - } else { - if !strings.Contains(content, input.OldString) { - return nil, EditDiffOutput{}, errors.New("old_string not found in file") - } - content = strings.Replace(content, input.OldString, input.NewString, 1) - count = 1 - } - - if err := s.medium.Write(input.Path, content); err != nil { - return nil, EditDiffOutput{}, fmt.Errorf("failed to write file: %w", err) - } - - return nil, EditDiffOutput{ - Path: input.Path, - Success: true, - Replacements: count, - }, nil -} - -// detectLanguageFromPath maps file extensions to language IDs. -func detectLanguageFromPath(path string) string { - ext := filepath.Ext(path) - switch ext { - case ".ts", ".tsx": - return "typescript" - case ".js", ".jsx": - return "javascript" - case ".go": - return "go" - case ".py": - return "python" - case ".rs": - return "rust" - case ".rb": - return "ruby" - case ".java": - return "java" - case ".php": - return "php" - case ".c", ".h": - return "c" - case ".cpp", ".hpp", ".cc", ".cxx": - return "cpp" - case ".cs": - return "csharp" - case ".html", ".htm": - return "html" - case ".css": - return "css" - case ".scss": - return "scss" - case ".json": - return "json" - case ".yaml", ".yml": - return "yaml" - case ".xml": - return "xml" - case ".md", ".markdown": - return "markdown" - case ".sql": - return "sql" - case ".sh", ".bash": - return "shell" - case ".swift": - return "swift" - case ".kt", ".kts": - return "kotlin" - default: - if filepath.Base(path) == "Dockerfile" { - return "dockerfile" - } - return "plaintext" - } -} - -// Run starts the MCP server. -// If MCP_ADDR is set, it starts a TCP server. -// Otherwise, it starts a Stdio server. -func (s *Service) Run(ctx context.Context) error { - addr := os.Getenv("MCP_ADDR") - if addr != "" { - return s.ServeTCP(ctx, addr) - } - return s.server.Run(ctx, &mcp.StdioTransport{}) -} - -// Server returns the underlying MCP server for advanced configuration. -func (s *Service) Server() *mcp.Server { - return s.server -} diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go deleted file mode 100644 index a1701de..0000000 --- a/mcp/mcp_test.go +++ /dev/null @@ -1,180 +0,0 @@ -package mcp - -import ( - "os" - "path/filepath" - "testing" -) - -func TestNew_Good_DefaultWorkspace(t *testing.T) { - cwd, err := os.Getwd() - if err != nil { - t.Fatalf("Failed to get working directory: %v", err) - } - - s, err := New() - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - if s.workspaceRoot != cwd { - t.Errorf("Expected default workspace root %s, got %s", cwd, s.workspaceRoot) - } - if s.medium == nil { - t.Error("Expected medium to be set") - } -} - -func TestNew_Good_CustomWorkspace(t *testing.T) { - tmpDir := t.TempDir() - - s, err := New(WithWorkspaceRoot(tmpDir)) - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - if s.workspaceRoot != tmpDir { - t.Errorf("Expected workspace root %s, got %s", tmpDir, s.workspaceRoot) - } - if s.medium == nil { - t.Error("Expected medium to be set") - } -} - -func TestNew_Good_NoRestriction(t *testing.T) { - s, err := New(WithWorkspaceRoot("")) - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - if s.workspaceRoot != "" { - t.Errorf("Expected empty workspace root, got %s", s.workspaceRoot) - } - if s.medium == nil { - t.Error("Expected medium to be set (unsandboxed)") - } -} - -func TestMedium_Good_ReadWrite(t *testing.T) { - tmpDir := t.TempDir() - s, err := New(WithWorkspaceRoot(tmpDir)) - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - // Write a file - testContent := "hello world" - err = s.medium.Write("test.txt", testContent) - if err != nil { - t.Fatalf("Failed to write file: %v", err) - } - - // Read it back - content, err := s.medium.Read("test.txt") - if err != nil { - t.Fatalf("Failed to read file: %v", err) - } - if content != testContent { - t.Errorf("Expected content %q, got %q", testContent, content) - } - - // Verify file exists on disk - diskPath := filepath.Join(tmpDir, "test.txt") - if _, err := os.Stat(diskPath); os.IsNotExist(err) { - t.Error("File should exist on disk") - } -} - -func TestMedium_Good_EnsureDir(t *testing.T) { - tmpDir := t.TempDir() - s, err := New(WithWorkspaceRoot(tmpDir)) - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - err = s.medium.EnsureDir("subdir/nested") - if err != nil { - t.Fatalf("Failed to create directory: %v", err) - } - - // Verify directory exists - diskPath := filepath.Join(tmpDir, "subdir", "nested") - info, err := os.Stat(diskPath) - if os.IsNotExist(err) { - t.Error("Directory should exist on disk") - } - if err == nil && !info.IsDir() { - t.Error("Path should be a directory") - } -} - -func TestMedium_Good_IsFile(t *testing.T) { - tmpDir := t.TempDir() - s, err := New(WithWorkspaceRoot(tmpDir)) - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - // File doesn't exist yet - if s.medium.IsFile("test.txt") { - t.Error("File should not exist yet") - } - - // Create the file - _ = s.medium.Write("test.txt", "content") - - // Now it should exist - if !s.medium.IsFile("test.txt") { - t.Error("File should exist after write") - } -} - -func TestSandboxing_Traversal_Sanitized(t *testing.T) { - tmpDir := t.TempDir() - s, err := New(WithWorkspaceRoot(tmpDir)) - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - // Path traversal is sanitized (.. becomes .), so ../secret.txt becomes - // ./secret.txt in the workspace. Since that file doesn't exist, we get - // a file not found error (not a traversal error). - _, err = s.medium.Read("../secret.txt") - if err == nil { - t.Error("Expected error (file not found)") - } - - // Absolute paths are allowed through - they access the real filesystem. - // This is intentional for full filesystem access. Callers wanting sandboxing - // should validate inputs before calling Medium. -} - -func TestSandboxing_Symlinks_Blocked(t *testing.T) { - tmpDir := t.TempDir() - outsideDir := t.TempDir() - - // Create a target file outside workspace - targetFile := filepath.Join(outsideDir, "secret.txt") - if err := os.WriteFile(targetFile, []byte("secret"), 0644); err != nil { - t.Fatalf("Failed to create target file: %v", err) - } - - // Create symlink inside workspace pointing outside - symlinkPath := filepath.Join(tmpDir, "link") - if err := os.Symlink(targetFile, symlinkPath); err != nil { - t.Skipf("Symlinks not supported: %v", err) - } - - s, err := New(WithWorkspaceRoot(tmpDir)) - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - // Symlinks pointing outside the sandbox root are blocked (security feature). - // The sandbox resolves the symlink target and rejects it because it escapes - // the workspace boundary. - _, err = s.medium.Read("link") - if err == nil { - t.Error("Expected permission denied for symlink escaping sandbox, but read succeeded") - } -} diff --git a/mcp/registry.go b/mcp/registry.go deleted file mode 100644 index 21ae123..0000000 --- a/mcp/registry.go +++ /dev/null @@ -1,149 +0,0 @@ -// SPDX-License-Identifier: EUPL-1.2 - -package mcp - -import ( - "context" - "encoding/json" - "iter" - "reflect" - "strings" - - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// RESTHandler handles a tool call from a REST endpoint. -// It receives raw JSON input and returns the typed output or an error. -type RESTHandler func(ctx context.Context, body []byte) (any, error) - -// ToolRecord captures metadata about a registered MCP tool. -type ToolRecord struct { - Name string // Tool name, e.g. "file_read" - Description string // Human-readable description - Group string // Subsystem group name, e.g. "files", "rag" - InputSchema map[string]any // JSON Schema from Go struct reflection - OutputSchema map[string]any // JSON Schema from Go struct reflection - RESTHandler RESTHandler // REST-callable handler created at registration time -} - -// addToolRecorded registers a tool with the MCP server AND records its metadata. -// This is a generic function that captures the In/Out types for schema extraction. -// It also creates a RESTHandler closure that can unmarshal JSON to the correct -// input type and call the handler directly, enabling the MCP-to-REST bridge. -func addToolRecorded[In, Out any](s *Service, server *mcp.Server, group string, t *mcp.Tool, h mcp.ToolHandlerFor[In, Out]) { - mcp.AddTool(server, t, h) - - restHandler := func(ctx context.Context, body []byte) (any, error) { - var input In - if len(body) > 0 { - if err := json.Unmarshal(body, &input); err != nil { - return nil, err - } - } - // nil: REST callers have no MCP request context. - // Tool handlers called via REST must not dereference CallToolRequest. - _, output, err := h(ctx, nil, input) - return output, err - } - - s.tools = append(s.tools, ToolRecord{ - Name: t.Name, - Description: t.Description, - Group: group, - InputSchema: structSchema(new(In)), - OutputSchema: structSchema(new(Out)), - RESTHandler: restHandler, - }) -} - -// structSchema builds a simple JSON Schema from a struct's json tags via reflection. -// Returns nil for non-struct types or empty structs. -func structSchema(v any) map[string]any { - t := reflect.TypeOf(v) - if t == nil { - return nil - } - if t.Kind() == reflect.Pointer { - t = t.Elem() - } - if t.Kind() != reflect.Struct { - return nil - } - if t.NumField() == 0 { - return map[string]any{"type": "object", "properties": map[string]any{}} - } - - properties := make(map[string]any) - required := make([]string, 0) - - for f := range t.Fields() { - f := f - if !f.IsExported() { - continue - } - jsonTag := f.Tag.Get("json") - if jsonTag == "-" { - continue - } - name := f.Name - isOptional := false - if jsonTag != "" { - parts := splitTag(jsonTag) - name = parts[0] - for _, p := range parts[1:] { - if p == "omitempty" { - isOptional = true - } - } - } - - prop := map[string]any{ - "type": goTypeToJSONType(f.Type), - } - properties[name] = prop - - if !isOptional { - required = append(required, name) - } - } - - schema := map[string]any{ - "type": "object", - "properties": properties, - } - if len(required) > 0 { - schema["required"] = required - } - return schema -} - -// splitTag splits a struct tag value by commas. -func splitTag(tag string) []string { - return strings.Split(tag, ",") -} - -// splitTagSeq returns an iterator over the tag parts. -func splitTagSeq(tag string) iter.Seq[string] { - return strings.SplitSeq(tag, ",") -} - -// goTypeToJSONType maps Go types to JSON Schema types. -func goTypeToJSONType(t reflect.Type) string { - switch t.Kind() { - case reflect.String: - return "string" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return "integer" - case reflect.Float32, reflect.Float64: - return "number" - case reflect.Bool: - return "boolean" - case reflect.Slice, reflect.Array: - return "array" - case reflect.Map, reflect.Struct: - return "object" - default: - return "string" - } -} diff --git a/mcp/registry_test.go b/mcp/registry_test.go deleted file mode 100644 index 15cdc14..0000000 --- a/mcp/registry_test.go +++ /dev/null @@ -1,150 +0,0 @@ -// SPDX-License-Identifier: EUPL-1.2 - -package mcp - -import ( - "testing" -) - -func TestToolRegistry_Good_RecordsTools(t *testing.T) { - svc, err := New(WithWorkspaceRoot(t.TempDir())) - if err != nil { - t.Fatal(err) - } - - tools := svc.Tools() - if len(tools) == 0 { - t.Fatal("expected non-empty tool registry") - } - - found := false - for _, tr := range tools { - if tr.Name == "file_read" { - found = true - break - } - } - if !found { - t.Error("expected file_read in tool registry") - } -} - -func TestToolRegistry_Good_SchemaExtraction(t *testing.T) { - svc, err := New(WithWorkspaceRoot(t.TempDir())) - if err != nil { - t.Fatal(err) - } - - var record ToolRecord - for _, tr := range svc.Tools() { - if tr.Name == "file_read" { - record = tr - break - } - } - if record.Name == "" { - t.Fatal("file_read not found in registry") - } - - if record.InputSchema == nil { - t.Fatal("expected non-nil InputSchema for file_read") - } - - props, ok := record.InputSchema["properties"].(map[string]any) - if !ok { - t.Fatal("expected properties map in InputSchema") - } - - if _, ok := props["path"]; !ok { - t.Error("expected 'path' property in file_read InputSchema") - } -} - -func TestToolRegistry_Good_ToolCount(t *testing.T) { - svc, err := New(WithWorkspaceRoot(t.TempDir())) - if err != nil { - t.Fatal(err) - } - - tools := svc.Tools() - // Built-in tools: file_read, file_write, file_delete, file_rename, - // file_exists, file_edit, dir_list, dir_create, lang_detect, lang_list - const expectedCount = 10 - if len(tools) != expectedCount { - t.Errorf("expected %d tools, got %d", expectedCount, len(tools)) - for _, tr := range tools { - t.Logf(" - %s (%s)", tr.Name, tr.Group) - } - } -} - -func TestToolRegistry_Good_GroupAssignment(t *testing.T) { - svc, err := New(WithWorkspaceRoot(t.TempDir())) - if err != nil { - t.Fatal(err) - } - - fileTools := []string{"file_read", "file_write", "file_delete", "file_rename", "file_exists", "file_edit", "dir_list", "dir_create"} - langTools := []string{"lang_detect", "lang_list"} - - byName := make(map[string]ToolRecord) - for _, tr := range svc.Tools() { - byName[tr.Name] = tr - } - - for _, name := range fileTools { - tr, ok := byName[name] - if !ok { - t.Errorf("tool %s not found in registry", name) - continue - } - if tr.Group != "files" { - t.Errorf("tool %s: expected group 'files', got %q", name, tr.Group) - } - } - - for _, name := range langTools { - tr, ok := byName[name] - if !ok { - t.Errorf("tool %s not found in registry", name) - continue - } - if tr.Group != "language" { - t.Errorf("tool %s: expected group 'language', got %q", name, tr.Group) - } - } -} - -func TestToolRegistry_Good_ToolRecordFields(t *testing.T) { - svc, err := New(WithWorkspaceRoot(t.TempDir())) - if err != nil { - t.Fatal(err) - } - - var record ToolRecord - for _, tr := range svc.Tools() { - if tr.Name == "file_write" { - record = tr - break - } - } - if record.Name == "" { - t.Fatal("file_write not found in registry") - } - - if record.Name != "file_write" { - t.Errorf("expected Name 'file_write', got %q", record.Name) - } - if record.Description == "" { - t.Error("expected non-empty Description") - } - if record.Group == "" { - t.Error("expected non-empty Group") - } - if record.InputSchema == nil { - t.Error("expected non-nil InputSchema") - } - if record.OutputSchema == nil { - t.Error("expected non-nil OutputSchema") - } -} diff --git a/mcp/subsystem.go b/mcp/subsystem.go deleted file mode 100644 index 56bd6f7..0000000 --- a/mcp/subsystem.go +++ /dev/null @@ -1,32 +0,0 @@ -package mcp - -import ( - "context" - - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// Subsystem registers additional MCP tools at startup. -// Implementations should be safe to call concurrently. -type Subsystem interface { - // Name returns a human-readable identifier for logging. - Name() string - - // RegisterTools adds tools to the MCP server during initialisation. - RegisterTools(server *mcp.Server) -} - -// SubsystemWithShutdown extends Subsystem with graceful cleanup. -type SubsystemWithShutdown interface { - Subsystem - Shutdown(ctx context.Context) error -} - -// WithSubsystem registers a subsystem whose tools will be added -// after the built-in tools during New(). -func WithSubsystem(sub Subsystem) Option { - return func(s *Service) error { - s.subsystems = append(s.subsystems, sub) - return nil - } -} diff --git a/mcp/subsystem_test.go b/mcp/subsystem_test.go deleted file mode 100644 index 5e823f7..0000000 --- a/mcp/subsystem_test.go +++ /dev/null @@ -1,114 +0,0 @@ -package mcp - -import ( - "context" - "testing" - - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// stubSubsystem is a minimal Subsystem for testing. -type stubSubsystem struct { - name string - toolsRegistered bool -} - -func (s *stubSubsystem) Name() string { return s.name } - -func (s *stubSubsystem) RegisterTools(server *mcp.Server) { - s.toolsRegistered = true -} - -// shutdownSubsystem tracks Shutdown calls. -type shutdownSubsystem struct { - stubSubsystem - shutdownCalled bool - shutdownErr error -} - -func (s *shutdownSubsystem) Shutdown(_ context.Context) error { - s.shutdownCalled = true - return s.shutdownErr -} - -func TestWithSubsystem_Good_Registration(t *testing.T) { - sub := &stubSubsystem{name: "test-sub"} - svc, err := New(WithSubsystem(sub)) - if err != nil { - t.Fatalf("New() failed: %v", err) - } - - if len(svc.Subsystems()) != 1 { - t.Fatalf("expected 1 subsystem, got %d", len(svc.Subsystems())) - } - if svc.Subsystems()[0].Name() != "test-sub" { - t.Errorf("expected name 'test-sub', got %q", svc.Subsystems()[0].Name()) - } -} - -func TestWithSubsystem_Good_ToolsRegistered(t *testing.T) { - sub := &stubSubsystem{name: "tools-sub"} - _, err := New(WithSubsystem(sub)) - if err != nil { - t.Fatalf("New() failed: %v", err) - } - if !sub.toolsRegistered { - t.Error("expected RegisterTools to have been called") - } -} - -func TestWithSubsystem_Good_MultipleSubsystems(t *testing.T) { - sub1 := &stubSubsystem{name: "sub-1"} - sub2 := &stubSubsystem{name: "sub-2"} - svc, err := New(WithSubsystem(sub1), WithSubsystem(sub2)) - if err != nil { - t.Fatalf("New() failed: %v", err) - } - if len(svc.Subsystems()) != 2 { - t.Fatalf("expected 2 subsystems, got %d", len(svc.Subsystems())) - } - if !sub1.toolsRegistered || !sub2.toolsRegistered { - t.Error("expected all subsystems to have RegisterTools called") - } -} - -func TestSubsystemShutdown_Good(t *testing.T) { - sub := &shutdownSubsystem{stubSubsystem: stubSubsystem{name: "shutdown-sub"}} - svc, err := New(WithSubsystem(sub)) - if err != nil { - t.Fatalf("New() failed: %v", err) - } - if err := svc.Shutdown(context.Background()); err != nil { - t.Fatalf("Shutdown() failed: %v", err) - } - if !sub.shutdownCalled { - t.Error("expected Shutdown to have been called") - } -} - -func TestSubsystemShutdown_Bad_Error(t *testing.T) { - sub := &shutdownSubsystem{ - stubSubsystem: stubSubsystem{name: "fail-sub"}, - shutdownErr: context.DeadlineExceeded, - } - svc, err := New(WithSubsystem(sub)) - if err != nil { - t.Fatalf("New() failed: %v", err) - } - err = svc.Shutdown(context.Background()) - if err == nil { - t.Fatal("expected error from Shutdown") - } -} - -func TestSubsystemShutdown_Good_NoShutdownInterface(t *testing.T) { - // A plain Subsystem (without Shutdown) should not cause errors. - sub := &stubSubsystem{name: "plain-sub"} - svc, err := New(WithSubsystem(sub)) - if err != nil { - t.Fatalf("New() failed: %v", err) - } - if err := svc.Shutdown(context.Background()); err != nil { - t.Fatalf("Shutdown() should succeed for non-shutdown subsystem: %v", err) - } -} diff --git a/mcp/tools_metrics.go b/mcp/tools_metrics.go deleted file mode 100644 index d0e3811..0000000 --- a/mcp/tools_metrics.go +++ /dev/null @@ -1,213 +0,0 @@ -package mcp - -import ( - "context" - "errors" - "fmt" - "strconv" - "strings" - "time" - - "forge.lthn.ai/core/go-ai/ai" - "forge.lthn.ai/core/go-log" - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// Default values for metrics operations. -const ( - DefaultMetricsSince = "7d" - DefaultMetricsLimit = 10 -) - -// MetricsRecordInput contains parameters for recording a metrics event. -type MetricsRecordInput struct { - Type string `json:"type"` // Event type (required) - AgentID string `json:"agent_id,omitempty"` // Agent identifier - Repo string `json:"repo,omitempty"` // Repository name - Data map[string]any `json:"data,omitempty"` // Additional event data -} - -// MetricsRecordOutput contains the result of recording a metrics event. -type MetricsRecordOutput struct { - Success bool `json:"success"` - Timestamp time.Time `json:"timestamp"` -} - -// MetricsQueryInput contains parameters for querying metrics. -type MetricsQueryInput struct { - Since string `json:"since,omitempty"` // Time range like "7d", "24h", "30m" (default: "7d") -} - -// MetricsQueryOutput contains the results of a metrics query. -type MetricsQueryOutput struct { - Total int `json:"total"` - ByType []MetricCount `json:"by_type"` - ByRepo []MetricCount `json:"by_repo"` - ByAgent []MetricCount `json:"by_agent"` - Events []MetricEventBrief `json:"events"` // Most recent 10 events -} - -// MetricCount represents a count for a specific key. -type MetricCount struct { - Key string `json:"key"` - Count int `json:"count"` -} - -// MetricEventBrief represents a brief summary of an event. -type MetricEventBrief struct { - Type string `json:"type"` - Timestamp time.Time `json:"timestamp"` - AgentID string `json:"agent_id,omitempty"` - Repo string `json:"repo,omitempty"` -} - -// registerMetricsTools adds metrics tools to the MCP server. -func (s *Service) registerMetricsTools(server *mcp.Server) { - mcp.AddTool(server, &mcp.Tool{ - Name: "metrics_record", - Description: "Record a metrics event for AI/security tracking. Events are stored in daily JSONL files.", - }, s.metricsRecord) - - mcp.AddTool(server, &mcp.Tool{ - Name: "metrics_query", - Description: "Query metrics events and get aggregated statistics by type, repo, and agent.", - }, s.metricsQuery) -} - -// metricsRecord handles the metrics_record tool call. -func (s *Service) metricsRecord(ctx context.Context, req *mcp.CallToolRequest, input MetricsRecordInput) (*mcp.CallToolResult, MetricsRecordOutput, error) { - s.logger.Info("MCP tool execution", "tool", "metrics_record", "type", input.Type, "agent_id", input.AgentID, "repo", input.Repo, "user", log.Username()) - - // Validate input - if input.Type == "" { - return nil, MetricsRecordOutput{}, errors.New("type cannot be empty") - } - - // Create the event - event := ai.Event{ - Type: input.Type, - Timestamp: time.Now(), - AgentID: input.AgentID, - Repo: input.Repo, - Data: input.Data, - } - - // Record the event - if err := ai.Record(event); err != nil { - log.Error("mcp: metrics record failed", "type", input.Type, "err", err) - return nil, MetricsRecordOutput{}, fmt.Errorf("failed to record metrics: %w", err) - } - - return nil, MetricsRecordOutput{ - Success: true, - Timestamp: event.Timestamp, - }, nil -} - -// metricsQuery handles the metrics_query tool call. -func (s *Service) metricsQuery(ctx context.Context, req *mcp.CallToolRequest, input MetricsQueryInput) (*mcp.CallToolResult, MetricsQueryOutput, error) { - // Apply defaults - since := input.Since - if since == "" { - since = DefaultMetricsSince - } - - s.logger.Info("MCP tool execution", "tool", "metrics_query", "since", since, "user", log.Username()) - - // Parse the duration - duration, err := parseDuration(since) - if err != nil { - return nil, MetricsQueryOutput{}, fmt.Errorf("invalid since value: %w", err) - } - - sinceTime := time.Now().Add(-duration) - - // Read events - events, err := ai.ReadEvents(sinceTime) - if err != nil { - log.Error("mcp: metrics query failed", "since", since, "err", err) - return nil, MetricsQueryOutput{}, fmt.Errorf("failed to read metrics: %w", err) - } - - // Get summary - summary := ai.Summary(events) - - // Build output - output := MetricsQueryOutput{ - Total: summary["total"].(int), - ByType: convertMetricCounts(summary["by_type"]), - ByRepo: convertMetricCounts(summary["by_repo"]), - ByAgent: convertMetricCounts(summary["by_agent"]), - Events: make([]MetricEventBrief, 0, DefaultMetricsLimit), - } - - // Get recent events (last 10, most recent first) - startIdx := max(len(events)-DefaultMetricsLimit, 0) - for i := len(events) - 1; i >= startIdx; i-- { - ev := events[i] - output.Events = append(output.Events, MetricEventBrief{ - Type: ev.Type, - Timestamp: ev.Timestamp, - AgentID: ev.AgentID, - Repo: ev.Repo, - }) - } - - return nil, output, nil -} - -// convertMetricCounts converts the summary map format to MetricCount slice. -func convertMetricCounts(data any) []MetricCount { - if data == nil { - return []MetricCount{} - } - - items, ok := data.([]map[string]any) - if !ok { - return []MetricCount{} - } - - result := make([]MetricCount, len(items)) - for i, item := range items { - key, _ := item["key"].(string) - count, _ := item["count"].(int) - result[i] = MetricCount{Key: key, Count: count} - } - return result -} - -// parseDuration parses a duration string like "7d", "24h", "30m". -func parseDuration(s string) (time.Duration, error) { - if s == "" { - return 0, errors.New("duration cannot be empty") - } - - s = strings.TrimSpace(s) - if len(s) < 2 { - return 0, fmt.Errorf("invalid duration format: %q", s) - } - - // Get the numeric part and unit - unit := s[len(s)-1] - numStr := s[:len(s)-1] - - num, err := strconv.Atoi(numStr) - if err != nil { - return 0, fmt.Errorf("invalid duration number: %q", numStr) - } - - if num <= 0 { - return 0, fmt.Errorf("duration must be positive: %d", num) - } - - switch unit { - case 'd': - return time.Duration(num) * 24 * time.Hour, nil - case 'h': - return time.Duration(num) * time.Hour, nil - case 'm': - return time.Duration(num) * time.Minute, nil - default: - return 0, fmt.Errorf("invalid duration unit: %q (expected d, h, or m)", string(unit)) - } -} diff --git a/mcp/tools_metrics_test.go b/mcp/tools_metrics_test.go deleted file mode 100644 index c34ee6c..0000000 --- a/mcp/tools_metrics_test.go +++ /dev/null @@ -1,207 +0,0 @@ -package mcp - -import ( - "testing" - "time" -) - -// TestMetricsToolsRegistered_Good verifies that metrics tools are registered with the MCP server. -func TestMetricsToolsRegistered_Good(t *testing.T) { - // Create a new MCP service - this should register all tools including metrics - s, err := New() - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - // The server should have registered the metrics tools - // We verify by checking that the server and logger exist - if s.server == nil { - t.Fatal("Server should not be nil") - } - - if s.logger == nil { - t.Error("Logger should not be nil") - } -} - -// TestMetricsRecordInput_Good verifies the MetricsRecordInput struct has expected fields. -func TestMetricsRecordInput_Good(t *testing.T) { - input := MetricsRecordInput{ - Type: "tool_call", - AgentID: "agent-123", - Repo: "host-uk/core", - Data: map[string]any{"tool": "file_read", "duration_ms": 150}, - } - - if input.Type != "tool_call" { - t.Errorf("Expected type 'tool_call', got %q", input.Type) - } - if input.AgentID != "agent-123" { - t.Errorf("Expected agent_id 'agent-123', got %q", input.AgentID) - } - if input.Repo != "host-uk/core" { - t.Errorf("Expected repo 'host-uk/core', got %q", input.Repo) - } - if input.Data["tool"] != "file_read" { - t.Errorf("Expected data[tool] 'file_read', got %v", input.Data["tool"]) - } -} - -// TestMetricsRecordOutput_Good verifies the MetricsRecordOutput struct has expected fields. -func TestMetricsRecordOutput_Good(t *testing.T) { - ts := time.Now() - output := MetricsRecordOutput{ - Success: true, - Timestamp: ts, - } - - if !output.Success { - t.Error("Expected success to be true") - } - if output.Timestamp != ts { - t.Errorf("Expected timestamp %v, got %v", ts, output.Timestamp) - } -} - -// TestMetricsQueryInput_Good verifies the MetricsQueryInput struct has expected fields. -func TestMetricsQueryInput_Good(t *testing.T) { - input := MetricsQueryInput{ - Since: "7d", - } - - if input.Since != "7d" { - t.Errorf("Expected since '7d', got %q", input.Since) - } -} - -// TestMetricsQueryInput_Defaults verifies default values are handled correctly. -func TestMetricsQueryInput_Defaults(t *testing.T) { - input := MetricsQueryInput{} - - // Empty since should use default when processed - if input.Since != "" { - t.Errorf("Expected empty since before defaults, got %q", input.Since) - } -} - -// TestMetricsQueryOutput_Good verifies the MetricsQueryOutput struct has expected fields. -func TestMetricsQueryOutput_Good(t *testing.T) { - output := MetricsQueryOutput{ - Total: 100, - ByType: []MetricCount{ - {Key: "tool_call", Count: 50}, - {Key: "query", Count: 30}, - }, - ByRepo: []MetricCount{ - {Key: "host-uk/core", Count: 40}, - }, - ByAgent: []MetricCount{ - {Key: "agent-123", Count: 25}, - }, - Events: []MetricEventBrief{ - {Type: "tool_call", Timestamp: time.Now(), AgentID: "agent-1", Repo: "host-uk/core"}, - }, - } - - if output.Total != 100 { - t.Errorf("Expected total 100, got %d", output.Total) - } - if len(output.ByType) != 2 { - t.Errorf("Expected 2 ByType entries, got %d", len(output.ByType)) - } - if output.ByType[0].Key != "tool_call" { - t.Errorf("Expected ByType[0].Key 'tool_call', got %q", output.ByType[0].Key) - } - if output.ByType[0].Count != 50 { - t.Errorf("Expected ByType[0].Count 50, got %d", output.ByType[0].Count) - } - if len(output.Events) != 1 { - t.Errorf("Expected 1 event, got %d", len(output.Events)) - } -} - -// TestMetricCount_Good verifies the MetricCount struct has expected fields. -func TestMetricCount_Good(t *testing.T) { - mc := MetricCount{ - Key: "tool_call", - Count: 42, - } - - if mc.Key != "tool_call" { - t.Errorf("Expected key 'tool_call', got %q", mc.Key) - } - if mc.Count != 42 { - t.Errorf("Expected count 42, got %d", mc.Count) - } -} - -// TestMetricEventBrief_Good verifies the MetricEventBrief struct has expected fields. -func TestMetricEventBrief_Good(t *testing.T) { - ts := time.Now() - ev := MetricEventBrief{ - Type: "tool_call", - Timestamp: ts, - AgentID: "agent-123", - Repo: "host-uk/core", - } - - if ev.Type != "tool_call" { - t.Errorf("Expected type 'tool_call', got %q", ev.Type) - } - if ev.Timestamp != ts { - t.Errorf("Expected timestamp %v, got %v", ts, ev.Timestamp) - } - if ev.AgentID != "agent-123" { - t.Errorf("Expected agent_id 'agent-123', got %q", ev.AgentID) - } - if ev.Repo != "host-uk/core" { - t.Errorf("Expected repo 'host-uk/core', got %q", ev.Repo) - } -} - -// TestParseDuration_Good verifies the parseDuration helper handles various formats. -func TestParseDuration_Good(t *testing.T) { - tests := []struct { - input string - expected time.Duration - }{ - {"7d", 7 * 24 * time.Hour}, - {"24h", 24 * time.Hour}, - {"30m", 30 * time.Minute}, - {"1d", 24 * time.Hour}, - {"14d", 14 * 24 * time.Hour}, - {"1h", time.Hour}, - {"10m", 10 * time.Minute}, - } - - for _, tc := range tests { - t.Run(tc.input, func(t *testing.T) { - d, err := parseDuration(tc.input) - if err != nil { - t.Fatalf("parseDuration(%q) returned error: %v", tc.input, err) - } - if d != tc.expected { - t.Errorf("parseDuration(%q) = %v, want %v", tc.input, d, tc.expected) - } - }) - } -} - -// TestParseDuration_Bad verifies parseDuration returns errors for invalid input. -func TestParseDuration_Bad(t *testing.T) { - tests := []string{ - "", - "abc", - "7x", - "-7d", - } - - for _, input := range tests { - t.Run(input, func(t *testing.T) { - _, err := parseDuration(input) - if err == nil { - t.Errorf("parseDuration(%q) should return error", input) - } - }) - } -} diff --git a/mcp/tools_ml.go b/mcp/tools_ml.go deleted file mode 100644 index 55a0f08..0000000 --- a/mcp/tools_ml.go +++ /dev/null @@ -1,290 +0,0 @@ -package mcp - -import ( - "context" - "errors" - "fmt" - "strings" - - "forge.lthn.ai/core/go-inference" - "forge.lthn.ai/core/go-ml" - "forge.lthn.ai/core/go-log" - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// MLSubsystem exposes ML inference and scoring tools via MCP. -type MLSubsystem struct { - service *ml.Service - logger *log.Logger -} - -// NewMLSubsystem creates an MCP subsystem for ML tools. -func NewMLSubsystem(svc *ml.Service) *MLSubsystem { - return &MLSubsystem{ - service: svc, - logger: log.Default(), - } -} - -func (m *MLSubsystem) Name() string { return "ml" } - -// RegisterTools adds ML tools to the MCP server. -func (m *MLSubsystem) RegisterTools(server *mcp.Server) { - mcp.AddTool(server, &mcp.Tool{ - Name: "ml_generate", - Description: "Generate text via a configured ML inference backend.", - }, m.mlGenerate) - - mcp.AddTool(server, &mcp.Tool{ - Name: "ml_score", - Description: "Score a prompt/response pair using heuristic and LLM judge suites.", - }, m.mlScore) - - mcp.AddTool(server, &mcp.Tool{ - Name: "ml_probe", - Description: "Run capability probes against an inference backend.", - }, m.mlProbe) - - mcp.AddTool(server, &mcp.Tool{ - Name: "ml_status", - Description: "Show training and generation progress from InfluxDB.", - }, m.mlStatus) - - mcp.AddTool(server, &mcp.Tool{ - Name: "ml_backends", - Description: "List available inference backends and their status.", - }, m.mlBackends) -} - -// --- Input/Output types --- - -// MLGenerateInput contains parameters for text generation. -type MLGenerateInput struct { - Prompt string `json:"prompt"` // The prompt to generate from - Backend string `json:"backend,omitempty"` // Backend name (default: service default) - Model string `json:"model,omitempty"` // Model override - Temperature float64 `json:"temperature,omitempty"` // Sampling temperature - MaxTokens int `json:"max_tokens,omitempty"` // Maximum tokens to generate -} - -// MLGenerateOutput contains the generation result. -type MLGenerateOutput struct { - Response string `json:"response"` - Backend string `json:"backend"` - Model string `json:"model,omitempty"` -} - -// MLScoreInput contains parameters for scoring a response. -type MLScoreInput struct { - Prompt string `json:"prompt"` // The original prompt - Response string `json:"response"` // The model response to score - Suites string `json:"suites,omitempty"` // Comma-separated suites (default: heuristic) -} - -// MLScoreOutput contains the scoring result. -type MLScoreOutput struct { - Heuristic *ml.HeuristicScores `json:"heuristic,omitempty"` - Semantic *ml.SemanticScores `json:"semantic,omitempty"` - Content *ml.ContentScores `json:"content,omitempty"` -} - -// MLProbeInput contains parameters for running probes. -type MLProbeInput struct { - Backend string `json:"backend,omitempty"` // Backend name - Categories string `json:"categories,omitempty"` // Comma-separated categories to run -} - -// MLProbeOutput contains probe results. -type MLProbeOutput struct { - Total int `json:"total"` - Results []MLProbeResultItem `json:"results"` -} - -// MLProbeResultItem is a single probe result. -type MLProbeResultItem struct { - ID string `json:"id"` - Category string `json:"category"` - Response string `json:"response"` -} - -// MLStatusInput contains parameters for the status query. -type MLStatusInput struct { - InfluxURL string `json:"influx_url,omitempty"` // InfluxDB URL override - InfluxDB string `json:"influx_db,omitempty"` // InfluxDB database override -} - -// MLStatusOutput contains pipeline status. -type MLStatusOutput struct { - Status string `json:"status"` -} - -// MLBackendsInput is empty — lists all backends. -type MLBackendsInput struct{} - -// MLBackendsOutput lists available backends. -type MLBackendsOutput struct { - Backends []MLBackendInfo `json:"backends"` - Default string `json:"default"` -} - -// MLBackendInfo describes a single backend. -type MLBackendInfo struct { - Name string `json:"name"` - Available bool `json:"available"` -} - -// --- Tool handlers --- - -// mlGenerate delegates to go-ml.Service.Generate, which internally uses -// InferenceAdapter to route generation through an inference.TextModel. -// Flow: go-ai → go-ml.Service.Generate → InferenceAdapter → inference.TextModel. -func (m *MLSubsystem) mlGenerate(ctx context.Context, req *mcp.CallToolRequest, input MLGenerateInput) (*mcp.CallToolResult, MLGenerateOutput, error) { - m.logger.Info("MCP tool execution", "tool", "ml_generate", "backend", input.Backend, "user", log.Username()) - - if input.Prompt == "" { - return nil, MLGenerateOutput{}, errors.New("prompt cannot be empty") - } - - opts := ml.GenOpts{ - Temperature: input.Temperature, - MaxTokens: input.MaxTokens, - Model: input.Model, - } - - result, err := m.service.Generate(ctx, input.Backend, input.Prompt, opts) - if err != nil { - return nil, MLGenerateOutput{}, fmt.Errorf("generate: %w", err) - } - - return nil, MLGenerateOutput{ - Response: result.Text, - Backend: input.Backend, - Model: input.Model, - }, nil -} - -func (m *MLSubsystem) mlScore(ctx context.Context, req *mcp.CallToolRequest, input MLScoreInput) (*mcp.CallToolResult, MLScoreOutput, error) { - m.logger.Info("MCP tool execution", "tool", "ml_score", "suites", input.Suites, "user", log.Username()) - - if input.Prompt == "" || input.Response == "" { - return nil, MLScoreOutput{}, errors.New("prompt and response cannot be empty") - } - - suites := input.Suites - if suites == "" { - suites = "heuristic" - } - - output := MLScoreOutput{} - - for suite := range strings.SplitSeq(suites, ",") { - suite = strings.TrimSpace(suite) - switch suite { - case "heuristic": - output.Heuristic = ml.ScoreHeuristic(input.Response) - case "semantic": - judge := m.service.Judge() - if judge == nil { - return nil, MLScoreOutput{}, errors.New("semantic scoring requires a judge backend") - } - s, err := judge.ScoreSemantic(ctx, input.Prompt, input.Response) - if err != nil { - return nil, MLScoreOutput{}, fmt.Errorf("semantic score: %w", err) - } - output.Semantic = s - case "content": - return nil, MLScoreOutput{}, errors.New("content scoring requires a ContentProbe — use ml_probe instead") - } - } - - return nil, output, nil -} - -// mlProbe runs capability probes by generating responses via go-ml.Service. -// Flow: go-ai → go-ml.Service.Generate → InferenceAdapter → inference.TextModel. -func (m *MLSubsystem) mlProbe(ctx context.Context, req *mcp.CallToolRequest, input MLProbeInput) (*mcp.CallToolResult, MLProbeOutput, error) { - m.logger.Info("MCP tool execution", "tool", "ml_probe", "backend", input.Backend, "user", log.Username()) - - // Filter probes by category if specified. - probes := ml.CapabilityProbes - if input.Categories != "" { - cats := make(map[string]bool) - for c := range strings.SplitSeq(input.Categories, ",") { - cats[strings.TrimSpace(c)] = true - } - var filtered []ml.Probe - for _, p := range probes { - if cats[p.Category] { - filtered = append(filtered, p) - } - } - probes = filtered - } - - var results []MLProbeResultItem - for _, probe := range probes { - result, err := m.service.Generate(ctx, input.Backend, probe.Prompt, ml.GenOpts{Temperature: 0.7, MaxTokens: 2048}) - respText := result.Text - if err != nil { - respText = fmt.Sprintf("error: %v", err) - } - results = append(results, MLProbeResultItem{ - ID: probe.ID, - Category: probe.Category, - Response: respText, - }) - } - - return nil, MLProbeOutput{ - Total: len(results), - Results: results, - }, nil -} - -func (m *MLSubsystem) mlStatus(ctx context.Context, req *mcp.CallToolRequest, input MLStatusInput) (*mcp.CallToolResult, MLStatusOutput, error) { - m.logger.Info("MCP tool execution", "tool", "ml_status", "user", log.Username()) - - url := input.InfluxURL - db := input.InfluxDB - if url == "" { - url = "http://localhost:8086" - } - if db == "" { - db = "lem" - } - - influx := ml.NewInfluxClient(url, db) - var buf strings.Builder - if err := ml.PrintStatus(influx, &buf); err != nil { - return nil, MLStatusOutput{}, fmt.Errorf("status: %w", err) - } - - return nil, MLStatusOutput{Status: buf.String()}, nil -} - -// mlBackends enumerates registered backends via the go-inference registry, -// bypassing go-ml.Service entirely. This is the canonical source of truth -// for backend availability since all backends register with inference.Register(). -func (m *MLSubsystem) mlBackends(ctx context.Context, req *mcp.CallToolRequest, input MLBackendsInput) (*mcp.CallToolResult, MLBackendsOutput, error) { - m.logger.Info("MCP tool execution", "tool", "ml_backends", "user", log.Username()) - - names := inference.List() - backends := make([]MLBackendInfo, 0, len(names)) - for _, name := range names { - b, ok := inference.Get(name) - backends = append(backends, MLBackendInfo{ - Name: name, - Available: ok && b.Available(), - }) - } - - defaultName := "" - if db, err := inference.Default(); err == nil { - defaultName = db.Name() - } - - return nil, MLBackendsOutput{ - Backends: backends, - Default: defaultName, - }, nil -} diff --git a/mcp/tools_ml_test.go b/mcp/tools_ml_test.go deleted file mode 100644 index 902405f..0000000 --- a/mcp/tools_ml_test.go +++ /dev/null @@ -1,479 +0,0 @@ -package mcp - -import ( - "context" - "fmt" - "strings" - "testing" - - "forge.lthn.ai/core/go-inference" - "forge.lthn.ai/core/go-ml" - "forge.lthn.ai/core/go/pkg/core" - "forge.lthn.ai/core/go-log" -) - -// --- Mock backend for inference registry --- - -// mockInferenceBackend implements inference.Backend for CI testing of ml_backends. -type mockInferenceBackend struct { - name string - available bool -} - -func (m *mockInferenceBackend) Name() string { return m.name } -func (m *mockInferenceBackend) Available() bool { return m.available } -func (m *mockInferenceBackend) LoadModel(_ string, _ ...inference.LoadOption) (inference.TextModel, error) { - return nil, fmt.Errorf("mock backend: LoadModel not implemented") -} - -// --- Mock ml.Backend for Generate --- - -// mockMLBackend implements ml.Backend for CI testing. -type mockMLBackend struct { - name string - available bool - generateResp string - generateErr error -} - -func (m *mockMLBackend) Name() string { return m.name } -func (m *mockMLBackend) Available() bool { return m.available } - -func (m *mockMLBackend) Generate(_ context.Context, _ string, _ ml.GenOpts) (ml.Result, error) { - return ml.Result{Text: m.generateResp}, m.generateErr -} - -func (m *mockMLBackend) Chat(_ context.Context, _ []ml.Message, _ ml.GenOpts) (ml.Result, error) { - return ml.Result{Text: m.generateResp}, m.generateErr -} - -// newTestMLSubsystem creates an MLSubsystem with a real ml.Service for testing. -func newTestMLSubsystem(t *testing.T, backends ...ml.Backend) *MLSubsystem { - t.Helper() - c, err := core.New( - core.WithName("ml", ml.NewService(ml.Options{})), - ) - if err != nil { - t.Fatalf("Failed to create framework core: %v", err) - } - svc, err := core.ServiceFor[*ml.Service](c, "ml") - if err != nil { - t.Fatalf("Failed to get ML service: %v", err) - } - // Register mock backends - for _, b := range backends { - svc.RegisterBackend(b.Name(), b) - } - return &MLSubsystem{ - service: svc, - logger: log.Default(), - } -} - -// --- Input/Output struct tests --- - -// TestMLGenerateInput_Good verifies all fields can be set. -func TestMLGenerateInput_Good(t *testing.T) { - input := MLGenerateInput{ - Prompt: "Hello world", - Backend: "test", - Model: "test-model", - Temperature: 0.7, - MaxTokens: 100, - } - if input.Prompt != "Hello world" { - t.Errorf("Expected prompt 'Hello world', got %q", input.Prompt) - } - if input.Temperature != 0.7 { - t.Errorf("Expected temperature 0.7, got %f", input.Temperature) - } - if input.MaxTokens != 100 { - t.Errorf("Expected max_tokens 100, got %d", input.MaxTokens) - } -} - -// TestMLScoreInput_Good verifies all fields can be set. -func TestMLScoreInput_Good(t *testing.T) { - input := MLScoreInput{ - Prompt: "test prompt", - Response: "test response", - Suites: "heuristic,semantic", - } - if input.Prompt != "test prompt" { - t.Errorf("Expected prompt 'test prompt', got %q", input.Prompt) - } - if input.Response != "test response" { - t.Errorf("Expected response 'test response', got %q", input.Response) - } -} - -// TestMLProbeInput_Good verifies all fields can be set. -func TestMLProbeInput_Good(t *testing.T) { - input := MLProbeInput{ - Backend: "test", - Categories: "reasoning,code", - } - if input.Backend != "test" { - t.Errorf("Expected backend 'test', got %q", input.Backend) - } -} - -// TestMLStatusInput_Good verifies all fields can be set. -func TestMLStatusInput_Good(t *testing.T) { - input := MLStatusInput{ - InfluxURL: "http://localhost:8086", - InfluxDB: "lem", - } - if input.InfluxURL != "http://localhost:8086" { - t.Errorf("Expected InfluxURL, got %q", input.InfluxURL) - } -} - -// TestMLBackendsInput_Good verifies empty struct. -func TestMLBackendsInput_Good(t *testing.T) { - _ = MLBackendsInput{} -} - -// TestMLBackendsOutput_Good verifies struct fields. -func TestMLBackendsOutput_Good(t *testing.T) { - output := MLBackendsOutput{ - Backends: []MLBackendInfo{ - {Name: "ollama", Available: true}, - {Name: "llama", Available: false}, - }, - Default: "ollama", - } - if len(output.Backends) != 2 { - t.Fatalf("Expected 2 backends, got %d", len(output.Backends)) - } - if output.Default != "ollama" { - t.Errorf("Expected default 'ollama', got %q", output.Default) - } - if !output.Backends[0].Available { - t.Error("Expected first backend to be available") - } -} - -// TestMLProbeOutput_Good verifies struct fields. -func TestMLProbeOutput_Good(t *testing.T) { - output := MLProbeOutput{ - Total: 2, - Results: []MLProbeResultItem{ - {ID: "probe-1", Category: "reasoning", Response: "test"}, - {ID: "probe-2", Category: "code", Response: "test2"}, - }, - } - if output.Total != 2 { - t.Errorf("Expected total 2, got %d", output.Total) - } - if output.Results[0].ID != "probe-1" { - t.Errorf("Expected ID 'probe-1', got %q", output.Results[0].ID) - } -} - -// TestMLStatusOutput_Good verifies struct fields. -func TestMLStatusOutput_Good(t *testing.T) { - output := MLStatusOutput{Status: "OK: 5 training runs"} - if output.Status != "OK: 5 training runs" { - t.Errorf("Unexpected status: %q", output.Status) - } -} - -// TestMLGenerateOutput_Good verifies struct fields. -func TestMLGenerateOutput_Good(t *testing.T) { - output := MLGenerateOutput{ - Response: "Generated text here", - Backend: "ollama", - Model: "qwen3:8b", - } - if output.Response != "Generated text here" { - t.Errorf("Unexpected response: %q", output.Response) - } -} - -// TestMLScoreOutput_Good verifies struct fields. -func TestMLScoreOutput_Good(t *testing.T) { - output := MLScoreOutput{ - Heuristic: &ml.HeuristicScores{}, - } - if output.Heuristic == nil { - t.Error("Expected Heuristic to be set") - } - if output.Semantic != nil { - t.Error("Expected Semantic to be nil") - } -} - -// --- Handler validation tests --- - -// TestMLGenerate_Bad_EmptyPrompt verifies empty prompt returns error. -func TestMLGenerate_Bad_EmptyPrompt(t *testing.T) { - m := newTestMLSubsystem(t) - ctx := context.Background() - - _, _, err := m.mlGenerate(ctx, nil, MLGenerateInput{}) - if err == nil { - t.Fatal("Expected error for empty prompt") - } - if !strings.Contains(err.Error(), "prompt cannot be empty") { - t.Errorf("Unexpected error: %v", err) - } -} - -// TestMLGenerate_Good_WithMockBackend verifies generate works with a mock backend. -func TestMLGenerate_Good_WithMockBackend(t *testing.T) { - mock := &mockMLBackend{ - name: "test-mock", - available: true, - generateResp: "mock response", - } - m := newTestMLSubsystem(t, mock) - ctx := context.Background() - - _, out, err := m.mlGenerate(ctx, nil, MLGenerateInput{ - Prompt: "test", - Backend: "test-mock", - }) - if err != nil { - t.Fatalf("mlGenerate failed: %v", err) - } - if out.Response != "mock response" { - t.Errorf("Expected 'mock response', got %q", out.Response) - } -} - -// TestMLGenerate_Bad_NoBackend verifies generate fails gracefully without a backend. -func TestMLGenerate_Bad_NoBackend(t *testing.T) { - m := newTestMLSubsystem(t) - ctx := context.Background() - - _, _, err := m.mlGenerate(ctx, nil, MLGenerateInput{ - Prompt: "test", - Backend: "nonexistent", - }) - if err == nil { - t.Fatal("Expected error for missing backend") - } - if !strings.Contains(err.Error(), "no backend available") { - t.Errorf("Unexpected error: %v", err) - } -} - -// TestMLScore_Bad_EmptyPrompt verifies empty prompt returns error. -func TestMLScore_Bad_EmptyPrompt(t *testing.T) { - m := newTestMLSubsystem(t) - ctx := context.Background() - - _, _, err := m.mlScore(ctx, nil, MLScoreInput{Response: "some"}) - if err == nil { - t.Fatal("Expected error for empty prompt") - } -} - -// TestMLScore_Bad_EmptyResponse verifies empty response returns error. -func TestMLScore_Bad_EmptyResponse(t *testing.T) { - m := newTestMLSubsystem(t) - ctx := context.Background() - - _, _, err := m.mlScore(ctx, nil, MLScoreInput{Prompt: "some"}) - if err == nil { - t.Fatal("Expected error for empty response") - } -} - -// TestMLScore_Good_Heuristic verifies heuristic scoring without live services. -func TestMLScore_Good_Heuristic(t *testing.T) { - m := newTestMLSubsystem(t) - ctx := context.Background() - - _, out, err := m.mlScore(ctx, nil, MLScoreInput{ - Prompt: "What is Go?", - Response: "Go is a statically typed, compiled programming language designed at Google.", - Suites: "heuristic", - }) - if err != nil { - t.Fatalf("mlScore failed: %v", err) - } - if out.Heuristic == nil { - t.Fatal("Expected heuristic scores to be set") - } -} - -// TestMLScore_Good_DefaultSuite verifies default suite is heuristic. -func TestMLScore_Good_DefaultSuite(t *testing.T) { - m := newTestMLSubsystem(t) - ctx := context.Background() - - _, out, err := m.mlScore(ctx, nil, MLScoreInput{ - Prompt: "What is Go?", - Response: "Go is a statically typed, compiled programming language designed at Google.", - }) - if err != nil { - t.Fatalf("mlScore failed: %v", err) - } - if out.Heuristic == nil { - t.Fatal("Expected heuristic scores (default suite)") - } -} - -// TestMLScore_Bad_SemanticNoJudge verifies semantic scoring fails without a judge. -func TestMLScore_Bad_SemanticNoJudge(t *testing.T) { - m := newTestMLSubsystem(t) - ctx := context.Background() - - _, _, err := m.mlScore(ctx, nil, MLScoreInput{ - Prompt: "test", - Response: "test", - Suites: "semantic", - }) - if err == nil { - t.Fatal("Expected error for semantic scoring without judge") - } - if !strings.Contains(err.Error(), "requires a judge") { - t.Errorf("Unexpected error: %v", err) - } -} - -// TestMLScore_Bad_ContentSuite verifies content suite redirects to ml_probe. -func TestMLScore_Bad_ContentSuite(t *testing.T) { - m := newTestMLSubsystem(t) - ctx := context.Background() - - _, _, err := m.mlScore(ctx, nil, MLScoreInput{ - Prompt: "test", - Response: "test", - Suites: "content", - }) - if err == nil { - t.Fatal("Expected error for content suite") - } - if !strings.Contains(err.Error(), "ContentProbe") { - t.Errorf("Unexpected error: %v", err) - } -} - -// TestMLProbe_Good_WithMockBackend verifies probes run with mock backend. -func TestMLProbe_Good_WithMockBackend(t *testing.T) { - mock := &mockMLBackend{ - name: "probe-mock", - available: true, - generateResp: "probe response", - } - m := newTestMLSubsystem(t, mock) - ctx := context.Background() - - _, out, err := m.mlProbe(ctx, nil, MLProbeInput{ - Backend: "probe-mock", - Categories: "reasoning", - }) - if err != nil { - t.Fatalf("mlProbe failed: %v", err) - } - // Should have run probes in the "reasoning" category - for _, r := range out.Results { - if r.Category != "reasoning" { - t.Errorf("Expected category 'reasoning', got %q", r.Category) - } - if r.Response != "probe response" { - t.Errorf("Expected 'probe response', got %q", r.Response) - } - } - if out.Total != len(out.Results) { - t.Errorf("Expected total %d, got %d", len(out.Results), out.Total) - } -} - -// TestMLProbe_Good_NoCategory verifies all probes run without category filter. -func TestMLProbe_Good_NoCategory(t *testing.T) { - mock := &mockMLBackend{ - name: "all-probe-mock", - available: true, - generateResp: "ok", - } - m := newTestMLSubsystem(t, mock) - ctx := context.Background() - - _, out, err := m.mlProbe(ctx, nil, MLProbeInput{Backend: "all-probe-mock"}) - if err != nil { - t.Fatalf("mlProbe failed: %v", err) - } - // Should run all 23 probes - if out.Total != len(ml.CapabilityProbes) { - t.Errorf("Expected %d probes, got %d", len(ml.CapabilityProbes), out.Total) - } -} - -// TestMLBackends_Good_EmptyRegistry verifies empty result when no backends registered. -func TestMLBackends_Good_EmptyRegistry(t *testing.T) { - m := newTestMLSubsystem(t) - ctx := context.Background() - - // Note: inference.List() returns global registry state. - // This test verifies the handler runs without panic. - _, out, err := m.mlBackends(ctx, nil, MLBackendsInput{}) - if err != nil { - t.Fatalf("mlBackends failed: %v", err) - } - // We can't guarantee what's in the global registry, but it should not panic - _ = out -} - -// TestMLBackends_Good_WithMockInferenceBackend verifies registered backend appears. -func TestMLBackends_Good_WithMockInferenceBackend(t *testing.T) { - // Register a mock backend in the global inference registry - mock := &mockInferenceBackend{name: "test-ci-mock", available: true} - inference.Register(mock) - - m := newTestMLSubsystem(t) - ctx := context.Background() - - _, out, err := m.mlBackends(ctx, nil, MLBackendsInput{}) - if err != nil { - t.Fatalf("mlBackends failed: %v", err) - } - - found := false - for _, b := range out.Backends { - if b.Name == "test-ci-mock" { - found = true - if !b.Available { - t.Error("Expected mock backend to be available") - } - } - } - if !found { - t.Error("Expected to find 'test-ci-mock' in backends list") - } -} - -// TestMLSubsystem_Good_Name verifies subsystem name. -func TestMLSubsystem_Good_Name(t *testing.T) { - m := newTestMLSubsystem(t) - if m.Name() != "ml" { - t.Errorf("Expected name 'ml', got %q", m.Name()) - } -} - -// TestNewMLSubsystem_Good verifies constructor. -func TestNewMLSubsystem_Good(t *testing.T) { - c, err := core.New( - core.WithName("ml", ml.NewService(ml.Options{})), - ) - if err != nil { - t.Fatalf("Failed to create core: %v", err) - } - svc, err := core.ServiceFor[*ml.Service](c, "ml") - if err != nil { - t.Fatalf("Failed to get service: %v", err) - } - sub := NewMLSubsystem(svc) - if sub == nil { - t.Fatal("Expected non-nil subsystem") - } - if sub.service != svc { - t.Error("Expected service to be set") - } - if sub.logger == nil { - t.Error("Expected logger to be set") - } -} diff --git a/mcp/tools_process.go b/mcp/tools_process.go deleted file mode 100644 index 9fab75b..0000000 --- a/mcp/tools_process.go +++ /dev/null @@ -1,305 +0,0 @@ -package mcp - -import ( - "context" - "errors" - "fmt" - "time" - - "forge.lthn.ai/core/go-log" - "forge.lthn.ai/core/go-process" - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// errIDEmpty is returned when a process tool call omits the required ID. -var errIDEmpty = errors.New("id cannot be empty") - -// ProcessStartInput contains parameters for starting a new process. -type ProcessStartInput struct { - Command string `json:"command"` // The command to run - Args []string `json:"args,omitempty"` // Command arguments - Dir string `json:"dir,omitempty"` // Working directory - Env []string `json:"env,omitempty"` // Environment variables (KEY=VALUE format) -} - -// ProcessStartOutput contains the result of starting a process. -type ProcessStartOutput struct { - ID string `json:"id"` - PID int `json:"pid"` - Command string `json:"command"` - Args []string `json:"args"` - StartedAt time.Time `json:"startedAt"` -} - -// ProcessStopInput contains parameters for gracefully stopping a process. -type ProcessStopInput struct { - ID string `json:"id"` // Process ID to stop -} - -// ProcessStopOutput contains the result of stopping a process. -type ProcessStopOutput struct { - ID string `json:"id"` - Success bool `json:"success"` - Message string `json:"message,omitempty"` -} - -// ProcessKillInput contains parameters for force killing a process. -type ProcessKillInput struct { - ID string `json:"id"` // Process ID to kill -} - -// ProcessKillOutput contains the result of killing a process. -type ProcessKillOutput struct { - ID string `json:"id"` - Success bool `json:"success"` - Message string `json:"message,omitempty"` -} - -// ProcessListInput contains parameters for listing processes. -type ProcessListInput struct { - RunningOnly bool `json:"running_only,omitempty"` // If true, only return running processes -} - -// ProcessListOutput contains the list of processes. -type ProcessListOutput struct { - Processes []ProcessInfo `json:"processes"` - Total int `json:"total"` -} - -// ProcessInfo represents information about a process. -type ProcessInfo struct { - ID string `json:"id"` - Command string `json:"command"` - Args []string `json:"args"` - Dir string `json:"dir"` - Status string `json:"status"` - PID int `json:"pid"` - ExitCode int `json:"exitCode"` - StartedAt time.Time `json:"startedAt"` - Duration time.Duration `json:"duration"` -} - -// ProcessOutputInput contains parameters for getting process output. -type ProcessOutputInput struct { - ID string `json:"id"` // Process ID -} - -// ProcessOutputOutput contains the captured output of a process. -type ProcessOutputOutput struct { - ID string `json:"id"` - Output string `json:"output"` -} - -// ProcessInputInput contains parameters for sending input to a process. -type ProcessInputInput struct { - ID string `json:"id"` // Process ID - Input string `json:"input"` // Input to send to stdin -} - -// ProcessInputOutput contains the result of sending input to a process. -type ProcessInputOutput struct { - ID string `json:"id"` - Success bool `json:"success"` - Message string `json:"message,omitempty"` -} - -// registerProcessTools adds process management tools to the MCP server. -// Returns false if process service is not available. -func (s *Service) registerProcessTools(server *mcp.Server) bool { - if s.processService == nil { - return false - } - - mcp.AddTool(server, &mcp.Tool{ - Name: "process_start", - Description: "Start a new external process. Returns process ID for tracking.", - }, s.processStart) - - mcp.AddTool(server, &mcp.Tool{ - Name: "process_stop", - Description: "Gracefully stop a running process by ID.", - }, s.processStop) - - mcp.AddTool(server, &mcp.Tool{ - Name: "process_kill", - Description: "Force kill a process by ID. Use when process_stop doesn't work.", - }, s.processKill) - - mcp.AddTool(server, &mcp.Tool{ - Name: "process_list", - Description: "List all managed processes. Use running_only=true for only active processes.", - }, s.processList) - - mcp.AddTool(server, &mcp.Tool{ - Name: "process_output", - Description: "Get the captured output of a process by ID.", - }, s.processOutput) - - mcp.AddTool(server, &mcp.Tool{ - Name: "process_input", - Description: "Send input to a running process stdin.", - }, s.processInput) - - return true -} - -// processStart handles the process_start tool call. -func (s *Service) processStart(ctx context.Context, req *mcp.CallToolRequest, input ProcessStartInput) (*mcp.CallToolResult, ProcessStartOutput, error) { - s.logger.Security("MCP tool execution", "tool", "process_start", "command", input.Command, "args", input.Args, "dir", input.Dir, "user", log.Username()) - - if input.Command == "" { - return nil, ProcessStartOutput{}, errors.New("command cannot be empty") - } - - opts := process.RunOptions{ - Command: input.Command, - Args: input.Args, - Dir: input.Dir, - Env: input.Env, - } - - proc, err := s.processService.StartWithOptions(ctx, opts) - if err != nil { - log.Error("mcp: process start failed", "command", input.Command, "err", err) - return nil, ProcessStartOutput{}, fmt.Errorf("failed to start process: %w", err) - } - - info := proc.Info() - return nil, ProcessStartOutput{ - ID: proc.ID, - PID: info.PID, - Command: proc.Command, - Args: proc.Args, - StartedAt: proc.StartedAt, - }, nil -} - -// processStop handles the process_stop tool call. -func (s *Service) processStop(ctx context.Context, req *mcp.CallToolRequest, input ProcessStopInput) (*mcp.CallToolResult, ProcessStopOutput, error) { - s.logger.Security("MCP tool execution", "tool", "process_stop", "id", input.ID, "user", log.Username()) - - if input.ID == "" { - return nil, ProcessStopOutput{}, errIDEmpty - } - - proc, err := s.processService.Get(input.ID) - if err != nil { - log.Error("mcp: process stop failed", "id", input.ID, "err", err) - return nil, ProcessStopOutput{}, fmt.Errorf("process not found: %w", err) - } - - // For graceful stop, we use Kill() which sends SIGKILL - // A more sophisticated implementation could use SIGTERM first - if err := proc.Kill(); err != nil { - log.Error("mcp: process stop kill failed", "id", input.ID, "err", err) - return nil, ProcessStopOutput{}, fmt.Errorf("failed to stop process: %w", err) - } - - return nil, ProcessStopOutput{ - ID: input.ID, - Success: true, - Message: "Process stop signal sent", - }, nil -} - -// processKill handles the process_kill tool call. -func (s *Service) processKill(ctx context.Context, req *mcp.CallToolRequest, input ProcessKillInput) (*mcp.CallToolResult, ProcessKillOutput, error) { - s.logger.Security("MCP tool execution", "tool", "process_kill", "id", input.ID, "user", log.Username()) - - if input.ID == "" { - return nil, ProcessKillOutput{}, errIDEmpty - } - - if err := s.processService.Kill(input.ID); err != nil { - log.Error("mcp: process kill failed", "id", input.ID, "err", err) - return nil, ProcessKillOutput{}, fmt.Errorf("failed to kill process: %w", err) - } - - return nil, ProcessKillOutput{ - ID: input.ID, - Success: true, - Message: "Process killed", - }, nil -} - -// processList handles the process_list tool call. -func (s *Service) processList(ctx context.Context, req *mcp.CallToolRequest, input ProcessListInput) (*mcp.CallToolResult, ProcessListOutput, error) { - s.logger.Info("MCP tool execution", "tool", "process_list", "running_only", input.RunningOnly, "user", log.Username()) - - var procs []*process.Process - if input.RunningOnly { - procs = s.processService.Running() - } else { - procs = s.processService.List() - } - - result := make([]ProcessInfo, len(procs)) - for i, p := range procs { - info := p.Info() - result[i] = ProcessInfo{ - ID: info.ID, - Command: info.Command, - Args: info.Args, - Dir: info.Dir, - Status: string(info.Status), - PID: info.PID, - ExitCode: info.ExitCode, - StartedAt: info.StartedAt, - Duration: info.Duration, - } - } - - return nil, ProcessListOutput{ - Processes: result, - Total: len(result), - }, nil -} - -// processOutput handles the process_output tool call. -func (s *Service) processOutput(ctx context.Context, req *mcp.CallToolRequest, input ProcessOutputInput) (*mcp.CallToolResult, ProcessOutputOutput, error) { - s.logger.Info("MCP tool execution", "tool", "process_output", "id", input.ID, "user", log.Username()) - - if input.ID == "" { - return nil, ProcessOutputOutput{}, errIDEmpty - } - - output, err := s.processService.Output(input.ID) - if err != nil { - log.Error("mcp: process output failed", "id", input.ID, "err", err) - return nil, ProcessOutputOutput{}, fmt.Errorf("failed to get process output: %w", err) - } - - return nil, ProcessOutputOutput{ - ID: input.ID, - Output: output, - }, nil -} - -// processInput handles the process_input tool call. -func (s *Service) processInput(ctx context.Context, req *mcp.CallToolRequest, input ProcessInputInput) (*mcp.CallToolResult, ProcessInputOutput, error) { - s.logger.Security("MCP tool execution", "tool", "process_input", "id", input.ID, "user", log.Username()) - - if input.ID == "" { - return nil, ProcessInputOutput{}, errIDEmpty - } - if input.Input == "" { - return nil, ProcessInputOutput{}, errors.New("input cannot be empty") - } - - proc, err := s.processService.Get(input.ID) - if err != nil { - log.Error("mcp: process input get failed", "id", input.ID, "err", err) - return nil, ProcessInputOutput{}, fmt.Errorf("process not found: %w", err) - } - - if err := proc.SendInput(input.Input); err != nil { - log.Error("mcp: process input send failed", "id", input.ID, "err", err) - return nil, ProcessInputOutput{}, fmt.Errorf("failed to send input: %w", err) - } - - return nil, ProcessInputOutput{ - ID: input.ID, - Success: true, - Message: "Input sent successfully", - }, nil -} diff --git a/mcp/tools_process_ci_test.go b/mcp/tools_process_ci_test.go deleted file mode 100644 index d8ea037..0000000 --- a/mcp/tools_process_ci_test.go +++ /dev/null @@ -1,515 +0,0 @@ -package mcp - -import ( - "context" - "strings" - "testing" - "time" - - "forge.lthn.ai/core/go/pkg/core" - "forge.lthn.ai/core/go-process" -) - -// newTestProcessService creates a real process.Service backed by a core.Core for CI tests. -func newTestProcessService(t *testing.T) *process.Service { - t.Helper() - c, err := core.New( - core.WithName("process", process.NewService(process.Options{})), - ) - if err != nil { - t.Fatalf("Failed to create framework core: %v", err) - } - svc, err := core.ServiceFor[*process.Service](c, "process") - if err != nil { - t.Fatalf("Failed to get process service: %v", err) - } - // Start services (calls OnStartup) - if err := c.ServiceStartup(context.Background(), nil); err != nil { - t.Fatalf("Failed to start core: %v", err) - } - t.Cleanup(func() { - _ = c.ServiceShutdown(context.Background()) - }) - return svc -} - -// newTestMCPWithProcess creates an MCP Service wired to a real process.Service. -func newTestMCPWithProcess(t *testing.T) (*Service, *process.Service) { - t.Helper() - ps := newTestProcessService(t) - s, err := New(WithProcessService(ps)) - if err != nil { - t.Fatalf("Failed to create MCP service: %v", err) - } - return s, ps -} - -// --- CI-safe handler tests --- - -// TestProcessStart_Good_Echo starts "echo hello" and verifies the output. -func TestProcessStart_Good_Echo(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - ctx := context.Background() - - _, out, err := s.processStart(ctx, nil, ProcessStartInput{ - Command: "echo", - Args: []string{"hello"}, - }) - if err != nil { - t.Fatalf("processStart failed: %v", err) - } - if out.ID == "" { - t.Error("Expected non-empty process ID") - } - if out.Command != "echo" { - t.Errorf("Expected command 'echo', got %q", out.Command) - } - if out.PID <= 0 { - t.Errorf("Expected positive PID, got %d", out.PID) - } - if out.StartedAt.IsZero() { - t.Error("Expected non-zero StartedAt") - } -} - -// TestProcessStart_Bad_EmptyCommand verifies empty command returns an error. -func TestProcessStart_Bad_EmptyCommand(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - ctx := context.Background() - - _, _, err := s.processStart(ctx, nil, ProcessStartInput{}) - if err == nil { - t.Fatal("Expected error for empty command") - } - if !strings.Contains(err.Error(), "command cannot be empty") { - t.Errorf("Unexpected error: %v", err) - } -} - -// TestProcessStart_Bad_NonexistentCommand verifies an invalid command returns an error. -func TestProcessStart_Bad_NonexistentCommand(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - ctx := context.Background() - - _, _, err := s.processStart(ctx, nil, ProcessStartInput{ - Command: "/nonexistent/binary/that/does/not/exist", - }) - if err == nil { - t.Fatal("Expected error for nonexistent command") - } -} - -// TestProcessList_Good_Empty verifies list is empty initially. -func TestProcessList_Good_Empty(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - ctx := context.Background() - - _, out, err := s.processList(ctx, nil, ProcessListInput{}) - if err != nil { - t.Fatalf("processList failed: %v", err) - } - if out.Total != 0 { - t.Errorf("Expected 0 processes, got %d", out.Total) - } -} - -// TestProcessList_Good_AfterStart verifies a started process appears in list. -func TestProcessList_Good_AfterStart(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - ctx := context.Background() - - // Start a short-lived process - _, startOut, err := s.processStart(ctx, nil, ProcessStartInput{ - Command: "echo", - Args: []string{"listing"}, - }) - if err != nil { - t.Fatalf("processStart failed: %v", err) - } - - // Give it a moment to register - time.Sleep(50 * time.Millisecond) - - // List all processes (including exited) - _, listOut, err := s.processList(ctx, nil, ProcessListInput{}) - if err != nil { - t.Fatalf("processList failed: %v", err) - } - if listOut.Total < 1 { - t.Fatalf("Expected at least 1 process, got %d", listOut.Total) - } - - found := false - for _, p := range listOut.Processes { - if p.ID == startOut.ID { - found = true - if p.Command != "echo" { - t.Errorf("Expected command 'echo', got %q", p.Command) - } - } - } - if !found { - t.Errorf("Process %s not found in list", startOut.ID) - } -} - -// TestProcessList_Good_RunningOnly verifies filtering for running-only processes. -func TestProcessList_Good_RunningOnly(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - ctx := context.Background() - - // Start a process that exits quickly - _, _, err := s.processStart(ctx, nil, ProcessStartInput{ - Command: "echo", - Args: []string{"done"}, - }) - if err != nil { - t.Fatalf("processStart failed: %v", err) - } - - // Wait for it to exit - time.Sleep(100 * time.Millisecond) - - // Running-only should be empty now - _, listOut, err := s.processList(ctx, nil, ProcessListInput{RunningOnly: true}) - if err != nil { - t.Fatalf("processList failed: %v", err) - } - if listOut.Total != 0 { - t.Errorf("Expected 0 running processes after echo exits, got %d", listOut.Total) - } -} - -// TestProcessOutput_Good_Echo verifies output capture from echo. -func TestProcessOutput_Good_Echo(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - ctx := context.Background() - - _, startOut, err := s.processStart(ctx, nil, ProcessStartInput{ - Command: "echo", - Args: []string{"output_test"}, - }) - if err != nil { - t.Fatalf("processStart failed: %v", err) - } - - // Wait for process to complete and output to be captured - time.Sleep(200 * time.Millisecond) - - _, outputOut, err := s.processOutput(ctx, nil, ProcessOutputInput{ID: startOut.ID}) - if err != nil { - t.Fatalf("processOutput failed: %v", err) - } - if !strings.Contains(outputOut.Output, "output_test") { - t.Errorf("Expected output to contain 'output_test', got %q", outputOut.Output) - } -} - -// TestProcessOutput_Bad_EmptyID verifies empty ID returns error. -func TestProcessOutput_Bad_EmptyID(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - ctx := context.Background() - - _, _, err := s.processOutput(ctx, nil, ProcessOutputInput{}) - if err == nil { - t.Fatal("Expected error for empty ID") - } - if !strings.Contains(err.Error(), "id cannot be empty") { - t.Errorf("Unexpected error: %v", err) - } -} - -// TestProcessOutput_Bad_NotFound verifies nonexistent ID returns error. -func TestProcessOutput_Bad_NotFound(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - ctx := context.Background() - - _, _, err := s.processOutput(ctx, nil, ProcessOutputInput{ID: "nonexistent-id"}) - if err == nil { - t.Fatal("Expected error for nonexistent ID") - } -} - -// TestProcessStop_Good_LongRunning starts a sleep, stops it, and verifies. -func TestProcessStop_Good_LongRunning(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - ctx := context.Background() - - // Start a process that sleeps for a while - _, startOut, err := s.processStart(ctx, nil, ProcessStartInput{ - Command: "sleep", - Args: []string{"10"}, - }) - if err != nil { - t.Fatalf("processStart failed: %v", err) - } - - // Verify it's running - time.Sleep(50 * time.Millisecond) - _, listOut, _ := s.processList(ctx, nil, ProcessListInput{RunningOnly: true}) - if listOut.Total < 1 { - t.Fatal("Expected at least 1 running process") - } - - // Stop it - _, stopOut, err := s.processStop(ctx, nil, ProcessStopInput{ID: startOut.ID}) - if err != nil { - t.Fatalf("processStop failed: %v", err) - } - if !stopOut.Success { - t.Error("Expected stop to succeed") - } - if stopOut.ID != startOut.ID { - t.Errorf("Expected ID %q, got %q", startOut.ID, stopOut.ID) - } -} - -// TestProcessStop_Bad_EmptyID verifies empty ID returns error. -func TestProcessStop_Bad_EmptyID(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - ctx := context.Background() - - _, _, err := s.processStop(ctx, nil, ProcessStopInput{}) - if err == nil { - t.Fatal("Expected error for empty ID") - } -} - -// TestProcessStop_Bad_NotFound verifies nonexistent ID returns error. -func TestProcessStop_Bad_NotFound(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - ctx := context.Background() - - _, _, err := s.processStop(ctx, nil, ProcessStopInput{ID: "nonexistent-id"}) - if err == nil { - t.Fatal("Expected error for nonexistent ID") - } -} - -// TestProcessKill_Good_LongRunning starts a sleep, kills it, and verifies. -func TestProcessKill_Good_LongRunning(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - ctx := context.Background() - - _, startOut, err := s.processStart(ctx, nil, ProcessStartInput{ - Command: "sleep", - Args: []string{"10"}, - }) - if err != nil { - t.Fatalf("processStart failed: %v", err) - } - - time.Sleep(50 * time.Millisecond) - - _, killOut, err := s.processKill(ctx, nil, ProcessKillInput{ID: startOut.ID}) - if err != nil { - t.Fatalf("processKill failed: %v", err) - } - if !killOut.Success { - t.Error("Expected kill to succeed") - } - if killOut.Message != "Process killed" { - t.Errorf("Expected message 'Process killed', got %q", killOut.Message) - } -} - -// TestProcessKill_Bad_EmptyID verifies empty ID returns error. -func TestProcessKill_Bad_EmptyID(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - ctx := context.Background() - - _, _, err := s.processKill(ctx, nil, ProcessKillInput{}) - if err == nil { - t.Fatal("Expected error for empty ID") - } -} - -// TestProcessKill_Bad_NotFound verifies nonexistent ID returns error. -func TestProcessKill_Bad_NotFound(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - ctx := context.Background() - - _, _, err := s.processKill(ctx, nil, ProcessKillInput{ID: "nonexistent-id"}) - if err == nil { - t.Fatal("Expected error for nonexistent ID") - } -} - -// TestProcessInput_Bad_EmptyID verifies empty ID returns error. -func TestProcessInput_Bad_EmptyID(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - ctx := context.Background() - - _, _, err := s.processInput(ctx, nil, ProcessInputInput{}) - if err == nil { - t.Fatal("Expected error for empty ID") - } -} - -// TestProcessInput_Bad_EmptyInput verifies empty input string returns error. -func TestProcessInput_Bad_EmptyInput(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - ctx := context.Background() - - _, _, err := s.processInput(ctx, nil, ProcessInputInput{ID: "some-id"}) - if err == nil { - t.Fatal("Expected error for empty input") - } -} - -// TestProcessInput_Bad_NotFound verifies nonexistent process ID returns error. -func TestProcessInput_Bad_NotFound(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - ctx := context.Background() - - _, _, err := s.processInput(ctx, nil, ProcessInputInput{ - ID: "nonexistent-id", - Input: "hello\n", - }) - if err == nil { - t.Fatal("Expected error for nonexistent ID") - } -} - -// TestProcessInput_Good_Cat sends input to cat and reads it back. -func TestProcessInput_Good_Cat(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - ctx := context.Background() - - // Start cat which reads stdin and echoes to stdout - _, startOut, err := s.processStart(ctx, nil, ProcessStartInput{ - Command: "cat", - }) - if err != nil { - t.Fatalf("processStart failed: %v", err) - } - - time.Sleep(50 * time.Millisecond) - - // Send input - _, inputOut, err := s.processInput(ctx, nil, ProcessInputInput{ - ID: startOut.ID, - Input: "stdin_test\n", - }) - if err != nil { - t.Fatalf("processInput failed: %v", err) - } - if !inputOut.Success { - t.Error("Expected input to succeed") - } - - // Wait for output capture - time.Sleep(100 * time.Millisecond) - - // Read output - _, outputOut, err := s.processOutput(ctx, nil, ProcessOutputInput{ID: startOut.ID}) - if err != nil { - t.Fatalf("processOutput failed: %v", err) - } - if !strings.Contains(outputOut.Output, "stdin_test") { - t.Errorf("Expected output to contain 'stdin_test', got %q", outputOut.Output) - } - - // Kill the cat process (it's still running) - _, _, _ = s.processKill(ctx, nil, ProcessKillInput{ID: startOut.ID}) -} - -// TestProcessStart_Good_WithDir verifies working directory is respected. -func TestProcessStart_Good_WithDir(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - ctx := context.Background() - dir := t.TempDir() - - _, startOut, err := s.processStart(ctx, nil, ProcessStartInput{ - Command: "pwd", - Dir: dir, - }) - if err != nil { - t.Fatalf("processStart failed: %v", err) - } - - time.Sleep(200 * time.Millisecond) - - _, outputOut, err := s.processOutput(ctx, nil, ProcessOutputInput{ID: startOut.ID}) - if err != nil { - t.Fatalf("processOutput failed: %v", err) - } - if !strings.Contains(outputOut.Output, dir) { - t.Errorf("Expected output to contain dir %q, got %q", dir, outputOut.Output) - } -} - -// TestProcessStart_Good_WithEnv verifies environment variables are passed. -func TestProcessStart_Good_WithEnv(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - ctx := context.Background() - - _, startOut, err := s.processStart(ctx, nil, ProcessStartInput{ - Command: "env", - Env: []string{"TEST_MCP_VAR=hello_from_test"}, - }) - if err != nil { - t.Fatalf("processStart failed: %v", err) - } - - time.Sleep(200 * time.Millisecond) - - _, outputOut, err := s.processOutput(ctx, nil, ProcessOutputInput{ID: startOut.ID}) - if err != nil { - t.Fatalf("processOutput failed: %v", err) - } - if !strings.Contains(outputOut.Output, "TEST_MCP_VAR=hello_from_test") { - t.Errorf("Expected output to contain env var, got %q", outputOut.Output) - } -} - -// TestProcessToolsRegistered_Good_WithService verifies tools are registered when service is provided. -func TestProcessToolsRegistered_Good_WithService(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - if s.processService == nil { - t.Error("Expected process service to be set") - } -} - -// TestProcessFullLifecycle_Good tests the start → list → output → kill → list cycle. -func TestProcessFullLifecycle_Good(t *testing.T) { - s, _ := newTestMCPWithProcess(t) - ctx := context.Background() - - // 1. Start - _, startOut, err := s.processStart(ctx, nil, ProcessStartInput{ - Command: "sleep", - Args: []string{"10"}, - }) - if err != nil { - t.Fatalf("processStart failed: %v", err) - } - - time.Sleep(50 * time.Millisecond) - - // 2. List (should be running) - _, listOut, _ := s.processList(ctx, nil, ProcessListInput{RunningOnly: true}) - if listOut.Total < 1 { - t.Fatal("Expected at least 1 running process") - } - - // 3. Kill - _, killOut, err := s.processKill(ctx, nil, ProcessKillInput{ID: startOut.ID}) - if err != nil { - t.Fatalf("processKill failed: %v", err) - } - if !killOut.Success { - t.Error("Expected kill to succeed") - } - - // 4. Wait for exit - time.Sleep(100 * time.Millisecond) - - // 5. Should not be running anymore - _, listOut, _ = s.processList(ctx, nil, ProcessListInput{RunningOnly: true}) - for _, p := range listOut.Processes { - if p.ID == startOut.ID { - t.Errorf("Process %s should not be running after kill", startOut.ID) - } - } -} diff --git a/mcp/tools_process_test.go b/mcp/tools_process_test.go deleted file mode 100644 index 724e2e4..0000000 --- a/mcp/tools_process_test.go +++ /dev/null @@ -1,290 +0,0 @@ -package mcp - -import ( - "testing" - "time" -) - -// TestProcessToolsRegistered_Good verifies that process tools are registered when process service is available. -func TestProcessToolsRegistered_Good(t *testing.T) { - // Create a new MCP service without process service - tools should not be registered - s, err := New() - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - if s.processService != nil { - t.Error("Process service should be nil by default") - } - - if s.server == nil { - t.Fatal("Server should not be nil") - } -} - -// TestProcessStartInput_Good verifies the ProcessStartInput struct has expected fields. -func TestProcessStartInput_Good(t *testing.T) { - input := ProcessStartInput{ - Command: "echo", - Args: []string{"hello", "world"}, - Dir: "/tmp", - Env: []string{"FOO=bar"}, - } - - if input.Command != "echo" { - t.Errorf("Expected command 'echo', got %q", input.Command) - } - if len(input.Args) != 2 { - t.Errorf("Expected 2 args, got %d", len(input.Args)) - } - if input.Dir != "/tmp" { - t.Errorf("Expected dir '/tmp', got %q", input.Dir) - } - if len(input.Env) != 1 { - t.Errorf("Expected 1 env var, got %d", len(input.Env)) - } -} - -// TestProcessStartOutput_Good verifies the ProcessStartOutput struct has expected fields. -func TestProcessStartOutput_Good(t *testing.T) { - now := time.Now() - output := ProcessStartOutput{ - ID: "proc-1", - PID: 12345, - Command: "echo", - Args: []string{"hello"}, - StartedAt: now, - } - - if output.ID != "proc-1" { - t.Errorf("Expected ID 'proc-1', got %q", output.ID) - } - if output.PID != 12345 { - t.Errorf("Expected PID 12345, got %d", output.PID) - } - if output.Command != "echo" { - t.Errorf("Expected command 'echo', got %q", output.Command) - } - if !output.StartedAt.Equal(now) { - t.Errorf("Expected StartedAt %v, got %v", now, output.StartedAt) - } -} - -// TestProcessStopInput_Good verifies the ProcessStopInput struct has expected fields. -func TestProcessStopInput_Good(t *testing.T) { - input := ProcessStopInput{ - ID: "proc-1", - } - - if input.ID != "proc-1" { - t.Errorf("Expected ID 'proc-1', got %q", input.ID) - } -} - -// TestProcessStopOutput_Good verifies the ProcessStopOutput struct has expected fields. -func TestProcessStopOutput_Good(t *testing.T) { - output := ProcessStopOutput{ - ID: "proc-1", - Success: true, - Message: "Process stopped", - } - - if output.ID != "proc-1" { - t.Errorf("Expected ID 'proc-1', got %q", output.ID) - } - if !output.Success { - t.Error("Expected Success to be true") - } - if output.Message != "Process stopped" { - t.Errorf("Expected message 'Process stopped', got %q", output.Message) - } -} - -// TestProcessKillInput_Good verifies the ProcessKillInput struct has expected fields. -func TestProcessKillInput_Good(t *testing.T) { - input := ProcessKillInput{ - ID: "proc-1", - } - - if input.ID != "proc-1" { - t.Errorf("Expected ID 'proc-1', got %q", input.ID) - } -} - -// TestProcessKillOutput_Good verifies the ProcessKillOutput struct has expected fields. -func TestProcessKillOutput_Good(t *testing.T) { - output := ProcessKillOutput{ - ID: "proc-1", - Success: true, - Message: "Process killed", - } - - if output.ID != "proc-1" { - t.Errorf("Expected ID 'proc-1', got %q", output.ID) - } - if !output.Success { - t.Error("Expected Success to be true") - } -} - -// TestProcessListInput_Good verifies the ProcessListInput struct has expected fields. -func TestProcessListInput_Good(t *testing.T) { - input := ProcessListInput{ - RunningOnly: true, - } - - if !input.RunningOnly { - t.Error("Expected RunningOnly to be true") - } -} - -// TestProcessListInput_Defaults verifies default values. -func TestProcessListInput_Defaults(t *testing.T) { - input := ProcessListInput{} - - if input.RunningOnly { - t.Error("Expected RunningOnly to default to false") - } -} - -// TestProcessListOutput_Good verifies the ProcessListOutput struct has expected fields. -func TestProcessListOutput_Good(t *testing.T) { - now := time.Now() - output := ProcessListOutput{ - Processes: []ProcessInfo{ - { - ID: "proc-1", - Command: "echo", - Args: []string{"hello"}, - Dir: "/tmp", - Status: "running", - PID: 12345, - ExitCode: 0, - StartedAt: now, - Duration: 5 * time.Second, - }, - }, - Total: 1, - } - - if len(output.Processes) != 1 { - t.Fatalf("Expected 1 process, got %d", len(output.Processes)) - } - if output.Total != 1 { - t.Errorf("Expected total 1, got %d", output.Total) - } - - proc := output.Processes[0] - if proc.ID != "proc-1" { - t.Errorf("Expected ID 'proc-1', got %q", proc.ID) - } - if proc.Status != "running" { - t.Errorf("Expected status 'running', got %q", proc.Status) - } - if proc.PID != 12345 { - t.Errorf("Expected PID 12345, got %d", proc.PID) - } -} - -// TestProcessOutputInput_Good verifies the ProcessOutputInput struct has expected fields. -func TestProcessOutputInput_Good(t *testing.T) { - input := ProcessOutputInput{ - ID: "proc-1", - } - - if input.ID != "proc-1" { - t.Errorf("Expected ID 'proc-1', got %q", input.ID) - } -} - -// TestProcessOutputOutput_Good verifies the ProcessOutputOutput struct has expected fields. -func TestProcessOutputOutput_Good(t *testing.T) { - output := ProcessOutputOutput{ - ID: "proc-1", - Output: "hello world\n", - } - - if output.ID != "proc-1" { - t.Errorf("Expected ID 'proc-1', got %q", output.ID) - } - if output.Output != "hello world\n" { - t.Errorf("Expected output 'hello world\\n', got %q", output.Output) - } -} - -// TestProcessInputInput_Good verifies the ProcessInputInput struct has expected fields. -func TestProcessInputInput_Good(t *testing.T) { - input := ProcessInputInput{ - ID: "proc-1", - Input: "test input\n", - } - - if input.ID != "proc-1" { - t.Errorf("Expected ID 'proc-1', got %q", input.ID) - } - if input.Input != "test input\n" { - t.Errorf("Expected input 'test input\\n', got %q", input.Input) - } -} - -// TestProcessInputOutput_Good verifies the ProcessInputOutput struct has expected fields. -func TestProcessInputOutput_Good(t *testing.T) { - output := ProcessInputOutput{ - ID: "proc-1", - Success: true, - Message: "Input sent", - } - - if output.ID != "proc-1" { - t.Errorf("Expected ID 'proc-1', got %q", output.ID) - } - if !output.Success { - t.Error("Expected Success to be true") - } -} - -// TestProcessInfo_Good verifies the ProcessInfo struct has expected fields. -func TestProcessInfo_Good(t *testing.T) { - now := time.Now() - info := ProcessInfo{ - ID: "proc-1", - Command: "echo", - Args: []string{"hello"}, - Dir: "/tmp", - Status: "exited", - PID: 12345, - ExitCode: 0, - StartedAt: now, - Duration: 2 * time.Second, - } - - if info.ID != "proc-1" { - t.Errorf("Expected ID 'proc-1', got %q", info.ID) - } - if info.Command != "echo" { - t.Errorf("Expected command 'echo', got %q", info.Command) - } - if info.Status != "exited" { - t.Errorf("Expected status 'exited', got %q", info.Status) - } - if info.ExitCode != 0 { - t.Errorf("Expected exit code 0, got %d", info.ExitCode) - } - if info.Duration != 2*time.Second { - t.Errorf("Expected duration 2s, got %v", info.Duration) - } -} - -// TestWithProcessService_Good verifies the WithProcessService option. -func TestWithProcessService_Good(t *testing.T) { - // Note: We can't easily create a real process.Service here without Core, - // so we just verify the option doesn't panic with nil. - s, err := New(WithProcessService(nil)) - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - if s.processService != nil { - t.Error("Expected processService to be nil when passed nil") - } -} diff --git a/mcp/tools_rag.go b/mcp/tools_rag.go deleted file mode 100644 index 89499f1..0000000 --- a/mcp/tools_rag.go +++ /dev/null @@ -1,233 +0,0 @@ -package mcp - -import ( - "context" - "errors" - "fmt" - - "forge.lthn.ai/core/go-rag" - "forge.lthn.ai/core/go-log" - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// Default values for RAG operations. -const ( - DefaultRAGCollection = "hostuk-docs" - DefaultRAGTopK = 5 -) - -// RAGQueryInput contains parameters for querying the RAG vector database. -type RAGQueryInput struct { - Question string `json:"question"` // The question or search query - Collection string `json:"collection,omitempty"` // Collection name (default: hostuk-docs) - TopK int `json:"topK,omitempty"` // Number of results to return (default: 5) -} - -// RAGQueryResult represents a single query result. -type RAGQueryResult struct { - Content string `json:"content"` - Source string `json:"source"` - Section string `json:"section,omitempty"` - Category string `json:"category,omitempty"` - ChunkIndex int `json:"chunkIndex,omitempty"` - Score float32 `json:"score"` -} - -// RAGQueryOutput contains the results of a RAG query. -type RAGQueryOutput struct { - Results []RAGQueryResult `json:"results"` - Query string `json:"query"` - Collection string `json:"collection"` - Context string `json:"context"` -} - -// RAGIngestInput contains parameters for ingesting documents into the RAG database. -type RAGIngestInput struct { - Path string `json:"path"` // File or directory path to ingest - Collection string `json:"collection,omitempty"` // Collection name (default: hostuk-docs) - Recreate bool `json:"recreate,omitempty"` // Whether to recreate the collection -} - -// RAGIngestOutput contains the result of a RAG ingest operation. -type RAGIngestOutput struct { - Success bool `json:"success"` - Path string `json:"path"` - Collection string `json:"collection"` - Chunks int `json:"chunks"` - Message string `json:"message,omitempty"` -} - -// RAGCollectionsInput contains parameters for listing collections. -type RAGCollectionsInput struct { - ShowStats bool `json:"show_stats,omitempty"` // Include collection stats (point count, status) -} - -// CollectionInfo contains information about a collection. -type CollectionInfo struct { - Name string `json:"name"` - PointsCount uint64 `json:"points_count"` - Status string `json:"status"` -} - -// RAGCollectionsOutput contains the list of available collections. -type RAGCollectionsOutput struct { - Collections []CollectionInfo `json:"collections"` -} - -// registerRAGTools adds RAG tools to the MCP server. -func (s *Service) registerRAGTools(server *mcp.Server) { - mcp.AddTool(server, &mcp.Tool{ - Name: "rag_query", - Description: "Query the RAG vector database for relevant documentation. Returns semantically similar content based on the query.", - }, s.ragQuery) - - mcp.AddTool(server, &mcp.Tool{ - Name: "rag_ingest", - Description: "Ingest documents into the RAG vector database. Supports both single files and directories.", - }, s.ragIngest) - - mcp.AddTool(server, &mcp.Tool{ - Name: "rag_collections", - Description: "List all available collections in the RAG vector database.", - }, s.ragCollections) -} - -// ragQuery handles the rag_query tool call. -func (s *Service) ragQuery(ctx context.Context, req *mcp.CallToolRequest, input RAGQueryInput) (*mcp.CallToolResult, RAGQueryOutput, error) { - // Apply defaults - collection := input.Collection - if collection == "" { - collection = DefaultRAGCollection - } - topK := input.TopK - if topK <= 0 { - topK = DefaultRAGTopK - } - - s.logger.Info("MCP tool execution", "tool", "rag_query", "question", input.Question, "collection", collection, "topK", topK, "user", log.Username()) - - // Validate input - if input.Question == "" { - return nil, RAGQueryOutput{}, errors.New("question cannot be empty") - } - - // Call the RAG query function - results, err := rag.QueryDocs(ctx, input.Question, collection, topK) - if err != nil { - log.Error("mcp: rag query failed", "question", input.Question, "collection", collection, "err", err) - return nil, RAGQueryOutput{}, fmt.Errorf("failed to query RAG: %w", err) - } - - // Convert results - output := RAGQueryOutput{ - Results: make([]RAGQueryResult, len(results)), - Query: input.Question, - Collection: collection, - Context: rag.FormatResultsContext(results), - } - for i, r := range results { - output.Results[i] = RAGQueryResult{ - Content: r.Text, - Source: r.Source, - Section: r.Section, - Category: r.Category, - ChunkIndex: r.ChunkIndex, - Score: r.Score, - } - } - - return nil, output, nil -} - -// ragIngest handles the rag_ingest tool call. -func (s *Service) ragIngest(ctx context.Context, req *mcp.CallToolRequest, input RAGIngestInput) (*mcp.CallToolResult, RAGIngestOutput, error) { - // Apply defaults - collection := input.Collection - if collection == "" { - collection = DefaultRAGCollection - } - - s.logger.Security("MCP tool execution", "tool", "rag_ingest", "path", input.Path, "collection", collection, "recreate", input.Recreate, "user", log.Username()) - - // Validate input - if input.Path == "" { - return nil, RAGIngestOutput{}, errors.New("path cannot be empty") - } - - // Check if path is a file or directory using the medium - info, err := s.medium.Stat(input.Path) - if err != nil { - log.Error("mcp: rag ingest stat failed", "path", input.Path, "err", err) - return nil, RAGIngestOutput{}, fmt.Errorf("failed to access path: %w", err) - } - - var message string - var chunks int - if info.IsDir() { - // Ingest directory - err = rag.IngestDirectory(ctx, input.Path, collection, input.Recreate) - if err != nil { - log.Error("mcp: rag ingest directory failed", "path", input.Path, "collection", collection, "err", err) - return nil, RAGIngestOutput{}, fmt.Errorf("failed to ingest directory: %w", err) - } - message = fmt.Sprintf("Successfully ingested directory %s into collection %s", input.Path, collection) - } else { - // Ingest single file - chunks, err = rag.IngestSingleFile(ctx, input.Path, collection) - if err != nil { - log.Error("mcp: rag ingest file failed", "path", input.Path, "collection", collection, "err", err) - return nil, RAGIngestOutput{}, fmt.Errorf("failed to ingest file: %w", err) - } - message = fmt.Sprintf("Successfully ingested file %s (%d chunks) into collection %s", input.Path, chunks, collection) - } - - return nil, RAGIngestOutput{ - Success: true, - Path: input.Path, - Collection: collection, - Chunks: chunks, - Message: message, - }, nil -} - -// ragCollections handles the rag_collections tool call. -func (s *Service) ragCollections(ctx context.Context, req *mcp.CallToolRequest, input RAGCollectionsInput) (*mcp.CallToolResult, RAGCollectionsOutput, error) { - s.logger.Info("MCP tool execution", "tool", "rag_collections", "show_stats", input.ShowStats, "user", log.Username()) - - // Create Qdrant client with default config - qdrantClient, err := rag.NewQdrantClient(rag.DefaultQdrantConfig()) - if err != nil { - log.Error("mcp: rag collections connect failed", "err", err) - return nil, RAGCollectionsOutput{}, fmt.Errorf("failed to connect to Qdrant: %w", err) - } - defer func() { _ = qdrantClient.Close() }() - - // List collections - collectionNames, err := qdrantClient.ListCollections(ctx) - if err != nil { - log.Error("mcp: rag collections list failed", "err", err) - return nil, RAGCollectionsOutput{}, fmt.Errorf("failed to list collections: %w", err) - } - - // Build collection info list - collections := make([]CollectionInfo, len(collectionNames)) - for i, name := range collectionNames { - collections[i] = CollectionInfo{Name: name} - - // Fetch stats if requested - if input.ShowStats { - info, err := qdrantClient.CollectionInfo(ctx, name) - if err != nil { - log.Error("mcp: rag collection info failed", "collection", name, "err", err) - // Continue with defaults on error - continue - } - collections[i].PointsCount = info.PointCount - collections[i].Status = info.Status - } - } - - return nil, RAGCollectionsOutput{ - Collections: collections, - }, nil -} diff --git a/mcp/tools_rag_ci_test.go b/mcp/tools_rag_ci_test.go deleted file mode 100644 index fb7d853..0000000 --- a/mcp/tools_rag_ci_test.go +++ /dev/null @@ -1,181 +0,0 @@ -package mcp - -import ( - "context" - "strings" - "testing" -) - -// RAG tools use package-level functions (rag.QueryDocs, rag.IngestDirectory, etc.) -// which require live Qdrant + Ollama services. Since those are not injectable, -// we test handler input validation, default application, and struct behaviour -// at the MCP handler level without requiring live services. - -// --- ragQuery validation --- - -// TestRagQuery_Bad_EmptyQuestion verifies empty question returns error. -func TestRagQuery_Bad_EmptyQuestion(t *testing.T) { - s, err := New() - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - ctx := context.Background() - - _, _, err = s.ragQuery(ctx, nil, RAGQueryInput{}) - if err == nil { - t.Fatal("Expected error for empty question") - } - if !strings.Contains(err.Error(), "question cannot be empty") { - t.Errorf("Unexpected error: %v", err) - } -} - -// TestRagQuery_Good_DefaultsApplied verifies defaults are applied before validation. -// Because the handler applies defaults then validates, a non-empty question with -// zero Collection/TopK should have defaults applied. We cannot verify the actual -// query (needs live Qdrant), but we can verify it gets past validation. -func TestRagQuery_Good_DefaultsApplied(t *testing.T) { - s, err := New() - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - ctx := context.Background() - - // This will fail when it tries to connect to Qdrant, but AFTER applying defaults. - // The error should NOT be about empty question. - _, _, err = s.ragQuery(ctx, nil, RAGQueryInput{Question: "test query"}) - if err == nil { - t.Skip("RAG query succeeded — live Qdrant available, skip default test") - } - // The error should be about connection failure, not validation - if strings.Contains(err.Error(), "question cannot be empty") { - t.Error("Defaults should have been applied before validation check") - } -} - -// --- ragIngest validation --- - -// TestRagIngest_Bad_EmptyPath verifies empty path returns error. -func TestRagIngest_Bad_EmptyPath(t *testing.T) { - s, err := New() - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - ctx := context.Background() - - _, _, err = s.ragIngest(ctx, nil, RAGIngestInput{}) - if err == nil { - t.Fatal("Expected error for empty path") - } - if !strings.Contains(err.Error(), "path cannot be empty") { - t.Errorf("Unexpected error: %v", err) - } -} - -// TestRagIngest_Bad_NonexistentPath verifies nonexistent path returns error. -func TestRagIngest_Bad_NonexistentPath(t *testing.T) { - s, err := New() - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - ctx := context.Background() - - _, _, err = s.ragIngest(ctx, nil, RAGIngestInput{ - Path: "/nonexistent/path/that/does/not/exist/at/all", - }) - if err == nil { - t.Fatal("Expected error for nonexistent path") - } -} - -// TestRagIngest_Good_DefaultCollection verifies the default collection is applied. -func TestRagIngest_Good_DefaultCollection(t *testing.T) { - s, err := New() - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - ctx := context.Background() - - // Use a real but inaccessible path to trigger stat error (not validation error). - // The collection default should be applied first. - _, _, err = s.ragIngest(ctx, nil, RAGIngestInput{ - Path: "/nonexistent/path/for/default/test", - }) - if err == nil { - t.Skip("Ingest succeeded unexpectedly") - } - // The error should NOT be about empty path - if strings.Contains(err.Error(), "path cannot be empty") { - t.Error("Default collection should have been applied") - } -} - -// --- ragCollections validation --- - -// TestRagCollections_Bad_NoQdrant verifies graceful error when Qdrant is not available. -func TestRagCollections_Bad_NoQdrant(t *testing.T) { - s, err := New() - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - ctx := context.Background() - - _, _, err = s.ragCollections(ctx, nil, RAGCollectionsInput{}) - if err == nil { - t.Skip("Qdrant is available — skip connection error test") - } - // Should get a connection error, not a panic - if !strings.Contains(err.Error(), "failed to connect") && !strings.Contains(err.Error(), "failed to list") { - t.Logf("Got error (expected connection failure): %v", err) - } -} - -// --- Struct round-trip tests --- - -// TestRAGQueryResult_Good_AllFields verifies all fields can be set and read. -func TestRAGQueryResult_Good_AllFields(t *testing.T) { - r := RAGQueryResult{ - Content: "test content", - Source: "source.md", - Section: "Overview", - Category: "docs", - ChunkIndex: 3, - Score: 0.88, - } - - if r.Content != "test content" { - t.Errorf("Expected content 'test content', got %q", r.Content) - } - if r.ChunkIndex != 3 { - t.Errorf("Expected chunkIndex 3, got %d", r.ChunkIndex) - } - if r.Score != 0.88 { - t.Errorf("Expected score 0.88, got %f", r.Score) - } -} - -// TestCollectionInfo_Good_AllFields verifies CollectionInfo field access. -func TestCollectionInfo_Good_AllFields(t *testing.T) { - c := CollectionInfo{ - Name: "test-collection", - PointsCount: 12345, - Status: "green", - } - - if c.Name != "test-collection" { - t.Errorf("Expected name 'test-collection', got %q", c.Name) - } - if c.PointsCount != 12345 { - t.Errorf("Expected PointsCount 12345, got %d", c.PointsCount) - } -} - -// TestRAGDefaults_Good verifies default constants are sensible. -func TestRAGDefaults_Good(t *testing.T) { - if DefaultRAGCollection != "hostuk-docs" { - t.Errorf("Expected default collection 'hostuk-docs', got %q", DefaultRAGCollection) - } - if DefaultRAGTopK != 5 { - t.Errorf("Expected default topK 5, got %d", DefaultRAGTopK) - } -} diff --git a/mcp/tools_rag_test.go b/mcp/tools_rag_test.go deleted file mode 100644 index 1c344f3..0000000 --- a/mcp/tools_rag_test.go +++ /dev/null @@ -1,173 +0,0 @@ -package mcp - -import ( - "testing" -) - -// TestRAGToolsRegistered_Good verifies that RAG tools are registered with the MCP server. -func TestRAGToolsRegistered_Good(t *testing.T) { - // Create a new MCP service - this should register all tools including RAG - s, err := New() - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - // The server should have registered the RAG tools - // We verify by checking that the tool handlers exist on the service - // (The actual MCP registration is tested by the SDK) - - if s.server == nil { - t.Fatal("Server should not be nil") - } - - // Verify the service was created with expected defaults - if s.logger == nil { - t.Error("Logger should not be nil") - } -} - -// TestRAGQueryInput_Good verifies the RAGQueryInput struct has expected fields. -func TestRAGQueryInput_Good(t *testing.T) { - input := RAGQueryInput{ - Question: "test question", - Collection: "test-collection", - TopK: 10, - } - - if input.Question != "test question" { - t.Errorf("Expected question 'test question', got %q", input.Question) - } - if input.Collection != "test-collection" { - t.Errorf("Expected collection 'test-collection', got %q", input.Collection) - } - if input.TopK != 10 { - t.Errorf("Expected topK 10, got %d", input.TopK) - } -} - -// TestRAGQueryInput_Defaults verifies default values are handled correctly. -func TestRAGQueryInput_Defaults(t *testing.T) { - // Empty input should use defaults when processed - input := RAGQueryInput{ - Question: "test", - } - - // Defaults should be applied in the handler, not in the struct - if input.Collection != "" { - t.Errorf("Expected empty collection before defaults, got %q", input.Collection) - } - if input.TopK != 0 { - t.Errorf("Expected zero topK before defaults, got %d", input.TopK) - } -} - -// TestRAGIngestInput_Good verifies the RAGIngestInput struct has expected fields. -func TestRAGIngestInput_Good(t *testing.T) { - input := RAGIngestInput{ - Path: "/path/to/docs", - Collection: "my-collection", - Recreate: true, - } - - if input.Path != "/path/to/docs" { - t.Errorf("Expected path '/path/to/docs', got %q", input.Path) - } - if input.Collection != "my-collection" { - t.Errorf("Expected collection 'my-collection', got %q", input.Collection) - } - if !input.Recreate { - t.Error("Expected recreate to be true") - } -} - -// TestRAGCollectionsInput_Good verifies the RAGCollectionsInput struct exists. -func TestRAGCollectionsInput_Good(t *testing.T) { - // RAGCollectionsInput has optional ShowStats parameter - input := RAGCollectionsInput{} - if input.ShowStats { - t.Error("Expected ShowStats to default to false") - } -} - -// TestRAGQueryOutput_Good verifies the RAGQueryOutput struct has expected fields. -func TestRAGQueryOutput_Good(t *testing.T) { - output := RAGQueryOutput{ - Results: []RAGQueryResult{ - { - Content: "some content", - Source: "doc.md", - Section: "Introduction", - Category: "docs", - Score: 0.95, - }, - }, - Query: "test query", - Collection: "test-collection", - Context: "...", - } - - if len(output.Results) != 1 { - t.Fatalf("Expected 1 result, got %d", len(output.Results)) - } - if output.Results[0].Content != "some content" { - t.Errorf("Expected content 'some content', got %q", output.Results[0].Content) - } - if output.Results[0].Score != 0.95 { - t.Errorf("Expected score 0.95, got %f", output.Results[0].Score) - } - if output.Context == "" { - t.Error("Expected context to be set") - } -} - -// TestRAGIngestOutput_Good verifies the RAGIngestOutput struct has expected fields. -func TestRAGIngestOutput_Good(t *testing.T) { - output := RAGIngestOutput{ - Success: true, - Path: "/path/to/docs", - Collection: "my-collection", - Chunks: 10, - Message: "Ingested successfully", - } - - if !output.Success { - t.Error("Expected success to be true") - } - if output.Path != "/path/to/docs" { - t.Errorf("Expected path '/path/to/docs', got %q", output.Path) - } - if output.Chunks != 10 { - t.Errorf("Expected chunks 10, got %d", output.Chunks) - } -} - -// TestRAGCollectionsOutput_Good verifies the RAGCollectionsOutput struct has expected fields. -func TestRAGCollectionsOutput_Good(t *testing.T) { - output := RAGCollectionsOutput{ - Collections: []CollectionInfo{ - {Name: "collection1", PointsCount: 100, Status: "green"}, - {Name: "collection2", PointsCount: 200, Status: "green"}, - }, - } - - if len(output.Collections) != 2 { - t.Fatalf("Expected 2 collections, got %d", len(output.Collections)) - } - if output.Collections[0].Name != "collection1" { - t.Errorf("Expected 'collection1', got %q", output.Collections[0].Name) - } - if output.Collections[0].PointsCount != 100 { - t.Errorf("Expected PointsCount 100, got %d", output.Collections[0].PointsCount) - } -} - -// TestRAGCollectionsInput_Good verifies the RAGCollectionsInput struct has expected fields. -func TestRAGCollectionsInput_ShowStats(t *testing.T) { - input := RAGCollectionsInput{ - ShowStats: true, - } - - if !input.ShowStats { - t.Error("Expected ShowStats to be true") - } -} diff --git a/mcp/tools_webview.go b/mcp/tools_webview.go deleted file mode 100644 index 8fbf941..0000000 --- a/mcp/tools_webview.go +++ /dev/null @@ -1,497 +0,0 @@ -package mcp - -import ( - "context" - "encoding/base64" - "errors" - "fmt" - "time" - - "forge.lthn.ai/core/go-log" - "forge.lthn.ai/core/go-webview" - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// webviewInstance holds the current webview connection. -// This is managed by the MCP service. -var webviewInstance *webview.Webview - -// Sentinel errors for webview tools. -var ( - errNotConnected = errors.New("not connected; use webview_connect first") - errSelectorRequired = errors.New("selector is required") -) - -// WebviewConnectInput contains parameters for connecting to Chrome DevTools. -type WebviewConnectInput struct { - DebugURL string `json:"debug_url"` // Chrome DevTools URL (e.g., http://localhost:9222) - Timeout int `json:"timeout,omitempty"` // Default timeout in seconds (default: 30) -} - -// WebviewConnectOutput contains the result of connecting to Chrome. -type WebviewConnectOutput struct { - Success bool `json:"success"` - Message string `json:"message,omitempty"` -} - -// WebviewNavigateInput contains parameters for navigating to a URL. -type WebviewNavigateInput struct { - URL string `json:"url"` // URL to navigate to -} - -// WebviewNavigateOutput contains the result of navigation. -type WebviewNavigateOutput struct { - Success bool `json:"success"` - URL string `json:"url"` -} - -// WebviewClickInput contains parameters for clicking an element. -type WebviewClickInput struct { - Selector string `json:"selector"` // CSS selector -} - -// WebviewClickOutput contains the result of a click action. -type WebviewClickOutput struct { - Success bool `json:"success"` -} - -// WebviewTypeInput contains parameters for typing text. -type WebviewTypeInput struct { - Selector string `json:"selector"` // CSS selector - Text string `json:"text"` // Text to type -} - -// WebviewTypeOutput contains the result of a type action. -type WebviewTypeOutput struct { - Success bool `json:"success"` -} - -// WebviewQueryInput contains parameters for querying an element. -type WebviewQueryInput struct { - Selector string `json:"selector"` // CSS selector - All bool `json:"all,omitempty"` // If true, return all matching elements -} - -// WebviewQueryOutput contains the result of a query. -type WebviewQueryOutput struct { - Found bool `json:"found"` - Count int `json:"count"` - Elements []WebviewElementInfo `json:"elements,omitempty"` -} - -// WebviewElementInfo represents information about a DOM element. -type WebviewElementInfo struct { - NodeID int `json:"nodeId"` - TagName string `json:"tagName"` - Attributes map[string]string `json:"attributes,omitempty"` - BoundingBox *webview.BoundingBox `json:"boundingBox,omitempty"` -} - -// WebviewConsoleInput contains parameters for getting console output. -type WebviewConsoleInput struct { - Clear bool `json:"clear,omitempty"` // If true, clear console after getting messages -} - -// WebviewConsoleOutput contains console messages. -type WebviewConsoleOutput struct { - Messages []WebviewConsoleMessage `json:"messages"` - Count int `json:"count"` -} - -// WebviewConsoleMessage represents a console message. -type WebviewConsoleMessage struct { - Type string `json:"type"` - Text string `json:"text"` - Timestamp string `json:"timestamp"` - URL string `json:"url,omitempty"` - Line int `json:"line,omitempty"` -} - -// WebviewEvalInput contains parameters for evaluating JavaScript. -type WebviewEvalInput struct { - Script string `json:"script"` // JavaScript to evaluate -} - -// WebviewEvalOutput contains the result of JavaScript evaluation. -type WebviewEvalOutput struct { - Success bool `json:"success"` - Result any `json:"result,omitempty"` - Error string `json:"error,omitempty"` -} - -// WebviewScreenshotInput contains parameters for taking a screenshot. -type WebviewScreenshotInput struct { - Format string `json:"format,omitempty"` // "png" or "jpeg" (default: png) -} - -// WebviewScreenshotOutput contains the screenshot data. -type WebviewScreenshotOutput struct { - Success bool `json:"success"` - Data string `json:"data"` // Base64 encoded image - Format string `json:"format"` -} - -// WebviewWaitInput contains parameters for waiting operations. -type WebviewWaitInput struct { - Selector string `json:"selector,omitempty"` // Wait for selector - Timeout int `json:"timeout,omitempty"` // Timeout in seconds -} - -// WebviewWaitOutput contains the result of waiting. -type WebviewWaitOutput struct { - Success bool `json:"success"` - Message string `json:"message,omitempty"` -} - -// WebviewDisconnectInput contains parameters for disconnecting. -type WebviewDisconnectInput struct{} - -// WebviewDisconnectOutput contains the result of disconnecting. -type WebviewDisconnectOutput struct { - Success bool `json:"success"` - Message string `json:"message,omitempty"` -} - -// registerWebviewTools adds webview tools to the MCP server. -func (s *Service) registerWebviewTools(server *mcp.Server) { - mcp.AddTool(server, &mcp.Tool{ - Name: "webview_connect", - Description: "Connect to Chrome DevTools Protocol. Start Chrome with --remote-debugging-port=9222 first.", - }, s.webviewConnect) - - mcp.AddTool(server, &mcp.Tool{ - Name: "webview_disconnect", - Description: "Disconnect from Chrome DevTools.", - }, s.webviewDisconnect) - - mcp.AddTool(server, &mcp.Tool{ - Name: "webview_navigate", - Description: "Navigate the browser to a URL.", - }, s.webviewNavigate) - - mcp.AddTool(server, &mcp.Tool{ - Name: "webview_click", - Description: "Click on an element by CSS selector.", - }, s.webviewClick) - - mcp.AddTool(server, &mcp.Tool{ - Name: "webview_type", - Description: "Type text into an element by CSS selector.", - }, s.webviewType) - - mcp.AddTool(server, &mcp.Tool{ - Name: "webview_query", - Description: "Query DOM elements by CSS selector.", - }, s.webviewQuery) - - mcp.AddTool(server, &mcp.Tool{ - Name: "webview_console", - Description: "Get browser console output.", - }, s.webviewConsole) - - mcp.AddTool(server, &mcp.Tool{ - Name: "webview_eval", - Description: "Evaluate JavaScript in the browser context.", - }, s.webviewEval) - - mcp.AddTool(server, &mcp.Tool{ - Name: "webview_screenshot", - Description: "Capture a screenshot of the browser window.", - }, s.webviewScreenshot) - - mcp.AddTool(server, &mcp.Tool{ - Name: "webview_wait", - Description: "Wait for an element to appear by CSS selector.", - }, s.webviewWait) -} - -// webviewConnect handles the webview_connect tool call. -func (s *Service) webviewConnect(ctx context.Context, req *mcp.CallToolRequest, input WebviewConnectInput) (*mcp.CallToolResult, WebviewConnectOutput, error) { - s.logger.Security("MCP tool execution", "tool", "webview_connect", "debug_url", input.DebugURL, "user", log.Username()) - - if input.DebugURL == "" { - return nil, WebviewConnectOutput{}, errors.New("debug_url is required") - } - - // Close existing connection if any - if webviewInstance != nil { - _ = webviewInstance.Close() - webviewInstance = nil - } - - // Set up options - opts := []webview.Option{ - webview.WithDebugURL(input.DebugURL), - } - - if input.Timeout > 0 { - opts = append(opts, webview.WithTimeout(time.Duration(input.Timeout)*time.Second)) - } - - // Create new webview instance - wv, err := webview.New(opts...) - if err != nil { - log.Error("mcp: webview connect failed", "debug_url", input.DebugURL, "err", err) - return nil, WebviewConnectOutput{}, fmt.Errorf("failed to connect: %w", err) - } - - webviewInstance = wv - - return nil, WebviewConnectOutput{ - Success: true, - Message: fmt.Sprintf("Connected to Chrome DevTools at %s", input.DebugURL), - }, nil -} - -// webviewDisconnect handles the webview_disconnect tool call. -func (s *Service) webviewDisconnect(ctx context.Context, req *mcp.CallToolRequest, input WebviewDisconnectInput) (*mcp.CallToolResult, WebviewDisconnectOutput, error) { - s.logger.Info("MCP tool execution", "tool", "webview_disconnect", "user", log.Username()) - - if webviewInstance == nil { - return nil, WebviewDisconnectOutput{ - Success: true, - Message: "No active connection", - }, nil - } - - if err := webviewInstance.Close(); err != nil { - log.Error("mcp: webview disconnect failed", "err", err) - return nil, WebviewDisconnectOutput{}, fmt.Errorf("failed to disconnect: %w", err) - } - - webviewInstance = nil - - return nil, WebviewDisconnectOutput{ - Success: true, - Message: "Disconnected from Chrome DevTools", - }, nil -} - -// webviewNavigate handles the webview_navigate tool call. -func (s *Service) webviewNavigate(ctx context.Context, req *mcp.CallToolRequest, input WebviewNavigateInput) (*mcp.CallToolResult, WebviewNavigateOutput, error) { - s.logger.Info("MCP tool execution", "tool", "webview_navigate", "url", input.URL, "user", log.Username()) - - if webviewInstance == nil { - return nil, WebviewNavigateOutput{}, errNotConnected - } - - if input.URL == "" { - return nil, WebviewNavigateOutput{}, errors.New("url is required") - } - - if err := webviewInstance.Navigate(input.URL); err != nil { - log.Error("mcp: webview navigate failed", "url", input.URL, "err", err) - return nil, WebviewNavigateOutput{}, fmt.Errorf("failed to navigate: %w", err) - } - - return nil, WebviewNavigateOutput{ - Success: true, - URL: input.URL, - }, nil -} - -// webviewClick handles the webview_click tool call. -func (s *Service) webviewClick(ctx context.Context, req *mcp.CallToolRequest, input WebviewClickInput) (*mcp.CallToolResult, WebviewClickOutput, error) { - s.logger.Info("MCP tool execution", "tool", "webview_click", "selector", input.Selector, "user", log.Username()) - - if webviewInstance == nil { - return nil, WebviewClickOutput{}, errNotConnected - } - - if input.Selector == "" { - return nil, WebviewClickOutput{}, errSelectorRequired - } - - if err := webviewInstance.Click(input.Selector); err != nil { - log.Error("mcp: webview click failed", "selector", input.Selector, "err", err) - return nil, WebviewClickOutput{}, fmt.Errorf("failed to click: %w", err) - } - - return nil, WebviewClickOutput{Success: true}, nil -} - -// webviewType handles the webview_type tool call. -func (s *Service) webviewType(ctx context.Context, req *mcp.CallToolRequest, input WebviewTypeInput) (*mcp.CallToolResult, WebviewTypeOutput, error) { - s.logger.Info("MCP tool execution", "tool", "webview_type", "selector", input.Selector, "user", log.Username()) - - if webviewInstance == nil { - return nil, WebviewTypeOutput{}, errNotConnected - } - - if input.Selector == "" { - return nil, WebviewTypeOutput{}, errSelectorRequired - } - - if err := webviewInstance.Type(input.Selector, input.Text); err != nil { - log.Error("mcp: webview type failed", "selector", input.Selector, "err", err) - return nil, WebviewTypeOutput{}, fmt.Errorf("failed to type: %w", err) - } - - return nil, WebviewTypeOutput{Success: true}, nil -} - -// webviewQuery handles the webview_query tool call. -func (s *Service) webviewQuery(ctx context.Context, req *mcp.CallToolRequest, input WebviewQueryInput) (*mcp.CallToolResult, WebviewQueryOutput, error) { - s.logger.Info("MCP tool execution", "tool", "webview_query", "selector", input.Selector, "all", input.All, "user", log.Username()) - - if webviewInstance == nil { - return nil, WebviewQueryOutput{}, errNotConnected - } - - if input.Selector == "" { - return nil, WebviewQueryOutput{}, errSelectorRequired - } - - if input.All { - elements, err := webviewInstance.QuerySelectorAll(input.Selector) - if err != nil { - log.Error("mcp: webview query all failed", "selector", input.Selector, "err", err) - return nil, WebviewQueryOutput{}, fmt.Errorf("failed to query: %w", err) - } - - output := WebviewQueryOutput{ - Found: len(elements) > 0, - Count: len(elements), - Elements: make([]WebviewElementInfo, len(elements)), - } - - for i, elem := range elements { - output.Elements[i] = WebviewElementInfo{ - NodeID: elem.NodeID, - TagName: elem.TagName, - Attributes: elem.Attributes, - BoundingBox: elem.BoundingBox, - } - } - - return nil, output, nil - } - - elem, err := webviewInstance.QuerySelector(input.Selector) - if err != nil { - // Element not found is not necessarily an error - return nil, WebviewQueryOutput{ - Found: false, - Count: 0, - }, nil - } - - return nil, WebviewQueryOutput{ - Found: true, - Count: 1, - Elements: []WebviewElementInfo{{ - NodeID: elem.NodeID, - TagName: elem.TagName, - Attributes: elem.Attributes, - BoundingBox: elem.BoundingBox, - }}, - }, nil -} - -// webviewConsole handles the webview_console tool call. -func (s *Service) webviewConsole(ctx context.Context, req *mcp.CallToolRequest, input WebviewConsoleInput) (*mcp.CallToolResult, WebviewConsoleOutput, error) { - s.logger.Info("MCP tool execution", "tool", "webview_console", "clear", input.Clear, "user", log.Username()) - - if webviewInstance == nil { - return nil, WebviewConsoleOutput{}, errNotConnected - } - - messages := webviewInstance.GetConsole() - - output := WebviewConsoleOutput{ - Messages: make([]WebviewConsoleMessage, len(messages)), - Count: len(messages), - } - - for i, msg := range messages { - output.Messages[i] = WebviewConsoleMessage{ - Type: msg.Type, - Text: msg.Text, - Timestamp: msg.Timestamp.Format(time.RFC3339), - URL: msg.URL, - Line: msg.Line, - } - } - - if input.Clear { - webviewInstance.ClearConsole() - } - - return nil, output, nil -} - -// webviewEval handles the webview_eval tool call. -func (s *Service) webviewEval(ctx context.Context, req *mcp.CallToolRequest, input WebviewEvalInput) (*mcp.CallToolResult, WebviewEvalOutput, error) { - s.logger.Security("MCP tool execution", "tool", "webview_eval", "user", log.Username()) - - if webviewInstance == nil { - return nil, WebviewEvalOutput{}, errNotConnected - } - - if input.Script == "" { - return nil, WebviewEvalOutput{}, errors.New("script is required") - } - - result, err := webviewInstance.Evaluate(input.Script) - if err != nil { - log.Error("mcp: webview eval failed", "err", err) - return nil, WebviewEvalOutput{ - Success: false, - Error: err.Error(), - }, nil - } - - return nil, WebviewEvalOutput{ - Success: true, - Result: result, - }, nil -} - -// webviewScreenshot handles the webview_screenshot tool call. -func (s *Service) webviewScreenshot(ctx context.Context, req *mcp.CallToolRequest, input WebviewScreenshotInput) (*mcp.CallToolResult, WebviewScreenshotOutput, error) { - s.logger.Info("MCP tool execution", "tool", "webview_screenshot", "format", input.Format, "user", log.Username()) - - if webviewInstance == nil { - return nil, WebviewScreenshotOutput{}, errNotConnected - } - - format := input.Format - if format == "" { - format = "png" - } - - data, err := webviewInstance.Screenshot() - if err != nil { - log.Error("mcp: webview screenshot failed", "err", err) - return nil, WebviewScreenshotOutput{}, fmt.Errorf("failed to capture screenshot: %w", err) - } - - return nil, WebviewScreenshotOutput{ - Success: true, - Data: base64.StdEncoding.EncodeToString(data), - Format: format, - }, nil -} - -// webviewWait handles the webview_wait tool call. -func (s *Service) webviewWait(ctx context.Context, req *mcp.CallToolRequest, input WebviewWaitInput) (*mcp.CallToolResult, WebviewWaitOutput, error) { - s.logger.Info("MCP tool execution", "tool", "webview_wait", "selector", input.Selector, "timeout", input.Timeout, "user", log.Username()) - - if webviewInstance == nil { - return nil, WebviewWaitOutput{}, errNotConnected - } - - if input.Selector == "" { - return nil, WebviewWaitOutput{}, errSelectorRequired - } - - if err := webviewInstance.WaitForSelector(input.Selector); err != nil { - log.Error("mcp: webview wait failed", "selector", input.Selector, "err", err) - return nil, WebviewWaitOutput{}, fmt.Errorf("failed to wait for selector: %w", err) - } - - return nil, WebviewWaitOutput{ - Success: true, - Message: fmt.Sprintf("Element found: %s", input.Selector), - }, nil -} diff --git a/mcp/tools_webview_test.go b/mcp/tools_webview_test.go deleted file mode 100644 index abb00fa..0000000 --- a/mcp/tools_webview_test.go +++ /dev/null @@ -1,452 +0,0 @@ -package mcp - -import ( - "testing" - "time" - - "forge.lthn.ai/core/go-webview" -) - -// skipIfShort skips webview tests in short mode (go test -short). -// Webview tool handlers require a running Chrome instance with -// --remote-debugging-port, which is not available in CI. -// Struct-level tests below are safe without Chrome, but any future -// tests that call webview tool handlers MUST use this guard. -func skipIfShort(t *testing.T) { - t.Helper() - if testing.Short() { - t.Skip("webview tests skipped in short mode (no Chrome available)") - } -} - -// TestWebviewToolsRegistered_Good verifies that webview tools are registered with the MCP server. -func TestWebviewToolsRegistered_Good(t *testing.T) { - // Create a new MCP service - this should register all tools including webview - s, err := New() - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - // The server should have registered the webview tools - if s.server == nil { - t.Fatal("Server should not be nil") - } - - // Verify the service was created with expected defaults - if s.logger == nil { - t.Error("Logger should not be nil") - } -} - -// TestWebviewToolHandlers_RequiresChrome demonstrates the CI guard -// for tests that would require a running Chrome instance. Any future -// test that calls webview tool handlers (webviewConnect, webviewNavigate, -// etc.) should call skipIfShort(t) at the top. -func TestWebviewToolHandlers_RequiresChrome(t *testing.T) { - skipIfShort(t) - - // This test verifies that webview tool handlers correctly reject - // calls when not connected to Chrome. - tmpDir := t.TempDir() - s, err := New(WithWorkspaceRoot(tmpDir)) - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - ctx := t.Context() - - // webview_navigate should fail without a connection - _, _, err = s.webviewNavigate(ctx, nil, WebviewNavigateInput{URL: "https://example.com"}) - if err == nil { - t.Error("Expected error when navigating without a webview connection") - } - - // webview_click should fail without a connection - _, _, err = s.webviewClick(ctx, nil, WebviewClickInput{Selector: "#btn"}) - if err == nil { - t.Error("Expected error when clicking without a webview connection") - } - - // webview_eval should fail without a connection - _, _, err = s.webviewEval(ctx, nil, WebviewEvalInput{Script: "1+1"}) - if err == nil { - t.Error("Expected error when evaluating without a webview connection") - } - - // webview_connect with invalid URL should fail - _, _, err = s.webviewConnect(ctx, nil, WebviewConnectInput{DebugURL: ""}) - if err == nil { - t.Error("Expected error when connecting with empty debug URL") - } -} - -// TestWebviewConnectInput_Good verifies the WebviewConnectInput struct has expected fields. -func TestWebviewConnectInput_Good(t *testing.T) { - input := WebviewConnectInput{ - DebugURL: "http://localhost:9222", - Timeout: 30, - } - - if input.DebugURL != "http://localhost:9222" { - t.Errorf("Expected debug_url 'http://localhost:9222', got %q", input.DebugURL) - } - if input.Timeout != 30 { - t.Errorf("Expected timeout 30, got %d", input.Timeout) - } -} - -// TestWebviewNavigateInput_Good verifies the WebviewNavigateInput struct has expected fields. -func TestWebviewNavigateInput_Good(t *testing.T) { - input := WebviewNavigateInput{ - URL: "https://example.com", - } - - if input.URL != "https://example.com" { - t.Errorf("Expected URL 'https://example.com', got %q", input.URL) - } -} - -// TestWebviewClickInput_Good verifies the WebviewClickInput struct has expected fields. -func TestWebviewClickInput_Good(t *testing.T) { - input := WebviewClickInput{ - Selector: "#submit-button", - } - - if input.Selector != "#submit-button" { - t.Errorf("Expected selector '#submit-button', got %q", input.Selector) - } -} - -// TestWebviewTypeInput_Good verifies the WebviewTypeInput struct has expected fields. -func TestWebviewTypeInput_Good(t *testing.T) { - input := WebviewTypeInput{ - Selector: "#email-input", - Text: "test@example.com", - } - - if input.Selector != "#email-input" { - t.Errorf("Expected selector '#email-input', got %q", input.Selector) - } - if input.Text != "test@example.com" { - t.Errorf("Expected text 'test@example.com', got %q", input.Text) - } -} - -// TestWebviewQueryInput_Good verifies the WebviewQueryInput struct has expected fields. -func TestWebviewQueryInput_Good(t *testing.T) { - input := WebviewQueryInput{ - Selector: "div.container", - All: true, - } - - if input.Selector != "div.container" { - t.Errorf("Expected selector 'div.container', got %q", input.Selector) - } - if !input.All { - t.Error("Expected all to be true") - } -} - -// TestWebviewQueryInput_Defaults verifies default values are handled correctly. -func TestWebviewQueryInput_Defaults(t *testing.T) { - input := WebviewQueryInput{ - Selector: ".test", - } - - if input.All { - t.Error("Expected all to default to false") - } -} - -// TestWebviewConsoleInput_Good verifies the WebviewConsoleInput struct has expected fields. -func TestWebviewConsoleInput_Good(t *testing.T) { - input := WebviewConsoleInput{ - Clear: true, - } - - if !input.Clear { - t.Error("Expected clear to be true") - } -} - -// TestWebviewEvalInput_Good verifies the WebviewEvalInput struct has expected fields. -func TestWebviewEvalInput_Good(t *testing.T) { - input := WebviewEvalInput{ - Script: "document.title", - } - - if input.Script != "document.title" { - t.Errorf("Expected script 'document.title', got %q", input.Script) - } -} - -// TestWebviewScreenshotInput_Good verifies the WebviewScreenshotInput struct has expected fields. -func TestWebviewScreenshotInput_Good(t *testing.T) { - input := WebviewScreenshotInput{ - Format: "png", - } - - if input.Format != "png" { - t.Errorf("Expected format 'png', got %q", input.Format) - } -} - -// TestWebviewScreenshotInput_Defaults verifies default values are handled correctly. -func TestWebviewScreenshotInput_Defaults(t *testing.T) { - input := WebviewScreenshotInput{} - - if input.Format != "" { - t.Errorf("Expected format to default to empty, got %q", input.Format) - } -} - -// TestWebviewWaitInput_Good verifies the WebviewWaitInput struct has expected fields. -func TestWebviewWaitInput_Good(t *testing.T) { - input := WebviewWaitInput{ - Selector: "#loading", - Timeout: 10, - } - - if input.Selector != "#loading" { - t.Errorf("Expected selector '#loading', got %q", input.Selector) - } - if input.Timeout != 10 { - t.Errorf("Expected timeout 10, got %d", input.Timeout) - } -} - -// TestWebviewConnectOutput_Good verifies the WebviewConnectOutput struct has expected fields. -func TestWebviewConnectOutput_Good(t *testing.T) { - output := WebviewConnectOutput{ - Success: true, - Message: "Connected to Chrome DevTools", - } - - if !output.Success { - t.Error("Expected success to be true") - } - if output.Message == "" { - t.Error("Expected message to be set") - } -} - -// TestWebviewNavigateOutput_Good verifies the WebviewNavigateOutput struct has expected fields. -func TestWebviewNavigateOutput_Good(t *testing.T) { - output := WebviewNavigateOutput{ - Success: true, - URL: "https://example.com", - } - - if !output.Success { - t.Error("Expected success to be true") - } - if output.URL != "https://example.com" { - t.Errorf("Expected URL 'https://example.com', got %q", output.URL) - } -} - -// TestWebviewQueryOutput_Good verifies the WebviewQueryOutput struct has expected fields. -func TestWebviewQueryOutput_Good(t *testing.T) { - output := WebviewQueryOutput{ - Found: true, - Count: 3, - Elements: []WebviewElementInfo{ - { - NodeID: 1, - TagName: "DIV", - Attributes: map[string]string{ - "class": "container", - }, - }, - }, - } - - if !output.Found { - t.Error("Expected found to be true") - } - if output.Count != 3 { - t.Errorf("Expected count 3, got %d", output.Count) - } - if len(output.Elements) != 1 { - t.Fatalf("Expected 1 element, got %d", len(output.Elements)) - } - if output.Elements[0].TagName != "DIV" { - t.Errorf("Expected tagName 'DIV', got %q", output.Elements[0].TagName) - } -} - -// TestWebviewConsoleOutput_Good verifies the WebviewConsoleOutput struct has expected fields. -func TestWebviewConsoleOutput_Good(t *testing.T) { - output := WebviewConsoleOutput{ - Messages: []WebviewConsoleMessage{ - { - Type: "log", - Text: "Hello, world!", - Timestamp: "2024-01-01T00:00:00Z", - }, - { - Type: "error", - Text: "An error occurred", - Timestamp: "2024-01-01T00:00:01Z", - URL: "https://example.com/script.js", - Line: 42, - }, - }, - Count: 2, - } - - if output.Count != 2 { - t.Errorf("Expected count 2, got %d", output.Count) - } - if len(output.Messages) != 2 { - t.Fatalf("Expected 2 messages, got %d", len(output.Messages)) - } - if output.Messages[0].Type != "log" { - t.Errorf("Expected type 'log', got %q", output.Messages[0].Type) - } - if output.Messages[1].Line != 42 { - t.Errorf("Expected line 42, got %d", output.Messages[1].Line) - } -} - -// TestWebviewEvalOutput_Good verifies the WebviewEvalOutput struct has expected fields. -func TestWebviewEvalOutput_Good(t *testing.T) { - output := WebviewEvalOutput{ - Success: true, - Result: "Example Page", - } - - if !output.Success { - t.Error("Expected success to be true") - } - if output.Result != "Example Page" { - t.Errorf("Expected result 'Example Page', got %v", output.Result) - } -} - -// TestWebviewEvalOutput_Error verifies the WebviewEvalOutput struct handles errors. -func TestWebviewEvalOutput_Error(t *testing.T) { - output := WebviewEvalOutput{ - Success: false, - Error: "ReferenceError: foo is not defined", - } - - if output.Success { - t.Error("Expected success to be false") - } - if output.Error == "" { - t.Error("Expected error message to be set") - } -} - -// TestWebviewScreenshotOutput_Good verifies the WebviewScreenshotOutput struct has expected fields. -func TestWebviewScreenshotOutput_Good(t *testing.T) { - output := WebviewScreenshotOutput{ - Success: true, - Data: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==", - Format: "png", - } - - if !output.Success { - t.Error("Expected success to be true") - } - if output.Data == "" { - t.Error("Expected data to be set") - } - if output.Format != "png" { - t.Errorf("Expected format 'png', got %q", output.Format) - } -} - -// TestWebviewElementInfo_Good verifies the WebviewElementInfo struct has expected fields. -func TestWebviewElementInfo_Good(t *testing.T) { - elem := WebviewElementInfo{ - NodeID: 123, - TagName: "INPUT", - Attributes: map[string]string{ - "type": "text", - "name": "email", - "class": "form-control", - }, - BoundingBox: &webview.BoundingBox{ - X: 100, - Y: 200, - Width: 300, - Height: 50, - }, - } - - if elem.NodeID != 123 { - t.Errorf("Expected nodeId 123, got %d", elem.NodeID) - } - if elem.TagName != "INPUT" { - t.Errorf("Expected tagName 'INPUT', got %q", elem.TagName) - } - if elem.Attributes["type"] != "text" { - t.Errorf("Expected type attribute 'text', got %q", elem.Attributes["type"]) - } - if elem.BoundingBox == nil { - t.Fatal("Expected bounding box to be set") - } - if elem.BoundingBox.Width != 300 { - t.Errorf("Expected width 300, got %f", elem.BoundingBox.Width) - } -} - -// TestWebviewConsoleMessage_Good verifies the WebviewConsoleMessage struct has expected fields. -func TestWebviewConsoleMessage_Good(t *testing.T) { - msg := WebviewConsoleMessage{ - Type: "error", - Text: "Failed to load resource", - Timestamp: time.Now().Format(time.RFC3339), - URL: "https://example.com/api/data", - Line: 1, - } - - if msg.Type != "error" { - t.Errorf("Expected type 'error', got %q", msg.Type) - } - if msg.Text == "" { - t.Error("Expected text to be set") - } - if msg.URL == "" { - t.Error("Expected URL to be set") - } -} - -// TestWebviewDisconnectInput_Good verifies the WebviewDisconnectInput struct exists. -func TestWebviewDisconnectInput_Good(t *testing.T) { - // WebviewDisconnectInput has no fields - input := WebviewDisconnectInput{} - _ = input // Just verify the struct exists -} - -// TestWebviewDisconnectOutput_Good verifies the WebviewDisconnectOutput struct has expected fields. -func TestWebviewDisconnectOutput_Good(t *testing.T) { - output := WebviewDisconnectOutput{ - Success: true, - Message: "Disconnected from Chrome DevTools", - } - - if !output.Success { - t.Error("Expected success to be true") - } - if output.Message == "" { - t.Error("Expected message to be set") - } -} - -// TestWebviewWaitOutput_Good verifies the WebviewWaitOutput struct has expected fields. -func TestWebviewWaitOutput_Good(t *testing.T) { - output := WebviewWaitOutput{ - Success: true, - Message: "Element found: #login-form", - } - - if !output.Success { - t.Error("Expected success to be true") - } - if output.Message == "" { - t.Error("Expected message to be set") - } -} diff --git a/mcp/tools_ws.go b/mcp/tools_ws.go deleted file mode 100644 index ccae53c..0000000 --- a/mcp/tools_ws.go +++ /dev/null @@ -1,142 +0,0 @@ -package mcp - -import ( - "context" - "fmt" - "net" - "net/http" - - "forge.lthn.ai/core/go-log" - "forge.lthn.ai/core/go-ws" - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// WSStartInput contains parameters for starting the WebSocket server. -type WSStartInput struct { - Addr string `json:"addr,omitempty"` // Address to listen on (default: ":8080") -} - -// WSStartOutput contains the result of starting the WebSocket server. -type WSStartOutput struct { - Success bool `json:"success"` - Addr string `json:"addr"` - Message string `json:"message,omitempty"` -} - -// WSInfoInput contains parameters for getting WebSocket hub info. -type WSInfoInput struct{} - -// WSInfoOutput contains WebSocket hub statistics. -type WSInfoOutput struct { - Clients int `json:"clients"` - Channels int `json:"channels"` -} - -// registerWSTools adds WebSocket tools to the MCP server. -// Returns false if WebSocket hub is not available. -func (s *Service) registerWSTools(server *mcp.Server) bool { - if s.wsHub == nil { - return false - } - - mcp.AddTool(server, &mcp.Tool{ - Name: "ws_start", - Description: "Start the WebSocket server for real-time process output streaming.", - }, s.wsStart) - - mcp.AddTool(server, &mcp.Tool{ - Name: "ws_info", - Description: "Get WebSocket hub statistics (connected clients and active channels).", - }, s.wsInfo) - - return true -} - -// wsStart handles the ws_start tool call. -func (s *Service) wsStart(ctx context.Context, req *mcp.CallToolRequest, input WSStartInput) (*mcp.CallToolResult, WSStartOutput, error) { - addr := input.Addr - if addr == "" { - addr = ":8080" - } - - s.logger.Security("MCP tool execution", "tool", "ws_start", "addr", addr, "user", log.Username()) - - // Check if server is already running - if s.wsServer != nil { - return nil, WSStartOutput{ - Success: true, - Addr: s.wsAddr, - Message: "WebSocket server already running", - }, nil - } - - // Create HTTP server with WebSocket handler - mux := http.NewServeMux() - mux.HandleFunc("/ws", s.wsHub.Handler()) - - server := &http.Server{ - Addr: addr, - Handler: mux, - } - - // Start listener to get actual address - ln, err := net.Listen("tcp", addr) - if err != nil { - log.Error("mcp: ws start listen failed", "addr", addr, "err", err) - return nil, WSStartOutput{}, fmt.Errorf("failed to listen on %s: %w", addr, err) - } - - actualAddr := ln.Addr().String() - s.wsServer = server - s.wsAddr = actualAddr - - // Start server in background - go func() { - if err := server.Serve(ln); err != nil && err != http.ErrServerClosed { - log.Error("mcp: ws server error", "err", err) - } - }() - - return nil, WSStartOutput{ - Success: true, - Addr: actualAddr, - Message: fmt.Sprintf("WebSocket server started at ws://%s/ws", actualAddr), - }, nil -} - -// wsInfo handles the ws_info tool call. -func (s *Service) wsInfo(ctx context.Context, req *mcp.CallToolRequest, input WSInfoInput) (*mcp.CallToolResult, WSInfoOutput, error) { - s.logger.Info("MCP tool execution", "tool", "ws_info", "user", log.Username()) - - stats := s.wsHub.Stats() - - return nil, WSInfoOutput{ - Clients: stats.Clients, - Channels: stats.Channels, - }, nil -} - -// ProcessEventCallback is a callback function for process events. -// It can be registered with the process service to forward events to WebSocket. -type ProcessEventCallback struct { - hub *ws.Hub -} - -// NewProcessEventCallback creates a callback that forwards process events to WebSocket. -func NewProcessEventCallback(hub *ws.Hub) *ProcessEventCallback { - return &ProcessEventCallback{hub: hub} -} - -// OnProcessOutput forwards process output to WebSocket subscribers. -func (c *ProcessEventCallback) OnProcessOutput(processID string, line string) { - if c.hub != nil { - _ = c.hub.SendProcessOutput(processID, line) - } -} - -// OnProcessStatus forwards process status changes to WebSocket subscribers. -func (c *ProcessEventCallback) OnProcessStatus(processID string, status string, exitCode int) { - if c.hub != nil { - _ = c.hub.SendProcessStatus(processID, status, exitCode) - } -} diff --git a/mcp/tools_ws_test.go b/mcp/tools_ws_test.go deleted file mode 100644 index 2ffaa51..0000000 --- a/mcp/tools_ws_test.go +++ /dev/null @@ -1,174 +0,0 @@ -package mcp - -import ( - "testing" - - "forge.lthn.ai/core/go-ws" -) - -// TestWSToolsRegistered_Good verifies that WebSocket tools are registered when hub is available. -func TestWSToolsRegistered_Good(t *testing.T) { - // Create a new MCP service without ws hub - tools should not be registered - s, err := New() - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - if s.wsHub != nil { - t.Error("WS hub should be nil by default") - } - - if s.server == nil { - t.Fatal("Server should not be nil") - } -} - -// TestWSStartInput_Good verifies the WSStartInput struct has expected fields. -func TestWSStartInput_Good(t *testing.T) { - input := WSStartInput{ - Addr: ":9090", - } - - if input.Addr != ":9090" { - t.Errorf("Expected addr ':9090', got %q", input.Addr) - } -} - -// TestWSStartInput_Defaults verifies default values. -func TestWSStartInput_Defaults(t *testing.T) { - input := WSStartInput{} - - if input.Addr != "" { - t.Errorf("Expected addr to default to empty, got %q", input.Addr) - } -} - -// TestWSStartOutput_Good verifies the WSStartOutput struct has expected fields. -func TestWSStartOutput_Good(t *testing.T) { - output := WSStartOutput{ - Success: true, - Addr: "127.0.0.1:8080", - Message: "WebSocket server started", - } - - if !output.Success { - t.Error("Expected Success to be true") - } - if output.Addr != "127.0.0.1:8080" { - t.Errorf("Expected addr '127.0.0.1:8080', got %q", output.Addr) - } - if output.Message != "WebSocket server started" { - t.Errorf("Expected message 'WebSocket server started', got %q", output.Message) - } -} - -// TestWSInfoInput_Good verifies the WSInfoInput struct exists (it's empty). -func TestWSInfoInput_Good(t *testing.T) { - input := WSInfoInput{} - _ = input // Just verify it compiles -} - -// TestWSInfoOutput_Good verifies the WSInfoOutput struct has expected fields. -func TestWSInfoOutput_Good(t *testing.T) { - output := WSInfoOutput{ - Clients: 5, - Channels: 3, - } - - if output.Clients != 5 { - t.Errorf("Expected clients 5, got %d", output.Clients) - } - if output.Channels != 3 { - t.Errorf("Expected channels 3, got %d", output.Channels) - } -} - -// TestWithWSHub_Good verifies the WithWSHub option. -func TestWithWSHub_Good(t *testing.T) { - hub := ws.NewHub() - - s, err := New(WithWSHub(hub)) - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - if s.wsHub != hub { - t.Error("Expected wsHub to be set") - } -} - -// TestWithWSHub_Nil verifies the WithWSHub option with nil. -func TestWithWSHub_Nil(t *testing.T) { - s, err := New(WithWSHub(nil)) - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - if s.wsHub != nil { - t.Error("Expected wsHub to be nil when passed nil") - } -} - -// TestProcessEventCallback_Good verifies the ProcessEventCallback struct. -func TestProcessEventCallback_Good(t *testing.T) { - hub := ws.NewHub() - callback := NewProcessEventCallback(hub) - - if callback.hub != hub { - t.Error("Expected callback hub to be set") - } - - // Test that methods don't panic - callback.OnProcessOutput("proc-1", "test output") - callback.OnProcessStatus("proc-1", "exited", 0) -} - -// TestProcessEventCallback_NilHub verifies the ProcessEventCallback with nil hub doesn't panic. -func TestProcessEventCallback_NilHub(t *testing.T) { - callback := NewProcessEventCallback(nil) - - if callback.hub != nil { - t.Error("Expected callback hub to be nil") - } - - // Test that methods don't panic with nil hub - callback.OnProcessOutput("proc-1", "test output") - callback.OnProcessStatus("proc-1", "exited", 0) -} - -// TestServiceWSHub_Good verifies the WSHub getter method. -func TestServiceWSHub_Good(t *testing.T) { - hub := ws.NewHub() - s, err := New(WithWSHub(hub)) - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - if s.WSHub() != hub { - t.Error("Expected WSHub() to return the hub") - } -} - -// TestServiceWSHub_Nil verifies the WSHub getter returns nil when not configured. -func TestServiceWSHub_Nil(t *testing.T) { - s, err := New() - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - if s.WSHub() != nil { - t.Error("Expected WSHub() to return nil when not configured") - } -} - -// TestServiceProcessService_Nil verifies the ProcessService getter returns nil when not configured. -func TestServiceProcessService_Nil(t *testing.T) { - s, err := New() - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - if s.ProcessService() != nil { - t.Error("Expected ProcessService() to return nil when not configured") - } -} diff --git a/mcp/transport_e2e_test.go b/mcp/transport_e2e_test.go deleted file mode 100644 index 1a9a8d0..0000000 --- a/mcp/transport_e2e_test.go +++ /dev/null @@ -1,742 +0,0 @@ -package mcp - -import ( - "bufio" - "encoding/json" - "fmt" - "net" - "os" - "path/filepath" - "strings" - "testing" - "time" - - "context" -) - -// jsonRPCRequest builds a raw JSON-RPC 2.0 request string with newline delimiter. -func jsonRPCRequest(id int, method string, params any) string { - msg := map[string]any{ - "jsonrpc": "2.0", - "id": id, - "method": method, - } - if params != nil { - msg["params"] = params - } - data, _ := json.Marshal(msg) - return string(data) + "\n" -} - -// jsonRPCNotification builds a raw JSON-RPC 2.0 notification (no id). -func jsonRPCNotification(method string) string { - msg := map[string]any{ - "jsonrpc": "2.0", - "method": method, - } - data, _ := json.Marshal(msg) - return string(data) + "\n" -} - -// readJSONRPCResponse reads a single line-delimited JSON-RPC response and -// returns the decoded map. It handles the case where the server sends a -// ping request interleaved with responses (responds to it and keeps reading). -func readJSONRPCResponse(t *testing.T, scanner *bufio.Scanner, conn net.Conn) map[string]any { - t.Helper() - for { - if !scanner.Scan() { - if err := scanner.Err(); err != nil { - t.Fatalf("scanner error: %v", err) - } - t.Fatal("unexpected EOF reading JSON-RPC response") - } - line := scanner.Text() - var msg map[string]any - if err := json.Unmarshal([]byte(line), &msg); err != nil { - t.Fatalf("failed to unmarshal response: %v\nraw: %s", err, line) - } - - // If this is a server-initiated request (e.g. ping), respond and keep reading. - if method, ok := msg["method"]; ok { - if id, hasID := msg["id"]; hasID { - resp := map[string]any{ - "jsonrpc": "2.0", - "id": id, - "result": map[string]any{}, - } - data, _ := json.Marshal(resp) - _, _ = conn.Write(append(data, '\n')) - _ = method // consume - continue - } - // Notification from server — ignore and keep reading - continue - } - - return msg - } -} - -// --- TCP E2E Tests --- - -func TestTCPTransport_E2E_FullRoundTrip(t *testing.T) { - // Create a temp workspace with a known file - tmpDir := t.TempDir() - testContent := "hello from tcp e2e test" - if err := os.WriteFile(filepath.Join(tmpDir, "test.txt"), []byte(testContent), 0644); err != nil { - t.Fatalf("Failed to create test file: %v", err) - } - - s, err := New(WithWorkspaceRoot(tmpDir)) - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Start TCP server on a random port - errCh := make(chan error, 1) - go func() { - errCh <- s.ServeTCP(ctx, "127.0.0.1:0") - }() - - // Wait for the server to start and get the actual address. - // ServeTCP creates its own listener internally, so we need to probe. - // We'll retry connecting for up to 2 seconds. - var conn net.Conn - deadline := time.Now().Add(2 * time.Second) - // Since ServeTCP binds :0, we can't predict the port. Instead, create - // our own listener to find a free port, close it, then pass that port - // to ServeTCP. This is a known race, but fine for tests. - cancel() - <-errCh - - // Restart with a known port: find a free port first - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("Failed to find free port: %v", err) - } - addr := ln.Addr().String() - ln.Close() - - ctx2, cancel2 := context.WithCancel(context.Background()) - defer cancel2() - - errCh2 := make(chan error, 1) - go func() { - errCh2 <- s.ServeTCP(ctx2, addr) - }() - - // Wait for server to accept connections - deadline = time.Now().Add(2 * time.Second) - for time.Now().Before(deadline) { - conn, err = net.DialTimeout("tcp", addr, 200*time.Millisecond) - if err == nil { - break - } - time.Sleep(50 * time.Millisecond) - } - if err != nil { - t.Fatalf("Failed to connect to TCP server at %s: %v", addr, err) - } - defer conn.Close() - - // Set a read deadline to avoid hanging - conn.SetDeadline(time.Now().Add(10 * time.Second)) - - scanner := bufio.NewScanner(conn) - scanner.Buffer(make([]byte, 64*1024), 10*1024*1024) - - // Step 1: Send initialize request - initReq := jsonRPCRequest(1, "initialize", map[string]any{ - "protocolVersion": "2024-11-05", - "capabilities": map[string]any{}, - "clientInfo": map[string]any{"name": "TestClient", "version": "1.0.0"}, - }) - if _, err := conn.Write([]byte(initReq)); err != nil { - t.Fatalf("Failed to send initialize: %v", err) - } - - // Read initialize response - initResp := readJSONRPCResponse(t, scanner, conn) - if initResp["error"] != nil { - t.Fatalf("Initialize returned error: %v", initResp["error"]) - } - result, ok := initResp["result"].(map[string]any) - if !ok { - t.Fatalf("Expected result object, got %T", initResp["result"]) - } - serverInfo, _ := result["serverInfo"].(map[string]any) - if serverInfo["name"] != "core-cli" { - t.Errorf("Expected server name 'core-cli', got %v", serverInfo["name"]) - } - - // Step 2: Send notifications/initialized - if _, err := conn.Write([]byte(jsonRPCNotification("notifications/initialized"))); err != nil { - t.Fatalf("Failed to send initialized notification: %v", err) - } - - // Step 3: Send tools/list - if _, err := conn.Write([]byte(jsonRPCRequest(2, "tools/list", nil))); err != nil { - t.Fatalf("Failed to send tools/list: %v", err) - } - - toolsResp := readJSONRPCResponse(t, scanner, conn) - if toolsResp["error"] != nil { - t.Fatalf("tools/list returned error: %v", toolsResp["error"]) - } - - toolsResult, ok := toolsResp["result"].(map[string]any) - if !ok { - t.Fatalf("Expected result object for tools/list, got %T", toolsResp["result"]) - } - tools, ok := toolsResult["tools"].([]any) - if !ok || len(tools) == 0 { - t.Fatal("Expected non-empty tools list") - } - - // Verify file_read is among the tools - foundFileRead := false - for _, tool := range tools { - toolMap, _ := tool.(map[string]any) - if toolMap["name"] == "file_read" { - foundFileRead = true - break - } - } - if !foundFileRead { - t.Error("Expected file_read tool in tools/list response") - } - - // Step 4: Call file_read - callReq := jsonRPCRequest(3, "tools/call", map[string]any{ - "name": "file_read", - "arguments": map[string]any{"path": "test.txt"}, - }) - if _, err := conn.Write([]byte(callReq)); err != nil { - t.Fatalf("Failed to send tools/call: %v", err) - } - - callResp := readJSONRPCResponse(t, scanner, conn) - if callResp["error"] != nil { - t.Fatalf("tools/call file_read returned error: %v", callResp["error"]) - } - - callResult, ok := callResp["result"].(map[string]any) - if !ok { - t.Fatalf("Expected result object for tools/call, got %T", callResp["result"]) - } - - // The MCP SDK wraps tool results in content array - content, ok := callResult["content"].([]any) - if !ok || len(content) == 0 { - t.Fatal("Expected non-empty content in tools/call response") - } - - firstContent, _ := content[0].(map[string]any) - text, _ := firstContent["text"].(string) - if !strings.Contains(text, testContent) { - t.Errorf("Expected file content to contain %q, got %q", testContent, text) - } - - // Graceful shutdown - cancel2() - err = <-errCh2 - if err != nil { - t.Errorf("ServeTCP returned error: %v", err) - } -} - -func TestTCPTransport_E2E_FileWrite(t *testing.T) { - tmpDir := t.TempDir() - - s, err := New(WithWorkspaceRoot(tmpDir)) - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Find free port - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("Failed to find free port: %v", err) - } - addr := ln.Addr().String() - ln.Close() - - errCh := make(chan error, 1) - go func() { - errCh <- s.ServeTCP(ctx, addr) - }() - - // Connect - var conn net.Conn - deadline := time.Now().Add(2 * time.Second) - for time.Now().Before(deadline) { - conn, err = net.DialTimeout("tcp", addr, 200*time.Millisecond) - if err == nil { - break - } - time.Sleep(50 * time.Millisecond) - } - if err != nil { - t.Fatalf("Failed to connect: %v", err) - } - defer conn.Close() - - conn.SetDeadline(time.Now().Add(10 * time.Second)) - scanner := bufio.NewScanner(conn) - scanner.Buffer(make([]byte, 64*1024), 10*1024*1024) - - // Initialize handshake - conn.Write([]byte(jsonRPCRequest(1, "initialize", map[string]any{ - "protocolVersion": "2024-11-05", - "capabilities": map[string]any{}, - "clientInfo": map[string]any{"name": "TestClient", "version": "1.0.0"}, - }))) - readJSONRPCResponse(t, scanner, conn) - conn.Write([]byte(jsonRPCNotification("notifications/initialized"))) - - // Write a file - writeContent := "written via tcp transport" - conn.Write([]byte(jsonRPCRequest(2, "tools/call", map[string]any{ - "name": "file_write", - "arguments": map[string]any{"path": "tcp-written.txt", "content": writeContent}, - }))) - writeResp := readJSONRPCResponse(t, scanner, conn) - if writeResp["error"] != nil { - t.Fatalf("file_write returned error: %v", writeResp["error"]) - } - - // Verify file on disk - diskContent, err := os.ReadFile(filepath.Join(tmpDir, "tcp-written.txt")) - if err != nil { - t.Fatalf("Failed to read written file: %v", err) - } - if string(diskContent) != writeContent { - t.Errorf("Expected %q on disk, got %q", writeContent, string(diskContent)) - } - - cancel() - <-errCh -} - -// --- Unix Socket E2E Tests --- - -// shortSocketPath returns a Unix socket path under /tmp that fits within -// the macOS 104-byte sun_path limit. t.TempDir() paths on macOS are -// often too long (>104 bytes) for Unix sockets. -func shortSocketPath(t *testing.T, suffix string) string { - t.Helper() - path := fmt.Sprintf("/tmp/mcp-test-%s-%d.sock", suffix, os.Getpid()) - t.Cleanup(func() { os.Remove(path) }) - return path -} - -func TestUnixTransport_E2E_FullRoundTrip(t *testing.T) { - // Create a temp workspace with a known file - tmpDir := t.TempDir() - testContent := "hello from unix e2e test" - if err := os.WriteFile(filepath.Join(tmpDir, "test.txt"), []byte(testContent), 0644); err != nil { - t.Fatalf("Failed to create test file: %v", err) - } - - s, err := New(WithWorkspaceRoot(tmpDir)) - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Use a short socket path to avoid macOS 104-byte sun_path limit - socketPath := shortSocketPath(t, "full") - - errCh := make(chan error, 1) - go func() { - errCh <- s.ServeUnix(ctx, socketPath) - }() - - // Wait for socket to appear - var conn net.Conn - deadline := time.Now().Add(2 * time.Second) - for time.Now().Before(deadline) { - conn, err = net.DialTimeout("unix", socketPath, 200*time.Millisecond) - if err == nil { - break - } - time.Sleep(50 * time.Millisecond) - } - if err != nil { - t.Fatalf("Failed to connect to Unix socket at %s: %v", socketPath, err) - } - defer conn.Close() - - conn.SetDeadline(time.Now().Add(10 * time.Second)) - scanner := bufio.NewScanner(conn) - scanner.Buffer(make([]byte, 64*1024), 10*1024*1024) - - // Step 1: Initialize - conn.Write([]byte(jsonRPCRequest(1, "initialize", map[string]any{ - "protocolVersion": "2024-11-05", - "capabilities": map[string]any{}, - "clientInfo": map[string]any{"name": "TestClient", "version": "1.0.0"}, - }))) - initResp := readJSONRPCResponse(t, scanner, conn) - if initResp["error"] != nil { - t.Fatalf("Initialize returned error: %v", initResp["error"]) - } - - // Step 2: Send initialized notification - conn.Write([]byte(jsonRPCNotification("notifications/initialized"))) - - // Step 3: tools/list - conn.Write([]byte(jsonRPCRequest(2, "tools/list", nil))) - toolsResp := readJSONRPCResponse(t, scanner, conn) - if toolsResp["error"] != nil { - t.Fatalf("tools/list returned error: %v", toolsResp["error"]) - } - - toolsResult, ok := toolsResp["result"].(map[string]any) - if !ok { - t.Fatalf("Expected result object, got %T", toolsResp["result"]) - } - tools, ok := toolsResult["tools"].([]any) - if !ok || len(tools) == 0 { - t.Fatal("Expected non-empty tools list") - } - - // Step 4: Call file_read - conn.Write([]byte(jsonRPCRequest(3, "tools/call", map[string]any{ - "name": "file_read", - "arguments": map[string]any{"path": "test.txt"}, - }))) - callResp := readJSONRPCResponse(t, scanner, conn) - if callResp["error"] != nil { - t.Fatalf("tools/call file_read returned error: %v", callResp["error"]) - } - - callResult, ok := callResp["result"].(map[string]any) - if !ok { - t.Fatalf("Expected result object, got %T", callResp["result"]) - } - content, ok := callResult["content"].([]any) - if !ok || len(content) == 0 { - t.Fatal("Expected non-empty content") - } - - firstContent, _ := content[0].(map[string]any) - text, _ := firstContent["text"].(string) - if !strings.Contains(text, testContent) { - t.Errorf("Expected content to contain %q, got %q", testContent, text) - } - - // Graceful shutdown - cancel() - err = <-errCh - if err != nil { - t.Errorf("ServeUnix returned error: %v", err) - } - - // Verify socket file is cleaned up - if _, err := os.Stat(socketPath); !os.IsNotExist(err) { - t.Error("Expected socket file to be cleaned up after shutdown") - } -} - -func TestUnixTransport_E2E_DirList(t *testing.T) { - tmpDir := t.TempDir() - - // Create some files and dirs - os.MkdirAll(filepath.Join(tmpDir, "subdir"), 0755) - os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("one"), 0644) - os.WriteFile(filepath.Join(tmpDir, "subdir", "file2.txt"), []byte("two"), 0644) - - s, err := New(WithWorkspaceRoot(tmpDir)) - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - socketPath := shortSocketPath(t, "dir") - - errCh := make(chan error, 1) - go func() { - errCh <- s.ServeUnix(ctx, socketPath) - }() - - var conn net.Conn - deadline := time.Now().Add(2 * time.Second) - for time.Now().Before(deadline) { - conn, err = net.DialTimeout("unix", socketPath, 200*time.Millisecond) - if err == nil { - break - } - time.Sleep(50 * time.Millisecond) - } - if err != nil { - t.Fatalf("Failed to connect: %v", err) - } - defer conn.Close() - - conn.SetDeadline(time.Now().Add(10 * time.Second)) - scanner := bufio.NewScanner(conn) - scanner.Buffer(make([]byte, 64*1024), 10*1024*1024) - - // Initialize - conn.Write([]byte(jsonRPCRequest(1, "initialize", map[string]any{ - "protocolVersion": "2024-11-05", - "capabilities": map[string]any{}, - "clientInfo": map[string]any{"name": "TestClient", "version": "1.0.0"}, - }))) - readJSONRPCResponse(t, scanner, conn) - conn.Write([]byte(jsonRPCNotification("notifications/initialized"))) - - // Call dir_list on root - conn.Write([]byte(jsonRPCRequest(2, "tools/call", map[string]any{ - "name": "dir_list", - "arguments": map[string]any{"path": "."}, - }))) - dirResp := readJSONRPCResponse(t, scanner, conn) - if dirResp["error"] != nil { - t.Fatalf("dir_list returned error: %v", dirResp["error"]) - } - - dirResult, ok := dirResp["result"].(map[string]any) - if !ok { - t.Fatalf("Expected result object, got %T", dirResp["result"]) - } - dirContent, ok := dirResult["content"].([]any) - if !ok || len(dirContent) == 0 { - t.Fatal("Expected non-empty content in dir_list response") - } - - // The response content should mention our files - firstItem, _ := dirContent[0].(map[string]any) - text, _ := firstItem["text"].(string) - if !strings.Contains(text, "file1.txt") && !strings.Contains(text, "subdir") { - t.Errorf("Expected dir_list to contain file1.txt or subdir, got: %s", text) - } - - cancel() - <-errCh -} - -// --- Stdio Transport Tests --- - -func TestStdioTransport_Documented_Skip(t *testing.T) { - // The MCP SDK's StdioTransport binds directly to os.Stdin and os.Stdout, - // with no way to inject custom io.Reader/io.Writer. Testing stdio transport - // would require spawning the binary as a subprocess and piping JSON-RPC - // messages through its stdin/stdout. - // - // Since the core MCP protocol handling is identical across all transports - // (the transport layer only handles framing), and we thoroughly test the - // protocol via TCP and Unix socket e2e tests, the stdio transport is - // effectively covered. The only untested code path is the StdioTransport - // adapter itself, which is a thin wrapper in the upstream SDK. - // - // If process-level testing is needed in the future, the approach would be: - // 1. Build the binary: `go build -o /tmp/mcp-test ./cmd/...` - // 2. Spawn it: exec.Command("/tmp/mcp-test", "mcp", "serve") - // 3. Pipe JSON-RPC to stdin, read from stdout - // 4. Verify responses match expected protocol - t.Skip("stdio transport requires process spawning; protocol is covered by TCP and Unix e2e tests") -} - -// --- Helper: verify a specific tool exists in tools/list response --- - -func assertToolExists(t *testing.T, tools []any, name string) { - t.Helper() - for _, tool := range tools { - toolMap, _ := tool.(map[string]any) - if toolMap["name"] == name { - return - } - } - toolNames := make([]string, 0, len(tools)) - for _, tool := range tools { - toolMap, _ := tool.(map[string]any) - if n, ok := toolMap["name"].(string); ok { - toolNames = append(toolNames, n) - } - } - t.Errorf("Expected tool %q in list, got: %v", name, toolNames) -} - -func TestTCPTransport_E2E_ToolsDiscovery(t *testing.T) { - tmpDir := t.TempDir() - - s, err := New(WithWorkspaceRoot(tmpDir)) - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("Failed to find free port: %v", err) - } - addr := ln.Addr().String() - ln.Close() - - errCh := make(chan error, 1) - go func() { - errCh <- s.ServeTCP(ctx, addr) - }() - - var conn net.Conn - deadline := time.Now().Add(2 * time.Second) - for time.Now().Before(deadline) { - conn, err = net.DialTimeout("tcp", addr, 200*time.Millisecond) - if err == nil { - break - } - time.Sleep(50 * time.Millisecond) - } - if err != nil { - t.Fatalf("Failed to connect: %v", err) - } - defer conn.Close() - - conn.SetDeadline(time.Now().Add(10 * time.Second)) - scanner := bufio.NewScanner(conn) - scanner.Buffer(make([]byte, 64*1024), 10*1024*1024) - - // Initialize - conn.Write([]byte(jsonRPCRequest(1, "initialize", map[string]any{ - "protocolVersion": "2024-11-05", - "capabilities": map[string]any{}, - "clientInfo": map[string]any{"name": "TestClient", "version": "1.0.0"}, - }))) - readJSONRPCResponse(t, scanner, conn) - conn.Write([]byte(jsonRPCNotification("notifications/initialized"))) - - // Get tools list - conn.Write([]byte(jsonRPCRequest(2, "tools/list", nil))) - toolsResp := readJSONRPCResponse(t, scanner, conn) - if toolsResp["error"] != nil { - t.Fatalf("tools/list error: %v", toolsResp["error"]) - } - toolsResult, _ := toolsResp["result"].(map[string]any) - tools, _ := toolsResult["tools"].([]any) - - // Verify all core tools are registered - expectedTools := []string{ - "file_read", "file_write", "file_delete", "file_rename", - "file_exists", "file_edit", "dir_list", "dir_create", - "lang_detect", "lang_list", - } - for _, name := range expectedTools { - assertToolExists(t, tools, name) - } - - // Log total tool count for visibility - t.Logf("Server registered %d tools", len(tools)) - - cancel() - <-errCh -} - -func TestTCPTransport_E2E_ErrorHandling(t *testing.T) { - tmpDir := t.TempDir() - - s, err := New(WithWorkspaceRoot(tmpDir)) - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("Failed to find free port: %v", err) - } - addr := ln.Addr().String() - ln.Close() - - errCh := make(chan error, 1) - go func() { - errCh <- s.ServeTCP(ctx, addr) - }() - - var conn net.Conn - deadline := time.Now().Add(2 * time.Second) - for time.Now().Before(deadline) { - conn, err = net.DialTimeout("tcp", addr, 200*time.Millisecond) - if err == nil { - break - } - time.Sleep(50 * time.Millisecond) - } - if err != nil { - t.Fatalf("Failed to connect: %v", err) - } - defer conn.Close() - - conn.SetDeadline(time.Now().Add(10 * time.Second)) - scanner := bufio.NewScanner(conn) - scanner.Buffer(make([]byte, 64*1024), 10*1024*1024) - - // Initialize - conn.Write([]byte(jsonRPCRequest(1, "initialize", map[string]any{ - "protocolVersion": "2024-11-05", - "capabilities": map[string]any{}, - "clientInfo": map[string]any{"name": "TestClient", "version": "1.0.0"}, - }))) - readJSONRPCResponse(t, scanner, conn) - conn.Write([]byte(jsonRPCNotification("notifications/initialized"))) - - // Try to read a nonexistent file - conn.Write([]byte(jsonRPCRequest(2, "tools/call", map[string]any{ - "name": "file_read", - "arguments": map[string]any{"path": "nonexistent.txt"}, - }))) - errResp := readJSONRPCResponse(t, scanner, conn) - - // The MCP SDK wraps tool errors as isError content, not JSON-RPC errors. - // Check both possibilities. - if errResp["error"] != nil { - // JSON-RPC level error — this is acceptable - t.Logf("Got JSON-RPC error for nonexistent file: %v", errResp["error"]) - } else { - errResult, _ := errResp["result"].(map[string]any) - isError, _ := errResult["isError"].(bool) - if !isError { - // Check content for error indicator - content, _ := errResult["content"].([]any) - if len(content) > 0 { - firstContent, _ := content[0].(map[string]any) - text, _ := firstContent["text"].(string) - t.Logf("Tool response for nonexistent file: %s", text) - } - } - } - - // Verify tools/call without params returns an error - conn.Write([]byte(jsonRPCRequest(3, "tools/call", nil))) - noParamsResp := readJSONRPCResponse(t, scanner, conn) - if noParamsResp["error"] == nil { - t.Log("tools/call without params did not return JSON-RPC error (SDK may handle differently)") - } else { - errObj, _ := noParamsResp["error"].(map[string]any) - code, _ := errObj["code"].(float64) - if code != -32600 { - t.Logf("tools/call without params returned error code: %v", code) - } - } - - cancel() - <-errCh -} - -// Suppress "unused import" for fmt — used in helpers -var _ = fmt.Sprintf diff --git a/mcp/transport_stdio.go b/mcp/transport_stdio.go deleted file mode 100644 index 10ea27c..0000000 --- a/mcp/transport_stdio.go +++ /dev/null @@ -1,15 +0,0 @@ -package mcp - -import ( - "context" - - "forge.lthn.ai/core/go-log" - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// ServeStdio starts the MCP server over stdin/stdout. -// This is the default transport for CLI integrations. -func (s *Service) ServeStdio(ctx context.Context) error { - s.logger.Info("MCP Stdio server starting", "user", log.Username()) - return s.server.Run(ctx, &mcp.StdioTransport{}) -} diff --git a/mcp/transport_tcp.go b/mcp/transport_tcp.go deleted file mode 100644 index eb7ec91..0000000 --- a/mcp/transport_tcp.go +++ /dev/null @@ -1,177 +0,0 @@ -package mcp - -import ( - "bufio" - "context" - "fmt" - "io" - "net" - "os" - "sync" - - "github.com/modelcontextprotocol/go-sdk/jsonrpc" - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// DefaultTCPAddr is the default address for the MCP TCP server. -const DefaultTCPAddr = "127.0.0.1:9100" - -// diagMu protects diagWriter from concurrent access across tests and goroutines. -var diagMu sync.Mutex - -// diagWriter is the destination for warning and diagnostic messages. -// Use diagPrintf to write to it safely. -var diagWriter io.Writer = os.Stderr - -// diagPrintf writes a formatted message to diagWriter under the mutex. -func diagPrintf(format string, args ...any) { - diagMu.Lock() - defer diagMu.Unlock() - fmt.Fprintf(diagWriter, format, args...) -} - -// setDiagWriter swaps the diagnostic writer and returns the previous one. -// Used by tests to capture output without racing. -func setDiagWriter(w io.Writer) io.Writer { - diagMu.Lock() - defer diagMu.Unlock() - old := diagWriter - diagWriter = w - return old -} - -// maxMCPMessageSize is the maximum size for MCP JSON-RPC messages (10 MB). -const maxMCPMessageSize = 10 * 1024 * 1024 - -// TCPTransport manages a TCP listener for MCP. -type TCPTransport struct { - addr string - listener net.Listener -} - -// NewTCPTransport creates a new TCP transport listener. -// It listens on the provided address (e.g. "localhost:9100"). -// Defaults to 127.0.0.1 when the host component is empty (e.g. ":9100"). -// Emits a security warning when explicitly binding to 0.0.0.0 (all interfaces). -func NewTCPTransport(addr string) (*TCPTransport, error) { - host, port, _ := net.SplitHostPort(addr) - if host == "" { - addr = net.JoinHostPort("127.0.0.1", port) - } else if host == "0.0.0.0" { - diagPrintf("WARNING: MCP TCP server binding to all interfaces (%s). Use 127.0.0.1 for local-only access.\n", addr) - } - listener, err := net.Listen("tcp", addr) - if err != nil { - return nil, err - } - return &TCPTransport{addr: addr, listener: listener}, nil -} - -// ServeTCP starts a TCP server for the MCP service. -// It accepts connections and spawns a new MCP server session for each connection. -func (s *Service) ServeTCP(ctx context.Context, addr string) error { - t, err := NewTCPTransport(addr) - if err != nil { - return err - } - defer func() { _ = t.listener.Close() }() - - // Close listener when context is cancelled to unblock Accept - go func() { - <-ctx.Done() - _ = t.listener.Close() - }() - - if addr == "" { - addr = t.listener.Addr().String() - } - diagPrintf("MCP TCP server listening on %s\n", addr) - - for { - conn, err := t.listener.Accept() - if err != nil { - select { - case <-ctx.Done(): - return nil - default: - diagPrintf("Accept error: %v\n", err) - continue - } - } - - go s.handleConnection(ctx, conn) - } -} - -func (s *Service) handleConnection(ctx context.Context, conn net.Conn) { - // Note: We don't defer conn.Close() here because it's closed by the Server/Transport - - // Create new server instance for this connection - impl := &mcp.Implementation{ - Name: "core-cli", - Version: "0.1.0", - } - server := mcp.NewServer(impl, nil) - s.registerTools(server) - - // Create transport for this connection - transport := &connTransport{conn: conn} - - // Run server (blocks until connection closed) - // Server.Run calls Connect, then Read loop. - if err := server.Run(ctx, transport); err != nil { - diagPrintf("Connection error: %v\n", err) - } -} - -// connTransport adapts net.Conn to mcp.Transport -type connTransport struct { - conn net.Conn -} - -func (t *connTransport) Connect(ctx context.Context) (mcp.Connection, error) { - scanner := bufio.NewScanner(t.conn) - scanner.Buffer(make([]byte, 64*1024), maxMCPMessageSize) - return &connConnection{ - conn: t.conn, - scanner: scanner, - }, nil -} - -// connConnection implements mcp.Connection -type connConnection struct { - conn net.Conn - scanner *bufio.Scanner -} - -func (c *connConnection) Read(ctx context.Context) (jsonrpc.Message, error) { - // Blocks until line is read - if !c.scanner.Scan() { - if err := c.scanner.Err(); err != nil { - return nil, err - } - // EOF - connection closed cleanly - return nil, io.EOF - } - line := c.scanner.Bytes() - return jsonrpc.DecodeMessage(line) -} - -func (c *connConnection) Write(ctx context.Context, msg jsonrpc.Message) error { - data, err := jsonrpc.EncodeMessage(msg) - if err != nil { - return err - } - // Append newline for line-delimited JSON - data = append(data, '\n') - _, err = c.conn.Write(data) - return err -} - -func (c *connConnection) Close() error { - return c.conn.Close() -} - -func (c *connConnection) SessionID() string { - return "tcp-session" // Unique ID might be better, but optional -} diff --git a/mcp/transport_tcp_test.go b/mcp/transport_tcp_test.go deleted file mode 100644 index ba9a229..0000000 --- a/mcp/transport_tcp_test.go +++ /dev/null @@ -1,184 +0,0 @@ -package mcp - -import ( - "bytes" - "context" - "net" - "os" - "strings" - "testing" - "time" -) - -func TestNewTCPTransport_Defaults(t *testing.T) { - // Test that empty string gets replaced with default address constant - // Note: We can't actually bind to 9100 as it may be in use, - // so we verify the address is set correctly before Listen is called - if DefaultTCPAddr != "127.0.0.1:9100" { - t.Errorf("Expected default constant 127.0.0.1:9100, got %s", DefaultTCPAddr) - } - - // Test with a dynamic port to verify transport creation works - tr, err := NewTCPTransport("127.0.0.1:0") - if err != nil { - t.Fatalf("Failed to create transport with dynamic port: %v", err) - } - defer tr.listener.Close() - - // Verify we got a valid address - if tr.addr != "127.0.0.1:0" { - t.Errorf("Expected address to be set, got %s", tr.addr) - } -} - -func TestNewTCPTransport_Warning(t *testing.T) { - // Capture warning output via setDiagWriter (mutex-protected, no race). - var buf bytes.Buffer - old := setDiagWriter(&buf) - defer setDiagWriter(old) - - // Trigger warning - tr, err := NewTCPTransport("0.0.0.0:9101") - if err != nil { - t.Fatalf("Failed to create transport: %v", err) - } - defer tr.listener.Close() - - output := buf.String() - if !strings.Contains(output, "WARNING") { - t.Error("Expected warning for binding to 0.0.0.0, but didn't find it in stderr") - } -} - -func TestServeTCP_Connection(t *testing.T) { - s, err := New() - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Use a random port for testing to avoid collisions - addr := "127.0.0.1:0" - - // Create transport first to get the actual address if we use :0 - tr, err := NewTCPTransport(addr) - if err != nil { - t.Fatalf("Failed to create transport: %v", err) - } - actualAddr := tr.listener.Addr().String() - tr.listener.Close() // Close it so ServeTCP can re-open it or use the same address - - // Start server in background - errCh := make(chan error, 1) - go func() { - errCh <- s.ServeTCP(ctx, actualAddr) - }() - - // Give it a moment to start - time.Sleep(100 * time.Millisecond) - - // Connect to the server - conn, err := net.Dial("tcp", actualAddr) - if err != nil { - t.Fatalf("Failed to connect to server: %v", err) - } - defer conn.Close() - - // Verify we can write to it - _, err = conn.Write([]byte("{}\n")) - if err != nil { - t.Errorf("Failed to write to connection: %v", err) - } - - // Shutdown server - cancel() - err = <-errCh - if err != nil { - t.Errorf("ServeTCP returned error: %v", err) - } -} - -func TestRun_TCPTrigger(t *testing.T) { - s, err := New() - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Set MCP_ADDR to empty to trigger default TCP - os.Setenv("MCP_ADDR", "") - defer os.Unsetenv("MCP_ADDR") - - // We use a random port for testing, but Run will try to use 127.0.0.1:9100 by default if we don't override. - // Since 9100 might be in use, we'll set MCP_ADDR to use :0 (random port) - os.Setenv("MCP_ADDR", "127.0.0.1:0") - - errCh := make(chan error, 1) - go func() { - errCh <- s.Run(ctx) - }() - - // Give it a moment to start - time.Sleep(100 * time.Millisecond) - - // Since we can't easily get the actual port used by Run (it's internal), - // we just verify it didn't immediately fail. - select { - case err := <-errCh: - t.Fatalf("Run failed immediately: %v", err) - default: - // still running, which is good - } - - cancel() - _ = <-errCh -} - -func TestServeTCP_MultipleConnections(t *testing.T) { - s, err := New() - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - addr := "127.0.0.1:0" - tr, err := NewTCPTransport(addr) - if err != nil { - t.Fatalf("Failed to create transport: %v", err) - } - actualAddr := tr.listener.Addr().String() - tr.listener.Close() - - errCh := make(chan error, 1) - go func() { - errCh <- s.ServeTCP(ctx, actualAddr) - }() - - time.Sleep(100 * time.Millisecond) - - // Connect multiple clients - const numClients = 3 - for i := range numClients { - conn, err := net.Dial("tcp", actualAddr) - if err != nil { - t.Fatalf("Client %d failed to connect: %v", i, err) - } - defer conn.Close() - _, err = conn.Write([]byte("{}\n")) - if err != nil { - t.Errorf("Client %d failed to write: %v", i, err) - } - } - - cancel() - err = <-errCh - if err != nil { - t.Errorf("ServeTCP returned error: %v", err) - } -} diff --git a/mcp/transport_unix.go b/mcp/transport_unix.go deleted file mode 100644 index c70d5d9..0000000 --- a/mcp/transport_unix.go +++ /dev/null @@ -1,52 +0,0 @@ -package mcp - -import ( - "context" - "net" - "os" - - "forge.lthn.ai/core/go-log" -) - -// ServeUnix starts a Unix domain socket server for the MCP service. -// The socket file is created at the given path and removed on shutdown. -// It accepts connections and spawns a new MCP server session for each connection. -func (s *Service) ServeUnix(ctx context.Context, socketPath string) error { - // Clean up any stale socket file - if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) { - s.logger.Warn("Failed to remove stale socket", "path", socketPath, "err", err) - } - - listener, err := net.Listen("unix", socketPath) - if err != nil { - return err - } - defer func() { - _ = listener.Close() - _ = os.Remove(socketPath) - }() - - // Close listener when context is cancelled to unblock Accept - go func() { - <-ctx.Done() - _ = listener.Close() - }() - - s.logger.Security("MCP Unix server listening", "path", socketPath, "user", log.Username()) - - for { - conn, err := listener.Accept() - if err != nil { - select { - case <-ctx.Done(): - return nil - default: - s.logger.Error("MCP Unix accept error", "err", err, "user", log.Username()) - continue - } - } - - s.logger.Security("MCP Unix connection accepted", "user", log.Username()) - go s.handleConnection(ctx, conn) - } -}