refactor: extract MCP server to core/mcp
Some checks failed
Security Scan / security (push) Successful in 9s
Test / test (push) Failing after 1m47s

Move mcp/, cmd/mcpcmd/, cmd/brain-seed/ to the new core/mcp repo.
Update daemon import to use forge.lthn.ai/core/mcp/pkg/mcp.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Snider 2026-03-09 18:40:50 +00:00
parent 413c637d26
commit 0202bec84a
43 changed files with 1 additions and 10067 deletions

View file

@ -1,502 +0,0 @@
// SPDX-License-Identifier: EUPL-1.2
// brain-seed imports Claude Code MEMORY.md files into the OpenBrain knowledge
// store via the MCP HTTP API (brain_remember tool). The Laravel app handles
// embedding, Qdrant storage, and MariaDB dual-write internally.
//
// Usage:
//
// go run ./cmd/brain-seed -api-key YOUR_KEY
// go run ./cmd/brain-seed -api-key YOUR_KEY -api https://lthn.sh/api/v1/mcp
// go run ./cmd/brain-seed -api-key YOUR_KEY -dry-run
// go run ./cmd/brain-seed -api-key YOUR_KEY -plans
// go run ./cmd/brain-seed -api-key YOUR_KEY -claude-md # Also import CLAUDE.md files
package main
import (
"bytes"
"crypto/tls"
"encoding/json"
"flag"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"time"
)
var (
apiURL = flag.String("api", "https://lthn.sh/api/v1/mcp", "MCP API base URL")
apiKey = flag.String("api-key", "", "MCP API key (Bearer token)")
server = flag.String("server", "hosthub-agent", "MCP server ID")
agent = flag.String("agent", "charon", "Agent ID for attribution")
dryRun = flag.Bool("dry-run", false, "Preview without storing")
plans = flag.Bool("plans", false, "Also import plan documents")
claudeMd = flag.Bool("claude-md", false, "Also import CLAUDE.md files")
memoryPath = flag.String("memory-path", "", "Override memory scan path (default: ~/.claude/projects/*/memory/)")
planPath = flag.String("plan-path", "", "Override plan scan path (default: ~/Code/*/docs/plans/)")
codePath = flag.String("code-path", "", "Override code root for CLAUDE.md scan (default: ~/Code)")
maxChars = flag.Int("max-chars", 3800, "Max chars per section (embeddinggemma limit ~4000)")
)
// httpClient with TLS skip for non-public TLDs (.lthn.sh has real certs, but
// allow .lan/.local if someone has legacy config).
var httpClient = &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: false},
},
}
func main() {
flag.Parse()
fmt.Println("OpenBrain Seed — MCP API Client")
fmt.Println(strings.Repeat("=", 55))
if *apiKey == "" && !*dryRun {
fmt.Println("ERROR: -api-key is required (or use -dry-run)")
fmt.Println(" Generate one at: https://lthn.sh/admin/mcp/api-keys")
os.Exit(1)
}
if *dryRun {
fmt.Println("[DRY RUN] — no data will be stored")
}
fmt.Printf("API: %s\n", *apiURL)
fmt.Printf("Server: %s | Agent: %s\n", *server, *agent)
// Discover memory files
memPath := *memoryPath
if memPath == "" {
home, _ := os.UserHomeDir()
memPath = filepath.Join(home, ".claude", "projects", "*", "memory")
}
memFiles, _ := filepath.Glob(filepath.Join(memPath, "*.md"))
fmt.Printf("\nFound %d memory files\n", len(memFiles))
// Discover plan files
var planFiles []string
if *plans {
pPath := *planPath
if pPath == "" {
home, _ := os.UserHomeDir()
pPath = filepath.Join(home, "Code", "*", "docs", "plans")
}
planFiles, _ = filepath.Glob(filepath.Join(pPath, "*.md"))
// Also check nested dirs (completed/, etc.)
nested, _ := filepath.Glob(filepath.Join(pPath, "*", "*.md"))
planFiles = append(planFiles, nested...)
// Also check host-uk nested repos
home, _ := os.UserHomeDir()
hostUkPath := filepath.Join(home, "Code", "host-uk", "*", "docs", "plans")
hostUkFiles, _ := filepath.Glob(filepath.Join(hostUkPath, "*.md"))
planFiles = append(planFiles, hostUkFiles...)
hostUkNested, _ := filepath.Glob(filepath.Join(hostUkPath, "*", "*.md"))
planFiles = append(planFiles, hostUkNested...)
fmt.Printf("Found %d plan files\n", len(planFiles))
}
// Discover CLAUDE.md files
var claudeFiles []string
if *claudeMd {
cPath := *codePath
if cPath == "" {
home, _ := os.UserHomeDir()
cPath = filepath.Join(home, "Code")
}
claudeFiles = discoverClaudeMdFiles(cPath)
fmt.Printf("Found %d CLAUDE.md files\n", len(claudeFiles))
}
imported := 0
skipped := 0
errors := 0
// Process memory files
fmt.Println("\n--- Memory Files ---")
for _, f := range memFiles {
project := extractProject(f)
sections := parseMarkdownSections(f)
filename := strings.TrimSuffix(filepath.Base(f), ".md")
if len(sections) == 0 {
fmt.Printf(" skip %s/%s (no sections)\n", project, filename)
skipped++
continue
}
for _, sec := range sections {
content := sec.heading + "\n\n" + sec.content
if strings.TrimSpace(sec.content) == "" {
skipped++
continue
}
memType := inferType(sec.heading, sec.content, "memory")
tags := buildTags(filename, "memory", project)
confidence := confidenceForSource("memory")
// Truncate to embedding model limit
content = truncate(content, *maxChars)
if *dryRun {
fmt.Printf(" [DRY] %s/%s :: %s (%s) — %d chars\n",
project, filename, sec.heading, memType, len(content))
imported++
continue
}
if err := callBrainRemember(content, memType, tags, project, confidence); err != nil {
fmt.Printf(" FAIL %s/%s :: %s — %v\n", project, filename, sec.heading, err)
errors++
continue
}
fmt.Printf(" ok %s/%s :: %s (%s)\n", project, filename, sec.heading, memType)
imported++
}
}
// Process plan files
if *plans && len(planFiles) > 0 {
fmt.Println("\n--- Plan Documents ---")
for _, f := range planFiles {
project := extractProjectFromPlan(f)
sections := parseMarkdownSections(f)
filename := strings.TrimSuffix(filepath.Base(f), ".md")
if len(sections) == 0 {
skipped++
continue
}
for _, sec := range sections {
content := sec.heading + "\n\n" + sec.content
if strings.TrimSpace(sec.content) == "" {
skipped++
continue
}
tags := buildTags(filename, "plans", project)
content = truncate(content, *maxChars)
if *dryRun {
fmt.Printf(" [DRY] %s :: %s / %s (plan) — %d chars\n",
project, filename, sec.heading, len(content))
imported++
continue
}
if err := callBrainRemember(content, "plan", tags, project, 0.6); err != nil {
fmt.Printf(" FAIL %s :: %s / %s — %v\n", project, filename, sec.heading, err)
errors++
continue
}
fmt.Printf(" ok %s :: %s / %s (plan)\n", project, filename, sec.heading)
imported++
}
}
}
// Process CLAUDE.md files
if *claudeMd && len(claudeFiles) > 0 {
fmt.Println("\n--- CLAUDE.md Files ---")
for _, f := range claudeFiles {
project := extractProjectFromClaudeMd(f)
sections := parseMarkdownSections(f)
if len(sections) == 0 {
skipped++
continue
}
for _, sec := range sections {
content := sec.heading + "\n\n" + sec.content
if strings.TrimSpace(sec.content) == "" {
skipped++
continue
}
tags := buildTags("CLAUDE", "claude-md", project)
content = truncate(content, *maxChars)
if *dryRun {
fmt.Printf(" [DRY] %s :: CLAUDE.md / %s (convention) — %d chars\n",
project, sec.heading, len(content))
imported++
continue
}
if err := callBrainRemember(content, "convention", tags, project, 0.9); err != nil {
fmt.Printf(" FAIL %s :: CLAUDE.md / %s — %v\n", project, sec.heading, err)
errors++
continue
}
fmt.Printf(" ok %s :: CLAUDE.md / %s (convention)\n", project, sec.heading)
imported++
}
}
}
fmt.Printf("\n%s\n", strings.Repeat("=", 55))
prefix := ""
if *dryRun {
prefix = "[DRY RUN] "
}
fmt.Printf("%sImported: %d | Skipped: %d | Errors: %d\n", prefix, imported, skipped, errors)
}
// callBrainRemember sends a memory to the MCP API via brain_remember tool.
func callBrainRemember(content, memType string, tags []string, project string, confidence float64) error {
args := map[string]any{
"content": content,
"type": memType,
"tags": tags,
"confidence": confidence,
}
if project != "" && project != "unknown" {
args["project"] = project
}
payload := map[string]any{
"server": *server,
"tool": "brain_remember",
"arguments": args,
}
body, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal: %w", err)
}
req, err := http.NewRequest("POST", *apiURL+"/tools/call", bytes.NewReader(body))
if err != nil {
return fmt.Errorf("request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+*apiKey)
resp, err := httpClient.Do(req)
if err != nil {
return fmt.Errorf("http: %w", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody))
}
var result struct {
Success bool `json:"success"`
Error string `json:"error"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return fmt.Errorf("decode: %w", err)
}
if !result.Success {
return fmt.Errorf("API: %s", result.Error)
}
return nil
}
// truncate caps content to maxLen chars, appending an ellipsis if truncated.
func truncate(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
// Find last space before limit to avoid splitting mid-word
cut := maxLen
if idx := strings.LastIndex(s[:maxLen], " "); idx > maxLen-200 {
cut = idx
}
return s[:cut] + "…"
}
// discoverClaudeMdFiles finds CLAUDE.md files across a code directory.
func discoverClaudeMdFiles(codePath string) []string {
var files []string
// Walk up to 4 levels deep, skip node_modules/vendor/.claude
_ = filepath.WalkDir(codePath, func(path string, d os.DirEntry, err error) error {
if err != nil {
return nil
}
if d.IsDir() {
name := d.Name()
if name == "node_modules" || name == "vendor" || name == ".claude" {
return filepath.SkipDir
}
// Limit depth
rel, _ := filepath.Rel(codePath, path)
if strings.Count(rel, string(os.PathSeparator)) > 3 {
return filepath.SkipDir
}
return nil
}
if d.Name() == "CLAUDE.md" {
files = append(files, path)
}
return nil
})
return files
}
// section is a parsed markdown section.
type section struct {
heading string
content string
}
var headingRe = regexp.MustCompile(`^#{1,3}\s+(.+)$`)
// parseMarkdownSections splits a markdown file by headings.
func parseMarkdownSections(path string) []section {
data, err := os.ReadFile(path)
if err != nil || len(data) == 0 {
return nil
}
var sections []section
lines := strings.Split(string(data), "\n")
var curHeading string
var curContent []string
for _, line := range lines {
if m := headingRe.FindStringSubmatch(line); m != nil {
if curHeading != "" && len(curContent) > 0 {
text := strings.TrimSpace(strings.Join(curContent, "\n"))
if text != "" {
sections = append(sections, section{curHeading, text})
}
}
curHeading = strings.TrimSpace(m[1])
curContent = nil
} else {
curContent = append(curContent, line)
}
}
// Flush last section
if curHeading != "" && len(curContent) > 0 {
text := strings.TrimSpace(strings.Join(curContent, "\n"))
if text != "" {
sections = append(sections, section{curHeading, text})
}
}
// If no headings found, treat entire file as one section
if len(sections) == 0 && strings.TrimSpace(string(data)) != "" {
sections = append(sections, section{
heading: strings.TrimSuffix(filepath.Base(path), ".md"),
content: strings.TrimSpace(string(data)),
})
}
return sections
}
// extractProject derives a project name from a Claude memory path.
// ~/.claude/projects/-Users-snider-Code-eaas/memory/MEMORY.md → "eaas"
func extractProject(path string) string {
re := regexp.MustCompile(`projects/[^/]*-([^-/]+)/memory/`)
if m := re.FindStringSubmatch(path); m != nil {
return m[1]
}
return "unknown"
}
// extractProjectFromPlan derives a project name from a plan path.
// ~/Code/eaas/docs/plans/foo.md → "eaas"
// ~/Code/host-uk/core/docs/plans/foo.md → "core"
func extractProjectFromPlan(path string) string {
// Check host-uk nested repos first
re := regexp.MustCompile(`Code/host-uk/([^/]+)/docs/plans/`)
if m := re.FindStringSubmatch(path); m != nil {
return m[1]
}
re = regexp.MustCompile(`Code/([^/]+)/docs/plans/`)
if m := re.FindStringSubmatch(path); m != nil {
return m[1]
}
return "unknown"
}
// extractProjectFromClaudeMd derives a project name from a CLAUDE.md path.
// ~/Code/host-uk/core/CLAUDE.md → "core"
// ~/Code/eaas/CLAUDE.md → "eaas"
func extractProjectFromClaudeMd(path string) string {
re := regexp.MustCompile(`Code/host-uk/([^/]+)/`)
if m := re.FindStringSubmatch(path); m != nil {
return m[1]
}
re = regexp.MustCompile(`Code/([^/]+)/`)
if m := re.FindStringSubmatch(path); m != nil {
return m[1]
}
return "unknown"
}
// inferType guesses the memory type from heading + content keywords.
func inferType(heading, content, source string) string {
// Source-specific defaults (match PHP BrainIngestCommand behaviour)
if source == "plans" {
return "plan"
}
if source == "claude-md" {
return "convention"
}
lower := strings.ToLower(heading + " " + content)
patterns := map[string][]string{
"architecture": {"architecture", "stack", "infrastructure", "layer", "service mesh"},
"convention": {"convention", "standard", "naming", "pattern", "rule", "coding"},
"decision": {"decision", "chose", "strategy", "approach", "domain"},
"bug": {"bug", "fix", "broken", "error", "issue", "lesson"},
"plan": {"plan", "todo", "roadmap", "milestone", "phase", "task"},
"research": {"research", "finding", "discovery", "analysis", "rfc"},
}
for t, keywords := range patterns {
for _, kw := range keywords {
if strings.Contains(lower, kw) {
return t
}
}
}
return "observation"
}
// buildTags creates the tag list for a memory.
func buildTags(filename, source, project string) []string {
tags := []string{"source:" + source}
if project != "" && project != "unknown" {
tags = append(tags, "project:"+project)
}
if filename != "MEMORY" && filename != "CLAUDE" {
tags = append(tags, strings.ReplaceAll(strings.ReplaceAll(filename, "-", " "), "_", " "))
}
return tags
}
// confidenceForSource returns a default confidence level matching the PHP ingest command.
func confidenceForSource(source string) float64 {
switch source {
case "claude-md":
return 0.9
case "memory":
return 0.8
case "plans":
return 0.6
default:
return 0.5
}
}

View file

@ -8,7 +8,7 @@ import (
"path/filepath"
"forge.lthn.ai/core/cli/pkg/cli"
"forge.lthn.ai/core/go-ai/mcp"
"forge.lthn.ai/core/mcp/pkg/mcp"
"forge.lthn.ai/core/go-log"
"forge.lthn.ai/core/go-process"
)

View file

@ -1,92 +0,0 @@
// Package mcpcmd provides the MCP server command.
//
// Commands:
// - mcp serve: Start the MCP server for AI tool integration
package mcpcmd
import (
"context"
"os"
"os/signal"
"syscall"
"forge.lthn.ai/core/cli/pkg/cli"
"forge.lthn.ai/core/go-ai/mcp"
)
var workspaceFlag string
var mcpCmd = &cli.Command{
Use: "mcp",
Short: "MCP server for AI tool integration",
Long: "Model Context Protocol (MCP) server providing file operations, RAG, and metrics tools.",
}
var serveCmd = &cli.Command{
Use: "serve",
Short: "Start the MCP server",
Long: `Start the MCP server on stdio (default) or TCP.
The server provides file operations, RAG tools, and metrics tools for AI assistants.
Environment variables:
MCP_ADDR TCP address to listen on (e.g., "localhost:9999")
If not set, uses stdio transport.
Examples:
# Start with stdio transport (for Claude Code integration)
core mcp serve
# Start with workspace restriction
core mcp serve --workspace /path/to/project
# Start TCP server
MCP_ADDR=localhost:9999 core mcp serve`,
RunE: func(cmd *cli.Command, args []string) error {
return runServe()
},
}
func initFlags() {
cli.StringFlag(serveCmd, &workspaceFlag, "workspace", "w", "", "Restrict file operations to this directory (empty = unrestricted)")
}
// AddMCPCommands registers the 'mcp' command and all subcommands.
func AddMCPCommands(root *cli.Command) {
initFlags()
mcpCmd.AddCommand(serveCmd)
root.AddCommand(mcpCmd)
}
func runServe() error {
// Build MCP service options
var opts []mcp.Option
if workspaceFlag != "" {
opts = append(opts, mcp.WithWorkspaceRoot(workspaceFlag))
} else {
// Explicitly unrestricted when no workspace specified
opts = append(opts, mcp.WithWorkspaceRoot(""))
}
// Create the MCP service
svc, err := mcp.New(opts...)
if err != nil {
return cli.Wrap(err, "create MCP service")
}
// Set up signal handling for clean shutdown
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigCh
cancel()
}()
// Run the server (blocks until context cancelled or error)
return svc.Run(ctx)
}

View file

@ -1,42 +0,0 @@
// SPDX-License-Identifier: EUPL-1.2
// Package brain provides an MCP subsystem that proxies OpenBrain knowledge
// store operations to the Laravel php-agentic backend via the IDE bridge.
package brain
import (
"context"
"errors"
"forge.lthn.ai/core/go-ai/mcp/ide"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// errBridgeNotAvailable is returned when a tool requires the Laravel bridge
// but it has not been initialised (headless mode).
var errBridgeNotAvailable = errors.New("brain: bridge not available")
// Subsystem implements mcp.Subsystem for OpenBrain knowledge store operations.
// It proxies brain_* tool calls to the Laravel backend via the shared IDE bridge.
type Subsystem struct {
bridge *ide.Bridge
}
// New creates a brain subsystem that uses the given IDE bridge for Laravel communication.
// Pass nil if headless (tools will return errBridgeNotAvailable).
func New(bridge *ide.Bridge) *Subsystem {
return &Subsystem{bridge: bridge}
}
// Name implements mcp.Subsystem.
func (s *Subsystem) Name() string { return "brain" }
// RegisterTools implements mcp.Subsystem.
func (s *Subsystem) RegisterTools(server *mcp.Server) {
s.registerBrainTools(server)
}
// Shutdown implements mcp.SubsystemWithShutdown.
func (s *Subsystem) Shutdown(_ context.Context) error {
return nil
}

View file

@ -1,229 +0,0 @@
// SPDX-License-Identifier: EUPL-1.2
package brain
import (
"context"
"encoding/json"
"testing"
"time"
)
// --- Nil bridge tests (headless mode) ---
func TestBrainRemember_Bad_NilBridge(t *testing.T) {
sub := New(nil)
_, _, err := sub.brainRemember(context.Background(), nil, RememberInput{
Content: "test memory",
Type: "observation",
})
if err == nil {
t.Error("expected error when bridge is nil")
}
}
func TestBrainRecall_Bad_NilBridge(t *testing.T) {
sub := New(nil)
_, _, err := sub.brainRecall(context.Background(), nil, RecallInput{
Query: "how does scoring work?",
})
if err == nil {
t.Error("expected error when bridge is nil")
}
}
func TestBrainForget_Bad_NilBridge(t *testing.T) {
sub := New(nil)
_, _, err := sub.brainForget(context.Background(), nil, ForgetInput{
ID: "550e8400-e29b-41d4-a716-446655440000",
})
if err == nil {
t.Error("expected error when bridge is nil")
}
}
func TestBrainList_Bad_NilBridge(t *testing.T) {
sub := New(nil)
_, _, err := sub.brainList(context.Background(), nil, ListInput{
Project: "eaas",
})
if err == nil {
t.Error("expected error when bridge is nil")
}
}
// --- Subsystem interface tests ---
func TestSubsystem_Good_Name(t *testing.T) {
sub := New(nil)
if sub.Name() != "brain" {
t.Errorf("expected Name() = 'brain', got %q", sub.Name())
}
}
func TestSubsystem_Good_ShutdownNoop(t *testing.T) {
sub := New(nil)
if err := sub.Shutdown(context.Background()); err != nil {
t.Errorf("Shutdown failed: %v", err)
}
}
// --- Struct round-trip tests ---
func TestRememberInput_Good_RoundTrip(t *testing.T) {
in := RememberInput{
Content: "LEM scoring was blind to negative emotions",
Type: "bug",
Tags: []string{"scoring", "lem"},
Project: "eaas",
Confidence: 0.95,
Supersedes: "550e8400-e29b-41d4-a716-446655440000",
ExpiresIn: 24,
}
data, err := json.Marshal(in)
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var out RememberInput
if err := json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
if out.Content != in.Content || out.Type != in.Type {
t.Errorf("round-trip mismatch: content or type")
}
if len(out.Tags) != 2 || out.Tags[0] != "scoring" {
t.Errorf("round-trip mismatch: tags")
}
if out.Confidence != 0.95 {
t.Errorf("round-trip mismatch: confidence %f != 0.95", out.Confidence)
}
}
func TestRememberOutput_Good_RoundTrip(t *testing.T) {
in := RememberOutput{
Success: true,
MemoryID: "550e8400-e29b-41d4-a716-446655440000",
Timestamp: time.Now().Truncate(time.Second),
}
data, err := json.Marshal(in)
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var out RememberOutput
if err := json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
if !out.Success || out.MemoryID != in.MemoryID {
t.Errorf("round-trip mismatch: %+v != %+v", out, in)
}
}
func TestRecallInput_Good_RoundTrip(t *testing.T) {
in := RecallInput{
Query: "how does verdict classification work?",
TopK: 5,
Filter: RecallFilter{
Project: "eaas",
MinConfidence: 0.5,
},
}
data, err := json.Marshal(in)
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var out RecallInput
if err := json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
if out.Query != in.Query || out.TopK != 5 {
t.Errorf("round-trip mismatch: query or topK")
}
if out.Filter.Project != "eaas" || out.Filter.MinConfidence != 0.5 {
t.Errorf("round-trip mismatch: filter")
}
}
func TestMemory_Good_RoundTrip(t *testing.T) {
in := Memory{
ID: "550e8400-e29b-41d4-a716-446655440000",
AgentID: "virgil",
Type: "decision",
Content: "Use Qdrant for vector search",
Tags: []string{"architecture", "openbrain"},
Project: "php-agentic",
Confidence: 0.9,
CreatedAt: "2026-03-03T12:00:00+00:00",
UpdatedAt: "2026-03-03T12:00:00+00:00",
}
data, err := json.Marshal(in)
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var out Memory
if err := json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
if out.ID != in.ID || out.AgentID != "virgil" || out.Type != "decision" {
t.Errorf("round-trip mismatch: %+v", out)
}
}
func TestForgetInput_Good_RoundTrip(t *testing.T) {
in := ForgetInput{
ID: "550e8400-e29b-41d4-a716-446655440000",
Reason: "Superseded by new approach",
}
data, err := json.Marshal(in)
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var out ForgetInput
if err := json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
if out.ID != in.ID || out.Reason != in.Reason {
t.Errorf("round-trip mismatch: %+v != %+v", out, in)
}
}
func TestListInput_Good_RoundTrip(t *testing.T) {
in := ListInput{
Project: "eaas",
Type: "decision",
AgentID: "charon",
Limit: 20,
}
data, err := json.Marshal(in)
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var out ListInput
if err := json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
if out.Project != "eaas" || out.Type != "decision" || out.AgentID != "charon" || out.Limit != 20 {
t.Errorf("round-trip mismatch: %+v", out)
}
}
func TestListOutput_Good_RoundTrip(t *testing.T) {
in := ListOutput{
Success: true,
Count: 2,
Memories: []Memory{
{ID: "id-1", AgentID: "virgil", Type: "decision", Content: "memory 1", Confidence: 0.9, CreatedAt: "2026-03-03T12:00:00+00:00", UpdatedAt: "2026-03-03T12:00:00+00:00"},
{ID: "id-2", AgentID: "charon", Type: "bug", Content: "memory 2", Confidence: 0.8, CreatedAt: "2026-03-03T13:00:00+00:00", UpdatedAt: "2026-03-03T13:00:00+00:00"},
},
}
data, err := json.Marshal(in)
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var out ListOutput
if err := json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
if !out.Success || out.Count != 2 || len(out.Memories) != 2 {
t.Errorf("round-trip mismatch: %+v", out)
}
}

View file

@ -1,220 +0,0 @@
// SPDX-License-Identifier: EUPL-1.2
package brain
import (
"context"
"fmt"
"time"
"forge.lthn.ai/core/go-ai/mcp/ide"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// -- Input/Output types -------------------------------------------------------
// RememberInput is the input for brain_remember.
type RememberInput struct {
Content string `json:"content"`
Type string `json:"type"`
Tags []string `json:"tags,omitempty"`
Project string `json:"project,omitempty"`
Confidence float64 `json:"confidence,omitempty"`
Supersedes string `json:"supersedes,omitempty"`
ExpiresIn int `json:"expires_in,omitempty"`
}
// RememberOutput is the output for brain_remember.
type RememberOutput struct {
Success bool `json:"success"`
MemoryID string `json:"memoryId,omitempty"`
Timestamp time.Time `json:"timestamp"`
}
// RecallInput is the input for brain_recall.
type RecallInput struct {
Query string `json:"query"`
TopK int `json:"top_k,omitempty"`
Filter RecallFilter `json:"filter,omitempty"`
}
// RecallFilter holds optional filter criteria for brain_recall.
type RecallFilter struct {
Project string `json:"project,omitempty"`
Type any `json:"type,omitempty"`
AgentID string `json:"agent_id,omitempty"`
MinConfidence float64 `json:"min_confidence,omitempty"`
}
// RecallOutput is the output for brain_recall.
type RecallOutput struct {
Success bool `json:"success"`
Count int `json:"count"`
Memories []Memory `json:"memories"`
}
// Memory is a single memory entry returned by recall or list.
type Memory struct {
ID string `json:"id"`
AgentID string `json:"agent_id"`
Type string `json:"type"`
Content string `json:"content"`
Tags []string `json:"tags,omitempty"`
Project string `json:"project,omitempty"`
Confidence float64 `json:"confidence"`
SupersedesID string `json:"supersedes_id,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
// ForgetInput is the input for brain_forget.
type ForgetInput struct {
ID string `json:"id"`
Reason string `json:"reason,omitempty"`
}
// ForgetOutput is the output for brain_forget.
type ForgetOutput struct {
Success bool `json:"success"`
Forgotten string `json:"forgotten"`
Timestamp time.Time `json:"timestamp"`
}
// ListInput is the input for brain_list.
type ListInput struct {
Project string `json:"project,omitempty"`
Type string `json:"type,omitempty"`
AgentID string `json:"agent_id,omitempty"`
Limit int `json:"limit,omitempty"`
}
// ListOutput is the output for brain_list.
type ListOutput struct {
Success bool `json:"success"`
Count int `json:"count"`
Memories []Memory `json:"memories"`
}
// -- Tool registration --------------------------------------------------------
func (s *Subsystem) registerBrainTools(server *mcp.Server) {
mcp.AddTool(server, &mcp.Tool{
Name: "brain_remember",
Description: "Store a memory in the shared OpenBrain knowledge store. Persists decisions, observations, conventions, research, plans, bugs, or architecture knowledge for other agents.",
}, s.brainRemember)
mcp.AddTool(server, &mcp.Tool{
Name: "brain_recall",
Description: "Semantic search across the shared OpenBrain knowledge store. Returns memories ranked by similarity to your query, with optional filtering.",
}, s.brainRecall)
mcp.AddTool(server, &mcp.Tool{
Name: "brain_forget",
Description: "Remove a memory from the shared OpenBrain knowledge store. Permanently deletes from both database and vector index.",
}, s.brainForget)
mcp.AddTool(server, &mcp.Tool{
Name: "brain_list",
Description: "List memories in the shared OpenBrain knowledge store. Supports filtering by project, type, and agent. No vector search -- use brain_recall for semantic queries.",
}, s.brainList)
}
// -- Tool handlers ------------------------------------------------------------
func (s *Subsystem) brainRemember(_ context.Context, _ *mcp.CallToolRequest, input RememberInput) (*mcp.CallToolResult, RememberOutput, error) {
if s.bridge == nil {
return nil, RememberOutput{}, errBridgeNotAvailable
}
err := s.bridge.Send(ide.BridgeMessage{
Type: "brain_remember",
Data: map[string]any{
"content": input.Content,
"type": input.Type,
"tags": input.Tags,
"project": input.Project,
"confidence": input.Confidence,
"supersedes": input.Supersedes,
"expires_in": input.ExpiresIn,
},
})
if err != nil {
return nil, RememberOutput{}, fmt.Errorf("failed to send brain_remember: %w", err)
}
return nil, RememberOutput{
Success: true,
Timestamp: time.Now(),
}, nil
}
func (s *Subsystem) brainRecall(_ context.Context, _ *mcp.CallToolRequest, input RecallInput) (*mcp.CallToolResult, RecallOutput, error) {
if s.bridge == nil {
return nil, RecallOutput{}, errBridgeNotAvailable
}
err := s.bridge.Send(ide.BridgeMessage{
Type: "brain_recall",
Data: map[string]any{
"query": input.Query,
"top_k": input.TopK,
"filter": input.Filter,
},
})
if err != nil {
return nil, RecallOutput{}, fmt.Errorf("failed to send brain_recall: %w", err)
}
return nil, RecallOutput{
Success: true,
Memories: []Memory{},
}, nil
}
func (s *Subsystem) brainForget(_ context.Context, _ *mcp.CallToolRequest, input ForgetInput) (*mcp.CallToolResult, ForgetOutput, error) {
if s.bridge == nil {
return nil, ForgetOutput{}, errBridgeNotAvailable
}
err := s.bridge.Send(ide.BridgeMessage{
Type: "brain_forget",
Data: map[string]any{
"id": input.ID,
"reason": input.Reason,
},
})
if err != nil {
return nil, ForgetOutput{}, fmt.Errorf("failed to send brain_forget: %w", err)
}
return nil, ForgetOutput{
Success: true,
Forgotten: input.ID,
Timestamp: time.Now(),
}, nil
}
func (s *Subsystem) brainList(_ context.Context, _ *mcp.CallToolRequest, input ListInput) (*mcp.CallToolResult, ListOutput, error) {
if s.bridge == nil {
return nil, ListOutput{}, errBridgeNotAvailable
}
err := s.bridge.Send(ide.BridgeMessage{
Type: "brain_list",
Data: map[string]any{
"project": input.Project,
"type": input.Type,
"agent_id": input.AgentID,
"limit": input.Limit,
},
})
if err != nil {
return nil, ListOutput{}, fmt.Errorf("failed to send brain_list: %w", err)
}
return nil, ListOutput{
Success: true,
Memories: []Memory{},
}, nil
}

View file

@ -1,64 +0,0 @@
// SPDX-License-Identifier: EUPL-1.2
package mcp
import (
"encoding/json"
"errors"
"io"
"net/http"
"github.com/gin-gonic/gin"
api "forge.lthn.ai/core/go-api"
)
// maxBodySize is the maximum request body size accepted by bridged tool endpoints.
const maxBodySize = 10 << 20 // 10 MB
// BridgeToAPI populates a go-api ToolBridge from recorded MCP tools.
// Each tool becomes a POST endpoint that reads a JSON body, dispatches
// to the tool's RESTHandler (which knows the concrete input type), and
// wraps the result in the standard api.Response envelope.
func BridgeToAPI(svc *Service, bridge *api.ToolBridge) {
for rec := range svc.ToolsSeq() {
desc := api.ToolDescriptor{
Name: rec.Name,
Description: rec.Description,
Group: rec.Group,
InputSchema: rec.InputSchema,
OutputSchema: rec.OutputSchema,
}
// Capture the handler for the closure.
handler := rec.RESTHandler
bridge.Add(desc, func(c *gin.Context) {
var body []byte
if c.Request.Body != nil {
var err error
body, err = io.ReadAll(io.LimitReader(c.Request.Body, maxBodySize))
if err != nil {
c.JSON(http.StatusBadRequest, api.Fail("invalid_request", "Failed to read request body"))
return
}
}
result, err := handler(c.Request.Context(), body)
if err != nil {
// Classify JSON parse errors as client errors (400),
// everything else as server errors (500).
var syntaxErr *json.SyntaxError
var typeErr *json.UnmarshalTypeError
if errors.As(err, &syntaxErr) || errors.As(err, &typeErr) {
c.JSON(http.StatusBadRequest, api.Fail("invalid_input", "Malformed JSON in request body"))
return
}
c.JSON(http.StatusInternalServerError, api.Fail("tool_error", "Tool execution failed"))
return
}
c.JSON(http.StatusOK, api.OK(result))
})
}
}

View file

@ -1,250 +0,0 @@
// SPDX-License-Identifier: EUPL-1.2
package mcp
import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/gin-gonic/gin"
api "forge.lthn.ai/core/go-api"
)
func init() {
gin.SetMode(gin.TestMode)
}
func TestBridgeToAPI_Good_AllTools(t *testing.T) {
svc, err := New(WithWorkspaceRoot(t.TempDir()))
if err != nil {
t.Fatal(err)
}
bridge := api.NewToolBridge("/tools")
BridgeToAPI(svc, bridge)
svcCount := len(svc.Tools())
bridgeCount := len(bridge.Tools())
if svcCount == 0 {
t.Fatal("expected non-zero tool count from service")
}
if bridgeCount != svcCount {
t.Fatalf("bridge tool count %d != service tool count %d", bridgeCount, svcCount)
}
// Verify names match.
svcNames := make(map[string]bool)
for _, tr := range svc.Tools() {
svcNames[tr.Name] = true
}
for _, td := range bridge.Tools() {
if !svcNames[td.Name] {
t.Errorf("bridge has tool %q not found in service", td.Name)
}
}
}
func TestBridgeToAPI_Good_DescribableGroup(t *testing.T) {
svc, err := New(WithWorkspaceRoot(t.TempDir()))
if err != nil {
t.Fatal(err)
}
bridge := api.NewToolBridge("/tools")
BridgeToAPI(svc, bridge)
// ToolBridge implements DescribableGroup.
var dg api.DescribableGroup = bridge
descs := dg.Describe()
if len(descs) != len(svc.Tools()) {
t.Fatalf("expected %d descriptions, got %d", len(svc.Tools()), len(descs))
}
for _, d := range descs {
if d.Method != "POST" {
t.Errorf("expected Method=POST for %s, got %q", d.Path, d.Method)
}
if d.Summary == "" {
t.Errorf("expected non-empty Summary for %s", d.Path)
}
if len(d.Tags) == 0 {
t.Errorf("expected non-empty Tags for %s", d.Path)
}
}
}
func TestBridgeToAPI_Good_FileRead(t *testing.T) {
tmpDir := t.TempDir()
// Create a test file in the workspace.
testContent := "hello from bridge test"
if err := os.WriteFile(filepath.Join(tmpDir, "test.txt"), []byte(testContent), 0644); err != nil {
t.Fatal(err)
}
svc, err := New(WithWorkspaceRoot(tmpDir))
if err != nil {
t.Fatal(err)
}
bridge := api.NewToolBridge("/tools")
BridgeToAPI(svc, bridge)
// Register with a Gin engine and make a request.
engine := gin.New()
rg := engine.Group(bridge.BasePath())
bridge.RegisterRoutes(rg)
body := `{"path":"test.txt"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/tools/file_read", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
engine.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
}
// Parse the response envelope.
var resp api.Response[ReadFileOutput]
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if !resp.Success {
t.Fatalf("expected Success=true, got error: %+v", resp.Error)
}
if resp.Data.Content != testContent {
t.Fatalf("expected content %q, got %q", testContent, resp.Data.Content)
}
if resp.Data.Path != "test.txt" {
t.Fatalf("expected path %q, got %q", "test.txt", resp.Data.Path)
}
}
func TestBridgeToAPI_Bad_InvalidJSON(t *testing.T) {
svc, err := New(WithWorkspaceRoot(t.TempDir()))
if err != nil {
t.Fatal(err)
}
bridge := api.NewToolBridge("/tools")
BridgeToAPI(svc, bridge)
engine := gin.New()
rg := engine.Group(bridge.BasePath())
bridge.RegisterRoutes(rg)
// Send malformed JSON.
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/tools/file_read", strings.NewReader("{bad json"))
req.Header.Set("Content-Type", "application/json")
engine.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
// The handler unmarshals via RESTHandler which returns an error,
// but since it's a JSON parse error it ends up as tool_error.
// Check we get a non-200 with an error envelope.
if w.Code == http.StatusOK {
t.Fatalf("expected non-200 for invalid JSON, got 200")
}
}
var resp api.Response[any]
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if resp.Success {
t.Fatal("expected Success=false for invalid JSON")
}
if resp.Error == nil {
t.Fatal("expected error in response")
}
}
func TestBridgeToAPI_Good_EndToEnd(t *testing.T) {
svc, err := New(WithWorkspaceRoot(t.TempDir()))
if err != nil {
t.Fatal(err)
}
bridge := api.NewToolBridge("/tools")
BridgeToAPI(svc, bridge)
// Create an api.Engine with the bridge registered and Swagger enabled.
e, err := api.New(
api.WithSwagger("MCP Bridge Test", "Testing MCP-to-REST bridge", "0.1.0"),
)
if err != nil {
t.Fatal(err)
}
e.Register(bridge)
// Use a real test server because gin-swagger reads RequestURI
// which is not populated by httptest.NewRecorder.
srv := httptest.NewServer(e.Handler())
defer srv.Close()
// Verify the health endpoint still works.
resp, err := http.Get(srv.URL + "/health")
if err != nil {
t.Fatalf("health request failed: %v", err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200 for /health, got %d", resp.StatusCode)
}
// Verify a tool endpoint is reachable through the engine.
resp2, err := http.Post(srv.URL+"/tools/lang_list", "application/json", nil)
if err != nil {
t.Fatalf("lang_list request failed: %v", err)
}
defer resp2.Body.Close()
if resp2.StatusCode != http.StatusOK {
t.Fatalf("expected 200 for /tools/lang_list, got %d", resp2.StatusCode)
}
var langResp api.Response[GetSupportedLanguagesOutput]
if err := json.NewDecoder(resp2.Body).Decode(&langResp); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if !langResp.Success {
t.Fatalf("expected Success=true, got error: %+v", langResp.Error)
}
if len(langResp.Data.Languages) == 0 {
t.Fatal("expected non-empty languages list")
}
// Verify Swagger endpoint contains tool paths.
resp3, err := http.Get(srv.URL + "/swagger/doc.json")
if err != nil {
t.Fatalf("swagger request failed: %v", err)
}
defer resp3.Body.Close()
if resp3.StatusCode != http.StatusOK {
t.Fatalf("expected 200 for /swagger/doc.json, got %d", resp3.StatusCode)
}
var specDoc map[string]any
if err := json.NewDecoder(resp3.Body).Decode(&specDoc); err != nil {
t.Fatalf("swagger unmarshal error: %v", err)
}
paths, ok := specDoc["paths"].(map[string]any)
if !ok {
t.Fatal("expected 'paths' in swagger spec")
}
if _, ok := paths["/tools/file_read"]; !ok {
t.Error("expected /tools/file_read in swagger paths")
}
if _, ok := paths["/tools/lang_list"]; !ok {
t.Error("expected /tools/lang_list in swagger paths")
}
}

View file

@ -1,191 +0,0 @@
package ide
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"sync"
"time"
"forge.lthn.ai/core/go-ws"
"github.com/gorilla/websocket"
)
// BridgeMessage is the wire format between the IDE and Laravel.
type BridgeMessage struct {
Type string `json:"type"`
Channel string `json:"channel,omitempty"`
SessionID string `json:"sessionId,omitempty"`
Data any `json:"data,omitempty"`
Timestamp time.Time `json:"timestamp"`
}
// Bridge maintains a WebSocket connection to the Laravel core-agentic
// backend and forwards responses to a local ws.Hub.
type Bridge struct {
cfg Config
hub *ws.Hub
conn *websocket.Conn
mu sync.Mutex
connected bool
cancel context.CancelFunc
}
// NewBridge creates a bridge that will connect to the Laravel backend and
// forward incoming messages to the provided ws.Hub channels.
func NewBridge(hub *ws.Hub, cfg Config) *Bridge {
return &Bridge{cfg: cfg, hub: hub}
}
// Start begins the connection loop in a background goroutine.
// Call Shutdown to stop it.
func (b *Bridge) Start(ctx context.Context) {
ctx, b.cancel = context.WithCancel(ctx)
go b.connectLoop(ctx)
}
// Shutdown cleanly closes the bridge.
func (b *Bridge) Shutdown() {
if b.cancel != nil {
b.cancel()
}
b.mu.Lock()
defer b.mu.Unlock()
if b.conn != nil {
b.conn.Close()
b.conn = nil
}
b.connected = false
}
// Connected reports whether the bridge has an active connection.
func (b *Bridge) Connected() bool {
b.mu.Lock()
defer b.mu.Unlock()
return b.connected
}
// Send sends a message to the Laravel backend.
func (b *Bridge) Send(msg BridgeMessage) error {
b.mu.Lock()
defer b.mu.Unlock()
if b.conn == nil {
return errors.New("bridge: not connected")
}
msg.Timestamp = time.Now()
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("bridge: marshal failed: %w", err)
}
return b.conn.WriteMessage(websocket.TextMessage, data)
}
// connectLoop reconnects to Laravel with exponential backoff.
func (b *Bridge) connectLoop(ctx context.Context) {
delay := b.cfg.ReconnectInterval
for {
select {
case <-ctx.Done():
return
default:
}
if err := b.dial(ctx); err != nil {
log.Printf("ide bridge: connect failed: %v", err)
select {
case <-ctx.Done():
return
case <-time.After(delay):
}
delay = min(delay*2, b.cfg.MaxReconnectInterval)
continue
}
// Reset backoff on successful connection
delay = b.cfg.ReconnectInterval
b.readLoop(ctx)
}
}
func (b *Bridge) dial(ctx context.Context) error {
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}
var header http.Header
if b.cfg.Token != "" {
header = http.Header{}
header.Set("Authorization", "Bearer "+b.cfg.Token)
}
conn, _, err := dialer.DialContext(ctx, b.cfg.LaravelWSURL, header)
if err != nil {
return err
}
b.mu.Lock()
b.conn = conn
b.connected = true
b.mu.Unlock()
log.Printf("ide bridge: connected to %s", b.cfg.LaravelWSURL)
return nil
}
func (b *Bridge) readLoop(ctx context.Context) {
defer func() {
b.mu.Lock()
if b.conn != nil {
b.conn.Close()
}
b.connected = false
b.mu.Unlock()
}()
for {
select {
case <-ctx.Done():
return
default:
}
_, data, err := b.conn.ReadMessage()
if err != nil {
log.Printf("ide bridge: read error: %v", err)
return
}
var msg BridgeMessage
if err := json.Unmarshal(data, &msg); err != nil {
log.Printf("ide bridge: unmarshal error: %v", err)
continue
}
b.dispatch(msg)
}
}
// dispatch routes an incoming message to the appropriate ws.Hub channel.
func (b *Bridge) dispatch(msg BridgeMessage) {
if b.hub == nil {
return
}
wsMsg := ws.Message{
Type: ws.TypeEvent,
Data: msg.Data,
}
channel := msg.Channel
if channel == "" {
channel = "ide:" + msg.Type
}
if err := b.hub.SendToChannel(channel, wsMsg); err != nil {
log.Printf("ide bridge: dispatch to %s failed: %v", channel, err)
}
}

View file

@ -1,442 +0,0 @@
package ide
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"forge.lthn.ai/core/go-ws"
"github.com/gorilla/websocket"
)
var testUpgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
// echoServer creates a test WebSocket server that echoes messages back.
func echoServer(t *testing.T) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := testUpgrader.Upgrade(w, r, nil)
if err != nil {
t.Logf("upgrade error: %v", err)
return
}
defer conn.Close()
for {
mt, data, err := conn.ReadMessage()
if err != nil {
break
}
if err := conn.WriteMessage(mt, data); err != nil {
break
}
}
}))
}
func wsURL(ts *httptest.Server) string {
return "ws" + strings.TrimPrefix(ts.URL, "http")
}
// waitConnected polls bridge.Connected() until true or timeout.
func waitConnected(t *testing.T, bridge *Bridge, timeout time.Duration) {
t.Helper()
deadline := time.Now().Add(timeout)
for !bridge.Connected() && time.Now().Before(deadline) {
time.Sleep(50 * time.Millisecond)
}
if !bridge.Connected() {
t.Fatal("bridge did not connect within timeout")
}
}
func TestBridge_Good_ConnectAndSend(t *testing.T) {
ts := echoServer(t)
defer ts.Close()
hub := ws.NewHub()
ctx := t.Context()
go hub.Run(ctx)
cfg := DefaultConfig()
cfg.LaravelWSURL = wsURL(ts)
cfg.ReconnectInterval = 100 * time.Millisecond
bridge := NewBridge(hub, cfg)
bridge.Start(ctx)
waitConnected(t, bridge, 2*time.Second)
err := bridge.Send(BridgeMessage{
Type: "test",
Data: "hello",
})
if err != nil {
t.Fatalf("Send() failed: %v", err)
}
}
func TestBridge_Good_Shutdown(t *testing.T) {
ts := echoServer(t)
defer ts.Close()
hub := ws.NewHub()
ctx := t.Context()
go hub.Run(ctx)
cfg := DefaultConfig()
cfg.LaravelWSURL = wsURL(ts)
cfg.ReconnectInterval = 100 * time.Millisecond
bridge := NewBridge(hub, cfg)
bridge.Start(ctx)
waitConnected(t, bridge, 2*time.Second)
bridge.Shutdown()
if bridge.Connected() {
t.Error("bridge should be disconnected after Shutdown")
}
}
func TestBridge_Bad_SendWithoutConnection(t *testing.T) {
hub := ws.NewHub()
cfg := DefaultConfig()
bridge := NewBridge(hub, cfg)
err := bridge.Send(BridgeMessage{Type: "test"})
if err == nil {
t.Error("expected error when sending without connection")
}
}
func TestBridge_Good_MessageDispatch(t *testing.T) {
// Server that sends a message to the bridge on connect.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := testUpgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
msg := BridgeMessage{
Type: "chat_response",
Channel: "chat:session-1",
Data: "hello from laravel",
}
data, _ := json.Marshal(msg)
conn.WriteMessage(websocket.TextMessage, data)
// Keep connection open
for {
_, _, err := conn.ReadMessage()
if err != nil {
break
}
}
}))
defer ts.Close()
hub := ws.NewHub()
ctx := t.Context()
go hub.Run(ctx)
cfg := DefaultConfig()
cfg.LaravelWSURL = wsURL(ts)
cfg.ReconnectInterval = 100 * time.Millisecond
bridge := NewBridge(hub, cfg)
bridge.Start(ctx)
waitConnected(t, bridge, 2*time.Second)
// Give time for the dispatched message to be processed.
time.Sleep(200 * time.Millisecond)
// Verify hub stats — the message was dispatched (even without subscribers).
// This confirms the dispatch path ran without error.
}
func TestBridge_Good_Reconnect(t *testing.T) {
// Use atomic counter to avoid data race between HTTP handler goroutine
// and the test goroutine.
var callCount atomic.Int32
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n := callCount.Add(1)
conn, err := testUpgrader.Upgrade(w, r, nil)
if err != nil {
return
}
// Close immediately on first connection to force reconnect
if n == 1 {
conn.Close()
return
}
defer conn.Close()
for {
_, _, err := conn.ReadMessage()
if err != nil {
break
}
}
}))
defer ts.Close()
hub := ws.NewHub()
ctx := t.Context()
go hub.Run(ctx)
cfg := DefaultConfig()
cfg.LaravelWSURL = wsURL(ts)
cfg.ReconnectInterval = 100 * time.Millisecond
cfg.MaxReconnectInterval = 200 * time.Millisecond
bridge := NewBridge(hub, cfg)
bridge.Start(ctx)
waitConnected(t, bridge, 3*time.Second)
if callCount.Load() < 2 {
t.Errorf("expected at least 2 connection attempts, got %d", callCount.Load())
}
}
func TestBridge_Good_ExponentialBackoff(t *testing.T) {
// Track timestamps of dial attempts to verify backoff behaviour.
// The server rejects the WebSocket upgrade with HTTP 403, so dial()
// returns an error and the exponential backoff path fires.
var attempts []time.Time
var mu sync.Mutex
var attemptCount atomic.Int32
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
attempts = append(attempts, time.Now())
mu.Unlock()
attemptCount.Add(1)
// Reject the upgrade — this makes dial() fail, triggering backoff.
http.Error(w, "forbidden", http.StatusForbidden)
}))
defer ts.Close()
hub := ws.NewHub()
ctx := t.Context()
go hub.Run(ctx)
cfg := DefaultConfig()
cfg.LaravelWSURL = wsURL(ts)
cfg.ReconnectInterval = 100 * time.Millisecond
cfg.MaxReconnectInterval = 400 * time.Millisecond
bridge := NewBridge(hub, cfg)
bridge.Start(ctx)
// Wait for at least 4 dial attempts.
deadline := time.Now().Add(5 * time.Second)
for attemptCount.Load() < 4 && time.Now().Before(deadline) {
time.Sleep(50 * time.Millisecond)
}
bridge.Shutdown()
mu.Lock()
defer mu.Unlock()
if len(attempts) < 4 {
t.Fatalf("expected at least 4 connection attempts, got %d", len(attempts))
}
// Verify exponential backoff: gap between attempts should increase.
// Expected delays: ~100ms, ~200ms, ~400ms (capped).
// Allow generous tolerance since timing is non-deterministic.
for i := 1; i < len(attempts) && i <= 3; i++ {
gap := attempts[i].Sub(attempts[i-1])
// Minimum expected delay doubles each time: 100, 200, 400.
// We check a lower bound (50% of expected) to be resilient.
expectedMin := time.Duration(50*(1<<(i-1))) * time.Millisecond
if gap < expectedMin {
t.Errorf("attempt %d->%d gap %v < expected minimum %v", i-1, i, gap, expectedMin)
}
}
// Verify the backoff caps at MaxReconnectInterval.
if len(attempts) >= 5 {
gap := attempts[4].Sub(attempts[3])
// After cap is hit, delay should not exceed MaxReconnectInterval + tolerance.
maxExpected := cfg.MaxReconnectInterval + 200*time.Millisecond
if gap > maxExpected {
t.Errorf("attempt 3->4 gap %v exceeded max backoff %v", gap, maxExpected)
}
}
}
func TestBridge_Good_ReconnectDetectsServerShutdown(t *testing.T) {
// Start a server that closes the WS connection on demand, then close
// the server entirely so the bridge cannot reconnect.
closeConn := make(chan struct{}, 1)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := testUpgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
// Wait for signal to close
<-closeConn
conn.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseGoingAway, "shutdown"))
}))
hub := ws.NewHub()
ctx := t.Context()
go hub.Run(ctx)
cfg := DefaultConfig()
cfg.LaravelWSURL = wsURL(ts)
// Use long reconnect so bridge stays disconnected after server dies.
cfg.ReconnectInterval = 5 * time.Second
cfg.MaxReconnectInterval = 5 * time.Second
bridge := NewBridge(hub, cfg)
bridge.Start(ctx)
waitConnected(t, bridge, 2*time.Second)
// Signal server handler to close the WS connection, then shut down
// the server so the reconnect dial() also fails.
closeConn <- struct{}{}
ts.Close()
// Wait for disconnection.
deadline := time.Now().Add(3 * time.Second)
for bridge.Connected() && time.Now().Before(deadline) {
time.Sleep(50 * time.Millisecond)
}
if bridge.Connected() {
t.Error("expected bridge to detect server-side connection close")
}
}
func TestBridge_Good_AuthHeader(t *testing.T) {
// Server that checks for the Authorization header on upgrade.
var receivedAuth atomic.Value
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAuth.Store(r.Header.Get("Authorization"))
conn, err := testUpgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
_, _, err := conn.ReadMessage()
if err != nil {
break
}
}
}))
defer ts.Close()
hub := ws.NewHub()
ctx := t.Context()
go hub.Run(ctx)
cfg := DefaultConfig()
cfg.LaravelWSURL = wsURL(ts)
cfg.ReconnectInterval = 100 * time.Millisecond
cfg.Token = "test-secret-token-42"
bridge := NewBridge(hub, cfg)
bridge.Start(ctx)
waitConnected(t, bridge, 2*time.Second)
auth, ok := receivedAuth.Load().(string)
if !ok || auth == "" {
t.Fatal("server did not receive Authorization header")
}
expected := "Bearer test-secret-token-42"
if auth != expected {
t.Errorf("expected auth header %q, got %q", expected, auth)
}
}
func TestBridge_Good_NoAuthHeaderWhenTokenEmpty(t *testing.T) {
// Verify that no Authorization header is sent when Token is empty.
var receivedAuth atomic.Value
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedAuth.Store(r.Header.Get("Authorization"))
conn, err := testUpgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
_, _, err := conn.ReadMessage()
if err != nil {
break
}
}
}))
defer ts.Close()
hub := ws.NewHub()
ctx := t.Context()
go hub.Run(ctx)
cfg := DefaultConfig()
cfg.LaravelWSURL = wsURL(ts)
cfg.ReconnectInterval = 100 * time.Millisecond
// Token intentionally left empty
bridge := NewBridge(hub, cfg)
bridge.Start(ctx)
waitConnected(t, bridge, 2*time.Second)
auth, _ := receivedAuth.Load().(string)
if auth != "" {
t.Errorf("expected no Authorization header when token is empty, got %q", auth)
}
}
func TestBridge_Good_WithTokenOption(t *testing.T) {
// Verify the WithToken option function works.
cfg := DefaultConfig()
opt := WithToken("my-token")
opt(&cfg)
if cfg.Token != "my-token" {
t.Errorf("expected token 'my-token', got %q", cfg.Token)
}
}
func TestSubsystem_Good_Name(t *testing.T) {
sub := New(nil)
if sub.Name() != "ide" {
t.Errorf("expected name 'ide', got %q", sub.Name())
}
}
func TestSubsystem_Good_NilHub(t *testing.T) {
sub := New(nil)
if sub.Bridge() != nil {
t.Error("expected nil bridge when hub is nil")
}
// Shutdown should not panic
if err := sub.Shutdown(context.Background()); err != nil {
t.Errorf("Shutdown with nil bridge failed: %v", err)
}
}

View file

@ -1,57 +0,0 @@
// Package ide provides an MCP subsystem that bridges the desktop IDE to
// a Laravel core-agentic backend over WebSocket.
package ide
import "time"
// Config holds connection and workspace settings for the IDE subsystem.
type Config struct {
// LaravelWSURL is the WebSocket endpoint for the Laravel core-agentic backend.
LaravelWSURL string
// WorkspaceRoot is the local path used as the default workspace context.
WorkspaceRoot string
// Token is the Bearer token sent in the Authorization header during
// WebSocket upgrade. When empty, no auth header is sent.
Token string
// ReconnectInterval controls how long to wait between reconnect attempts.
ReconnectInterval time.Duration
// MaxReconnectInterval caps exponential backoff for reconnection.
MaxReconnectInterval time.Duration
}
// DefaultConfig returns sensible defaults for local development.
func DefaultConfig() Config {
return Config{
LaravelWSURL: "ws://localhost:9876/ws",
WorkspaceRoot: ".",
ReconnectInterval: 2 * time.Second,
MaxReconnectInterval: 30 * time.Second,
}
}
// Option configures the IDE subsystem.
type Option func(*Config)
// WithLaravelURL sets the Laravel WebSocket endpoint.
func WithLaravelURL(url string) Option {
return func(c *Config) { c.LaravelWSURL = url }
}
// WithWorkspaceRoot sets the workspace root directory.
func WithWorkspaceRoot(root string) Option {
return func(c *Config) { c.WorkspaceRoot = root }
}
// WithReconnectInterval sets the base reconnect interval.
func WithReconnectInterval(d time.Duration) Option {
return func(c *Config) { c.ReconnectInterval = d }
}
// WithToken sets the Bearer token for WebSocket authentication.
func WithToken(token string) Option {
return func(c *Config) { c.Token = token }
}

View file

@ -1,62 +0,0 @@
package ide
import (
"context"
"errors"
"forge.lthn.ai/core/go-ws"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// errBridgeNotAvailable is returned when a tool requires the Laravel bridge
// but it has not been initialised (headless mode).
var errBridgeNotAvailable = errors.New("bridge not available")
// Subsystem implements mcp.Subsystem and mcp.SubsystemWithShutdown for the IDE.
type Subsystem struct {
cfg Config
bridge *Bridge
hub *ws.Hub
}
// New creates an IDE subsystem. The ws.Hub is used for real-time forwarding;
// pass nil if headless (tools still work but real-time streaming is disabled).
func New(hub *ws.Hub, opts ...Option) *Subsystem {
cfg := DefaultConfig()
for _, opt := range opts {
opt(&cfg)
}
var bridge *Bridge
if hub != nil {
bridge = NewBridge(hub, cfg)
}
return &Subsystem{cfg: cfg, bridge: bridge, hub: hub}
}
// Name implements mcp.Subsystem.
func (s *Subsystem) Name() string { return "ide" }
// RegisterTools implements mcp.Subsystem.
func (s *Subsystem) RegisterTools(server *mcp.Server) {
s.registerChatTools(server)
s.registerBuildTools(server)
s.registerDashboardTools(server)
}
// Shutdown implements mcp.SubsystemWithShutdown.
func (s *Subsystem) Shutdown(_ context.Context) error {
if s.bridge != nil {
s.bridge.Shutdown()
}
return nil
}
// Bridge returns the Laravel WebSocket bridge (may be nil in headless mode).
func (s *Subsystem) Bridge() *Bridge { return s.bridge }
// StartBridge begins the background connection to the Laravel backend.
func (s *Subsystem) StartBridge(ctx context.Context) {
if s.bridge != nil {
s.bridge.Start(ctx)
}
}

View file

@ -1,114 +0,0 @@
package ide
import (
"context"
"time"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// Build tool input/output types.
// BuildStatusInput is the input for ide_build_status.
type BuildStatusInput struct {
BuildID string `json:"buildId"`
}
// BuildInfo represents a single build.
type BuildInfo struct {
ID string `json:"id"`
Repo string `json:"repo"`
Branch string `json:"branch"`
Status string `json:"status"`
Duration string `json:"duration,omitempty"`
StartedAt time.Time `json:"startedAt"`
}
// BuildStatusOutput is the output for ide_build_status.
type BuildStatusOutput struct {
Build BuildInfo `json:"build"`
}
// BuildListInput is the input for ide_build_list.
type BuildListInput struct {
Repo string `json:"repo,omitempty"`
Limit int `json:"limit,omitempty"`
}
// BuildListOutput is the output for ide_build_list.
type BuildListOutput struct {
Builds []BuildInfo `json:"builds"`
}
// BuildLogsInput is the input for ide_build_logs.
type BuildLogsInput struct {
BuildID string `json:"buildId"`
Tail int `json:"tail,omitempty"`
}
// BuildLogsOutput is the output for ide_build_logs.
type BuildLogsOutput struct {
BuildID string `json:"buildId"`
Lines []string `json:"lines"`
}
func (s *Subsystem) registerBuildTools(server *mcp.Server) {
mcp.AddTool(server, &mcp.Tool{
Name: "ide_build_status",
Description: "Get the status of a specific build",
}, s.buildStatus)
mcp.AddTool(server, &mcp.Tool{
Name: "ide_build_list",
Description: "List recent builds, optionally filtered by repository",
}, s.buildList)
mcp.AddTool(server, &mcp.Tool{
Name: "ide_build_logs",
Description: "Retrieve log output for a build",
}, s.buildLogs)
}
// buildStatus requests build status from the Laravel backend.
// Stub implementation: sends request via bridge, returns "unknown" status. Awaiting Laravel backend.
func (s *Subsystem) buildStatus(_ context.Context, _ *mcp.CallToolRequest, input BuildStatusInput) (*mcp.CallToolResult, BuildStatusOutput, error) {
if s.bridge == nil {
return nil, BuildStatusOutput{}, errBridgeNotAvailable
}
_ = s.bridge.Send(BridgeMessage{
Type: "build_status",
Data: map[string]any{"buildId": input.BuildID},
})
return nil, BuildStatusOutput{
Build: BuildInfo{ID: input.BuildID, Status: "unknown"},
}, nil
}
// buildList requests a list of builds from the Laravel backend.
// Stub implementation: sends request via bridge, returns empty list. Awaiting Laravel backend.
func (s *Subsystem) buildList(_ context.Context, _ *mcp.CallToolRequest, input BuildListInput) (*mcp.CallToolResult, BuildListOutput, error) {
if s.bridge == nil {
return nil, BuildListOutput{}, errBridgeNotAvailable
}
_ = s.bridge.Send(BridgeMessage{
Type: "build_list",
Data: map[string]any{"repo": input.Repo, "limit": input.Limit},
})
return nil, BuildListOutput{Builds: []BuildInfo{}}, nil
}
// buildLogs requests build log output from the Laravel backend.
// Stub implementation: sends request via bridge, returns empty lines. Awaiting Laravel backend.
func (s *Subsystem) buildLogs(_ context.Context, _ *mcp.CallToolRequest, input BuildLogsInput) (*mcp.CallToolResult, BuildLogsOutput, error) {
if s.bridge == nil {
return nil, BuildLogsOutput{}, errBridgeNotAvailable
}
_ = s.bridge.Send(BridgeMessage{
Type: "build_logs",
Data: map[string]any{"buildId": input.BuildID, "tail": input.Tail},
})
return nil, BuildLogsOutput{
BuildID: input.BuildID,
Lines: []string{},
}, nil
}

View file

@ -1,201 +0,0 @@
package ide
import (
"context"
"fmt"
"time"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// Chat tool input/output types.
// ChatSendInput is the input for ide_chat_send.
type ChatSendInput struct {
SessionID string `json:"sessionId"`
Message string `json:"message"`
}
// ChatSendOutput is the output for ide_chat_send.
type ChatSendOutput struct {
Sent bool `json:"sent"`
SessionID string `json:"sessionId"`
Timestamp time.Time `json:"timestamp"`
}
// ChatHistoryInput is the input for ide_chat_history.
type ChatHistoryInput struct {
SessionID string `json:"sessionId"`
Limit int `json:"limit,omitempty"`
}
// ChatMessage represents a single message in history.
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
Timestamp time.Time `json:"timestamp"`
}
// ChatHistoryOutput is the output for ide_chat_history.
type ChatHistoryOutput struct {
SessionID string `json:"sessionId"`
Messages []ChatMessage `json:"messages"`
}
// SessionListInput is the input for ide_session_list.
type SessionListInput struct{}
// Session represents an agent session.
type Session struct {
ID string `json:"id"`
Name string `json:"name"`
Status string `json:"status"`
CreatedAt time.Time `json:"createdAt"`
}
// SessionListOutput is the output for ide_session_list.
type SessionListOutput struct {
Sessions []Session `json:"sessions"`
}
// SessionCreateInput is the input for ide_session_create.
type SessionCreateInput struct {
Name string `json:"name"`
}
// SessionCreateOutput is the output for ide_session_create.
type SessionCreateOutput struct {
Session Session `json:"session"`
}
// PlanStatusInput is the input for ide_plan_status.
type PlanStatusInput struct {
SessionID string `json:"sessionId"`
}
// PlanStep is a single step in an agent plan.
type PlanStep struct {
Name string `json:"name"`
Status string `json:"status"`
}
// PlanStatusOutput is the output for ide_plan_status.
type PlanStatusOutput struct {
SessionID string `json:"sessionId"`
Status string `json:"status"`
Steps []PlanStep `json:"steps"`
}
func (s *Subsystem) registerChatTools(server *mcp.Server) {
mcp.AddTool(server, &mcp.Tool{
Name: "ide_chat_send",
Description: "Send a message to an agent chat session",
}, s.chatSend)
mcp.AddTool(server, &mcp.Tool{
Name: "ide_chat_history",
Description: "Retrieve message history for a chat session",
}, s.chatHistory)
mcp.AddTool(server, &mcp.Tool{
Name: "ide_session_list",
Description: "List active agent sessions",
}, s.sessionList)
mcp.AddTool(server, &mcp.Tool{
Name: "ide_session_create",
Description: "Create a new agent session",
}, s.sessionCreate)
mcp.AddTool(server, &mcp.Tool{
Name: "ide_plan_status",
Description: "Get the current plan status for a session",
}, s.planStatus)
}
// chatSend forwards a chat message to the Laravel backend via bridge.
// Stub implementation: delegates to bridge, real response arrives via WebSocket subscription.
func (s *Subsystem) chatSend(_ context.Context, _ *mcp.CallToolRequest, input ChatSendInput) (*mcp.CallToolResult, ChatSendOutput, error) {
if s.bridge == nil {
return nil, ChatSendOutput{}, errBridgeNotAvailable
}
err := s.bridge.Send(BridgeMessage{
Type: "chat_send",
Channel: "chat:" + input.SessionID,
SessionID: input.SessionID,
Data: input.Message,
})
if err != nil {
return nil, ChatSendOutput{}, fmt.Errorf("failed to send message: %w", err)
}
return nil, ChatSendOutput{
Sent: true,
SessionID: input.SessionID,
Timestamp: time.Now(),
}, nil
}
// chatHistory requests message history from the Laravel backend.
// Stub implementation: sends request via bridge, returns empty messages. Real data arrives via WebSocket.
func (s *Subsystem) chatHistory(_ context.Context, _ *mcp.CallToolRequest, input ChatHistoryInput) (*mcp.CallToolResult, ChatHistoryOutput, error) {
if s.bridge == nil {
return nil, ChatHistoryOutput{}, errBridgeNotAvailable
}
// Request history via bridge; for now return placeholder indicating the
// request was forwarded. Real data arrives via WebSocket subscription.
_ = s.bridge.Send(BridgeMessage{
Type: "chat_history",
SessionID: input.SessionID,
Data: map[string]any{"limit": input.Limit},
})
return nil, ChatHistoryOutput{
SessionID: input.SessionID,
Messages: []ChatMessage{},
}, nil
}
// sessionList requests the session list from the Laravel backend.
// Stub implementation: sends request via bridge, returns empty sessions. Awaiting Laravel backend.
func (s *Subsystem) sessionList(_ context.Context, _ *mcp.CallToolRequest, _ SessionListInput) (*mcp.CallToolResult, SessionListOutput, error) {
if s.bridge == nil {
return nil, SessionListOutput{}, errBridgeNotAvailable
}
_ = s.bridge.Send(BridgeMessage{Type: "session_list"})
return nil, SessionListOutput{Sessions: []Session{}}, nil
}
// sessionCreate requests a new session from the Laravel backend.
// Stub implementation: sends request via bridge, returns placeholder session. Awaiting Laravel backend.
func (s *Subsystem) sessionCreate(_ context.Context, _ *mcp.CallToolRequest, input SessionCreateInput) (*mcp.CallToolResult, SessionCreateOutput, error) {
if s.bridge == nil {
return nil, SessionCreateOutput{}, errBridgeNotAvailable
}
_ = s.bridge.Send(BridgeMessage{
Type: "session_create",
Data: map[string]any{"name": input.Name},
})
return nil, SessionCreateOutput{
Session: Session{
Name: input.Name,
Status: "creating",
CreatedAt: time.Now(),
},
}, nil
}
// planStatus requests plan status from the Laravel backend.
// Stub implementation: sends request via bridge, returns "unknown" status. Awaiting Laravel backend.
func (s *Subsystem) planStatus(_ context.Context, _ *mcp.CallToolRequest, input PlanStatusInput) (*mcp.CallToolResult, PlanStatusOutput, error) {
if s.bridge == nil {
return nil, PlanStatusOutput{}, errBridgeNotAvailable
}
_ = s.bridge.Send(BridgeMessage{
Type: "plan_status",
SessionID: input.SessionID,
})
return nil, PlanStatusOutput{
SessionID: input.SessionID,
Status: "unknown",
Steps: []PlanStep{},
}, nil
}

View file

@ -1,132 +0,0 @@
package ide
import (
"context"
"time"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// Dashboard tool input/output types.
// DashboardOverviewInput is the input for ide_dashboard_overview.
type DashboardOverviewInput struct{}
// DashboardOverview contains high-level platform stats.
type DashboardOverview struct {
Repos int `json:"repos"`
Services int `json:"services"`
ActiveSessions int `json:"activeSessions"`
RecentBuilds int `json:"recentBuilds"`
BridgeOnline bool `json:"bridgeOnline"`
}
// DashboardOverviewOutput is the output for ide_dashboard_overview.
type DashboardOverviewOutput struct {
Overview DashboardOverview `json:"overview"`
}
// DashboardActivityInput is the input for ide_dashboard_activity.
type DashboardActivityInput struct {
Limit int `json:"limit,omitempty"`
}
// ActivityEvent represents a single activity feed item.
type ActivityEvent struct {
Type string `json:"type"`
Message string `json:"message"`
Timestamp time.Time `json:"timestamp"`
}
// DashboardActivityOutput is the output for ide_dashboard_activity.
type DashboardActivityOutput struct {
Events []ActivityEvent `json:"events"`
}
// DashboardMetricsInput is the input for ide_dashboard_metrics.
type DashboardMetricsInput struct {
Period string `json:"period,omitempty"` // "1h", "24h", "7d"
}
// DashboardMetrics contains aggregate metrics.
type DashboardMetrics struct {
BuildsTotal int `json:"buildsTotal"`
BuildsSuccess int `json:"buildsSuccess"`
BuildsFailed int `json:"buildsFailed"`
AvgBuildTime string `json:"avgBuildTime"`
AgentSessions int `json:"agentSessions"`
MessagesTotal int `json:"messagesTotal"`
SuccessRate float64 `json:"successRate"`
}
// DashboardMetricsOutput is the output for ide_dashboard_metrics.
type DashboardMetricsOutput struct {
Period string `json:"period"`
Metrics DashboardMetrics `json:"metrics"`
}
func (s *Subsystem) registerDashboardTools(server *mcp.Server) {
mcp.AddTool(server, &mcp.Tool{
Name: "ide_dashboard_overview",
Description: "Get a high-level overview of the platform (repos, services, sessions, builds)",
}, s.dashboardOverview)
mcp.AddTool(server, &mcp.Tool{
Name: "ide_dashboard_activity",
Description: "Get the recent activity feed",
}, s.dashboardActivity)
mcp.AddTool(server, &mcp.Tool{
Name: "ide_dashboard_metrics",
Description: "Get aggregate build and agent metrics for a time period",
}, s.dashboardMetrics)
}
// dashboardOverview returns a platform overview with bridge status.
// Stub implementation: only BridgeOnline is live; other fields return zero values. Awaiting Laravel backend.
func (s *Subsystem) dashboardOverview(_ context.Context, _ *mcp.CallToolRequest, _ DashboardOverviewInput) (*mcp.CallToolResult, DashboardOverviewOutput, error) {
connected := s.bridge != nil && s.bridge.Connected()
if s.bridge != nil {
_ = s.bridge.Send(BridgeMessage{Type: "dashboard_overview"})
}
return nil, DashboardOverviewOutput{
Overview: DashboardOverview{
BridgeOnline: connected,
},
}, nil
}
// dashboardActivity requests the activity feed from the Laravel backend.
// Stub implementation: sends request via bridge, returns empty events. Awaiting Laravel backend.
func (s *Subsystem) dashboardActivity(_ context.Context, _ *mcp.CallToolRequest, input DashboardActivityInput) (*mcp.CallToolResult, DashboardActivityOutput, error) {
if s.bridge == nil {
return nil, DashboardActivityOutput{}, errBridgeNotAvailable
}
_ = s.bridge.Send(BridgeMessage{
Type: "dashboard_activity",
Data: map[string]any{"limit": input.Limit},
})
return nil, DashboardActivityOutput{Events: []ActivityEvent{}}, nil
}
// dashboardMetrics requests aggregate metrics from the Laravel backend.
// Stub implementation: sends request via bridge, returns zero metrics. Awaiting Laravel backend.
func (s *Subsystem) dashboardMetrics(_ context.Context, _ *mcp.CallToolRequest, input DashboardMetricsInput) (*mcp.CallToolResult, DashboardMetricsOutput, error) {
if s.bridge == nil {
return nil, DashboardMetricsOutput{}, errBridgeNotAvailable
}
period := input.Period
if period == "" {
period = "24h"
}
_ = s.bridge.Send(BridgeMessage{
Type: "dashboard_metrics",
Data: map[string]any{"period": period},
})
return nil, DashboardMetricsOutput{
Period: period,
Metrics: DashboardMetrics{},
}, nil
}

View file

@ -1,781 +0,0 @@
package ide
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"forge.lthn.ai/core/go-ws"
)
// --- Helpers ---
// newNilBridgeSubsystem returns a Subsystem with no hub/bridge (headless mode).
func newNilBridgeSubsystem() *Subsystem {
return New(nil)
}
// newConnectedSubsystem returns a Subsystem with a connected bridge and a
// running echo WS server. Caller must cancel ctx and close server when done.
func newConnectedSubsystem(t *testing.T) (*Subsystem, context.CancelFunc, *httptest.Server) {
t.Helper()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := testUpgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
mt, data, err := conn.ReadMessage()
if err != nil {
break
}
_ = conn.WriteMessage(mt, data)
}
}))
hub := ws.NewHub()
ctx, cancel := context.WithCancel(context.Background())
go hub.Run(ctx)
sub := New(hub,
WithLaravelURL(wsURL(ts)),
WithReconnectInterval(50*time.Millisecond),
)
sub.StartBridge(ctx)
waitConnected(t, sub.Bridge(), 2*time.Second)
return sub, cancel, ts
}
// --- 4.3: Chat tool tests ---
// TestChatSend_Bad_NilBridge verifies chatSend returns error without a bridge.
func TestChatSend_Bad_NilBridge(t *testing.T) {
sub := newNilBridgeSubsystem()
_, _, err := sub.chatSend(context.Background(), nil, ChatSendInput{
SessionID: "s1",
Message: "hello",
})
if err == nil {
t.Error("expected error when bridge is nil")
}
}
// TestChatSend_Good_Connected verifies chatSend succeeds with a connected bridge.
func TestChatSend_Good_Connected(t *testing.T) {
sub, cancel, ts := newConnectedSubsystem(t)
defer cancel()
defer ts.Close()
_, out, err := sub.chatSend(context.Background(), nil, ChatSendInput{
SessionID: "sess-42",
Message: "hello",
})
if err != nil {
t.Fatalf("chatSend failed: %v", err)
}
if !out.Sent {
t.Error("expected Sent=true")
}
if out.SessionID != "sess-42" {
t.Errorf("expected sessionId 'sess-42', got %q", out.SessionID)
}
if out.Timestamp.IsZero() {
t.Error("expected non-zero timestamp")
}
}
// TestChatHistory_Bad_NilBridge verifies chatHistory returns error without a bridge.
func TestChatHistory_Bad_NilBridge(t *testing.T) {
sub := newNilBridgeSubsystem()
_, _, err := sub.chatHistory(context.Background(), nil, ChatHistoryInput{
SessionID: "s1",
})
if err == nil {
t.Error("expected error when bridge is nil")
}
}
// TestChatHistory_Good_Connected verifies chatHistory succeeds and returns empty messages.
func TestChatHistory_Good_Connected(t *testing.T) {
sub, cancel, ts := newConnectedSubsystem(t)
defer cancel()
defer ts.Close()
_, out, err := sub.chatHistory(context.Background(), nil, ChatHistoryInput{
SessionID: "sess-1",
Limit: 50,
})
if err != nil {
t.Fatalf("chatHistory failed: %v", err)
}
if out.SessionID != "sess-1" {
t.Errorf("expected sessionId 'sess-1', got %q", out.SessionID)
}
if out.Messages == nil {
t.Error("expected non-nil messages slice")
}
if len(out.Messages) != 0 {
t.Errorf("expected 0 messages (stub), got %d", len(out.Messages))
}
}
// TestSessionList_Bad_NilBridge verifies sessionList returns error without a bridge.
func TestSessionList_Bad_NilBridge(t *testing.T) {
sub := newNilBridgeSubsystem()
_, _, err := sub.sessionList(context.Background(), nil, SessionListInput{})
if err == nil {
t.Error("expected error when bridge is nil")
}
}
// TestSessionList_Good_Connected verifies sessionList returns empty sessions.
func TestSessionList_Good_Connected(t *testing.T) {
sub, cancel, ts := newConnectedSubsystem(t)
defer cancel()
defer ts.Close()
_, out, err := sub.sessionList(context.Background(), nil, SessionListInput{})
if err != nil {
t.Fatalf("sessionList failed: %v", err)
}
if out.Sessions == nil {
t.Error("expected non-nil sessions slice")
}
if len(out.Sessions) != 0 {
t.Errorf("expected 0 sessions (stub), got %d", len(out.Sessions))
}
}
// TestSessionCreate_Bad_NilBridge verifies sessionCreate returns error without a bridge.
func TestSessionCreate_Bad_NilBridge(t *testing.T) {
sub := newNilBridgeSubsystem()
_, _, err := sub.sessionCreate(context.Background(), nil, SessionCreateInput{
Name: "test",
})
if err == nil {
t.Error("expected error when bridge is nil")
}
}
// TestSessionCreate_Good_Connected verifies sessionCreate returns a session stub.
func TestSessionCreate_Good_Connected(t *testing.T) {
sub, cancel, ts := newConnectedSubsystem(t)
defer cancel()
defer ts.Close()
_, out, err := sub.sessionCreate(context.Background(), nil, SessionCreateInput{
Name: "my-session",
})
if err != nil {
t.Fatalf("sessionCreate failed: %v", err)
}
if out.Session.Name != "my-session" {
t.Errorf("expected name 'my-session', got %q", out.Session.Name)
}
if out.Session.Status != "creating" {
t.Errorf("expected status 'creating', got %q", out.Session.Status)
}
if out.Session.CreatedAt.IsZero() {
t.Error("expected non-zero CreatedAt")
}
}
// TestPlanStatus_Bad_NilBridge verifies planStatus returns error without a bridge.
func TestPlanStatus_Bad_NilBridge(t *testing.T) {
sub := newNilBridgeSubsystem()
_, _, err := sub.planStatus(context.Background(), nil, PlanStatusInput{
SessionID: "s1",
})
if err == nil {
t.Error("expected error when bridge is nil")
}
}
// TestPlanStatus_Good_Connected verifies planStatus returns a stub status.
func TestPlanStatus_Good_Connected(t *testing.T) {
sub, cancel, ts := newConnectedSubsystem(t)
defer cancel()
defer ts.Close()
_, out, err := sub.planStatus(context.Background(), nil, PlanStatusInput{
SessionID: "sess-7",
})
if err != nil {
t.Fatalf("planStatus failed: %v", err)
}
if out.SessionID != "sess-7" {
t.Errorf("expected sessionId 'sess-7', got %q", out.SessionID)
}
if out.Status != "unknown" {
t.Errorf("expected status 'unknown', got %q", out.Status)
}
if out.Steps == nil {
t.Error("expected non-nil steps slice")
}
}
// --- 4.3: Build tool tests ---
// TestBuildStatus_Bad_NilBridge verifies buildStatus returns error without a bridge.
func TestBuildStatus_Bad_NilBridge(t *testing.T) {
sub := newNilBridgeSubsystem()
_, _, err := sub.buildStatus(context.Background(), nil, BuildStatusInput{
BuildID: "b1",
})
if err == nil {
t.Error("expected error when bridge is nil")
}
}
// TestBuildStatus_Good_Connected verifies buildStatus returns a stub.
func TestBuildStatus_Good_Connected(t *testing.T) {
sub, cancel, ts := newConnectedSubsystem(t)
defer cancel()
defer ts.Close()
_, out, err := sub.buildStatus(context.Background(), nil, BuildStatusInput{
BuildID: "build-99",
})
if err != nil {
t.Fatalf("buildStatus failed: %v", err)
}
if out.Build.ID != "build-99" {
t.Errorf("expected build ID 'build-99', got %q", out.Build.ID)
}
if out.Build.Status != "unknown" {
t.Errorf("expected status 'unknown', got %q", out.Build.Status)
}
}
// TestBuildList_Bad_NilBridge verifies buildList returns error without a bridge.
func TestBuildList_Bad_NilBridge(t *testing.T) {
sub := newNilBridgeSubsystem()
_, _, err := sub.buildList(context.Background(), nil, BuildListInput{
Repo: "core-php",
Limit: 10,
})
if err == nil {
t.Error("expected error when bridge is nil")
}
}
// TestBuildList_Good_Connected verifies buildList returns an empty list.
func TestBuildList_Good_Connected(t *testing.T) {
sub, cancel, ts := newConnectedSubsystem(t)
defer cancel()
defer ts.Close()
_, out, err := sub.buildList(context.Background(), nil, BuildListInput{
Repo: "core-php",
Limit: 10,
})
if err != nil {
t.Fatalf("buildList failed: %v", err)
}
if out.Builds == nil {
t.Error("expected non-nil builds slice")
}
if len(out.Builds) != 0 {
t.Errorf("expected 0 builds (stub), got %d", len(out.Builds))
}
}
// TestBuildLogs_Bad_NilBridge verifies buildLogs returns error without a bridge.
func TestBuildLogs_Bad_NilBridge(t *testing.T) {
sub := newNilBridgeSubsystem()
_, _, err := sub.buildLogs(context.Background(), nil, BuildLogsInput{
BuildID: "b1",
Tail: 100,
})
if err == nil {
t.Error("expected error when bridge is nil")
}
}
// TestBuildLogs_Good_Connected verifies buildLogs returns empty lines.
func TestBuildLogs_Good_Connected(t *testing.T) {
sub, cancel, ts := newConnectedSubsystem(t)
defer cancel()
defer ts.Close()
_, out, err := sub.buildLogs(context.Background(), nil, BuildLogsInput{
BuildID: "build-55",
Tail: 50,
})
if err != nil {
t.Fatalf("buildLogs failed: %v", err)
}
if out.BuildID != "build-55" {
t.Errorf("expected buildId 'build-55', got %q", out.BuildID)
}
if out.Lines == nil {
t.Error("expected non-nil lines slice")
}
if len(out.Lines) != 0 {
t.Errorf("expected 0 lines (stub), got %d", len(out.Lines))
}
}
// --- 4.3: Dashboard tool tests ---
// TestDashboardOverview_Good_NilBridge verifies dashboardOverview works without bridge
// (it does not return error — it reports BridgeOnline=false).
func TestDashboardOverview_Good_NilBridge(t *testing.T) {
sub := newNilBridgeSubsystem()
_, out, err := sub.dashboardOverview(context.Background(), nil, DashboardOverviewInput{})
if err != nil {
t.Fatalf("dashboardOverview failed: %v", err)
}
if out.Overview.BridgeOnline {
t.Error("expected BridgeOnline=false when bridge is nil")
}
}
// TestDashboardOverview_Good_Connected verifies dashboardOverview reports bridge online.
func TestDashboardOverview_Good_Connected(t *testing.T) {
sub, cancel, ts := newConnectedSubsystem(t)
defer cancel()
defer ts.Close()
_, out, err := sub.dashboardOverview(context.Background(), nil, DashboardOverviewInput{})
if err != nil {
t.Fatalf("dashboardOverview failed: %v", err)
}
if !out.Overview.BridgeOnline {
t.Error("expected BridgeOnline=true when bridge is connected")
}
}
// TestDashboardActivity_Bad_NilBridge verifies dashboardActivity returns error without bridge.
func TestDashboardActivity_Bad_NilBridge(t *testing.T) {
sub := newNilBridgeSubsystem()
_, _, err := sub.dashboardActivity(context.Background(), nil, DashboardActivityInput{
Limit: 10,
})
if err == nil {
t.Error("expected error when bridge is nil")
}
}
// TestDashboardActivity_Good_Connected verifies dashboardActivity returns empty events.
func TestDashboardActivity_Good_Connected(t *testing.T) {
sub, cancel, ts := newConnectedSubsystem(t)
defer cancel()
defer ts.Close()
_, out, err := sub.dashboardActivity(context.Background(), nil, DashboardActivityInput{
Limit: 20,
})
if err != nil {
t.Fatalf("dashboardActivity failed: %v", err)
}
if out.Events == nil {
t.Error("expected non-nil events slice")
}
if len(out.Events) != 0 {
t.Errorf("expected 0 events (stub), got %d", len(out.Events))
}
}
// TestDashboardMetrics_Bad_NilBridge verifies dashboardMetrics returns error without bridge.
func TestDashboardMetrics_Bad_NilBridge(t *testing.T) {
sub := newNilBridgeSubsystem()
_, _, err := sub.dashboardMetrics(context.Background(), nil, DashboardMetricsInput{
Period: "1h",
})
if err == nil {
t.Error("expected error when bridge is nil")
}
}
// TestDashboardMetrics_Good_Connected verifies dashboardMetrics returns empty metrics.
func TestDashboardMetrics_Good_Connected(t *testing.T) {
sub, cancel, ts := newConnectedSubsystem(t)
defer cancel()
defer ts.Close()
_, out, err := sub.dashboardMetrics(context.Background(), nil, DashboardMetricsInput{
Period: "7d",
})
if err != nil {
t.Fatalf("dashboardMetrics failed: %v", err)
}
if out.Period != "7d" {
t.Errorf("expected period '7d', got %q", out.Period)
}
}
// TestDashboardMetrics_Good_DefaultPeriod verifies the default period is "24h".
func TestDashboardMetrics_Good_DefaultPeriod(t *testing.T) {
sub, cancel, ts := newConnectedSubsystem(t)
defer cancel()
defer ts.Close()
_, out, err := sub.dashboardMetrics(context.Background(), nil, DashboardMetricsInput{})
if err != nil {
t.Fatalf("dashboardMetrics failed: %v", err)
}
if out.Period != "24h" {
t.Errorf("expected default period '24h', got %q", out.Period)
}
}
// --- Struct serialisation round-trip tests ---
// TestChatSendInput_Good_RoundTrip verifies JSON serialisation of ChatSendInput.
func TestChatSendInput_Good_RoundTrip(t *testing.T) {
in := ChatSendInput{SessionID: "s1", Message: "hello"}
data, err := json.Marshal(in)
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var out ChatSendInput
if err := json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
if out != in {
t.Errorf("round-trip mismatch: %+v != %+v", out, in)
}
}
// TestChatSendOutput_Good_RoundTrip verifies JSON serialisation of ChatSendOutput.
func TestChatSendOutput_Good_RoundTrip(t *testing.T) {
in := ChatSendOutput{Sent: true, SessionID: "s1", Timestamp: time.Now().Truncate(time.Second)}
data, err := json.Marshal(in)
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var out ChatSendOutput
if err := json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
if out.Sent != in.Sent || out.SessionID != in.SessionID {
t.Errorf("round-trip mismatch: %+v != %+v", out, in)
}
}
// TestChatHistoryOutput_Good_RoundTrip verifies ChatHistoryOutput JSON round-trip.
func TestChatHistoryOutput_Good_RoundTrip(t *testing.T) {
in := ChatHistoryOutput{
SessionID: "s1",
Messages: []ChatMessage{
{Role: "user", Content: "hello", Timestamp: time.Now().Truncate(time.Second)},
{Role: "assistant", Content: "hi", Timestamp: time.Now().Truncate(time.Second)},
},
}
data, err := json.Marshal(in)
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var out ChatHistoryOutput
if err := json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
if out.SessionID != in.SessionID {
t.Errorf("sessionId mismatch: %q != %q", out.SessionID, in.SessionID)
}
if len(out.Messages) != 2 {
t.Errorf("expected 2 messages, got %d", len(out.Messages))
}
}
// TestSessionListOutput_Good_RoundTrip verifies SessionListOutput JSON round-trip.
func TestSessionListOutput_Good_RoundTrip(t *testing.T) {
in := SessionListOutput{
Sessions: []Session{
{ID: "s1", Name: "test", Status: "active", CreatedAt: time.Now().Truncate(time.Second)},
},
}
data, err := json.Marshal(in)
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var out SessionListOutput
if err := json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
if len(out.Sessions) != 1 || out.Sessions[0].ID != "s1" {
t.Errorf("round-trip mismatch: %+v", out)
}
}
// TestPlanStatusOutput_Good_RoundTrip verifies PlanStatusOutput JSON round-trip.
func TestPlanStatusOutput_Good_RoundTrip(t *testing.T) {
in := PlanStatusOutput{
SessionID: "s1",
Status: "running",
Steps: []PlanStep{{Name: "step1", Status: "done"}, {Name: "step2", Status: "pending"}},
}
data, err := json.Marshal(in)
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var out PlanStatusOutput
if err := json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
if out.SessionID != "s1" || len(out.Steps) != 2 {
t.Errorf("round-trip mismatch: %+v", out)
}
}
// TestBuildStatusOutput_Good_RoundTrip verifies BuildStatusOutput JSON round-trip.
func TestBuildStatusOutput_Good_RoundTrip(t *testing.T) {
in := BuildStatusOutput{
Build: BuildInfo{
ID: "b1",
Repo: "core-php",
Branch: "main",
Status: "success",
Duration: "2m30s",
StartedAt: time.Now().Truncate(time.Second),
},
}
data, err := json.Marshal(in)
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var out BuildStatusOutput
if err := json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
if out.Build.ID != "b1" || out.Build.Status != "success" {
t.Errorf("round-trip mismatch: %+v", out)
}
}
// TestBuildListOutput_Good_RoundTrip verifies BuildListOutput JSON round-trip.
func TestBuildListOutput_Good_RoundTrip(t *testing.T) {
in := BuildListOutput{
Builds: []BuildInfo{
{ID: "b1", Repo: "core-php", Branch: "main", Status: "success"},
{ID: "b2", Repo: "core-admin", Branch: "dev", Status: "failed"},
},
}
data, err := json.Marshal(in)
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var out BuildListOutput
if err := json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
if len(out.Builds) != 2 {
t.Errorf("expected 2 builds, got %d", len(out.Builds))
}
}
// TestBuildLogsOutput_Good_RoundTrip verifies BuildLogsOutput JSON round-trip.
func TestBuildLogsOutput_Good_RoundTrip(t *testing.T) {
in := BuildLogsOutput{
BuildID: "b1",
Lines: []string{"line1", "line2", "line3"},
}
data, err := json.Marshal(in)
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var out BuildLogsOutput
if err := json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
if out.BuildID != "b1" || len(out.Lines) != 3 {
t.Errorf("round-trip mismatch: %+v", out)
}
}
// TestDashboardOverviewOutput_Good_RoundTrip verifies DashboardOverviewOutput JSON round-trip.
func TestDashboardOverviewOutput_Good_RoundTrip(t *testing.T) {
in := DashboardOverviewOutput{
Overview: DashboardOverview{
Repos: 18,
Services: 5,
ActiveSessions: 3,
RecentBuilds: 12,
BridgeOnline: true,
},
}
data, err := json.Marshal(in)
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var out DashboardOverviewOutput
if err := json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
if out.Overview.Repos != 18 || !out.Overview.BridgeOnline {
t.Errorf("round-trip mismatch: %+v", out)
}
}
// TestDashboardActivityOutput_Good_RoundTrip verifies DashboardActivityOutput JSON round-trip.
func TestDashboardActivityOutput_Good_RoundTrip(t *testing.T) {
in := DashboardActivityOutput{
Events: []ActivityEvent{
{Type: "deploy", Message: "deployed v1.2", Timestamp: time.Now().Truncate(time.Second)},
},
}
data, err := json.Marshal(in)
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var out DashboardActivityOutput
if err := json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
if len(out.Events) != 1 || out.Events[0].Type != "deploy" {
t.Errorf("round-trip mismatch: %+v", out)
}
}
// TestDashboardMetricsOutput_Good_RoundTrip verifies DashboardMetricsOutput JSON round-trip.
func TestDashboardMetricsOutput_Good_RoundTrip(t *testing.T) {
in := DashboardMetricsOutput{
Period: "24h",
Metrics: DashboardMetrics{
BuildsTotal: 100,
BuildsSuccess: 90,
BuildsFailed: 10,
AvgBuildTime: "3m",
AgentSessions: 5,
MessagesTotal: 500,
SuccessRate: 0.9,
},
}
data, err := json.Marshal(in)
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var out DashboardMetricsOutput
if err := json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
if out.Period != "24h" || out.Metrics.BuildsTotal != 100 || out.Metrics.SuccessRate != 0.9 {
t.Errorf("round-trip mismatch: %+v", out)
}
}
// TestBridgeMessage_Good_RoundTrip verifies BridgeMessage JSON round-trip.
func TestBridgeMessage_Good_RoundTrip(t *testing.T) {
in := BridgeMessage{
Type: "test",
Channel: "ch1",
SessionID: "s1",
Data: "payload",
Timestamp: time.Now().Truncate(time.Second),
}
data, err := json.Marshal(in)
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var out BridgeMessage
if err := json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
if out.Type != "test" || out.Channel != "ch1" || out.SessionID != "s1" {
t.Errorf("round-trip mismatch: %+v", out)
}
}
// --- Subsystem integration tests ---
// TestSubsystem_Good_RegisterTools verifies RegisterTools does not panic.
func TestSubsystem_Good_RegisterTools(t *testing.T) {
// RegisterTools requires a real mcp.Server which is complex to construct
// in isolation. This test verifies the Subsystem can be created and
// the Bridge/Shutdown path works end-to-end.
sub := New(nil)
if sub.Bridge() != nil {
t.Error("expected nil bridge with nil hub")
}
if err := sub.Shutdown(context.Background()); err != nil {
t.Errorf("Shutdown failed: %v", err)
}
}
// TestSubsystem_Good_StartBridgeNilHub verifies StartBridge is a no-op with nil hub.
func TestSubsystem_Good_StartBridgeNilHub(t *testing.T) {
sub := New(nil)
// Should not panic
sub.StartBridge(context.Background())
}
// TestSubsystem_Good_WithOptions verifies all config options apply correctly.
func TestSubsystem_Good_WithOptions(t *testing.T) {
hub := ws.NewHub()
sub := New(hub,
WithLaravelURL("ws://custom:1234/ws"),
WithWorkspaceRoot("/tmp/test"),
WithReconnectInterval(5*time.Second),
WithToken("secret-123"),
)
if sub.cfg.LaravelWSURL != "ws://custom:1234/ws" {
t.Errorf("expected custom URL, got %q", sub.cfg.LaravelWSURL)
}
if sub.cfg.WorkspaceRoot != "/tmp/test" {
t.Errorf("expected workspace '/tmp/test', got %q", sub.cfg.WorkspaceRoot)
}
if sub.cfg.ReconnectInterval != 5*time.Second {
t.Errorf("expected 5s reconnect interval, got %v", sub.cfg.ReconnectInterval)
}
if sub.cfg.Token != "secret-123" {
t.Errorf("expected token 'secret-123', got %q", sub.cfg.Token)
}
}
// --- Tool sends correct bridge message type ---
// TestChatSend_Good_BridgeMessageType verifies the bridge receives the correct message type.
func TestChatSend_Good_BridgeMessageType(t *testing.T) {
msgCh := make(chan BridgeMessage, 1)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := testUpgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
_, data, err := conn.ReadMessage()
if err != nil {
return
}
var msg BridgeMessage
json.Unmarshal(data, &msg)
msgCh <- msg
// Keep alive
for {
if _, _, err := conn.ReadMessage(); err != nil {
break
}
}
}))
defer ts.Close()
hub := ws.NewHub()
ctx := t.Context()
go hub.Run(ctx)
sub := New(hub, WithLaravelURL(wsURL(ts)), WithReconnectInterval(50*time.Millisecond))
sub.StartBridge(ctx)
waitConnected(t, sub.Bridge(), 2*time.Second)
sub.chatSend(ctx, nil, ChatSendInput{SessionID: "s1", Message: "test"})
select {
case received := <-msgCh:
if received.Type != "chat_send" {
t.Errorf("expected bridge message type 'chat_send', got %q", received.Type)
}
if received.Channel != "chat:s1" {
t.Errorf("expected channel 'chat:s1', got %q", received.Channel)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for bridge message")
}
}

View file

@ -1,121 +0,0 @@
package mcp
import (
"context"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIntegration_FileTools(t *testing.T) {
tmpDir := t.TempDir()
s, err := New(WithWorkspaceRoot(tmpDir))
assert.NoError(t, err)
ctx := context.Background()
// 1. Test file_write
writeInput := WriteFileInput{
Path: "test.txt",
Content: "hello world",
}
_, writeOutput, err := s.writeFile(ctx, nil, writeInput)
assert.NoError(t, err)
assert.True(t, writeOutput.Success)
assert.Equal(t, "test.txt", writeOutput.Path)
// Verify on disk
content, _ := os.ReadFile(filepath.Join(tmpDir, "test.txt"))
assert.Equal(t, "hello world", string(content))
// 2. Test file_read
readInput := ReadFileInput{
Path: "test.txt",
}
_, readOutput, err := s.readFile(ctx, nil, readInput)
assert.NoError(t, err)
assert.Equal(t, "hello world", readOutput.Content)
assert.Equal(t, "plaintext", readOutput.Language)
// 3. Test file_edit (replace_all=false)
editInput := EditDiffInput{
Path: "test.txt",
OldString: "world",
NewString: "mcp",
}
_, editOutput, err := s.editDiff(ctx, nil, editInput)
assert.NoError(t, err)
assert.True(t, editOutput.Success)
assert.Equal(t, 1, editOutput.Replacements)
// Verify change
_, readOutput, _ = s.readFile(ctx, nil, readInput)
assert.Equal(t, "hello mcp", readOutput.Content)
// 4. Test file_edit (replace_all=true)
_ = s.medium.Write("multi.txt", "abc abc abc")
editInputMulti := EditDiffInput{
Path: "multi.txt",
OldString: "abc",
NewString: "xyz",
ReplaceAll: true,
}
_, editOutput, err = s.editDiff(ctx, nil, editInputMulti)
assert.NoError(t, err)
assert.Equal(t, 3, editOutput.Replacements)
content, _ = os.ReadFile(filepath.Join(tmpDir, "multi.txt"))
assert.Equal(t, "xyz xyz xyz", string(content))
// 5. Test dir_list
_ = s.medium.EnsureDir("subdir")
_ = s.medium.Write("subdir/file1.txt", "content1")
listInput := ListDirectoryInput{
Path: "subdir",
}
_, listOutput, err := s.listDirectory(ctx, nil, listInput)
assert.NoError(t, err)
assert.Len(t, listOutput.Entries, 1)
assert.Equal(t, "file1.txt", listOutput.Entries[0].Name)
assert.False(t, listOutput.Entries[0].IsDir)
}
func TestIntegration_ErrorPaths(t *testing.T) {
tmpDir := t.TempDir()
s, err := New(WithWorkspaceRoot(tmpDir))
assert.NoError(t, err)
ctx := context.Background()
// Read nonexistent file
_, _, err = s.readFile(ctx, nil, ReadFileInput{Path: "nonexistent.txt"})
assert.Error(t, err)
// Edit nonexistent file
_, _, err = s.editDiff(ctx, nil, EditDiffInput{
Path: "nonexistent.txt",
OldString: "foo",
NewString: "bar",
})
assert.Error(t, err)
// Edit with empty old_string
_, _, err = s.editDiff(ctx, nil, EditDiffInput{
Path: "test.txt",
OldString: "",
NewString: "bar",
})
assert.Error(t, err)
// Edit with old_string not found
_ = s.medium.Write("test.txt", "hello")
_, _, err = s.editDiff(ctx, nil, EditDiffInput{
Path: "test.txt",
OldString: "missing",
NewString: "bar",
})
assert.Error(t, err)
}

View file

@ -1,40 +0,0 @@
// SPDX-License-Identifier: EUPL-1.2
package mcp
import (
"slices"
"testing"
)
func TestService_Iterators(t *testing.T) {
svc, err := New(WithWorkspaceRoot(t.TempDir()))
if err != nil {
t.Fatal(err)
}
// Test ToolsSeq
tools := slices.Collect(svc.ToolsSeq())
if len(tools) == 0 {
t.Error("expected non-empty ToolsSeq")
}
if len(tools) != len(svc.Tools()) {
t.Errorf("ToolsSeq length %d != Tools() length %d", len(tools), len(svc.Tools()))
}
// Test SubsystemsSeq
subsystems := slices.Collect(svc.SubsystemsSeq())
if len(subsystems) != len(svc.Subsystems()) {
t.Errorf("SubsystemsSeq length %d != Subsystems() length %d", len(subsystems), len(svc.Subsystems()))
}
}
func TestRegistry_SplitTagSeq(t *testing.T) {
tag := "name,omitempty,json"
parts := slices.Collect(splitTagSeq(tag))
expected := []string{"name", "omitempty", "json"}
if !slices.Equal(parts, expected) {
t.Errorf("expected %v, got %v", expected, parts)
}
}

View file

@ -1,580 +0,0 @@
// SPDX-License-Identifier: EUPL-1.2
// Package mcp provides a lightweight MCP (Model Context Protocol) server for CLI use.
// For full GUI integration (display, webview, process management), see core-gui/pkg/mcp.
package mcp
import (
"context"
"errors"
"fmt"
"iter"
"net/http"
"os"
"path/filepath"
"slices"
"strings"
"forge.lthn.ai/core/go-io"
"forge.lthn.ai/core/go-log"
"forge.lthn.ai/core/go-process"
"forge.lthn.ai/core/go-ws"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// Service provides a lightweight MCP server with file operations only.
// For full GUI features, use the core-gui package.
type Service struct {
server *mcp.Server
workspaceRoot string // Root directory for file operations (empty = unrestricted)
medium io.Medium // Filesystem medium for sandboxed operations
subsystems []Subsystem // Additional subsystems registered via WithSubsystem
logger *log.Logger // Logger for tool execution auditing
processService *process.Service // Process management service (optional)
wsHub *ws.Hub // WebSocket hub for real-time streaming (optional)
wsServer *http.Server // WebSocket HTTP server (optional)
wsAddr string // WebSocket server address
tools []ToolRecord // Parallel tool registry for REST bridge
}
// Option configures a Service.
type Option func(*Service) error
// WithWorkspaceRoot restricts file operations to the given directory.
// All paths are validated to be within this directory.
// An empty string disables the restriction (not recommended).
func WithWorkspaceRoot(root string) Option {
return func(s *Service) error {
if root == "" {
// Explicitly disable restriction - use unsandboxed global
s.workspaceRoot = ""
s.medium = io.Local
return nil
}
// Create sandboxed medium for this workspace
abs, err := filepath.Abs(root)
if err != nil {
return fmt.Errorf("invalid workspace root: %w", err)
}
m, err := io.NewSandboxed(abs)
if err != nil {
return fmt.Errorf("failed to create workspace medium: %w", err)
}
s.workspaceRoot = abs
s.medium = m
return nil
}
}
// New creates a new MCP service with file operations.
// By default, restricts file access to the current working directory.
// Use WithWorkspaceRoot("") to disable restrictions (not recommended).
// Returns an error if initialization fails.
func New(opts ...Option) (*Service, error) {
impl := &mcp.Implementation{
Name: "core-cli",
Version: "0.1.0",
}
server := mcp.NewServer(impl, nil)
s := &Service{
server: server,
logger: log.Default(),
}
// Default to current working directory with sandboxed medium
cwd, err := os.Getwd()
if err != nil {
return nil, fmt.Errorf("failed to get working directory: %w", err)
}
s.workspaceRoot = cwd
m, err := io.NewSandboxed(cwd)
if err != nil {
return nil, fmt.Errorf("failed to create sandboxed medium: %w", err)
}
s.medium = m
// Apply options
for _, opt := range opts {
if err := opt(s); err != nil {
return nil, fmt.Errorf("failed to apply option: %w", err)
}
}
s.registerTools(s.server)
// Register subsystem tools.
for _, sub := range s.subsystems {
sub.RegisterTools(s.server)
}
return s, nil
}
// Subsystems returns the registered subsystems.
func (s *Service) Subsystems() []Subsystem {
return s.subsystems
}
// SubsystemsSeq returns an iterator over the registered subsystems.
func (s *Service) SubsystemsSeq() iter.Seq[Subsystem] {
return slices.Values(s.subsystems)
}
// Tools returns all recorded tool metadata.
func (s *Service) Tools() []ToolRecord {
return s.tools
}
// ToolsSeq returns an iterator over all recorded tool metadata.
func (s *Service) ToolsSeq() iter.Seq[ToolRecord] {
return slices.Values(s.tools)
}
// Shutdown gracefully shuts down all subsystems that support it.
func (s *Service) Shutdown(ctx context.Context) error {
for _, sub := range s.subsystems {
if sh, ok := sub.(SubsystemWithShutdown); ok {
if err := sh.Shutdown(ctx); err != nil {
return fmt.Errorf("shutdown %s: %w", sub.Name(), err)
}
}
}
return nil
}
// WithProcessService configures the process management service.
func WithProcessService(ps *process.Service) Option {
return func(s *Service) error {
s.processService = ps
return nil
}
}
// WithWSHub configures the WebSocket hub for real-time streaming.
func WithWSHub(hub *ws.Hub) Option {
return func(s *Service) error {
s.wsHub = hub
return nil
}
}
// WSHub returns the WebSocket hub.
func (s *Service) WSHub() *ws.Hub {
return s.wsHub
}
// ProcessService returns the process service.
func (s *Service) ProcessService() *process.Service {
return s.processService
}
// registerTools adds file operation tools to the MCP server.
func (s *Service) registerTools(server *mcp.Server) {
// File operations
addToolRecorded(s, server, "files", &mcp.Tool{
Name: "file_read",
Description: "Read the contents of a file",
}, s.readFile)
addToolRecorded(s, server, "files", &mcp.Tool{
Name: "file_write",
Description: "Write content to a file",
}, s.writeFile)
addToolRecorded(s, server, "files", &mcp.Tool{
Name: "file_delete",
Description: "Delete a file or empty directory",
}, s.deleteFile)
addToolRecorded(s, server, "files", &mcp.Tool{
Name: "file_rename",
Description: "Rename or move a file",
}, s.renameFile)
addToolRecorded(s, server, "files", &mcp.Tool{
Name: "file_exists",
Description: "Check if a file or directory exists",
}, s.fileExists)
addToolRecorded(s, server, "files", &mcp.Tool{
Name: "file_edit",
Description: "Edit a file by replacing old_string with new_string. Use replace_all=true to replace all occurrences.",
}, s.editDiff)
// Directory operations
addToolRecorded(s, server, "files", &mcp.Tool{
Name: "dir_list",
Description: "List contents of a directory",
}, s.listDirectory)
addToolRecorded(s, server, "files", &mcp.Tool{
Name: "dir_create",
Description: "Create a new directory",
}, s.createDirectory)
// Language detection
addToolRecorded(s, server, "language", &mcp.Tool{
Name: "lang_detect",
Description: "Detect the programming language of a file",
}, s.detectLanguage)
addToolRecorded(s, server, "language", &mcp.Tool{
Name: "lang_list",
Description: "Get list of supported programming languages",
}, s.getSupportedLanguages)
}
// Tool input/output types for MCP file operations.
// ReadFileInput contains parameters for reading a file.
type ReadFileInput struct {
Path string `json:"path"`
}
// ReadFileOutput contains the result of reading a file.
type ReadFileOutput struct {
Content string `json:"content"`
Language string `json:"language"`
Path string `json:"path"`
}
// WriteFileInput contains parameters for writing a file.
type WriteFileInput struct {
Path string `json:"path"`
Content string `json:"content"`
}
// WriteFileOutput contains the result of writing a file.
type WriteFileOutput struct {
Success bool `json:"success"`
Path string `json:"path"`
}
// ListDirectoryInput contains parameters for listing a directory.
type ListDirectoryInput struct {
Path string `json:"path"`
}
// ListDirectoryOutput contains the result of listing a directory.
type ListDirectoryOutput struct {
Entries []DirectoryEntry `json:"entries"`
Path string `json:"path"`
}
// DirectoryEntry represents a single entry in a directory listing.
type DirectoryEntry struct {
Name string `json:"name"`
Path string `json:"path"`
IsDir bool `json:"isDir"`
Size int64 `json:"size"`
}
// CreateDirectoryInput contains parameters for creating a directory.
type CreateDirectoryInput struct {
Path string `json:"path"`
}
// CreateDirectoryOutput contains the result of creating a directory.
type CreateDirectoryOutput struct {
Success bool `json:"success"`
Path string `json:"path"`
}
// DeleteFileInput contains parameters for deleting a file.
type DeleteFileInput struct {
Path string `json:"path"`
}
// DeleteFileOutput contains the result of deleting a file.
type DeleteFileOutput struct {
Success bool `json:"success"`
Path string `json:"path"`
}
// RenameFileInput contains parameters for renaming a file.
type RenameFileInput struct {
OldPath string `json:"oldPath"`
NewPath string `json:"newPath"`
}
// RenameFileOutput contains the result of renaming a file.
type RenameFileOutput struct {
Success bool `json:"success"`
OldPath string `json:"oldPath"`
NewPath string `json:"newPath"`
}
// FileExistsInput contains parameters for checking file existence.
type FileExistsInput struct {
Path string `json:"path"`
}
// FileExistsOutput contains the result of checking file existence.
type FileExistsOutput struct {
Exists bool `json:"exists"`
IsDir bool `json:"isDir"`
Path string `json:"path"`
}
// DetectLanguageInput contains parameters for detecting file language.
type DetectLanguageInput struct {
Path string `json:"path"`
}
// DetectLanguageOutput contains the detected programming language.
type DetectLanguageOutput struct {
Language string `json:"language"`
Path string `json:"path"`
}
// GetSupportedLanguagesInput is an empty struct for the languages query.
type GetSupportedLanguagesInput struct{}
// GetSupportedLanguagesOutput contains the list of supported languages.
type GetSupportedLanguagesOutput struct {
Languages []LanguageInfo `json:"languages"`
}
// LanguageInfo describes a supported programming language.
type LanguageInfo struct {
ID string `json:"id"`
Name string `json:"name"`
Extensions []string `json:"extensions"`
}
// EditDiffInput contains parameters for editing a file via diff.
type EditDiffInput struct {
Path string `json:"path"`
OldString string `json:"old_string"`
NewString string `json:"new_string"`
ReplaceAll bool `json:"replace_all,omitempty"`
}
// EditDiffOutput contains the result of a diff-based edit operation.
type EditDiffOutput struct {
Path string `json:"path"`
Success bool `json:"success"`
Replacements int `json:"replacements"`
}
// Tool handlers
func (s *Service) readFile(ctx context.Context, req *mcp.CallToolRequest, input ReadFileInput) (*mcp.CallToolResult, ReadFileOutput, error) {
content, err := s.medium.Read(input.Path)
if err != nil {
return nil, ReadFileOutput{}, fmt.Errorf("failed to read file: %w", err)
}
return nil, ReadFileOutput{
Content: content,
Language: detectLanguageFromPath(input.Path),
Path: input.Path,
}, nil
}
func (s *Service) writeFile(ctx context.Context, req *mcp.CallToolRequest, input WriteFileInput) (*mcp.CallToolResult, WriteFileOutput, error) {
// Medium.Write creates parent directories automatically
if err := s.medium.Write(input.Path, input.Content); err != nil {
return nil, WriteFileOutput{}, fmt.Errorf("failed to write file: %w", err)
}
return nil, WriteFileOutput{Success: true, Path: input.Path}, nil
}
func (s *Service) listDirectory(ctx context.Context, req *mcp.CallToolRequest, input ListDirectoryInput) (*mcp.CallToolResult, ListDirectoryOutput, error) {
entries, err := s.medium.List(input.Path)
if err != nil {
return nil, ListDirectoryOutput{}, fmt.Errorf("failed to list directory: %w", err)
}
result := make([]DirectoryEntry, 0, len(entries))
for _, e := range entries {
info, _ := e.Info()
var size int64
if info != nil {
size = info.Size()
}
result = append(result, DirectoryEntry{
Name: e.Name(),
Path: filepath.Join(input.Path, e.Name()), // Note: This might be relative path, client might expect absolute?
// Issue 103 says "Replace ... with local.Medium sandboxing".
// Previous code returned `filepath.Join(input.Path, e.Name())`.
// If input.Path is relative, this preserves it.
IsDir: e.IsDir(),
Size: size,
})
}
return nil, ListDirectoryOutput{Entries: result, Path: input.Path}, nil
}
func (s *Service) createDirectory(ctx context.Context, req *mcp.CallToolRequest, input CreateDirectoryInput) (*mcp.CallToolResult, CreateDirectoryOutput, error) {
if err := s.medium.EnsureDir(input.Path); err != nil {
return nil, CreateDirectoryOutput{}, fmt.Errorf("failed to create directory: %w", err)
}
return nil, CreateDirectoryOutput{Success: true, Path: input.Path}, nil
}
func (s *Service) deleteFile(ctx context.Context, req *mcp.CallToolRequest, input DeleteFileInput) (*mcp.CallToolResult, DeleteFileOutput, error) {
if err := s.medium.Delete(input.Path); err != nil {
return nil, DeleteFileOutput{}, fmt.Errorf("failed to delete file: %w", err)
}
return nil, DeleteFileOutput{Success: true, Path: input.Path}, nil
}
func (s *Service) renameFile(ctx context.Context, req *mcp.CallToolRequest, input RenameFileInput) (*mcp.CallToolResult, RenameFileOutput, error) {
if err := s.medium.Rename(input.OldPath, input.NewPath); err != nil {
return nil, RenameFileOutput{}, fmt.Errorf("failed to rename file: %w", err)
}
return nil, RenameFileOutput{Success: true, OldPath: input.OldPath, NewPath: input.NewPath}, nil
}
func (s *Service) fileExists(ctx context.Context, req *mcp.CallToolRequest, input FileExistsInput) (*mcp.CallToolResult, FileExistsOutput, error) {
exists := s.medium.IsFile(input.Path)
if exists {
return nil, FileExistsOutput{Exists: true, IsDir: false, Path: input.Path}, nil
}
// Check if it's a directory by attempting to list it
// List might fail if it's a file too (but we checked IsFile) or if doesn't exist.
_, err := s.medium.List(input.Path)
isDir := err == nil
// If List failed, it might mean it doesn't exist OR it's a special file or permissions.
// Assuming if List works, it's a directory.
// Refinement: If it doesn't exist, List returns error.
return nil, FileExistsOutput{Exists: isDir, IsDir: isDir, Path: input.Path}, nil
}
func (s *Service) detectLanguage(ctx context.Context, req *mcp.CallToolRequest, input DetectLanguageInput) (*mcp.CallToolResult, DetectLanguageOutput, error) {
lang := detectLanguageFromPath(input.Path)
return nil, DetectLanguageOutput{Language: lang, Path: input.Path}, nil
}
func (s *Service) getSupportedLanguages(ctx context.Context, req *mcp.CallToolRequest, input GetSupportedLanguagesInput) (*mcp.CallToolResult, GetSupportedLanguagesOutput, error) {
languages := []LanguageInfo{
{ID: "typescript", Name: "TypeScript", Extensions: []string{".ts", ".tsx"}},
{ID: "javascript", Name: "JavaScript", Extensions: []string{".js", ".jsx"}},
{ID: "go", Name: "Go", Extensions: []string{".go"}},
{ID: "python", Name: "Python", Extensions: []string{".py"}},
{ID: "rust", Name: "Rust", Extensions: []string{".rs"}},
{ID: "java", Name: "Java", Extensions: []string{".java"}},
{ID: "php", Name: "PHP", Extensions: []string{".php"}},
{ID: "ruby", Name: "Ruby", Extensions: []string{".rb"}},
{ID: "html", Name: "HTML", Extensions: []string{".html", ".htm"}},
{ID: "css", Name: "CSS", Extensions: []string{".css"}},
{ID: "json", Name: "JSON", Extensions: []string{".json"}},
{ID: "yaml", Name: "YAML", Extensions: []string{".yaml", ".yml"}},
{ID: "markdown", Name: "Markdown", Extensions: []string{".md", ".markdown"}},
{ID: "sql", Name: "SQL", Extensions: []string{".sql"}},
{ID: "shell", Name: "Shell", Extensions: []string{".sh", ".bash"}},
}
return nil, GetSupportedLanguagesOutput{Languages: languages}, nil
}
func (s *Service) editDiff(ctx context.Context, req *mcp.CallToolRequest, input EditDiffInput) (*mcp.CallToolResult, EditDiffOutput, error) {
if input.OldString == "" {
return nil, EditDiffOutput{}, errors.New("old_string cannot be empty")
}
content, err := s.medium.Read(input.Path)
if err != nil {
return nil, EditDiffOutput{}, fmt.Errorf("failed to read file: %w", err)
}
count := 0
if input.ReplaceAll {
count = strings.Count(content, input.OldString)
if count == 0 {
return nil, EditDiffOutput{}, errors.New("old_string not found in file")
}
content = strings.ReplaceAll(content, input.OldString, input.NewString)
} else {
if !strings.Contains(content, input.OldString) {
return nil, EditDiffOutput{}, errors.New("old_string not found in file")
}
content = strings.Replace(content, input.OldString, input.NewString, 1)
count = 1
}
if err := s.medium.Write(input.Path, content); err != nil {
return nil, EditDiffOutput{}, fmt.Errorf("failed to write file: %w", err)
}
return nil, EditDiffOutput{
Path: input.Path,
Success: true,
Replacements: count,
}, nil
}
// detectLanguageFromPath maps file extensions to language IDs.
func detectLanguageFromPath(path string) string {
ext := filepath.Ext(path)
switch ext {
case ".ts", ".tsx":
return "typescript"
case ".js", ".jsx":
return "javascript"
case ".go":
return "go"
case ".py":
return "python"
case ".rs":
return "rust"
case ".rb":
return "ruby"
case ".java":
return "java"
case ".php":
return "php"
case ".c", ".h":
return "c"
case ".cpp", ".hpp", ".cc", ".cxx":
return "cpp"
case ".cs":
return "csharp"
case ".html", ".htm":
return "html"
case ".css":
return "css"
case ".scss":
return "scss"
case ".json":
return "json"
case ".yaml", ".yml":
return "yaml"
case ".xml":
return "xml"
case ".md", ".markdown":
return "markdown"
case ".sql":
return "sql"
case ".sh", ".bash":
return "shell"
case ".swift":
return "swift"
case ".kt", ".kts":
return "kotlin"
default:
if filepath.Base(path) == "Dockerfile" {
return "dockerfile"
}
return "plaintext"
}
}
// Run starts the MCP server.
// If MCP_ADDR is set, it starts a TCP server.
// Otherwise, it starts a Stdio server.
func (s *Service) Run(ctx context.Context) error {
addr := os.Getenv("MCP_ADDR")
if addr != "" {
return s.ServeTCP(ctx, addr)
}
return s.server.Run(ctx, &mcp.StdioTransport{})
}
// Server returns the underlying MCP server for advanced configuration.
func (s *Service) Server() *mcp.Server {
return s.server
}

View file

@ -1,180 +0,0 @@
package mcp
import (
"os"
"path/filepath"
"testing"
)
func TestNew_Good_DefaultWorkspace(t *testing.T) {
cwd, err := os.Getwd()
if err != nil {
t.Fatalf("Failed to get working directory: %v", err)
}
s, err := New()
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
if s.workspaceRoot != cwd {
t.Errorf("Expected default workspace root %s, got %s", cwd, s.workspaceRoot)
}
if s.medium == nil {
t.Error("Expected medium to be set")
}
}
func TestNew_Good_CustomWorkspace(t *testing.T) {
tmpDir := t.TempDir()
s, err := New(WithWorkspaceRoot(tmpDir))
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
if s.workspaceRoot != tmpDir {
t.Errorf("Expected workspace root %s, got %s", tmpDir, s.workspaceRoot)
}
if s.medium == nil {
t.Error("Expected medium to be set")
}
}
func TestNew_Good_NoRestriction(t *testing.T) {
s, err := New(WithWorkspaceRoot(""))
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
if s.workspaceRoot != "" {
t.Errorf("Expected empty workspace root, got %s", s.workspaceRoot)
}
if s.medium == nil {
t.Error("Expected medium to be set (unsandboxed)")
}
}
func TestMedium_Good_ReadWrite(t *testing.T) {
tmpDir := t.TempDir()
s, err := New(WithWorkspaceRoot(tmpDir))
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
// Write a file
testContent := "hello world"
err = s.medium.Write("test.txt", testContent)
if err != nil {
t.Fatalf("Failed to write file: %v", err)
}
// Read it back
content, err := s.medium.Read("test.txt")
if err != nil {
t.Fatalf("Failed to read file: %v", err)
}
if content != testContent {
t.Errorf("Expected content %q, got %q", testContent, content)
}
// Verify file exists on disk
diskPath := filepath.Join(tmpDir, "test.txt")
if _, err := os.Stat(diskPath); os.IsNotExist(err) {
t.Error("File should exist on disk")
}
}
func TestMedium_Good_EnsureDir(t *testing.T) {
tmpDir := t.TempDir()
s, err := New(WithWorkspaceRoot(tmpDir))
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
err = s.medium.EnsureDir("subdir/nested")
if err != nil {
t.Fatalf("Failed to create directory: %v", err)
}
// Verify directory exists
diskPath := filepath.Join(tmpDir, "subdir", "nested")
info, err := os.Stat(diskPath)
if os.IsNotExist(err) {
t.Error("Directory should exist on disk")
}
if err == nil && !info.IsDir() {
t.Error("Path should be a directory")
}
}
func TestMedium_Good_IsFile(t *testing.T) {
tmpDir := t.TempDir()
s, err := New(WithWorkspaceRoot(tmpDir))
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
// File doesn't exist yet
if s.medium.IsFile("test.txt") {
t.Error("File should not exist yet")
}
// Create the file
_ = s.medium.Write("test.txt", "content")
// Now it should exist
if !s.medium.IsFile("test.txt") {
t.Error("File should exist after write")
}
}
func TestSandboxing_Traversal_Sanitized(t *testing.T) {
tmpDir := t.TempDir()
s, err := New(WithWorkspaceRoot(tmpDir))
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
// Path traversal is sanitized (.. becomes .), so ../secret.txt becomes
// ./secret.txt in the workspace. Since that file doesn't exist, we get
// a file not found error (not a traversal error).
_, err = s.medium.Read("../secret.txt")
if err == nil {
t.Error("Expected error (file not found)")
}
// Absolute paths are allowed through - they access the real filesystem.
// This is intentional for full filesystem access. Callers wanting sandboxing
// should validate inputs before calling Medium.
}
func TestSandboxing_Symlinks_Blocked(t *testing.T) {
tmpDir := t.TempDir()
outsideDir := t.TempDir()
// Create a target file outside workspace
targetFile := filepath.Join(outsideDir, "secret.txt")
if err := os.WriteFile(targetFile, []byte("secret"), 0644); err != nil {
t.Fatalf("Failed to create target file: %v", err)
}
// Create symlink inside workspace pointing outside
symlinkPath := filepath.Join(tmpDir, "link")
if err := os.Symlink(targetFile, symlinkPath); err != nil {
t.Skipf("Symlinks not supported: %v", err)
}
s, err := New(WithWorkspaceRoot(tmpDir))
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
// Symlinks pointing outside the sandbox root are blocked (security feature).
// The sandbox resolves the symlink target and rejects it because it escapes
// the workspace boundary.
_, err = s.medium.Read("link")
if err == nil {
t.Error("Expected permission denied for symlink escaping sandbox, but read succeeded")
}
}

View file

@ -1,149 +0,0 @@
// SPDX-License-Identifier: EUPL-1.2
package mcp
import (
"context"
"encoding/json"
"iter"
"reflect"
"strings"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// RESTHandler handles a tool call from a REST endpoint.
// It receives raw JSON input and returns the typed output or an error.
type RESTHandler func(ctx context.Context, body []byte) (any, error)
// ToolRecord captures metadata about a registered MCP tool.
type ToolRecord struct {
Name string // Tool name, e.g. "file_read"
Description string // Human-readable description
Group string // Subsystem group name, e.g. "files", "rag"
InputSchema map[string]any // JSON Schema from Go struct reflection
OutputSchema map[string]any // JSON Schema from Go struct reflection
RESTHandler RESTHandler // REST-callable handler created at registration time
}
// addToolRecorded registers a tool with the MCP server AND records its metadata.
// This is a generic function that captures the In/Out types for schema extraction.
// It also creates a RESTHandler closure that can unmarshal JSON to the correct
// input type and call the handler directly, enabling the MCP-to-REST bridge.
func addToolRecorded[In, Out any](s *Service, server *mcp.Server, group string, t *mcp.Tool, h mcp.ToolHandlerFor[In, Out]) {
mcp.AddTool(server, t, h)
restHandler := func(ctx context.Context, body []byte) (any, error) {
var input In
if len(body) > 0 {
if err := json.Unmarshal(body, &input); err != nil {
return nil, err
}
}
// nil: REST callers have no MCP request context.
// Tool handlers called via REST must not dereference CallToolRequest.
_, output, err := h(ctx, nil, input)
return output, err
}
s.tools = append(s.tools, ToolRecord{
Name: t.Name,
Description: t.Description,
Group: group,
InputSchema: structSchema(new(In)),
OutputSchema: structSchema(new(Out)),
RESTHandler: restHandler,
})
}
// structSchema builds a simple JSON Schema from a struct's json tags via reflection.
// Returns nil for non-struct types or empty structs.
func structSchema(v any) map[string]any {
t := reflect.TypeOf(v)
if t == nil {
return nil
}
if t.Kind() == reflect.Pointer {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return nil
}
if t.NumField() == 0 {
return map[string]any{"type": "object", "properties": map[string]any{}}
}
properties := make(map[string]any)
required := make([]string, 0)
for f := range t.Fields() {
f := f
if !f.IsExported() {
continue
}
jsonTag := f.Tag.Get("json")
if jsonTag == "-" {
continue
}
name := f.Name
isOptional := false
if jsonTag != "" {
parts := splitTag(jsonTag)
name = parts[0]
for _, p := range parts[1:] {
if p == "omitempty" {
isOptional = true
}
}
}
prop := map[string]any{
"type": goTypeToJSONType(f.Type),
}
properties[name] = prop
if !isOptional {
required = append(required, name)
}
}
schema := map[string]any{
"type": "object",
"properties": properties,
}
if len(required) > 0 {
schema["required"] = required
}
return schema
}
// splitTag splits a struct tag value by commas.
func splitTag(tag string) []string {
return strings.Split(tag, ",")
}
// splitTagSeq returns an iterator over the tag parts.
func splitTagSeq(tag string) iter.Seq[string] {
return strings.SplitSeq(tag, ",")
}
// goTypeToJSONType maps Go types to JSON Schema types.
func goTypeToJSONType(t reflect.Type) string {
switch t.Kind() {
case reflect.String:
return "string"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return "integer"
case reflect.Float32, reflect.Float64:
return "number"
case reflect.Bool:
return "boolean"
case reflect.Slice, reflect.Array:
return "array"
case reflect.Map, reflect.Struct:
return "object"
default:
return "string"
}
}

View file

@ -1,150 +0,0 @@
// SPDX-License-Identifier: EUPL-1.2
package mcp
import (
"testing"
)
func TestToolRegistry_Good_RecordsTools(t *testing.T) {
svc, err := New(WithWorkspaceRoot(t.TempDir()))
if err != nil {
t.Fatal(err)
}
tools := svc.Tools()
if len(tools) == 0 {
t.Fatal("expected non-empty tool registry")
}
found := false
for _, tr := range tools {
if tr.Name == "file_read" {
found = true
break
}
}
if !found {
t.Error("expected file_read in tool registry")
}
}
func TestToolRegistry_Good_SchemaExtraction(t *testing.T) {
svc, err := New(WithWorkspaceRoot(t.TempDir()))
if err != nil {
t.Fatal(err)
}
var record ToolRecord
for _, tr := range svc.Tools() {
if tr.Name == "file_read" {
record = tr
break
}
}
if record.Name == "" {
t.Fatal("file_read not found in registry")
}
if record.InputSchema == nil {
t.Fatal("expected non-nil InputSchema for file_read")
}
props, ok := record.InputSchema["properties"].(map[string]any)
if !ok {
t.Fatal("expected properties map in InputSchema")
}
if _, ok := props["path"]; !ok {
t.Error("expected 'path' property in file_read InputSchema")
}
}
func TestToolRegistry_Good_ToolCount(t *testing.T) {
svc, err := New(WithWorkspaceRoot(t.TempDir()))
if err != nil {
t.Fatal(err)
}
tools := svc.Tools()
// Built-in tools: file_read, file_write, file_delete, file_rename,
// file_exists, file_edit, dir_list, dir_create, lang_detect, lang_list
const expectedCount = 10
if len(tools) != expectedCount {
t.Errorf("expected %d tools, got %d", expectedCount, len(tools))
for _, tr := range tools {
t.Logf(" - %s (%s)", tr.Name, tr.Group)
}
}
}
func TestToolRegistry_Good_GroupAssignment(t *testing.T) {
svc, err := New(WithWorkspaceRoot(t.TempDir()))
if err != nil {
t.Fatal(err)
}
fileTools := []string{"file_read", "file_write", "file_delete", "file_rename", "file_exists", "file_edit", "dir_list", "dir_create"}
langTools := []string{"lang_detect", "lang_list"}
byName := make(map[string]ToolRecord)
for _, tr := range svc.Tools() {
byName[tr.Name] = tr
}
for _, name := range fileTools {
tr, ok := byName[name]
if !ok {
t.Errorf("tool %s not found in registry", name)
continue
}
if tr.Group != "files" {
t.Errorf("tool %s: expected group 'files', got %q", name, tr.Group)
}
}
for _, name := range langTools {
tr, ok := byName[name]
if !ok {
t.Errorf("tool %s not found in registry", name)
continue
}
if tr.Group != "language" {
t.Errorf("tool %s: expected group 'language', got %q", name, tr.Group)
}
}
}
func TestToolRegistry_Good_ToolRecordFields(t *testing.T) {
svc, err := New(WithWorkspaceRoot(t.TempDir()))
if err != nil {
t.Fatal(err)
}
var record ToolRecord
for _, tr := range svc.Tools() {
if tr.Name == "file_write" {
record = tr
break
}
}
if record.Name == "" {
t.Fatal("file_write not found in registry")
}
if record.Name != "file_write" {
t.Errorf("expected Name 'file_write', got %q", record.Name)
}
if record.Description == "" {
t.Error("expected non-empty Description")
}
if record.Group == "" {
t.Error("expected non-empty Group")
}
if record.InputSchema == nil {
t.Error("expected non-nil InputSchema")
}
if record.OutputSchema == nil {
t.Error("expected non-nil OutputSchema")
}
}

View file

@ -1,32 +0,0 @@
package mcp
import (
"context"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// Subsystem registers additional MCP tools at startup.
// Implementations should be safe to call concurrently.
type Subsystem interface {
// Name returns a human-readable identifier for logging.
Name() string
// RegisterTools adds tools to the MCP server during initialisation.
RegisterTools(server *mcp.Server)
}
// SubsystemWithShutdown extends Subsystem with graceful cleanup.
type SubsystemWithShutdown interface {
Subsystem
Shutdown(ctx context.Context) error
}
// WithSubsystem registers a subsystem whose tools will be added
// after the built-in tools during New().
func WithSubsystem(sub Subsystem) Option {
return func(s *Service) error {
s.subsystems = append(s.subsystems, sub)
return nil
}
}

View file

@ -1,114 +0,0 @@
package mcp
import (
"context"
"testing"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// stubSubsystem is a minimal Subsystem for testing.
type stubSubsystem struct {
name string
toolsRegistered bool
}
func (s *stubSubsystem) Name() string { return s.name }
func (s *stubSubsystem) RegisterTools(server *mcp.Server) {
s.toolsRegistered = true
}
// shutdownSubsystem tracks Shutdown calls.
type shutdownSubsystem struct {
stubSubsystem
shutdownCalled bool
shutdownErr error
}
func (s *shutdownSubsystem) Shutdown(_ context.Context) error {
s.shutdownCalled = true
return s.shutdownErr
}
func TestWithSubsystem_Good_Registration(t *testing.T) {
sub := &stubSubsystem{name: "test-sub"}
svc, err := New(WithSubsystem(sub))
if err != nil {
t.Fatalf("New() failed: %v", err)
}
if len(svc.Subsystems()) != 1 {
t.Fatalf("expected 1 subsystem, got %d", len(svc.Subsystems()))
}
if svc.Subsystems()[0].Name() != "test-sub" {
t.Errorf("expected name 'test-sub', got %q", svc.Subsystems()[0].Name())
}
}
func TestWithSubsystem_Good_ToolsRegistered(t *testing.T) {
sub := &stubSubsystem{name: "tools-sub"}
_, err := New(WithSubsystem(sub))
if err != nil {
t.Fatalf("New() failed: %v", err)
}
if !sub.toolsRegistered {
t.Error("expected RegisterTools to have been called")
}
}
func TestWithSubsystem_Good_MultipleSubsystems(t *testing.T) {
sub1 := &stubSubsystem{name: "sub-1"}
sub2 := &stubSubsystem{name: "sub-2"}
svc, err := New(WithSubsystem(sub1), WithSubsystem(sub2))
if err != nil {
t.Fatalf("New() failed: %v", err)
}
if len(svc.Subsystems()) != 2 {
t.Fatalf("expected 2 subsystems, got %d", len(svc.Subsystems()))
}
if !sub1.toolsRegistered || !sub2.toolsRegistered {
t.Error("expected all subsystems to have RegisterTools called")
}
}
func TestSubsystemShutdown_Good(t *testing.T) {
sub := &shutdownSubsystem{stubSubsystem: stubSubsystem{name: "shutdown-sub"}}
svc, err := New(WithSubsystem(sub))
if err != nil {
t.Fatalf("New() failed: %v", err)
}
if err := svc.Shutdown(context.Background()); err != nil {
t.Fatalf("Shutdown() failed: %v", err)
}
if !sub.shutdownCalled {
t.Error("expected Shutdown to have been called")
}
}
func TestSubsystemShutdown_Bad_Error(t *testing.T) {
sub := &shutdownSubsystem{
stubSubsystem: stubSubsystem{name: "fail-sub"},
shutdownErr: context.DeadlineExceeded,
}
svc, err := New(WithSubsystem(sub))
if err != nil {
t.Fatalf("New() failed: %v", err)
}
err = svc.Shutdown(context.Background())
if err == nil {
t.Fatal("expected error from Shutdown")
}
}
func TestSubsystemShutdown_Good_NoShutdownInterface(t *testing.T) {
// A plain Subsystem (without Shutdown) should not cause errors.
sub := &stubSubsystem{name: "plain-sub"}
svc, err := New(WithSubsystem(sub))
if err != nil {
t.Fatalf("New() failed: %v", err)
}
if err := svc.Shutdown(context.Background()); err != nil {
t.Fatalf("Shutdown() should succeed for non-shutdown subsystem: %v", err)
}
}

View file

@ -1,213 +0,0 @@
package mcp
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"time"
"forge.lthn.ai/core/go-ai/ai"
"forge.lthn.ai/core/go-log"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// Default values for metrics operations.
const (
DefaultMetricsSince = "7d"
DefaultMetricsLimit = 10
)
// MetricsRecordInput contains parameters for recording a metrics event.
type MetricsRecordInput struct {
Type string `json:"type"` // Event type (required)
AgentID string `json:"agent_id,omitempty"` // Agent identifier
Repo string `json:"repo,omitempty"` // Repository name
Data map[string]any `json:"data,omitempty"` // Additional event data
}
// MetricsRecordOutput contains the result of recording a metrics event.
type MetricsRecordOutput struct {
Success bool `json:"success"`
Timestamp time.Time `json:"timestamp"`
}
// MetricsQueryInput contains parameters for querying metrics.
type MetricsQueryInput struct {
Since string `json:"since,omitempty"` // Time range like "7d", "24h", "30m" (default: "7d")
}
// MetricsQueryOutput contains the results of a metrics query.
type MetricsQueryOutput struct {
Total int `json:"total"`
ByType []MetricCount `json:"by_type"`
ByRepo []MetricCount `json:"by_repo"`
ByAgent []MetricCount `json:"by_agent"`
Events []MetricEventBrief `json:"events"` // Most recent 10 events
}
// MetricCount represents a count for a specific key.
type MetricCount struct {
Key string `json:"key"`
Count int `json:"count"`
}
// MetricEventBrief represents a brief summary of an event.
type MetricEventBrief struct {
Type string `json:"type"`
Timestamp time.Time `json:"timestamp"`
AgentID string `json:"agent_id,omitempty"`
Repo string `json:"repo,omitempty"`
}
// registerMetricsTools adds metrics tools to the MCP server.
func (s *Service) registerMetricsTools(server *mcp.Server) {
mcp.AddTool(server, &mcp.Tool{
Name: "metrics_record",
Description: "Record a metrics event for AI/security tracking. Events are stored in daily JSONL files.",
}, s.metricsRecord)
mcp.AddTool(server, &mcp.Tool{
Name: "metrics_query",
Description: "Query metrics events and get aggregated statistics by type, repo, and agent.",
}, s.metricsQuery)
}
// metricsRecord handles the metrics_record tool call.
func (s *Service) metricsRecord(ctx context.Context, req *mcp.CallToolRequest, input MetricsRecordInput) (*mcp.CallToolResult, MetricsRecordOutput, error) {
s.logger.Info("MCP tool execution", "tool", "metrics_record", "type", input.Type, "agent_id", input.AgentID, "repo", input.Repo, "user", log.Username())
// Validate input
if input.Type == "" {
return nil, MetricsRecordOutput{}, errors.New("type cannot be empty")
}
// Create the event
event := ai.Event{
Type: input.Type,
Timestamp: time.Now(),
AgentID: input.AgentID,
Repo: input.Repo,
Data: input.Data,
}
// Record the event
if err := ai.Record(event); err != nil {
log.Error("mcp: metrics record failed", "type", input.Type, "err", err)
return nil, MetricsRecordOutput{}, fmt.Errorf("failed to record metrics: %w", err)
}
return nil, MetricsRecordOutput{
Success: true,
Timestamp: event.Timestamp,
}, nil
}
// metricsQuery handles the metrics_query tool call.
func (s *Service) metricsQuery(ctx context.Context, req *mcp.CallToolRequest, input MetricsQueryInput) (*mcp.CallToolResult, MetricsQueryOutput, error) {
// Apply defaults
since := input.Since
if since == "" {
since = DefaultMetricsSince
}
s.logger.Info("MCP tool execution", "tool", "metrics_query", "since", since, "user", log.Username())
// Parse the duration
duration, err := parseDuration(since)
if err != nil {
return nil, MetricsQueryOutput{}, fmt.Errorf("invalid since value: %w", err)
}
sinceTime := time.Now().Add(-duration)
// Read events
events, err := ai.ReadEvents(sinceTime)
if err != nil {
log.Error("mcp: metrics query failed", "since", since, "err", err)
return nil, MetricsQueryOutput{}, fmt.Errorf("failed to read metrics: %w", err)
}
// Get summary
summary := ai.Summary(events)
// Build output
output := MetricsQueryOutput{
Total: summary["total"].(int),
ByType: convertMetricCounts(summary["by_type"]),
ByRepo: convertMetricCounts(summary["by_repo"]),
ByAgent: convertMetricCounts(summary["by_agent"]),
Events: make([]MetricEventBrief, 0, DefaultMetricsLimit),
}
// Get recent events (last 10, most recent first)
startIdx := max(len(events)-DefaultMetricsLimit, 0)
for i := len(events) - 1; i >= startIdx; i-- {
ev := events[i]
output.Events = append(output.Events, MetricEventBrief{
Type: ev.Type,
Timestamp: ev.Timestamp,
AgentID: ev.AgentID,
Repo: ev.Repo,
})
}
return nil, output, nil
}
// convertMetricCounts converts the summary map format to MetricCount slice.
func convertMetricCounts(data any) []MetricCount {
if data == nil {
return []MetricCount{}
}
items, ok := data.([]map[string]any)
if !ok {
return []MetricCount{}
}
result := make([]MetricCount, len(items))
for i, item := range items {
key, _ := item["key"].(string)
count, _ := item["count"].(int)
result[i] = MetricCount{Key: key, Count: count}
}
return result
}
// parseDuration parses a duration string like "7d", "24h", "30m".
func parseDuration(s string) (time.Duration, error) {
if s == "" {
return 0, errors.New("duration cannot be empty")
}
s = strings.TrimSpace(s)
if len(s) < 2 {
return 0, fmt.Errorf("invalid duration format: %q", s)
}
// Get the numeric part and unit
unit := s[len(s)-1]
numStr := s[:len(s)-1]
num, err := strconv.Atoi(numStr)
if err != nil {
return 0, fmt.Errorf("invalid duration number: %q", numStr)
}
if num <= 0 {
return 0, fmt.Errorf("duration must be positive: %d", num)
}
switch unit {
case 'd':
return time.Duration(num) * 24 * time.Hour, nil
case 'h':
return time.Duration(num) * time.Hour, nil
case 'm':
return time.Duration(num) * time.Minute, nil
default:
return 0, fmt.Errorf("invalid duration unit: %q (expected d, h, or m)", string(unit))
}
}

View file

@ -1,207 +0,0 @@
package mcp
import (
"testing"
"time"
)
// TestMetricsToolsRegistered_Good verifies that metrics tools are registered with the MCP server.
func TestMetricsToolsRegistered_Good(t *testing.T) {
// Create a new MCP service - this should register all tools including metrics
s, err := New()
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
// The server should have registered the metrics tools
// We verify by checking that the server and logger exist
if s.server == nil {
t.Fatal("Server should not be nil")
}
if s.logger == nil {
t.Error("Logger should not be nil")
}
}
// TestMetricsRecordInput_Good verifies the MetricsRecordInput struct has expected fields.
func TestMetricsRecordInput_Good(t *testing.T) {
input := MetricsRecordInput{
Type: "tool_call",
AgentID: "agent-123",
Repo: "host-uk/core",
Data: map[string]any{"tool": "file_read", "duration_ms": 150},
}
if input.Type != "tool_call" {
t.Errorf("Expected type 'tool_call', got %q", input.Type)
}
if input.AgentID != "agent-123" {
t.Errorf("Expected agent_id 'agent-123', got %q", input.AgentID)
}
if input.Repo != "host-uk/core" {
t.Errorf("Expected repo 'host-uk/core', got %q", input.Repo)
}
if input.Data["tool"] != "file_read" {
t.Errorf("Expected data[tool] 'file_read', got %v", input.Data["tool"])
}
}
// TestMetricsRecordOutput_Good verifies the MetricsRecordOutput struct has expected fields.
func TestMetricsRecordOutput_Good(t *testing.T) {
ts := time.Now()
output := MetricsRecordOutput{
Success: true,
Timestamp: ts,
}
if !output.Success {
t.Error("Expected success to be true")
}
if output.Timestamp != ts {
t.Errorf("Expected timestamp %v, got %v", ts, output.Timestamp)
}
}
// TestMetricsQueryInput_Good verifies the MetricsQueryInput struct has expected fields.
func TestMetricsQueryInput_Good(t *testing.T) {
input := MetricsQueryInput{
Since: "7d",
}
if input.Since != "7d" {
t.Errorf("Expected since '7d', got %q", input.Since)
}
}
// TestMetricsQueryInput_Defaults verifies default values are handled correctly.
func TestMetricsQueryInput_Defaults(t *testing.T) {
input := MetricsQueryInput{}
// Empty since should use default when processed
if input.Since != "" {
t.Errorf("Expected empty since before defaults, got %q", input.Since)
}
}
// TestMetricsQueryOutput_Good verifies the MetricsQueryOutput struct has expected fields.
func TestMetricsQueryOutput_Good(t *testing.T) {
output := MetricsQueryOutput{
Total: 100,
ByType: []MetricCount{
{Key: "tool_call", Count: 50},
{Key: "query", Count: 30},
},
ByRepo: []MetricCount{
{Key: "host-uk/core", Count: 40},
},
ByAgent: []MetricCount{
{Key: "agent-123", Count: 25},
},
Events: []MetricEventBrief{
{Type: "tool_call", Timestamp: time.Now(), AgentID: "agent-1", Repo: "host-uk/core"},
},
}
if output.Total != 100 {
t.Errorf("Expected total 100, got %d", output.Total)
}
if len(output.ByType) != 2 {
t.Errorf("Expected 2 ByType entries, got %d", len(output.ByType))
}
if output.ByType[0].Key != "tool_call" {
t.Errorf("Expected ByType[0].Key 'tool_call', got %q", output.ByType[0].Key)
}
if output.ByType[0].Count != 50 {
t.Errorf("Expected ByType[0].Count 50, got %d", output.ByType[0].Count)
}
if len(output.Events) != 1 {
t.Errorf("Expected 1 event, got %d", len(output.Events))
}
}
// TestMetricCount_Good verifies the MetricCount struct has expected fields.
func TestMetricCount_Good(t *testing.T) {
mc := MetricCount{
Key: "tool_call",
Count: 42,
}
if mc.Key != "tool_call" {
t.Errorf("Expected key 'tool_call', got %q", mc.Key)
}
if mc.Count != 42 {
t.Errorf("Expected count 42, got %d", mc.Count)
}
}
// TestMetricEventBrief_Good verifies the MetricEventBrief struct has expected fields.
func TestMetricEventBrief_Good(t *testing.T) {
ts := time.Now()
ev := MetricEventBrief{
Type: "tool_call",
Timestamp: ts,
AgentID: "agent-123",
Repo: "host-uk/core",
}
if ev.Type != "tool_call" {
t.Errorf("Expected type 'tool_call', got %q", ev.Type)
}
if ev.Timestamp != ts {
t.Errorf("Expected timestamp %v, got %v", ts, ev.Timestamp)
}
if ev.AgentID != "agent-123" {
t.Errorf("Expected agent_id 'agent-123', got %q", ev.AgentID)
}
if ev.Repo != "host-uk/core" {
t.Errorf("Expected repo 'host-uk/core', got %q", ev.Repo)
}
}
// TestParseDuration_Good verifies the parseDuration helper handles various formats.
func TestParseDuration_Good(t *testing.T) {
tests := []struct {
input string
expected time.Duration
}{
{"7d", 7 * 24 * time.Hour},
{"24h", 24 * time.Hour},
{"30m", 30 * time.Minute},
{"1d", 24 * time.Hour},
{"14d", 14 * 24 * time.Hour},
{"1h", time.Hour},
{"10m", 10 * time.Minute},
}
for _, tc := range tests {
t.Run(tc.input, func(t *testing.T) {
d, err := parseDuration(tc.input)
if err != nil {
t.Fatalf("parseDuration(%q) returned error: %v", tc.input, err)
}
if d != tc.expected {
t.Errorf("parseDuration(%q) = %v, want %v", tc.input, d, tc.expected)
}
})
}
}
// TestParseDuration_Bad verifies parseDuration returns errors for invalid input.
func TestParseDuration_Bad(t *testing.T) {
tests := []string{
"",
"abc",
"7x",
"-7d",
}
for _, input := range tests {
t.Run(input, func(t *testing.T) {
_, err := parseDuration(input)
if err == nil {
t.Errorf("parseDuration(%q) should return error", input)
}
})
}
}

View file

@ -1,290 +0,0 @@
package mcp
import (
"context"
"errors"
"fmt"
"strings"
"forge.lthn.ai/core/go-inference"
"forge.lthn.ai/core/go-ml"
"forge.lthn.ai/core/go-log"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// MLSubsystem exposes ML inference and scoring tools via MCP.
type MLSubsystem struct {
service *ml.Service
logger *log.Logger
}
// NewMLSubsystem creates an MCP subsystem for ML tools.
func NewMLSubsystem(svc *ml.Service) *MLSubsystem {
return &MLSubsystem{
service: svc,
logger: log.Default(),
}
}
func (m *MLSubsystem) Name() string { return "ml" }
// RegisterTools adds ML tools to the MCP server.
func (m *MLSubsystem) RegisterTools(server *mcp.Server) {
mcp.AddTool(server, &mcp.Tool{
Name: "ml_generate",
Description: "Generate text via a configured ML inference backend.",
}, m.mlGenerate)
mcp.AddTool(server, &mcp.Tool{
Name: "ml_score",
Description: "Score a prompt/response pair using heuristic and LLM judge suites.",
}, m.mlScore)
mcp.AddTool(server, &mcp.Tool{
Name: "ml_probe",
Description: "Run capability probes against an inference backend.",
}, m.mlProbe)
mcp.AddTool(server, &mcp.Tool{
Name: "ml_status",
Description: "Show training and generation progress from InfluxDB.",
}, m.mlStatus)
mcp.AddTool(server, &mcp.Tool{
Name: "ml_backends",
Description: "List available inference backends and their status.",
}, m.mlBackends)
}
// --- Input/Output types ---
// MLGenerateInput contains parameters for text generation.
type MLGenerateInput struct {
Prompt string `json:"prompt"` // The prompt to generate from
Backend string `json:"backend,omitempty"` // Backend name (default: service default)
Model string `json:"model,omitempty"` // Model override
Temperature float64 `json:"temperature,omitempty"` // Sampling temperature
MaxTokens int `json:"max_tokens,omitempty"` // Maximum tokens to generate
}
// MLGenerateOutput contains the generation result.
type MLGenerateOutput struct {
Response string `json:"response"`
Backend string `json:"backend"`
Model string `json:"model,omitempty"`
}
// MLScoreInput contains parameters for scoring a response.
type MLScoreInput struct {
Prompt string `json:"prompt"` // The original prompt
Response string `json:"response"` // The model response to score
Suites string `json:"suites,omitempty"` // Comma-separated suites (default: heuristic)
}
// MLScoreOutput contains the scoring result.
type MLScoreOutput struct {
Heuristic *ml.HeuristicScores `json:"heuristic,omitempty"`
Semantic *ml.SemanticScores `json:"semantic,omitempty"`
Content *ml.ContentScores `json:"content,omitempty"`
}
// MLProbeInput contains parameters for running probes.
type MLProbeInput struct {
Backend string `json:"backend,omitempty"` // Backend name
Categories string `json:"categories,omitempty"` // Comma-separated categories to run
}
// MLProbeOutput contains probe results.
type MLProbeOutput struct {
Total int `json:"total"`
Results []MLProbeResultItem `json:"results"`
}
// MLProbeResultItem is a single probe result.
type MLProbeResultItem struct {
ID string `json:"id"`
Category string `json:"category"`
Response string `json:"response"`
}
// MLStatusInput contains parameters for the status query.
type MLStatusInput struct {
InfluxURL string `json:"influx_url,omitempty"` // InfluxDB URL override
InfluxDB string `json:"influx_db,omitempty"` // InfluxDB database override
}
// MLStatusOutput contains pipeline status.
type MLStatusOutput struct {
Status string `json:"status"`
}
// MLBackendsInput is empty — lists all backends.
type MLBackendsInput struct{}
// MLBackendsOutput lists available backends.
type MLBackendsOutput struct {
Backends []MLBackendInfo `json:"backends"`
Default string `json:"default"`
}
// MLBackendInfo describes a single backend.
type MLBackendInfo struct {
Name string `json:"name"`
Available bool `json:"available"`
}
// --- Tool handlers ---
// mlGenerate delegates to go-ml.Service.Generate, which internally uses
// InferenceAdapter to route generation through an inference.TextModel.
// Flow: go-ai → go-ml.Service.Generate → InferenceAdapter → inference.TextModel.
func (m *MLSubsystem) mlGenerate(ctx context.Context, req *mcp.CallToolRequest, input MLGenerateInput) (*mcp.CallToolResult, MLGenerateOutput, error) {
m.logger.Info("MCP tool execution", "tool", "ml_generate", "backend", input.Backend, "user", log.Username())
if input.Prompt == "" {
return nil, MLGenerateOutput{}, errors.New("prompt cannot be empty")
}
opts := ml.GenOpts{
Temperature: input.Temperature,
MaxTokens: input.MaxTokens,
Model: input.Model,
}
result, err := m.service.Generate(ctx, input.Backend, input.Prompt, opts)
if err != nil {
return nil, MLGenerateOutput{}, fmt.Errorf("generate: %w", err)
}
return nil, MLGenerateOutput{
Response: result.Text,
Backend: input.Backend,
Model: input.Model,
}, nil
}
func (m *MLSubsystem) mlScore(ctx context.Context, req *mcp.CallToolRequest, input MLScoreInput) (*mcp.CallToolResult, MLScoreOutput, error) {
m.logger.Info("MCP tool execution", "tool", "ml_score", "suites", input.Suites, "user", log.Username())
if input.Prompt == "" || input.Response == "" {
return nil, MLScoreOutput{}, errors.New("prompt and response cannot be empty")
}
suites := input.Suites
if suites == "" {
suites = "heuristic"
}
output := MLScoreOutput{}
for suite := range strings.SplitSeq(suites, ",") {
suite = strings.TrimSpace(suite)
switch suite {
case "heuristic":
output.Heuristic = ml.ScoreHeuristic(input.Response)
case "semantic":
judge := m.service.Judge()
if judge == nil {
return nil, MLScoreOutput{}, errors.New("semantic scoring requires a judge backend")
}
s, err := judge.ScoreSemantic(ctx, input.Prompt, input.Response)
if err != nil {
return nil, MLScoreOutput{}, fmt.Errorf("semantic score: %w", err)
}
output.Semantic = s
case "content":
return nil, MLScoreOutput{}, errors.New("content scoring requires a ContentProbe — use ml_probe instead")
}
}
return nil, output, nil
}
// mlProbe runs capability probes by generating responses via go-ml.Service.
// Flow: go-ai → go-ml.Service.Generate → InferenceAdapter → inference.TextModel.
func (m *MLSubsystem) mlProbe(ctx context.Context, req *mcp.CallToolRequest, input MLProbeInput) (*mcp.CallToolResult, MLProbeOutput, error) {
m.logger.Info("MCP tool execution", "tool", "ml_probe", "backend", input.Backend, "user", log.Username())
// Filter probes by category if specified.
probes := ml.CapabilityProbes
if input.Categories != "" {
cats := make(map[string]bool)
for c := range strings.SplitSeq(input.Categories, ",") {
cats[strings.TrimSpace(c)] = true
}
var filtered []ml.Probe
for _, p := range probes {
if cats[p.Category] {
filtered = append(filtered, p)
}
}
probes = filtered
}
var results []MLProbeResultItem
for _, probe := range probes {
result, err := m.service.Generate(ctx, input.Backend, probe.Prompt, ml.GenOpts{Temperature: 0.7, MaxTokens: 2048})
respText := result.Text
if err != nil {
respText = fmt.Sprintf("error: %v", err)
}
results = append(results, MLProbeResultItem{
ID: probe.ID,
Category: probe.Category,
Response: respText,
})
}
return nil, MLProbeOutput{
Total: len(results),
Results: results,
}, nil
}
func (m *MLSubsystem) mlStatus(ctx context.Context, req *mcp.CallToolRequest, input MLStatusInput) (*mcp.CallToolResult, MLStatusOutput, error) {
m.logger.Info("MCP tool execution", "tool", "ml_status", "user", log.Username())
url := input.InfluxURL
db := input.InfluxDB
if url == "" {
url = "http://localhost:8086"
}
if db == "" {
db = "lem"
}
influx := ml.NewInfluxClient(url, db)
var buf strings.Builder
if err := ml.PrintStatus(influx, &buf); err != nil {
return nil, MLStatusOutput{}, fmt.Errorf("status: %w", err)
}
return nil, MLStatusOutput{Status: buf.String()}, nil
}
// mlBackends enumerates registered backends via the go-inference registry,
// bypassing go-ml.Service entirely. This is the canonical source of truth
// for backend availability since all backends register with inference.Register().
func (m *MLSubsystem) mlBackends(ctx context.Context, req *mcp.CallToolRequest, input MLBackendsInput) (*mcp.CallToolResult, MLBackendsOutput, error) {
m.logger.Info("MCP tool execution", "tool", "ml_backends", "user", log.Username())
names := inference.List()
backends := make([]MLBackendInfo, 0, len(names))
for _, name := range names {
b, ok := inference.Get(name)
backends = append(backends, MLBackendInfo{
Name: name,
Available: ok && b.Available(),
})
}
defaultName := ""
if db, err := inference.Default(); err == nil {
defaultName = db.Name()
}
return nil, MLBackendsOutput{
Backends: backends,
Default: defaultName,
}, nil
}

View file

@ -1,479 +0,0 @@
package mcp
import (
"context"
"fmt"
"strings"
"testing"
"forge.lthn.ai/core/go-inference"
"forge.lthn.ai/core/go-ml"
"forge.lthn.ai/core/go/pkg/core"
"forge.lthn.ai/core/go-log"
)
// --- Mock backend for inference registry ---
// mockInferenceBackend implements inference.Backend for CI testing of ml_backends.
type mockInferenceBackend struct {
name string
available bool
}
func (m *mockInferenceBackend) Name() string { return m.name }
func (m *mockInferenceBackend) Available() bool { return m.available }
func (m *mockInferenceBackend) LoadModel(_ string, _ ...inference.LoadOption) (inference.TextModel, error) {
return nil, fmt.Errorf("mock backend: LoadModel not implemented")
}
// --- Mock ml.Backend for Generate ---
// mockMLBackend implements ml.Backend for CI testing.
type mockMLBackend struct {
name string
available bool
generateResp string
generateErr error
}
func (m *mockMLBackend) Name() string { return m.name }
func (m *mockMLBackend) Available() bool { return m.available }
func (m *mockMLBackend) Generate(_ context.Context, _ string, _ ml.GenOpts) (ml.Result, error) {
return ml.Result{Text: m.generateResp}, m.generateErr
}
func (m *mockMLBackend) Chat(_ context.Context, _ []ml.Message, _ ml.GenOpts) (ml.Result, error) {
return ml.Result{Text: m.generateResp}, m.generateErr
}
// newTestMLSubsystem creates an MLSubsystem with a real ml.Service for testing.
func newTestMLSubsystem(t *testing.T, backends ...ml.Backend) *MLSubsystem {
t.Helper()
c, err := core.New(
core.WithName("ml", ml.NewService(ml.Options{})),
)
if err != nil {
t.Fatalf("Failed to create framework core: %v", err)
}
svc, err := core.ServiceFor[*ml.Service](c, "ml")
if err != nil {
t.Fatalf("Failed to get ML service: %v", err)
}
// Register mock backends
for _, b := range backends {
svc.RegisterBackend(b.Name(), b)
}
return &MLSubsystem{
service: svc,
logger: log.Default(),
}
}
// --- Input/Output struct tests ---
// TestMLGenerateInput_Good verifies all fields can be set.
func TestMLGenerateInput_Good(t *testing.T) {
input := MLGenerateInput{
Prompt: "Hello world",
Backend: "test",
Model: "test-model",
Temperature: 0.7,
MaxTokens: 100,
}
if input.Prompt != "Hello world" {
t.Errorf("Expected prompt 'Hello world', got %q", input.Prompt)
}
if input.Temperature != 0.7 {
t.Errorf("Expected temperature 0.7, got %f", input.Temperature)
}
if input.MaxTokens != 100 {
t.Errorf("Expected max_tokens 100, got %d", input.MaxTokens)
}
}
// TestMLScoreInput_Good verifies all fields can be set.
func TestMLScoreInput_Good(t *testing.T) {
input := MLScoreInput{
Prompt: "test prompt",
Response: "test response",
Suites: "heuristic,semantic",
}
if input.Prompt != "test prompt" {
t.Errorf("Expected prompt 'test prompt', got %q", input.Prompt)
}
if input.Response != "test response" {
t.Errorf("Expected response 'test response', got %q", input.Response)
}
}
// TestMLProbeInput_Good verifies all fields can be set.
func TestMLProbeInput_Good(t *testing.T) {
input := MLProbeInput{
Backend: "test",
Categories: "reasoning,code",
}
if input.Backend != "test" {
t.Errorf("Expected backend 'test', got %q", input.Backend)
}
}
// TestMLStatusInput_Good verifies all fields can be set.
func TestMLStatusInput_Good(t *testing.T) {
input := MLStatusInput{
InfluxURL: "http://localhost:8086",
InfluxDB: "lem",
}
if input.InfluxURL != "http://localhost:8086" {
t.Errorf("Expected InfluxURL, got %q", input.InfluxURL)
}
}
// TestMLBackendsInput_Good verifies empty struct.
func TestMLBackendsInput_Good(t *testing.T) {
_ = MLBackendsInput{}
}
// TestMLBackendsOutput_Good verifies struct fields.
func TestMLBackendsOutput_Good(t *testing.T) {
output := MLBackendsOutput{
Backends: []MLBackendInfo{
{Name: "ollama", Available: true},
{Name: "llama", Available: false},
},
Default: "ollama",
}
if len(output.Backends) != 2 {
t.Fatalf("Expected 2 backends, got %d", len(output.Backends))
}
if output.Default != "ollama" {
t.Errorf("Expected default 'ollama', got %q", output.Default)
}
if !output.Backends[0].Available {
t.Error("Expected first backend to be available")
}
}
// TestMLProbeOutput_Good verifies struct fields.
func TestMLProbeOutput_Good(t *testing.T) {
output := MLProbeOutput{
Total: 2,
Results: []MLProbeResultItem{
{ID: "probe-1", Category: "reasoning", Response: "test"},
{ID: "probe-2", Category: "code", Response: "test2"},
},
}
if output.Total != 2 {
t.Errorf("Expected total 2, got %d", output.Total)
}
if output.Results[0].ID != "probe-1" {
t.Errorf("Expected ID 'probe-1', got %q", output.Results[0].ID)
}
}
// TestMLStatusOutput_Good verifies struct fields.
func TestMLStatusOutput_Good(t *testing.T) {
output := MLStatusOutput{Status: "OK: 5 training runs"}
if output.Status != "OK: 5 training runs" {
t.Errorf("Unexpected status: %q", output.Status)
}
}
// TestMLGenerateOutput_Good verifies struct fields.
func TestMLGenerateOutput_Good(t *testing.T) {
output := MLGenerateOutput{
Response: "Generated text here",
Backend: "ollama",
Model: "qwen3:8b",
}
if output.Response != "Generated text here" {
t.Errorf("Unexpected response: %q", output.Response)
}
}
// TestMLScoreOutput_Good verifies struct fields.
func TestMLScoreOutput_Good(t *testing.T) {
output := MLScoreOutput{
Heuristic: &ml.HeuristicScores{},
}
if output.Heuristic == nil {
t.Error("Expected Heuristic to be set")
}
if output.Semantic != nil {
t.Error("Expected Semantic to be nil")
}
}
// --- Handler validation tests ---
// TestMLGenerate_Bad_EmptyPrompt verifies empty prompt returns error.
func TestMLGenerate_Bad_EmptyPrompt(t *testing.T) {
m := newTestMLSubsystem(t)
ctx := context.Background()
_, _, err := m.mlGenerate(ctx, nil, MLGenerateInput{})
if err == nil {
t.Fatal("Expected error for empty prompt")
}
if !strings.Contains(err.Error(), "prompt cannot be empty") {
t.Errorf("Unexpected error: %v", err)
}
}
// TestMLGenerate_Good_WithMockBackend verifies generate works with a mock backend.
func TestMLGenerate_Good_WithMockBackend(t *testing.T) {
mock := &mockMLBackend{
name: "test-mock",
available: true,
generateResp: "mock response",
}
m := newTestMLSubsystem(t, mock)
ctx := context.Background()
_, out, err := m.mlGenerate(ctx, nil, MLGenerateInput{
Prompt: "test",
Backend: "test-mock",
})
if err != nil {
t.Fatalf("mlGenerate failed: %v", err)
}
if out.Response != "mock response" {
t.Errorf("Expected 'mock response', got %q", out.Response)
}
}
// TestMLGenerate_Bad_NoBackend verifies generate fails gracefully without a backend.
func TestMLGenerate_Bad_NoBackend(t *testing.T) {
m := newTestMLSubsystem(t)
ctx := context.Background()
_, _, err := m.mlGenerate(ctx, nil, MLGenerateInput{
Prompt: "test",
Backend: "nonexistent",
})
if err == nil {
t.Fatal("Expected error for missing backend")
}
if !strings.Contains(err.Error(), "no backend available") {
t.Errorf("Unexpected error: %v", err)
}
}
// TestMLScore_Bad_EmptyPrompt verifies empty prompt returns error.
func TestMLScore_Bad_EmptyPrompt(t *testing.T) {
m := newTestMLSubsystem(t)
ctx := context.Background()
_, _, err := m.mlScore(ctx, nil, MLScoreInput{Response: "some"})
if err == nil {
t.Fatal("Expected error for empty prompt")
}
}
// TestMLScore_Bad_EmptyResponse verifies empty response returns error.
func TestMLScore_Bad_EmptyResponse(t *testing.T) {
m := newTestMLSubsystem(t)
ctx := context.Background()
_, _, err := m.mlScore(ctx, nil, MLScoreInput{Prompt: "some"})
if err == nil {
t.Fatal("Expected error for empty response")
}
}
// TestMLScore_Good_Heuristic verifies heuristic scoring without live services.
func TestMLScore_Good_Heuristic(t *testing.T) {
m := newTestMLSubsystem(t)
ctx := context.Background()
_, out, err := m.mlScore(ctx, nil, MLScoreInput{
Prompt: "What is Go?",
Response: "Go is a statically typed, compiled programming language designed at Google.",
Suites: "heuristic",
})
if err != nil {
t.Fatalf("mlScore failed: %v", err)
}
if out.Heuristic == nil {
t.Fatal("Expected heuristic scores to be set")
}
}
// TestMLScore_Good_DefaultSuite verifies default suite is heuristic.
func TestMLScore_Good_DefaultSuite(t *testing.T) {
m := newTestMLSubsystem(t)
ctx := context.Background()
_, out, err := m.mlScore(ctx, nil, MLScoreInput{
Prompt: "What is Go?",
Response: "Go is a statically typed, compiled programming language designed at Google.",
})
if err != nil {
t.Fatalf("mlScore failed: %v", err)
}
if out.Heuristic == nil {
t.Fatal("Expected heuristic scores (default suite)")
}
}
// TestMLScore_Bad_SemanticNoJudge verifies semantic scoring fails without a judge.
func TestMLScore_Bad_SemanticNoJudge(t *testing.T) {
m := newTestMLSubsystem(t)
ctx := context.Background()
_, _, err := m.mlScore(ctx, nil, MLScoreInput{
Prompt: "test",
Response: "test",
Suites: "semantic",
})
if err == nil {
t.Fatal("Expected error for semantic scoring without judge")
}
if !strings.Contains(err.Error(), "requires a judge") {
t.Errorf("Unexpected error: %v", err)
}
}
// TestMLScore_Bad_ContentSuite verifies content suite redirects to ml_probe.
func TestMLScore_Bad_ContentSuite(t *testing.T) {
m := newTestMLSubsystem(t)
ctx := context.Background()
_, _, err := m.mlScore(ctx, nil, MLScoreInput{
Prompt: "test",
Response: "test",
Suites: "content",
})
if err == nil {
t.Fatal("Expected error for content suite")
}
if !strings.Contains(err.Error(), "ContentProbe") {
t.Errorf("Unexpected error: %v", err)
}
}
// TestMLProbe_Good_WithMockBackend verifies probes run with mock backend.
func TestMLProbe_Good_WithMockBackend(t *testing.T) {
mock := &mockMLBackend{
name: "probe-mock",
available: true,
generateResp: "probe response",
}
m := newTestMLSubsystem(t, mock)
ctx := context.Background()
_, out, err := m.mlProbe(ctx, nil, MLProbeInput{
Backend: "probe-mock",
Categories: "reasoning",
})
if err != nil {
t.Fatalf("mlProbe failed: %v", err)
}
// Should have run probes in the "reasoning" category
for _, r := range out.Results {
if r.Category != "reasoning" {
t.Errorf("Expected category 'reasoning', got %q", r.Category)
}
if r.Response != "probe response" {
t.Errorf("Expected 'probe response', got %q", r.Response)
}
}
if out.Total != len(out.Results) {
t.Errorf("Expected total %d, got %d", len(out.Results), out.Total)
}
}
// TestMLProbe_Good_NoCategory verifies all probes run without category filter.
func TestMLProbe_Good_NoCategory(t *testing.T) {
mock := &mockMLBackend{
name: "all-probe-mock",
available: true,
generateResp: "ok",
}
m := newTestMLSubsystem(t, mock)
ctx := context.Background()
_, out, err := m.mlProbe(ctx, nil, MLProbeInput{Backend: "all-probe-mock"})
if err != nil {
t.Fatalf("mlProbe failed: %v", err)
}
// Should run all 23 probes
if out.Total != len(ml.CapabilityProbes) {
t.Errorf("Expected %d probes, got %d", len(ml.CapabilityProbes), out.Total)
}
}
// TestMLBackends_Good_EmptyRegistry verifies empty result when no backends registered.
func TestMLBackends_Good_EmptyRegistry(t *testing.T) {
m := newTestMLSubsystem(t)
ctx := context.Background()
// Note: inference.List() returns global registry state.
// This test verifies the handler runs without panic.
_, out, err := m.mlBackends(ctx, nil, MLBackendsInput{})
if err != nil {
t.Fatalf("mlBackends failed: %v", err)
}
// We can't guarantee what's in the global registry, but it should not panic
_ = out
}
// TestMLBackends_Good_WithMockInferenceBackend verifies registered backend appears.
func TestMLBackends_Good_WithMockInferenceBackend(t *testing.T) {
// Register a mock backend in the global inference registry
mock := &mockInferenceBackend{name: "test-ci-mock", available: true}
inference.Register(mock)
m := newTestMLSubsystem(t)
ctx := context.Background()
_, out, err := m.mlBackends(ctx, nil, MLBackendsInput{})
if err != nil {
t.Fatalf("mlBackends failed: %v", err)
}
found := false
for _, b := range out.Backends {
if b.Name == "test-ci-mock" {
found = true
if !b.Available {
t.Error("Expected mock backend to be available")
}
}
}
if !found {
t.Error("Expected to find 'test-ci-mock' in backends list")
}
}
// TestMLSubsystem_Good_Name verifies subsystem name.
func TestMLSubsystem_Good_Name(t *testing.T) {
m := newTestMLSubsystem(t)
if m.Name() != "ml" {
t.Errorf("Expected name 'ml', got %q", m.Name())
}
}
// TestNewMLSubsystem_Good verifies constructor.
func TestNewMLSubsystem_Good(t *testing.T) {
c, err := core.New(
core.WithName("ml", ml.NewService(ml.Options{})),
)
if err != nil {
t.Fatalf("Failed to create core: %v", err)
}
svc, err := core.ServiceFor[*ml.Service](c, "ml")
if err != nil {
t.Fatalf("Failed to get service: %v", err)
}
sub := NewMLSubsystem(svc)
if sub == nil {
t.Fatal("Expected non-nil subsystem")
}
if sub.service != svc {
t.Error("Expected service to be set")
}
if sub.logger == nil {
t.Error("Expected logger to be set")
}
}

View file

@ -1,305 +0,0 @@
package mcp
import (
"context"
"errors"
"fmt"
"time"
"forge.lthn.ai/core/go-log"
"forge.lthn.ai/core/go-process"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// errIDEmpty is returned when a process tool call omits the required ID.
var errIDEmpty = errors.New("id cannot be empty")
// ProcessStartInput contains parameters for starting a new process.
type ProcessStartInput struct {
Command string `json:"command"` // The command to run
Args []string `json:"args,omitempty"` // Command arguments
Dir string `json:"dir,omitempty"` // Working directory
Env []string `json:"env,omitempty"` // Environment variables (KEY=VALUE format)
}
// ProcessStartOutput contains the result of starting a process.
type ProcessStartOutput struct {
ID string `json:"id"`
PID int `json:"pid"`
Command string `json:"command"`
Args []string `json:"args"`
StartedAt time.Time `json:"startedAt"`
}
// ProcessStopInput contains parameters for gracefully stopping a process.
type ProcessStopInput struct {
ID string `json:"id"` // Process ID to stop
}
// ProcessStopOutput contains the result of stopping a process.
type ProcessStopOutput struct {
ID string `json:"id"`
Success bool `json:"success"`
Message string `json:"message,omitempty"`
}
// ProcessKillInput contains parameters for force killing a process.
type ProcessKillInput struct {
ID string `json:"id"` // Process ID to kill
}
// ProcessKillOutput contains the result of killing a process.
type ProcessKillOutput struct {
ID string `json:"id"`
Success bool `json:"success"`
Message string `json:"message,omitempty"`
}
// ProcessListInput contains parameters for listing processes.
type ProcessListInput struct {
RunningOnly bool `json:"running_only,omitempty"` // If true, only return running processes
}
// ProcessListOutput contains the list of processes.
type ProcessListOutput struct {
Processes []ProcessInfo `json:"processes"`
Total int `json:"total"`
}
// ProcessInfo represents information about a process.
type ProcessInfo struct {
ID string `json:"id"`
Command string `json:"command"`
Args []string `json:"args"`
Dir string `json:"dir"`
Status string `json:"status"`
PID int `json:"pid"`
ExitCode int `json:"exitCode"`
StartedAt time.Time `json:"startedAt"`
Duration time.Duration `json:"duration"`
}
// ProcessOutputInput contains parameters for getting process output.
type ProcessOutputInput struct {
ID string `json:"id"` // Process ID
}
// ProcessOutputOutput contains the captured output of a process.
type ProcessOutputOutput struct {
ID string `json:"id"`
Output string `json:"output"`
}
// ProcessInputInput contains parameters for sending input to a process.
type ProcessInputInput struct {
ID string `json:"id"` // Process ID
Input string `json:"input"` // Input to send to stdin
}
// ProcessInputOutput contains the result of sending input to a process.
type ProcessInputOutput struct {
ID string `json:"id"`
Success bool `json:"success"`
Message string `json:"message,omitempty"`
}
// registerProcessTools adds process management tools to the MCP server.
// Returns false if process service is not available.
func (s *Service) registerProcessTools(server *mcp.Server) bool {
if s.processService == nil {
return false
}
mcp.AddTool(server, &mcp.Tool{
Name: "process_start",
Description: "Start a new external process. Returns process ID for tracking.",
}, s.processStart)
mcp.AddTool(server, &mcp.Tool{
Name: "process_stop",
Description: "Gracefully stop a running process by ID.",
}, s.processStop)
mcp.AddTool(server, &mcp.Tool{
Name: "process_kill",
Description: "Force kill a process by ID. Use when process_stop doesn't work.",
}, s.processKill)
mcp.AddTool(server, &mcp.Tool{
Name: "process_list",
Description: "List all managed processes. Use running_only=true for only active processes.",
}, s.processList)
mcp.AddTool(server, &mcp.Tool{
Name: "process_output",
Description: "Get the captured output of a process by ID.",
}, s.processOutput)
mcp.AddTool(server, &mcp.Tool{
Name: "process_input",
Description: "Send input to a running process stdin.",
}, s.processInput)
return true
}
// processStart handles the process_start tool call.
func (s *Service) processStart(ctx context.Context, req *mcp.CallToolRequest, input ProcessStartInput) (*mcp.CallToolResult, ProcessStartOutput, error) {
s.logger.Security("MCP tool execution", "tool", "process_start", "command", input.Command, "args", input.Args, "dir", input.Dir, "user", log.Username())
if input.Command == "" {
return nil, ProcessStartOutput{}, errors.New("command cannot be empty")
}
opts := process.RunOptions{
Command: input.Command,
Args: input.Args,
Dir: input.Dir,
Env: input.Env,
}
proc, err := s.processService.StartWithOptions(ctx, opts)
if err != nil {
log.Error("mcp: process start failed", "command", input.Command, "err", err)
return nil, ProcessStartOutput{}, fmt.Errorf("failed to start process: %w", err)
}
info := proc.Info()
return nil, ProcessStartOutput{
ID: proc.ID,
PID: info.PID,
Command: proc.Command,
Args: proc.Args,
StartedAt: proc.StartedAt,
}, nil
}
// processStop handles the process_stop tool call.
func (s *Service) processStop(ctx context.Context, req *mcp.CallToolRequest, input ProcessStopInput) (*mcp.CallToolResult, ProcessStopOutput, error) {
s.logger.Security("MCP tool execution", "tool", "process_stop", "id", input.ID, "user", log.Username())
if input.ID == "" {
return nil, ProcessStopOutput{}, errIDEmpty
}
proc, err := s.processService.Get(input.ID)
if err != nil {
log.Error("mcp: process stop failed", "id", input.ID, "err", err)
return nil, ProcessStopOutput{}, fmt.Errorf("process not found: %w", err)
}
// For graceful stop, we use Kill() which sends SIGKILL
// A more sophisticated implementation could use SIGTERM first
if err := proc.Kill(); err != nil {
log.Error("mcp: process stop kill failed", "id", input.ID, "err", err)
return nil, ProcessStopOutput{}, fmt.Errorf("failed to stop process: %w", err)
}
return nil, ProcessStopOutput{
ID: input.ID,
Success: true,
Message: "Process stop signal sent",
}, nil
}
// processKill handles the process_kill tool call.
func (s *Service) processKill(ctx context.Context, req *mcp.CallToolRequest, input ProcessKillInput) (*mcp.CallToolResult, ProcessKillOutput, error) {
s.logger.Security("MCP tool execution", "tool", "process_kill", "id", input.ID, "user", log.Username())
if input.ID == "" {
return nil, ProcessKillOutput{}, errIDEmpty
}
if err := s.processService.Kill(input.ID); err != nil {
log.Error("mcp: process kill failed", "id", input.ID, "err", err)
return nil, ProcessKillOutput{}, fmt.Errorf("failed to kill process: %w", err)
}
return nil, ProcessKillOutput{
ID: input.ID,
Success: true,
Message: "Process killed",
}, nil
}
// processList handles the process_list tool call.
func (s *Service) processList(ctx context.Context, req *mcp.CallToolRequest, input ProcessListInput) (*mcp.CallToolResult, ProcessListOutput, error) {
s.logger.Info("MCP tool execution", "tool", "process_list", "running_only", input.RunningOnly, "user", log.Username())
var procs []*process.Process
if input.RunningOnly {
procs = s.processService.Running()
} else {
procs = s.processService.List()
}
result := make([]ProcessInfo, len(procs))
for i, p := range procs {
info := p.Info()
result[i] = ProcessInfo{
ID: info.ID,
Command: info.Command,
Args: info.Args,
Dir: info.Dir,
Status: string(info.Status),
PID: info.PID,
ExitCode: info.ExitCode,
StartedAt: info.StartedAt,
Duration: info.Duration,
}
}
return nil, ProcessListOutput{
Processes: result,
Total: len(result),
}, nil
}
// processOutput handles the process_output tool call.
func (s *Service) processOutput(ctx context.Context, req *mcp.CallToolRequest, input ProcessOutputInput) (*mcp.CallToolResult, ProcessOutputOutput, error) {
s.logger.Info("MCP tool execution", "tool", "process_output", "id", input.ID, "user", log.Username())
if input.ID == "" {
return nil, ProcessOutputOutput{}, errIDEmpty
}
output, err := s.processService.Output(input.ID)
if err != nil {
log.Error("mcp: process output failed", "id", input.ID, "err", err)
return nil, ProcessOutputOutput{}, fmt.Errorf("failed to get process output: %w", err)
}
return nil, ProcessOutputOutput{
ID: input.ID,
Output: output,
}, nil
}
// processInput handles the process_input tool call.
func (s *Service) processInput(ctx context.Context, req *mcp.CallToolRequest, input ProcessInputInput) (*mcp.CallToolResult, ProcessInputOutput, error) {
s.logger.Security("MCP tool execution", "tool", "process_input", "id", input.ID, "user", log.Username())
if input.ID == "" {
return nil, ProcessInputOutput{}, errIDEmpty
}
if input.Input == "" {
return nil, ProcessInputOutput{}, errors.New("input cannot be empty")
}
proc, err := s.processService.Get(input.ID)
if err != nil {
log.Error("mcp: process input get failed", "id", input.ID, "err", err)
return nil, ProcessInputOutput{}, fmt.Errorf("process not found: %w", err)
}
if err := proc.SendInput(input.Input); err != nil {
log.Error("mcp: process input send failed", "id", input.ID, "err", err)
return nil, ProcessInputOutput{}, fmt.Errorf("failed to send input: %w", err)
}
return nil, ProcessInputOutput{
ID: input.ID,
Success: true,
Message: "Input sent successfully",
}, nil
}

View file

@ -1,515 +0,0 @@
package mcp
import (
"context"
"strings"
"testing"
"time"
"forge.lthn.ai/core/go/pkg/core"
"forge.lthn.ai/core/go-process"
)
// newTestProcessService creates a real process.Service backed by a core.Core for CI tests.
func newTestProcessService(t *testing.T) *process.Service {
t.Helper()
c, err := core.New(
core.WithName("process", process.NewService(process.Options{})),
)
if err != nil {
t.Fatalf("Failed to create framework core: %v", err)
}
svc, err := core.ServiceFor[*process.Service](c, "process")
if err != nil {
t.Fatalf("Failed to get process service: %v", err)
}
// Start services (calls OnStartup)
if err := c.ServiceStartup(context.Background(), nil); err != nil {
t.Fatalf("Failed to start core: %v", err)
}
t.Cleanup(func() {
_ = c.ServiceShutdown(context.Background())
})
return svc
}
// newTestMCPWithProcess creates an MCP Service wired to a real process.Service.
func newTestMCPWithProcess(t *testing.T) (*Service, *process.Service) {
t.Helper()
ps := newTestProcessService(t)
s, err := New(WithProcessService(ps))
if err != nil {
t.Fatalf("Failed to create MCP service: %v", err)
}
return s, ps
}
// --- CI-safe handler tests ---
// TestProcessStart_Good_Echo starts "echo hello" and verifies the output.
func TestProcessStart_Good_Echo(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
ctx := context.Background()
_, out, err := s.processStart(ctx, nil, ProcessStartInput{
Command: "echo",
Args: []string{"hello"},
})
if err != nil {
t.Fatalf("processStart failed: %v", err)
}
if out.ID == "" {
t.Error("Expected non-empty process ID")
}
if out.Command != "echo" {
t.Errorf("Expected command 'echo', got %q", out.Command)
}
if out.PID <= 0 {
t.Errorf("Expected positive PID, got %d", out.PID)
}
if out.StartedAt.IsZero() {
t.Error("Expected non-zero StartedAt")
}
}
// TestProcessStart_Bad_EmptyCommand verifies empty command returns an error.
func TestProcessStart_Bad_EmptyCommand(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
ctx := context.Background()
_, _, err := s.processStart(ctx, nil, ProcessStartInput{})
if err == nil {
t.Fatal("Expected error for empty command")
}
if !strings.Contains(err.Error(), "command cannot be empty") {
t.Errorf("Unexpected error: %v", err)
}
}
// TestProcessStart_Bad_NonexistentCommand verifies an invalid command returns an error.
func TestProcessStart_Bad_NonexistentCommand(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
ctx := context.Background()
_, _, err := s.processStart(ctx, nil, ProcessStartInput{
Command: "/nonexistent/binary/that/does/not/exist",
})
if err == nil {
t.Fatal("Expected error for nonexistent command")
}
}
// TestProcessList_Good_Empty verifies list is empty initially.
func TestProcessList_Good_Empty(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
ctx := context.Background()
_, out, err := s.processList(ctx, nil, ProcessListInput{})
if err != nil {
t.Fatalf("processList failed: %v", err)
}
if out.Total != 0 {
t.Errorf("Expected 0 processes, got %d", out.Total)
}
}
// TestProcessList_Good_AfterStart verifies a started process appears in list.
func TestProcessList_Good_AfterStart(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
ctx := context.Background()
// Start a short-lived process
_, startOut, err := s.processStart(ctx, nil, ProcessStartInput{
Command: "echo",
Args: []string{"listing"},
})
if err != nil {
t.Fatalf("processStart failed: %v", err)
}
// Give it a moment to register
time.Sleep(50 * time.Millisecond)
// List all processes (including exited)
_, listOut, err := s.processList(ctx, nil, ProcessListInput{})
if err != nil {
t.Fatalf("processList failed: %v", err)
}
if listOut.Total < 1 {
t.Fatalf("Expected at least 1 process, got %d", listOut.Total)
}
found := false
for _, p := range listOut.Processes {
if p.ID == startOut.ID {
found = true
if p.Command != "echo" {
t.Errorf("Expected command 'echo', got %q", p.Command)
}
}
}
if !found {
t.Errorf("Process %s not found in list", startOut.ID)
}
}
// TestProcessList_Good_RunningOnly verifies filtering for running-only processes.
func TestProcessList_Good_RunningOnly(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
ctx := context.Background()
// Start a process that exits quickly
_, _, err := s.processStart(ctx, nil, ProcessStartInput{
Command: "echo",
Args: []string{"done"},
})
if err != nil {
t.Fatalf("processStart failed: %v", err)
}
// Wait for it to exit
time.Sleep(100 * time.Millisecond)
// Running-only should be empty now
_, listOut, err := s.processList(ctx, nil, ProcessListInput{RunningOnly: true})
if err != nil {
t.Fatalf("processList failed: %v", err)
}
if listOut.Total != 0 {
t.Errorf("Expected 0 running processes after echo exits, got %d", listOut.Total)
}
}
// TestProcessOutput_Good_Echo verifies output capture from echo.
func TestProcessOutput_Good_Echo(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
ctx := context.Background()
_, startOut, err := s.processStart(ctx, nil, ProcessStartInput{
Command: "echo",
Args: []string{"output_test"},
})
if err != nil {
t.Fatalf("processStart failed: %v", err)
}
// Wait for process to complete and output to be captured
time.Sleep(200 * time.Millisecond)
_, outputOut, err := s.processOutput(ctx, nil, ProcessOutputInput{ID: startOut.ID})
if err != nil {
t.Fatalf("processOutput failed: %v", err)
}
if !strings.Contains(outputOut.Output, "output_test") {
t.Errorf("Expected output to contain 'output_test', got %q", outputOut.Output)
}
}
// TestProcessOutput_Bad_EmptyID verifies empty ID returns error.
func TestProcessOutput_Bad_EmptyID(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
ctx := context.Background()
_, _, err := s.processOutput(ctx, nil, ProcessOutputInput{})
if err == nil {
t.Fatal("Expected error for empty ID")
}
if !strings.Contains(err.Error(), "id cannot be empty") {
t.Errorf("Unexpected error: %v", err)
}
}
// TestProcessOutput_Bad_NotFound verifies nonexistent ID returns error.
func TestProcessOutput_Bad_NotFound(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
ctx := context.Background()
_, _, err := s.processOutput(ctx, nil, ProcessOutputInput{ID: "nonexistent-id"})
if err == nil {
t.Fatal("Expected error for nonexistent ID")
}
}
// TestProcessStop_Good_LongRunning starts a sleep, stops it, and verifies.
func TestProcessStop_Good_LongRunning(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
ctx := context.Background()
// Start a process that sleeps for a while
_, startOut, err := s.processStart(ctx, nil, ProcessStartInput{
Command: "sleep",
Args: []string{"10"},
})
if err != nil {
t.Fatalf("processStart failed: %v", err)
}
// Verify it's running
time.Sleep(50 * time.Millisecond)
_, listOut, _ := s.processList(ctx, nil, ProcessListInput{RunningOnly: true})
if listOut.Total < 1 {
t.Fatal("Expected at least 1 running process")
}
// Stop it
_, stopOut, err := s.processStop(ctx, nil, ProcessStopInput{ID: startOut.ID})
if err != nil {
t.Fatalf("processStop failed: %v", err)
}
if !stopOut.Success {
t.Error("Expected stop to succeed")
}
if stopOut.ID != startOut.ID {
t.Errorf("Expected ID %q, got %q", startOut.ID, stopOut.ID)
}
}
// TestProcessStop_Bad_EmptyID verifies empty ID returns error.
func TestProcessStop_Bad_EmptyID(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
ctx := context.Background()
_, _, err := s.processStop(ctx, nil, ProcessStopInput{})
if err == nil {
t.Fatal("Expected error for empty ID")
}
}
// TestProcessStop_Bad_NotFound verifies nonexistent ID returns error.
func TestProcessStop_Bad_NotFound(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
ctx := context.Background()
_, _, err := s.processStop(ctx, nil, ProcessStopInput{ID: "nonexistent-id"})
if err == nil {
t.Fatal("Expected error for nonexistent ID")
}
}
// TestProcessKill_Good_LongRunning starts a sleep, kills it, and verifies.
func TestProcessKill_Good_LongRunning(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
ctx := context.Background()
_, startOut, err := s.processStart(ctx, nil, ProcessStartInput{
Command: "sleep",
Args: []string{"10"},
})
if err != nil {
t.Fatalf("processStart failed: %v", err)
}
time.Sleep(50 * time.Millisecond)
_, killOut, err := s.processKill(ctx, nil, ProcessKillInput{ID: startOut.ID})
if err != nil {
t.Fatalf("processKill failed: %v", err)
}
if !killOut.Success {
t.Error("Expected kill to succeed")
}
if killOut.Message != "Process killed" {
t.Errorf("Expected message 'Process killed', got %q", killOut.Message)
}
}
// TestProcessKill_Bad_EmptyID verifies empty ID returns error.
func TestProcessKill_Bad_EmptyID(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
ctx := context.Background()
_, _, err := s.processKill(ctx, nil, ProcessKillInput{})
if err == nil {
t.Fatal("Expected error for empty ID")
}
}
// TestProcessKill_Bad_NotFound verifies nonexistent ID returns error.
func TestProcessKill_Bad_NotFound(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
ctx := context.Background()
_, _, err := s.processKill(ctx, nil, ProcessKillInput{ID: "nonexistent-id"})
if err == nil {
t.Fatal("Expected error for nonexistent ID")
}
}
// TestProcessInput_Bad_EmptyID verifies empty ID returns error.
func TestProcessInput_Bad_EmptyID(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
ctx := context.Background()
_, _, err := s.processInput(ctx, nil, ProcessInputInput{})
if err == nil {
t.Fatal("Expected error for empty ID")
}
}
// TestProcessInput_Bad_EmptyInput verifies empty input string returns error.
func TestProcessInput_Bad_EmptyInput(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
ctx := context.Background()
_, _, err := s.processInput(ctx, nil, ProcessInputInput{ID: "some-id"})
if err == nil {
t.Fatal("Expected error for empty input")
}
}
// TestProcessInput_Bad_NotFound verifies nonexistent process ID returns error.
func TestProcessInput_Bad_NotFound(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
ctx := context.Background()
_, _, err := s.processInput(ctx, nil, ProcessInputInput{
ID: "nonexistent-id",
Input: "hello\n",
})
if err == nil {
t.Fatal("Expected error for nonexistent ID")
}
}
// TestProcessInput_Good_Cat sends input to cat and reads it back.
func TestProcessInput_Good_Cat(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
ctx := context.Background()
// Start cat which reads stdin and echoes to stdout
_, startOut, err := s.processStart(ctx, nil, ProcessStartInput{
Command: "cat",
})
if err != nil {
t.Fatalf("processStart failed: %v", err)
}
time.Sleep(50 * time.Millisecond)
// Send input
_, inputOut, err := s.processInput(ctx, nil, ProcessInputInput{
ID: startOut.ID,
Input: "stdin_test\n",
})
if err != nil {
t.Fatalf("processInput failed: %v", err)
}
if !inputOut.Success {
t.Error("Expected input to succeed")
}
// Wait for output capture
time.Sleep(100 * time.Millisecond)
// Read output
_, outputOut, err := s.processOutput(ctx, nil, ProcessOutputInput{ID: startOut.ID})
if err != nil {
t.Fatalf("processOutput failed: %v", err)
}
if !strings.Contains(outputOut.Output, "stdin_test") {
t.Errorf("Expected output to contain 'stdin_test', got %q", outputOut.Output)
}
// Kill the cat process (it's still running)
_, _, _ = s.processKill(ctx, nil, ProcessKillInput{ID: startOut.ID})
}
// TestProcessStart_Good_WithDir verifies working directory is respected.
func TestProcessStart_Good_WithDir(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
ctx := context.Background()
dir := t.TempDir()
_, startOut, err := s.processStart(ctx, nil, ProcessStartInput{
Command: "pwd",
Dir: dir,
})
if err != nil {
t.Fatalf("processStart failed: %v", err)
}
time.Sleep(200 * time.Millisecond)
_, outputOut, err := s.processOutput(ctx, nil, ProcessOutputInput{ID: startOut.ID})
if err != nil {
t.Fatalf("processOutput failed: %v", err)
}
if !strings.Contains(outputOut.Output, dir) {
t.Errorf("Expected output to contain dir %q, got %q", dir, outputOut.Output)
}
}
// TestProcessStart_Good_WithEnv verifies environment variables are passed.
func TestProcessStart_Good_WithEnv(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
ctx := context.Background()
_, startOut, err := s.processStart(ctx, nil, ProcessStartInput{
Command: "env",
Env: []string{"TEST_MCP_VAR=hello_from_test"},
})
if err != nil {
t.Fatalf("processStart failed: %v", err)
}
time.Sleep(200 * time.Millisecond)
_, outputOut, err := s.processOutput(ctx, nil, ProcessOutputInput{ID: startOut.ID})
if err != nil {
t.Fatalf("processOutput failed: %v", err)
}
if !strings.Contains(outputOut.Output, "TEST_MCP_VAR=hello_from_test") {
t.Errorf("Expected output to contain env var, got %q", outputOut.Output)
}
}
// TestProcessToolsRegistered_Good_WithService verifies tools are registered when service is provided.
func TestProcessToolsRegistered_Good_WithService(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
if s.processService == nil {
t.Error("Expected process service to be set")
}
}
// TestProcessFullLifecycle_Good tests the start → list → output → kill → list cycle.
func TestProcessFullLifecycle_Good(t *testing.T) {
s, _ := newTestMCPWithProcess(t)
ctx := context.Background()
// 1. Start
_, startOut, err := s.processStart(ctx, nil, ProcessStartInput{
Command: "sleep",
Args: []string{"10"},
})
if err != nil {
t.Fatalf("processStart failed: %v", err)
}
time.Sleep(50 * time.Millisecond)
// 2. List (should be running)
_, listOut, _ := s.processList(ctx, nil, ProcessListInput{RunningOnly: true})
if listOut.Total < 1 {
t.Fatal("Expected at least 1 running process")
}
// 3. Kill
_, killOut, err := s.processKill(ctx, nil, ProcessKillInput{ID: startOut.ID})
if err != nil {
t.Fatalf("processKill failed: %v", err)
}
if !killOut.Success {
t.Error("Expected kill to succeed")
}
// 4. Wait for exit
time.Sleep(100 * time.Millisecond)
// 5. Should not be running anymore
_, listOut, _ = s.processList(ctx, nil, ProcessListInput{RunningOnly: true})
for _, p := range listOut.Processes {
if p.ID == startOut.ID {
t.Errorf("Process %s should not be running after kill", startOut.ID)
}
}
}

View file

@ -1,290 +0,0 @@
package mcp
import (
"testing"
"time"
)
// TestProcessToolsRegistered_Good verifies that process tools are registered when process service is available.
func TestProcessToolsRegistered_Good(t *testing.T) {
// Create a new MCP service without process service - tools should not be registered
s, err := New()
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
if s.processService != nil {
t.Error("Process service should be nil by default")
}
if s.server == nil {
t.Fatal("Server should not be nil")
}
}
// TestProcessStartInput_Good verifies the ProcessStartInput struct has expected fields.
func TestProcessStartInput_Good(t *testing.T) {
input := ProcessStartInput{
Command: "echo",
Args: []string{"hello", "world"},
Dir: "/tmp",
Env: []string{"FOO=bar"},
}
if input.Command != "echo" {
t.Errorf("Expected command 'echo', got %q", input.Command)
}
if len(input.Args) != 2 {
t.Errorf("Expected 2 args, got %d", len(input.Args))
}
if input.Dir != "/tmp" {
t.Errorf("Expected dir '/tmp', got %q", input.Dir)
}
if len(input.Env) != 1 {
t.Errorf("Expected 1 env var, got %d", len(input.Env))
}
}
// TestProcessStartOutput_Good verifies the ProcessStartOutput struct has expected fields.
func TestProcessStartOutput_Good(t *testing.T) {
now := time.Now()
output := ProcessStartOutput{
ID: "proc-1",
PID: 12345,
Command: "echo",
Args: []string{"hello"},
StartedAt: now,
}
if output.ID != "proc-1" {
t.Errorf("Expected ID 'proc-1', got %q", output.ID)
}
if output.PID != 12345 {
t.Errorf("Expected PID 12345, got %d", output.PID)
}
if output.Command != "echo" {
t.Errorf("Expected command 'echo', got %q", output.Command)
}
if !output.StartedAt.Equal(now) {
t.Errorf("Expected StartedAt %v, got %v", now, output.StartedAt)
}
}
// TestProcessStopInput_Good verifies the ProcessStopInput struct has expected fields.
func TestProcessStopInput_Good(t *testing.T) {
input := ProcessStopInput{
ID: "proc-1",
}
if input.ID != "proc-1" {
t.Errorf("Expected ID 'proc-1', got %q", input.ID)
}
}
// TestProcessStopOutput_Good verifies the ProcessStopOutput struct has expected fields.
func TestProcessStopOutput_Good(t *testing.T) {
output := ProcessStopOutput{
ID: "proc-1",
Success: true,
Message: "Process stopped",
}
if output.ID != "proc-1" {
t.Errorf("Expected ID 'proc-1', got %q", output.ID)
}
if !output.Success {
t.Error("Expected Success to be true")
}
if output.Message != "Process stopped" {
t.Errorf("Expected message 'Process stopped', got %q", output.Message)
}
}
// TestProcessKillInput_Good verifies the ProcessKillInput struct has expected fields.
func TestProcessKillInput_Good(t *testing.T) {
input := ProcessKillInput{
ID: "proc-1",
}
if input.ID != "proc-1" {
t.Errorf("Expected ID 'proc-1', got %q", input.ID)
}
}
// TestProcessKillOutput_Good verifies the ProcessKillOutput struct has expected fields.
func TestProcessKillOutput_Good(t *testing.T) {
output := ProcessKillOutput{
ID: "proc-1",
Success: true,
Message: "Process killed",
}
if output.ID != "proc-1" {
t.Errorf("Expected ID 'proc-1', got %q", output.ID)
}
if !output.Success {
t.Error("Expected Success to be true")
}
}
// TestProcessListInput_Good verifies the ProcessListInput struct has expected fields.
func TestProcessListInput_Good(t *testing.T) {
input := ProcessListInput{
RunningOnly: true,
}
if !input.RunningOnly {
t.Error("Expected RunningOnly to be true")
}
}
// TestProcessListInput_Defaults verifies default values.
func TestProcessListInput_Defaults(t *testing.T) {
input := ProcessListInput{}
if input.RunningOnly {
t.Error("Expected RunningOnly to default to false")
}
}
// TestProcessListOutput_Good verifies the ProcessListOutput struct has expected fields.
func TestProcessListOutput_Good(t *testing.T) {
now := time.Now()
output := ProcessListOutput{
Processes: []ProcessInfo{
{
ID: "proc-1",
Command: "echo",
Args: []string{"hello"},
Dir: "/tmp",
Status: "running",
PID: 12345,
ExitCode: 0,
StartedAt: now,
Duration: 5 * time.Second,
},
},
Total: 1,
}
if len(output.Processes) != 1 {
t.Fatalf("Expected 1 process, got %d", len(output.Processes))
}
if output.Total != 1 {
t.Errorf("Expected total 1, got %d", output.Total)
}
proc := output.Processes[0]
if proc.ID != "proc-1" {
t.Errorf("Expected ID 'proc-1', got %q", proc.ID)
}
if proc.Status != "running" {
t.Errorf("Expected status 'running', got %q", proc.Status)
}
if proc.PID != 12345 {
t.Errorf("Expected PID 12345, got %d", proc.PID)
}
}
// TestProcessOutputInput_Good verifies the ProcessOutputInput struct has expected fields.
func TestProcessOutputInput_Good(t *testing.T) {
input := ProcessOutputInput{
ID: "proc-1",
}
if input.ID != "proc-1" {
t.Errorf("Expected ID 'proc-1', got %q", input.ID)
}
}
// TestProcessOutputOutput_Good verifies the ProcessOutputOutput struct has expected fields.
func TestProcessOutputOutput_Good(t *testing.T) {
output := ProcessOutputOutput{
ID: "proc-1",
Output: "hello world\n",
}
if output.ID != "proc-1" {
t.Errorf("Expected ID 'proc-1', got %q", output.ID)
}
if output.Output != "hello world\n" {
t.Errorf("Expected output 'hello world\\n', got %q", output.Output)
}
}
// TestProcessInputInput_Good verifies the ProcessInputInput struct has expected fields.
func TestProcessInputInput_Good(t *testing.T) {
input := ProcessInputInput{
ID: "proc-1",
Input: "test input\n",
}
if input.ID != "proc-1" {
t.Errorf("Expected ID 'proc-1', got %q", input.ID)
}
if input.Input != "test input\n" {
t.Errorf("Expected input 'test input\\n', got %q", input.Input)
}
}
// TestProcessInputOutput_Good verifies the ProcessInputOutput struct has expected fields.
func TestProcessInputOutput_Good(t *testing.T) {
output := ProcessInputOutput{
ID: "proc-1",
Success: true,
Message: "Input sent",
}
if output.ID != "proc-1" {
t.Errorf("Expected ID 'proc-1', got %q", output.ID)
}
if !output.Success {
t.Error("Expected Success to be true")
}
}
// TestProcessInfo_Good verifies the ProcessInfo struct has expected fields.
func TestProcessInfo_Good(t *testing.T) {
now := time.Now()
info := ProcessInfo{
ID: "proc-1",
Command: "echo",
Args: []string{"hello"},
Dir: "/tmp",
Status: "exited",
PID: 12345,
ExitCode: 0,
StartedAt: now,
Duration: 2 * time.Second,
}
if info.ID != "proc-1" {
t.Errorf("Expected ID 'proc-1', got %q", info.ID)
}
if info.Command != "echo" {
t.Errorf("Expected command 'echo', got %q", info.Command)
}
if info.Status != "exited" {
t.Errorf("Expected status 'exited', got %q", info.Status)
}
if info.ExitCode != 0 {
t.Errorf("Expected exit code 0, got %d", info.ExitCode)
}
if info.Duration != 2*time.Second {
t.Errorf("Expected duration 2s, got %v", info.Duration)
}
}
// TestWithProcessService_Good verifies the WithProcessService option.
func TestWithProcessService_Good(t *testing.T) {
// Note: We can't easily create a real process.Service here without Core,
// so we just verify the option doesn't panic with nil.
s, err := New(WithProcessService(nil))
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
if s.processService != nil {
t.Error("Expected processService to be nil when passed nil")
}
}

View file

@ -1,233 +0,0 @@
package mcp
import (
"context"
"errors"
"fmt"
"forge.lthn.ai/core/go-rag"
"forge.lthn.ai/core/go-log"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// Default values for RAG operations.
const (
DefaultRAGCollection = "hostuk-docs"
DefaultRAGTopK = 5
)
// RAGQueryInput contains parameters for querying the RAG vector database.
type RAGQueryInput struct {
Question string `json:"question"` // The question or search query
Collection string `json:"collection,omitempty"` // Collection name (default: hostuk-docs)
TopK int `json:"topK,omitempty"` // Number of results to return (default: 5)
}
// RAGQueryResult represents a single query result.
type RAGQueryResult struct {
Content string `json:"content"`
Source string `json:"source"`
Section string `json:"section,omitempty"`
Category string `json:"category,omitempty"`
ChunkIndex int `json:"chunkIndex,omitempty"`
Score float32 `json:"score"`
}
// RAGQueryOutput contains the results of a RAG query.
type RAGQueryOutput struct {
Results []RAGQueryResult `json:"results"`
Query string `json:"query"`
Collection string `json:"collection"`
Context string `json:"context"`
}
// RAGIngestInput contains parameters for ingesting documents into the RAG database.
type RAGIngestInput struct {
Path string `json:"path"` // File or directory path to ingest
Collection string `json:"collection,omitempty"` // Collection name (default: hostuk-docs)
Recreate bool `json:"recreate,omitempty"` // Whether to recreate the collection
}
// RAGIngestOutput contains the result of a RAG ingest operation.
type RAGIngestOutput struct {
Success bool `json:"success"`
Path string `json:"path"`
Collection string `json:"collection"`
Chunks int `json:"chunks"`
Message string `json:"message,omitempty"`
}
// RAGCollectionsInput contains parameters for listing collections.
type RAGCollectionsInput struct {
ShowStats bool `json:"show_stats,omitempty"` // Include collection stats (point count, status)
}
// CollectionInfo contains information about a collection.
type CollectionInfo struct {
Name string `json:"name"`
PointsCount uint64 `json:"points_count"`
Status string `json:"status"`
}
// RAGCollectionsOutput contains the list of available collections.
type RAGCollectionsOutput struct {
Collections []CollectionInfo `json:"collections"`
}
// registerRAGTools adds RAG tools to the MCP server.
func (s *Service) registerRAGTools(server *mcp.Server) {
mcp.AddTool(server, &mcp.Tool{
Name: "rag_query",
Description: "Query the RAG vector database for relevant documentation. Returns semantically similar content based on the query.",
}, s.ragQuery)
mcp.AddTool(server, &mcp.Tool{
Name: "rag_ingest",
Description: "Ingest documents into the RAG vector database. Supports both single files and directories.",
}, s.ragIngest)
mcp.AddTool(server, &mcp.Tool{
Name: "rag_collections",
Description: "List all available collections in the RAG vector database.",
}, s.ragCollections)
}
// ragQuery handles the rag_query tool call.
func (s *Service) ragQuery(ctx context.Context, req *mcp.CallToolRequest, input RAGQueryInput) (*mcp.CallToolResult, RAGQueryOutput, error) {
// Apply defaults
collection := input.Collection
if collection == "" {
collection = DefaultRAGCollection
}
topK := input.TopK
if topK <= 0 {
topK = DefaultRAGTopK
}
s.logger.Info("MCP tool execution", "tool", "rag_query", "question", input.Question, "collection", collection, "topK", topK, "user", log.Username())
// Validate input
if input.Question == "" {
return nil, RAGQueryOutput{}, errors.New("question cannot be empty")
}
// Call the RAG query function
results, err := rag.QueryDocs(ctx, input.Question, collection, topK)
if err != nil {
log.Error("mcp: rag query failed", "question", input.Question, "collection", collection, "err", err)
return nil, RAGQueryOutput{}, fmt.Errorf("failed to query RAG: %w", err)
}
// Convert results
output := RAGQueryOutput{
Results: make([]RAGQueryResult, len(results)),
Query: input.Question,
Collection: collection,
Context: rag.FormatResultsContext(results),
}
for i, r := range results {
output.Results[i] = RAGQueryResult{
Content: r.Text,
Source: r.Source,
Section: r.Section,
Category: r.Category,
ChunkIndex: r.ChunkIndex,
Score: r.Score,
}
}
return nil, output, nil
}
// ragIngest handles the rag_ingest tool call.
func (s *Service) ragIngest(ctx context.Context, req *mcp.CallToolRequest, input RAGIngestInput) (*mcp.CallToolResult, RAGIngestOutput, error) {
// Apply defaults
collection := input.Collection
if collection == "" {
collection = DefaultRAGCollection
}
s.logger.Security("MCP tool execution", "tool", "rag_ingest", "path", input.Path, "collection", collection, "recreate", input.Recreate, "user", log.Username())
// Validate input
if input.Path == "" {
return nil, RAGIngestOutput{}, errors.New("path cannot be empty")
}
// Check if path is a file or directory using the medium
info, err := s.medium.Stat(input.Path)
if err != nil {
log.Error("mcp: rag ingest stat failed", "path", input.Path, "err", err)
return nil, RAGIngestOutput{}, fmt.Errorf("failed to access path: %w", err)
}
var message string
var chunks int
if info.IsDir() {
// Ingest directory
err = rag.IngestDirectory(ctx, input.Path, collection, input.Recreate)
if err != nil {
log.Error("mcp: rag ingest directory failed", "path", input.Path, "collection", collection, "err", err)
return nil, RAGIngestOutput{}, fmt.Errorf("failed to ingest directory: %w", err)
}
message = fmt.Sprintf("Successfully ingested directory %s into collection %s", input.Path, collection)
} else {
// Ingest single file
chunks, err = rag.IngestSingleFile(ctx, input.Path, collection)
if err != nil {
log.Error("mcp: rag ingest file failed", "path", input.Path, "collection", collection, "err", err)
return nil, RAGIngestOutput{}, fmt.Errorf("failed to ingest file: %w", err)
}
message = fmt.Sprintf("Successfully ingested file %s (%d chunks) into collection %s", input.Path, chunks, collection)
}
return nil, RAGIngestOutput{
Success: true,
Path: input.Path,
Collection: collection,
Chunks: chunks,
Message: message,
}, nil
}
// ragCollections handles the rag_collections tool call.
func (s *Service) ragCollections(ctx context.Context, req *mcp.CallToolRequest, input RAGCollectionsInput) (*mcp.CallToolResult, RAGCollectionsOutput, error) {
s.logger.Info("MCP tool execution", "tool", "rag_collections", "show_stats", input.ShowStats, "user", log.Username())
// Create Qdrant client with default config
qdrantClient, err := rag.NewQdrantClient(rag.DefaultQdrantConfig())
if err != nil {
log.Error("mcp: rag collections connect failed", "err", err)
return nil, RAGCollectionsOutput{}, fmt.Errorf("failed to connect to Qdrant: %w", err)
}
defer func() { _ = qdrantClient.Close() }()
// List collections
collectionNames, err := qdrantClient.ListCollections(ctx)
if err != nil {
log.Error("mcp: rag collections list failed", "err", err)
return nil, RAGCollectionsOutput{}, fmt.Errorf("failed to list collections: %w", err)
}
// Build collection info list
collections := make([]CollectionInfo, len(collectionNames))
for i, name := range collectionNames {
collections[i] = CollectionInfo{Name: name}
// Fetch stats if requested
if input.ShowStats {
info, err := qdrantClient.CollectionInfo(ctx, name)
if err != nil {
log.Error("mcp: rag collection info failed", "collection", name, "err", err)
// Continue with defaults on error
continue
}
collections[i].PointsCount = info.PointCount
collections[i].Status = info.Status
}
}
return nil, RAGCollectionsOutput{
Collections: collections,
}, nil
}

View file

@ -1,181 +0,0 @@
package mcp
import (
"context"
"strings"
"testing"
)
// RAG tools use package-level functions (rag.QueryDocs, rag.IngestDirectory, etc.)
// which require live Qdrant + Ollama services. Since those are not injectable,
// we test handler input validation, default application, and struct behaviour
// at the MCP handler level without requiring live services.
// --- ragQuery validation ---
// TestRagQuery_Bad_EmptyQuestion verifies empty question returns error.
func TestRagQuery_Bad_EmptyQuestion(t *testing.T) {
s, err := New()
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
ctx := context.Background()
_, _, err = s.ragQuery(ctx, nil, RAGQueryInput{})
if err == nil {
t.Fatal("Expected error for empty question")
}
if !strings.Contains(err.Error(), "question cannot be empty") {
t.Errorf("Unexpected error: %v", err)
}
}
// TestRagQuery_Good_DefaultsApplied verifies defaults are applied before validation.
// Because the handler applies defaults then validates, a non-empty question with
// zero Collection/TopK should have defaults applied. We cannot verify the actual
// query (needs live Qdrant), but we can verify it gets past validation.
func TestRagQuery_Good_DefaultsApplied(t *testing.T) {
s, err := New()
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
ctx := context.Background()
// This will fail when it tries to connect to Qdrant, but AFTER applying defaults.
// The error should NOT be about empty question.
_, _, err = s.ragQuery(ctx, nil, RAGQueryInput{Question: "test query"})
if err == nil {
t.Skip("RAG query succeeded — live Qdrant available, skip default test")
}
// The error should be about connection failure, not validation
if strings.Contains(err.Error(), "question cannot be empty") {
t.Error("Defaults should have been applied before validation check")
}
}
// --- ragIngest validation ---
// TestRagIngest_Bad_EmptyPath verifies empty path returns error.
func TestRagIngest_Bad_EmptyPath(t *testing.T) {
s, err := New()
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
ctx := context.Background()
_, _, err = s.ragIngest(ctx, nil, RAGIngestInput{})
if err == nil {
t.Fatal("Expected error for empty path")
}
if !strings.Contains(err.Error(), "path cannot be empty") {
t.Errorf("Unexpected error: %v", err)
}
}
// TestRagIngest_Bad_NonexistentPath verifies nonexistent path returns error.
func TestRagIngest_Bad_NonexistentPath(t *testing.T) {
s, err := New()
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
ctx := context.Background()
_, _, err = s.ragIngest(ctx, nil, RAGIngestInput{
Path: "/nonexistent/path/that/does/not/exist/at/all",
})
if err == nil {
t.Fatal("Expected error for nonexistent path")
}
}
// TestRagIngest_Good_DefaultCollection verifies the default collection is applied.
func TestRagIngest_Good_DefaultCollection(t *testing.T) {
s, err := New()
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
ctx := context.Background()
// Use a real but inaccessible path to trigger stat error (not validation error).
// The collection default should be applied first.
_, _, err = s.ragIngest(ctx, nil, RAGIngestInput{
Path: "/nonexistent/path/for/default/test",
})
if err == nil {
t.Skip("Ingest succeeded unexpectedly")
}
// The error should NOT be about empty path
if strings.Contains(err.Error(), "path cannot be empty") {
t.Error("Default collection should have been applied")
}
}
// --- ragCollections validation ---
// TestRagCollections_Bad_NoQdrant verifies graceful error when Qdrant is not available.
func TestRagCollections_Bad_NoQdrant(t *testing.T) {
s, err := New()
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
ctx := context.Background()
_, _, err = s.ragCollections(ctx, nil, RAGCollectionsInput{})
if err == nil {
t.Skip("Qdrant is available — skip connection error test")
}
// Should get a connection error, not a panic
if !strings.Contains(err.Error(), "failed to connect") && !strings.Contains(err.Error(), "failed to list") {
t.Logf("Got error (expected connection failure): %v", err)
}
}
// --- Struct round-trip tests ---
// TestRAGQueryResult_Good_AllFields verifies all fields can be set and read.
func TestRAGQueryResult_Good_AllFields(t *testing.T) {
r := RAGQueryResult{
Content: "test content",
Source: "source.md",
Section: "Overview",
Category: "docs",
ChunkIndex: 3,
Score: 0.88,
}
if r.Content != "test content" {
t.Errorf("Expected content 'test content', got %q", r.Content)
}
if r.ChunkIndex != 3 {
t.Errorf("Expected chunkIndex 3, got %d", r.ChunkIndex)
}
if r.Score != 0.88 {
t.Errorf("Expected score 0.88, got %f", r.Score)
}
}
// TestCollectionInfo_Good_AllFields verifies CollectionInfo field access.
func TestCollectionInfo_Good_AllFields(t *testing.T) {
c := CollectionInfo{
Name: "test-collection",
PointsCount: 12345,
Status: "green",
}
if c.Name != "test-collection" {
t.Errorf("Expected name 'test-collection', got %q", c.Name)
}
if c.PointsCount != 12345 {
t.Errorf("Expected PointsCount 12345, got %d", c.PointsCount)
}
}
// TestRAGDefaults_Good verifies default constants are sensible.
func TestRAGDefaults_Good(t *testing.T) {
if DefaultRAGCollection != "hostuk-docs" {
t.Errorf("Expected default collection 'hostuk-docs', got %q", DefaultRAGCollection)
}
if DefaultRAGTopK != 5 {
t.Errorf("Expected default topK 5, got %d", DefaultRAGTopK)
}
}

View file

@ -1,173 +0,0 @@
package mcp
import (
"testing"
)
// TestRAGToolsRegistered_Good verifies that RAG tools are registered with the MCP server.
func TestRAGToolsRegistered_Good(t *testing.T) {
// Create a new MCP service - this should register all tools including RAG
s, err := New()
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
// The server should have registered the RAG tools
// We verify by checking that the tool handlers exist on the service
// (The actual MCP registration is tested by the SDK)
if s.server == nil {
t.Fatal("Server should not be nil")
}
// Verify the service was created with expected defaults
if s.logger == nil {
t.Error("Logger should not be nil")
}
}
// TestRAGQueryInput_Good verifies the RAGQueryInput struct has expected fields.
func TestRAGQueryInput_Good(t *testing.T) {
input := RAGQueryInput{
Question: "test question",
Collection: "test-collection",
TopK: 10,
}
if input.Question != "test question" {
t.Errorf("Expected question 'test question', got %q", input.Question)
}
if input.Collection != "test-collection" {
t.Errorf("Expected collection 'test-collection', got %q", input.Collection)
}
if input.TopK != 10 {
t.Errorf("Expected topK 10, got %d", input.TopK)
}
}
// TestRAGQueryInput_Defaults verifies default values are handled correctly.
func TestRAGQueryInput_Defaults(t *testing.T) {
// Empty input should use defaults when processed
input := RAGQueryInput{
Question: "test",
}
// Defaults should be applied in the handler, not in the struct
if input.Collection != "" {
t.Errorf("Expected empty collection before defaults, got %q", input.Collection)
}
if input.TopK != 0 {
t.Errorf("Expected zero topK before defaults, got %d", input.TopK)
}
}
// TestRAGIngestInput_Good verifies the RAGIngestInput struct has expected fields.
func TestRAGIngestInput_Good(t *testing.T) {
input := RAGIngestInput{
Path: "/path/to/docs",
Collection: "my-collection",
Recreate: true,
}
if input.Path != "/path/to/docs" {
t.Errorf("Expected path '/path/to/docs', got %q", input.Path)
}
if input.Collection != "my-collection" {
t.Errorf("Expected collection 'my-collection', got %q", input.Collection)
}
if !input.Recreate {
t.Error("Expected recreate to be true")
}
}
// TestRAGCollectionsInput_Good verifies the RAGCollectionsInput struct exists.
func TestRAGCollectionsInput_Good(t *testing.T) {
// RAGCollectionsInput has optional ShowStats parameter
input := RAGCollectionsInput{}
if input.ShowStats {
t.Error("Expected ShowStats to default to false")
}
}
// TestRAGQueryOutput_Good verifies the RAGQueryOutput struct has expected fields.
func TestRAGQueryOutput_Good(t *testing.T) {
output := RAGQueryOutput{
Results: []RAGQueryResult{
{
Content: "some content",
Source: "doc.md",
Section: "Introduction",
Category: "docs",
Score: 0.95,
},
},
Query: "test query",
Collection: "test-collection",
Context: "<retrieved_context>...</retrieved_context>",
}
if len(output.Results) != 1 {
t.Fatalf("Expected 1 result, got %d", len(output.Results))
}
if output.Results[0].Content != "some content" {
t.Errorf("Expected content 'some content', got %q", output.Results[0].Content)
}
if output.Results[0].Score != 0.95 {
t.Errorf("Expected score 0.95, got %f", output.Results[0].Score)
}
if output.Context == "" {
t.Error("Expected context to be set")
}
}
// TestRAGIngestOutput_Good verifies the RAGIngestOutput struct has expected fields.
func TestRAGIngestOutput_Good(t *testing.T) {
output := RAGIngestOutput{
Success: true,
Path: "/path/to/docs",
Collection: "my-collection",
Chunks: 10,
Message: "Ingested successfully",
}
if !output.Success {
t.Error("Expected success to be true")
}
if output.Path != "/path/to/docs" {
t.Errorf("Expected path '/path/to/docs', got %q", output.Path)
}
if output.Chunks != 10 {
t.Errorf("Expected chunks 10, got %d", output.Chunks)
}
}
// TestRAGCollectionsOutput_Good verifies the RAGCollectionsOutput struct has expected fields.
func TestRAGCollectionsOutput_Good(t *testing.T) {
output := RAGCollectionsOutput{
Collections: []CollectionInfo{
{Name: "collection1", PointsCount: 100, Status: "green"},
{Name: "collection2", PointsCount: 200, Status: "green"},
},
}
if len(output.Collections) != 2 {
t.Fatalf("Expected 2 collections, got %d", len(output.Collections))
}
if output.Collections[0].Name != "collection1" {
t.Errorf("Expected 'collection1', got %q", output.Collections[0].Name)
}
if output.Collections[0].PointsCount != 100 {
t.Errorf("Expected PointsCount 100, got %d", output.Collections[0].PointsCount)
}
}
// TestRAGCollectionsInput_Good verifies the RAGCollectionsInput struct has expected fields.
func TestRAGCollectionsInput_ShowStats(t *testing.T) {
input := RAGCollectionsInput{
ShowStats: true,
}
if !input.ShowStats {
t.Error("Expected ShowStats to be true")
}
}

View file

@ -1,497 +0,0 @@
package mcp
import (
"context"
"encoding/base64"
"errors"
"fmt"
"time"
"forge.lthn.ai/core/go-log"
"forge.lthn.ai/core/go-webview"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// webviewInstance holds the current webview connection.
// This is managed by the MCP service.
var webviewInstance *webview.Webview
// Sentinel errors for webview tools.
var (
errNotConnected = errors.New("not connected; use webview_connect first")
errSelectorRequired = errors.New("selector is required")
)
// WebviewConnectInput contains parameters for connecting to Chrome DevTools.
type WebviewConnectInput struct {
DebugURL string `json:"debug_url"` // Chrome DevTools URL (e.g., http://localhost:9222)
Timeout int `json:"timeout,omitempty"` // Default timeout in seconds (default: 30)
}
// WebviewConnectOutput contains the result of connecting to Chrome.
type WebviewConnectOutput struct {
Success bool `json:"success"`
Message string `json:"message,omitempty"`
}
// WebviewNavigateInput contains parameters for navigating to a URL.
type WebviewNavigateInput struct {
URL string `json:"url"` // URL to navigate to
}
// WebviewNavigateOutput contains the result of navigation.
type WebviewNavigateOutput struct {
Success bool `json:"success"`
URL string `json:"url"`
}
// WebviewClickInput contains parameters for clicking an element.
type WebviewClickInput struct {
Selector string `json:"selector"` // CSS selector
}
// WebviewClickOutput contains the result of a click action.
type WebviewClickOutput struct {
Success bool `json:"success"`
}
// WebviewTypeInput contains parameters for typing text.
type WebviewTypeInput struct {
Selector string `json:"selector"` // CSS selector
Text string `json:"text"` // Text to type
}
// WebviewTypeOutput contains the result of a type action.
type WebviewTypeOutput struct {
Success bool `json:"success"`
}
// WebviewQueryInput contains parameters for querying an element.
type WebviewQueryInput struct {
Selector string `json:"selector"` // CSS selector
All bool `json:"all,omitempty"` // If true, return all matching elements
}
// WebviewQueryOutput contains the result of a query.
type WebviewQueryOutput struct {
Found bool `json:"found"`
Count int `json:"count"`
Elements []WebviewElementInfo `json:"elements,omitempty"`
}
// WebviewElementInfo represents information about a DOM element.
type WebviewElementInfo struct {
NodeID int `json:"nodeId"`
TagName string `json:"tagName"`
Attributes map[string]string `json:"attributes,omitempty"`
BoundingBox *webview.BoundingBox `json:"boundingBox,omitempty"`
}
// WebviewConsoleInput contains parameters for getting console output.
type WebviewConsoleInput struct {
Clear bool `json:"clear,omitempty"` // If true, clear console after getting messages
}
// WebviewConsoleOutput contains console messages.
type WebviewConsoleOutput struct {
Messages []WebviewConsoleMessage `json:"messages"`
Count int `json:"count"`
}
// WebviewConsoleMessage represents a console message.
type WebviewConsoleMessage struct {
Type string `json:"type"`
Text string `json:"text"`
Timestamp string `json:"timestamp"`
URL string `json:"url,omitempty"`
Line int `json:"line,omitempty"`
}
// WebviewEvalInput contains parameters for evaluating JavaScript.
type WebviewEvalInput struct {
Script string `json:"script"` // JavaScript to evaluate
}
// WebviewEvalOutput contains the result of JavaScript evaluation.
type WebviewEvalOutput struct {
Success bool `json:"success"`
Result any `json:"result,omitempty"`
Error string `json:"error,omitempty"`
}
// WebviewScreenshotInput contains parameters for taking a screenshot.
type WebviewScreenshotInput struct {
Format string `json:"format,omitempty"` // "png" or "jpeg" (default: png)
}
// WebviewScreenshotOutput contains the screenshot data.
type WebviewScreenshotOutput struct {
Success bool `json:"success"`
Data string `json:"data"` // Base64 encoded image
Format string `json:"format"`
}
// WebviewWaitInput contains parameters for waiting operations.
type WebviewWaitInput struct {
Selector string `json:"selector,omitempty"` // Wait for selector
Timeout int `json:"timeout,omitempty"` // Timeout in seconds
}
// WebviewWaitOutput contains the result of waiting.
type WebviewWaitOutput struct {
Success bool `json:"success"`
Message string `json:"message,omitempty"`
}
// WebviewDisconnectInput contains parameters for disconnecting.
type WebviewDisconnectInput struct{}
// WebviewDisconnectOutput contains the result of disconnecting.
type WebviewDisconnectOutput struct {
Success bool `json:"success"`
Message string `json:"message,omitempty"`
}
// registerWebviewTools adds webview tools to the MCP server.
func (s *Service) registerWebviewTools(server *mcp.Server) {
mcp.AddTool(server, &mcp.Tool{
Name: "webview_connect",
Description: "Connect to Chrome DevTools Protocol. Start Chrome with --remote-debugging-port=9222 first.",
}, s.webviewConnect)
mcp.AddTool(server, &mcp.Tool{
Name: "webview_disconnect",
Description: "Disconnect from Chrome DevTools.",
}, s.webviewDisconnect)
mcp.AddTool(server, &mcp.Tool{
Name: "webview_navigate",
Description: "Navigate the browser to a URL.",
}, s.webviewNavigate)
mcp.AddTool(server, &mcp.Tool{
Name: "webview_click",
Description: "Click on an element by CSS selector.",
}, s.webviewClick)
mcp.AddTool(server, &mcp.Tool{
Name: "webview_type",
Description: "Type text into an element by CSS selector.",
}, s.webviewType)
mcp.AddTool(server, &mcp.Tool{
Name: "webview_query",
Description: "Query DOM elements by CSS selector.",
}, s.webviewQuery)
mcp.AddTool(server, &mcp.Tool{
Name: "webview_console",
Description: "Get browser console output.",
}, s.webviewConsole)
mcp.AddTool(server, &mcp.Tool{
Name: "webview_eval",
Description: "Evaluate JavaScript in the browser context.",
}, s.webviewEval)
mcp.AddTool(server, &mcp.Tool{
Name: "webview_screenshot",
Description: "Capture a screenshot of the browser window.",
}, s.webviewScreenshot)
mcp.AddTool(server, &mcp.Tool{
Name: "webview_wait",
Description: "Wait for an element to appear by CSS selector.",
}, s.webviewWait)
}
// webviewConnect handles the webview_connect tool call.
func (s *Service) webviewConnect(ctx context.Context, req *mcp.CallToolRequest, input WebviewConnectInput) (*mcp.CallToolResult, WebviewConnectOutput, error) {
s.logger.Security("MCP tool execution", "tool", "webview_connect", "debug_url", input.DebugURL, "user", log.Username())
if input.DebugURL == "" {
return nil, WebviewConnectOutput{}, errors.New("debug_url is required")
}
// Close existing connection if any
if webviewInstance != nil {
_ = webviewInstance.Close()
webviewInstance = nil
}
// Set up options
opts := []webview.Option{
webview.WithDebugURL(input.DebugURL),
}
if input.Timeout > 0 {
opts = append(opts, webview.WithTimeout(time.Duration(input.Timeout)*time.Second))
}
// Create new webview instance
wv, err := webview.New(opts...)
if err != nil {
log.Error("mcp: webview connect failed", "debug_url", input.DebugURL, "err", err)
return nil, WebviewConnectOutput{}, fmt.Errorf("failed to connect: %w", err)
}
webviewInstance = wv
return nil, WebviewConnectOutput{
Success: true,
Message: fmt.Sprintf("Connected to Chrome DevTools at %s", input.DebugURL),
}, nil
}
// webviewDisconnect handles the webview_disconnect tool call.
func (s *Service) webviewDisconnect(ctx context.Context, req *mcp.CallToolRequest, input WebviewDisconnectInput) (*mcp.CallToolResult, WebviewDisconnectOutput, error) {
s.logger.Info("MCP tool execution", "tool", "webview_disconnect", "user", log.Username())
if webviewInstance == nil {
return nil, WebviewDisconnectOutput{
Success: true,
Message: "No active connection",
}, nil
}
if err := webviewInstance.Close(); err != nil {
log.Error("mcp: webview disconnect failed", "err", err)
return nil, WebviewDisconnectOutput{}, fmt.Errorf("failed to disconnect: %w", err)
}
webviewInstance = nil
return nil, WebviewDisconnectOutput{
Success: true,
Message: "Disconnected from Chrome DevTools",
}, nil
}
// webviewNavigate handles the webview_navigate tool call.
func (s *Service) webviewNavigate(ctx context.Context, req *mcp.CallToolRequest, input WebviewNavigateInput) (*mcp.CallToolResult, WebviewNavigateOutput, error) {
s.logger.Info("MCP tool execution", "tool", "webview_navigate", "url", input.URL, "user", log.Username())
if webviewInstance == nil {
return nil, WebviewNavigateOutput{}, errNotConnected
}
if input.URL == "" {
return nil, WebviewNavigateOutput{}, errors.New("url is required")
}
if err := webviewInstance.Navigate(input.URL); err != nil {
log.Error("mcp: webview navigate failed", "url", input.URL, "err", err)
return nil, WebviewNavigateOutput{}, fmt.Errorf("failed to navigate: %w", err)
}
return nil, WebviewNavigateOutput{
Success: true,
URL: input.URL,
}, nil
}
// webviewClick handles the webview_click tool call.
func (s *Service) webviewClick(ctx context.Context, req *mcp.CallToolRequest, input WebviewClickInput) (*mcp.CallToolResult, WebviewClickOutput, error) {
s.logger.Info("MCP tool execution", "tool", "webview_click", "selector", input.Selector, "user", log.Username())
if webviewInstance == nil {
return nil, WebviewClickOutput{}, errNotConnected
}
if input.Selector == "" {
return nil, WebviewClickOutput{}, errSelectorRequired
}
if err := webviewInstance.Click(input.Selector); err != nil {
log.Error("mcp: webview click failed", "selector", input.Selector, "err", err)
return nil, WebviewClickOutput{}, fmt.Errorf("failed to click: %w", err)
}
return nil, WebviewClickOutput{Success: true}, nil
}
// webviewType handles the webview_type tool call.
func (s *Service) webviewType(ctx context.Context, req *mcp.CallToolRequest, input WebviewTypeInput) (*mcp.CallToolResult, WebviewTypeOutput, error) {
s.logger.Info("MCP tool execution", "tool", "webview_type", "selector", input.Selector, "user", log.Username())
if webviewInstance == nil {
return nil, WebviewTypeOutput{}, errNotConnected
}
if input.Selector == "" {
return nil, WebviewTypeOutput{}, errSelectorRequired
}
if err := webviewInstance.Type(input.Selector, input.Text); err != nil {
log.Error("mcp: webview type failed", "selector", input.Selector, "err", err)
return nil, WebviewTypeOutput{}, fmt.Errorf("failed to type: %w", err)
}
return nil, WebviewTypeOutput{Success: true}, nil
}
// webviewQuery handles the webview_query tool call.
func (s *Service) webviewQuery(ctx context.Context, req *mcp.CallToolRequest, input WebviewQueryInput) (*mcp.CallToolResult, WebviewQueryOutput, error) {
s.logger.Info("MCP tool execution", "tool", "webview_query", "selector", input.Selector, "all", input.All, "user", log.Username())
if webviewInstance == nil {
return nil, WebviewQueryOutput{}, errNotConnected
}
if input.Selector == "" {
return nil, WebviewQueryOutput{}, errSelectorRequired
}
if input.All {
elements, err := webviewInstance.QuerySelectorAll(input.Selector)
if err != nil {
log.Error("mcp: webview query all failed", "selector", input.Selector, "err", err)
return nil, WebviewQueryOutput{}, fmt.Errorf("failed to query: %w", err)
}
output := WebviewQueryOutput{
Found: len(elements) > 0,
Count: len(elements),
Elements: make([]WebviewElementInfo, len(elements)),
}
for i, elem := range elements {
output.Elements[i] = WebviewElementInfo{
NodeID: elem.NodeID,
TagName: elem.TagName,
Attributes: elem.Attributes,
BoundingBox: elem.BoundingBox,
}
}
return nil, output, nil
}
elem, err := webviewInstance.QuerySelector(input.Selector)
if err != nil {
// Element not found is not necessarily an error
return nil, WebviewQueryOutput{
Found: false,
Count: 0,
}, nil
}
return nil, WebviewQueryOutput{
Found: true,
Count: 1,
Elements: []WebviewElementInfo{{
NodeID: elem.NodeID,
TagName: elem.TagName,
Attributes: elem.Attributes,
BoundingBox: elem.BoundingBox,
}},
}, nil
}
// webviewConsole handles the webview_console tool call.
func (s *Service) webviewConsole(ctx context.Context, req *mcp.CallToolRequest, input WebviewConsoleInput) (*mcp.CallToolResult, WebviewConsoleOutput, error) {
s.logger.Info("MCP tool execution", "tool", "webview_console", "clear", input.Clear, "user", log.Username())
if webviewInstance == nil {
return nil, WebviewConsoleOutput{}, errNotConnected
}
messages := webviewInstance.GetConsole()
output := WebviewConsoleOutput{
Messages: make([]WebviewConsoleMessage, len(messages)),
Count: len(messages),
}
for i, msg := range messages {
output.Messages[i] = WebviewConsoleMessage{
Type: msg.Type,
Text: msg.Text,
Timestamp: msg.Timestamp.Format(time.RFC3339),
URL: msg.URL,
Line: msg.Line,
}
}
if input.Clear {
webviewInstance.ClearConsole()
}
return nil, output, nil
}
// webviewEval handles the webview_eval tool call.
func (s *Service) webviewEval(ctx context.Context, req *mcp.CallToolRequest, input WebviewEvalInput) (*mcp.CallToolResult, WebviewEvalOutput, error) {
s.logger.Security("MCP tool execution", "tool", "webview_eval", "user", log.Username())
if webviewInstance == nil {
return nil, WebviewEvalOutput{}, errNotConnected
}
if input.Script == "" {
return nil, WebviewEvalOutput{}, errors.New("script is required")
}
result, err := webviewInstance.Evaluate(input.Script)
if err != nil {
log.Error("mcp: webview eval failed", "err", err)
return nil, WebviewEvalOutput{
Success: false,
Error: err.Error(),
}, nil
}
return nil, WebviewEvalOutput{
Success: true,
Result: result,
}, nil
}
// webviewScreenshot handles the webview_screenshot tool call.
func (s *Service) webviewScreenshot(ctx context.Context, req *mcp.CallToolRequest, input WebviewScreenshotInput) (*mcp.CallToolResult, WebviewScreenshotOutput, error) {
s.logger.Info("MCP tool execution", "tool", "webview_screenshot", "format", input.Format, "user", log.Username())
if webviewInstance == nil {
return nil, WebviewScreenshotOutput{}, errNotConnected
}
format := input.Format
if format == "" {
format = "png"
}
data, err := webviewInstance.Screenshot()
if err != nil {
log.Error("mcp: webview screenshot failed", "err", err)
return nil, WebviewScreenshotOutput{}, fmt.Errorf("failed to capture screenshot: %w", err)
}
return nil, WebviewScreenshotOutput{
Success: true,
Data: base64.StdEncoding.EncodeToString(data),
Format: format,
}, nil
}
// webviewWait handles the webview_wait tool call.
func (s *Service) webviewWait(ctx context.Context, req *mcp.CallToolRequest, input WebviewWaitInput) (*mcp.CallToolResult, WebviewWaitOutput, error) {
s.logger.Info("MCP tool execution", "tool", "webview_wait", "selector", input.Selector, "timeout", input.Timeout, "user", log.Username())
if webviewInstance == nil {
return nil, WebviewWaitOutput{}, errNotConnected
}
if input.Selector == "" {
return nil, WebviewWaitOutput{}, errSelectorRequired
}
if err := webviewInstance.WaitForSelector(input.Selector); err != nil {
log.Error("mcp: webview wait failed", "selector", input.Selector, "err", err)
return nil, WebviewWaitOutput{}, fmt.Errorf("failed to wait for selector: %w", err)
}
return nil, WebviewWaitOutput{
Success: true,
Message: fmt.Sprintf("Element found: %s", input.Selector),
}, nil
}

View file

@ -1,452 +0,0 @@
package mcp
import (
"testing"
"time"
"forge.lthn.ai/core/go-webview"
)
// skipIfShort skips webview tests in short mode (go test -short).
// Webview tool handlers require a running Chrome instance with
// --remote-debugging-port, which is not available in CI.
// Struct-level tests below are safe without Chrome, but any future
// tests that call webview tool handlers MUST use this guard.
func skipIfShort(t *testing.T) {
t.Helper()
if testing.Short() {
t.Skip("webview tests skipped in short mode (no Chrome available)")
}
}
// TestWebviewToolsRegistered_Good verifies that webview tools are registered with the MCP server.
func TestWebviewToolsRegistered_Good(t *testing.T) {
// Create a new MCP service - this should register all tools including webview
s, err := New()
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
// The server should have registered the webview tools
if s.server == nil {
t.Fatal("Server should not be nil")
}
// Verify the service was created with expected defaults
if s.logger == nil {
t.Error("Logger should not be nil")
}
}
// TestWebviewToolHandlers_RequiresChrome demonstrates the CI guard
// for tests that would require a running Chrome instance. Any future
// test that calls webview tool handlers (webviewConnect, webviewNavigate,
// etc.) should call skipIfShort(t) at the top.
func TestWebviewToolHandlers_RequiresChrome(t *testing.T) {
skipIfShort(t)
// This test verifies that webview tool handlers correctly reject
// calls when not connected to Chrome.
tmpDir := t.TempDir()
s, err := New(WithWorkspaceRoot(tmpDir))
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
ctx := t.Context()
// webview_navigate should fail without a connection
_, _, err = s.webviewNavigate(ctx, nil, WebviewNavigateInput{URL: "https://example.com"})
if err == nil {
t.Error("Expected error when navigating without a webview connection")
}
// webview_click should fail without a connection
_, _, err = s.webviewClick(ctx, nil, WebviewClickInput{Selector: "#btn"})
if err == nil {
t.Error("Expected error when clicking without a webview connection")
}
// webview_eval should fail without a connection
_, _, err = s.webviewEval(ctx, nil, WebviewEvalInput{Script: "1+1"})
if err == nil {
t.Error("Expected error when evaluating without a webview connection")
}
// webview_connect with invalid URL should fail
_, _, err = s.webviewConnect(ctx, nil, WebviewConnectInput{DebugURL: ""})
if err == nil {
t.Error("Expected error when connecting with empty debug URL")
}
}
// TestWebviewConnectInput_Good verifies the WebviewConnectInput struct has expected fields.
func TestWebviewConnectInput_Good(t *testing.T) {
input := WebviewConnectInput{
DebugURL: "http://localhost:9222",
Timeout: 30,
}
if input.DebugURL != "http://localhost:9222" {
t.Errorf("Expected debug_url 'http://localhost:9222', got %q", input.DebugURL)
}
if input.Timeout != 30 {
t.Errorf("Expected timeout 30, got %d", input.Timeout)
}
}
// TestWebviewNavigateInput_Good verifies the WebviewNavigateInput struct has expected fields.
func TestWebviewNavigateInput_Good(t *testing.T) {
input := WebviewNavigateInput{
URL: "https://example.com",
}
if input.URL != "https://example.com" {
t.Errorf("Expected URL 'https://example.com', got %q", input.URL)
}
}
// TestWebviewClickInput_Good verifies the WebviewClickInput struct has expected fields.
func TestWebviewClickInput_Good(t *testing.T) {
input := WebviewClickInput{
Selector: "#submit-button",
}
if input.Selector != "#submit-button" {
t.Errorf("Expected selector '#submit-button', got %q", input.Selector)
}
}
// TestWebviewTypeInput_Good verifies the WebviewTypeInput struct has expected fields.
func TestWebviewTypeInput_Good(t *testing.T) {
input := WebviewTypeInput{
Selector: "#email-input",
Text: "test@example.com",
}
if input.Selector != "#email-input" {
t.Errorf("Expected selector '#email-input', got %q", input.Selector)
}
if input.Text != "test@example.com" {
t.Errorf("Expected text 'test@example.com', got %q", input.Text)
}
}
// TestWebviewQueryInput_Good verifies the WebviewQueryInput struct has expected fields.
func TestWebviewQueryInput_Good(t *testing.T) {
input := WebviewQueryInput{
Selector: "div.container",
All: true,
}
if input.Selector != "div.container" {
t.Errorf("Expected selector 'div.container', got %q", input.Selector)
}
if !input.All {
t.Error("Expected all to be true")
}
}
// TestWebviewQueryInput_Defaults verifies default values are handled correctly.
func TestWebviewQueryInput_Defaults(t *testing.T) {
input := WebviewQueryInput{
Selector: ".test",
}
if input.All {
t.Error("Expected all to default to false")
}
}
// TestWebviewConsoleInput_Good verifies the WebviewConsoleInput struct has expected fields.
func TestWebviewConsoleInput_Good(t *testing.T) {
input := WebviewConsoleInput{
Clear: true,
}
if !input.Clear {
t.Error("Expected clear to be true")
}
}
// TestWebviewEvalInput_Good verifies the WebviewEvalInput struct has expected fields.
func TestWebviewEvalInput_Good(t *testing.T) {
input := WebviewEvalInput{
Script: "document.title",
}
if input.Script != "document.title" {
t.Errorf("Expected script 'document.title', got %q", input.Script)
}
}
// TestWebviewScreenshotInput_Good verifies the WebviewScreenshotInput struct has expected fields.
func TestWebviewScreenshotInput_Good(t *testing.T) {
input := WebviewScreenshotInput{
Format: "png",
}
if input.Format != "png" {
t.Errorf("Expected format 'png', got %q", input.Format)
}
}
// TestWebviewScreenshotInput_Defaults verifies default values are handled correctly.
func TestWebviewScreenshotInput_Defaults(t *testing.T) {
input := WebviewScreenshotInput{}
if input.Format != "" {
t.Errorf("Expected format to default to empty, got %q", input.Format)
}
}
// TestWebviewWaitInput_Good verifies the WebviewWaitInput struct has expected fields.
func TestWebviewWaitInput_Good(t *testing.T) {
input := WebviewWaitInput{
Selector: "#loading",
Timeout: 10,
}
if input.Selector != "#loading" {
t.Errorf("Expected selector '#loading', got %q", input.Selector)
}
if input.Timeout != 10 {
t.Errorf("Expected timeout 10, got %d", input.Timeout)
}
}
// TestWebviewConnectOutput_Good verifies the WebviewConnectOutput struct has expected fields.
func TestWebviewConnectOutput_Good(t *testing.T) {
output := WebviewConnectOutput{
Success: true,
Message: "Connected to Chrome DevTools",
}
if !output.Success {
t.Error("Expected success to be true")
}
if output.Message == "" {
t.Error("Expected message to be set")
}
}
// TestWebviewNavigateOutput_Good verifies the WebviewNavigateOutput struct has expected fields.
func TestWebviewNavigateOutput_Good(t *testing.T) {
output := WebviewNavigateOutput{
Success: true,
URL: "https://example.com",
}
if !output.Success {
t.Error("Expected success to be true")
}
if output.URL != "https://example.com" {
t.Errorf("Expected URL 'https://example.com', got %q", output.URL)
}
}
// TestWebviewQueryOutput_Good verifies the WebviewQueryOutput struct has expected fields.
func TestWebviewQueryOutput_Good(t *testing.T) {
output := WebviewQueryOutput{
Found: true,
Count: 3,
Elements: []WebviewElementInfo{
{
NodeID: 1,
TagName: "DIV",
Attributes: map[string]string{
"class": "container",
},
},
},
}
if !output.Found {
t.Error("Expected found to be true")
}
if output.Count != 3 {
t.Errorf("Expected count 3, got %d", output.Count)
}
if len(output.Elements) != 1 {
t.Fatalf("Expected 1 element, got %d", len(output.Elements))
}
if output.Elements[0].TagName != "DIV" {
t.Errorf("Expected tagName 'DIV', got %q", output.Elements[0].TagName)
}
}
// TestWebviewConsoleOutput_Good verifies the WebviewConsoleOutput struct has expected fields.
func TestWebviewConsoleOutput_Good(t *testing.T) {
output := WebviewConsoleOutput{
Messages: []WebviewConsoleMessage{
{
Type: "log",
Text: "Hello, world!",
Timestamp: "2024-01-01T00:00:00Z",
},
{
Type: "error",
Text: "An error occurred",
Timestamp: "2024-01-01T00:00:01Z",
URL: "https://example.com/script.js",
Line: 42,
},
},
Count: 2,
}
if output.Count != 2 {
t.Errorf("Expected count 2, got %d", output.Count)
}
if len(output.Messages) != 2 {
t.Fatalf("Expected 2 messages, got %d", len(output.Messages))
}
if output.Messages[0].Type != "log" {
t.Errorf("Expected type 'log', got %q", output.Messages[0].Type)
}
if output.Messages[1].Line != 42 {
t.Errorf("Expected line 42, got %d", output.Messages[1].Line)
}
}
// TestWebviewEvalOutput_Good verifies the WebviewEvalOutput struct has expected fields.
func TestWebviewEvalOutput_Good(t *testing.T) {
output := WebviewEvalOutput{
Success: true,
Result: "Example Page",
}
if !output.Success {
t.Error("Expected success to be true")
}
if output.Result != "Example Page" {
t.Errorf("Expected result 'Example Page', got %v", output.Result)
}
}
// TestWebviewEvalOutput_Error verifies the WebviewEvalOutput struct handles errors.
func TestWebviewEvalOutput_Error(t *testing.T) {
output := WebviewEvalOutput{
Success: false,
Error: "ReferenceError: foo is not defined",
}
if output.Success {
t.Error("Expected success to be false")
}
if output.Error == "" {
t.Error("Expected error message to be set")
}
}
// TestWebviewScreenshotOutput_Good verifies the WebviewScreenshotOutput struct has expected fields.
func TestWebviewScreenshotOutput_Good(t *testing.T) {
output := WebviewScreenshotOutput{
Success: true,
Data: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
Format: "png",
}
if !output.Success {
t.Error("Expected success to be true")
}
if output.Data == "" {
t.Error("Expected data to be set")
}
if output.Format != "png" {
t.Errorf("Expected format 'png', got %q", output.Format)
}
}
// TestWebviewElementInfo_Good verifies the WebviewElementInfo struct has expected fields.
func TestWebviewElementInfo_Good(t *testing.T) {
elem := WebviewElementInfo{
NodeID: 123,
TagName: "INPUT",
Attributes: map[string]string{
"type": "text",
"name": "email",
"class": "form-control",
},
BoundingBox: &webview.BoundingBox{
X: 100,
Y: 200,
Width: 300,
Height: 50,
},
}
if elem.NodeID != 123 {
t.Errorf("Expected nodeId 123, got %d", elem.NodeID)
}
if elem.TagName != "INPUT" {
t.Errorf("Expected tagName 'INPUT', got %q", elem.TagName)
}
if elem.Attributes["type"] != "text" {
t.Errorf("Expected type attribute 'text', got %q", elem.Attributes["type"])
}
if elem.BoundingBox == nil {
t.Fatal("Expected bounding box to be set")
}
if elem.BoundingBox.Width != 300 {
t.Errorf("Expected width 300, got %f", elem.BoundingBox.Width)
}
}
// TestWebviewConsoleMessage_Good verifies the WebviewConsoleMessage struct has expected fields.
func TestWebviewConsoleMessage_Good(t *testing.T) {
msg := WebviewConsoleMessage{
Type: "error",
Text: "Failed to load resource",
Timestamp: time.Now().Format(time.RFC3339),
URL: "https://example.com/api/data",
Line: 1,
}
if msg.Type != "error" {
t.Errorf("Expected type 'error', got %q", msg.Type)
}
if msg.Text == "" {
t.Error("Expected text to be set")
}
if msg.URL == "" {
t.Error("Expected URL to be set")
}
}
// TestWebviewDisconnectInput_Good verifies the WebviewDisconnectInput struct exists.
func TestWebviewDisconnectInput_Good(t *testing.T) {
// WebviewDisconnectInput has no fields
input := WebviewDisconnectInput{}
_ = input // Just verify the struct exists
}
// TestWebviewDisconnectOutput_Good verifies the WebviewDisconnectOutput struct has expected fields.
func TestWebviewDisconnectOutput_Good(t *testing.T) {
output := WebviewDisconnectOutput{
Success: true,
Message: "Disconnected from Chrome DevTools",
}
if !output.Success {
t.Error("Expected success to be true")
}
if output.Message == "" {
t.Error("Expected message to be set")
}
}
// TestWebviewWaitOutput_Good verifies the WebviewWaitOutput struct has expected fields.
func TestWebviewWaitOutput_Good(t *testing.T) {
output := WebviewWaitOutput{
Success: true,
Message: "Element found: #login-form",
}
if !output.Success {
t.Error("Expected success to be true")
}
if output.Message == "" {
t.Error("Expected message to be set")
}
}

View file

@ -1,142 +0,0 @@
package mcp
import (
"context"
"fmt"
"net"
"net/http"
"forge.lthn.ai/core/go-log"
"forge.lthn.ai/core/go-ws"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// WSStartInput contains parameters for starting the WebSocket server.
type WSStartInput struct {
Addr string `json:"addr,omitempty"` // Address to listen on (default: ":8080")
}
// WSStartOutput contains the result of starting the WebSocket server.
type WSStartOutput struct {
Success bool `json:"success"`
Addr string `json:"addr"`
Message string `json:"message,omitempty"`
}
// WSInfoInput contains parameters for getting WebSocket hub info.
type WSInfoInput struct{}
// WSInfoOutput contains WebSocket hub statistics.
type WSInfoOutput struct {
Clients int `json:"clients"`
Channels int `json:"channels"`
}
// registerWSTools adds WebSocket tools to the MCP server.
// Returns false if WebSocket hub is not available.
func (s *Service) registerWSTools(server *mcp.Server) bool {
if s.wsHub == nil {
return false
}
mcp.AddTool(server, &mcp.Tool{
Name: "ws_start",
Description: "Start the WebSocket server for real-time process output streaming.",
}, s.wsStart)
mcp.AddTool(server, &mcp.Tool{
Name: "ws_info",
Description: "Get WebSocket hub statistics (connected clients and active channels).",
}, s.wsInfo)
return true
}
// wsStart handles the ws_start tool call.
func (s *Service) wsStart(ctx context.Context, req *mcp.CallToolRequest, input WSStartInput) (*mcp.CallToolResult, WSStartOutput, error) {
addr := input.Addr
if addr == "" {
addr = ":8080"
}
s.logger.Security("MCP tool execution", "tool", "ws_start", "addr", addr, "user", log.Username())
// Check if server is already running
if s.wsServer != nil {
return nil, WSStartOutput{
Success: true,
Addr: s.wsAddr,
Message: "WebSocket server already running",
}, nil
}
// Create HTTP server with WebSocket handler
mux := http.NewServeMux()
mux.HandleFunc("/ws", s.wsHub.Handler())
server := &http.Server{
Addr: addr,
Handler: mux,
}
// Start listener to get actual address
ln, err := net.Listen("tcp", addr)
if err != nil {
log.Error("mcp: ws start listen failed", "addr", addr, "err", err)
return nil, WSStartOutput{}, fmt.Errorf("failed to listen on %s: %w", addr, err)
}
actualAddr := ln.Addr().String()
s.wsServer = server
s.wsAddr = actualAddr
// Start server in background
go func() {
if err := server.Serve(ln); err != nil && err != http.ErrServerClosed {
log.Error("mcp: ws server error", "err", err)
}
}()
return nil, WSStartOutput{
Success: true,
Addr: actualAddr,
Message: fmt.Sprintf("WebSocket server started at ws://%s/ws", actualAddr),
}, nil
}
// wsInfo handles the ws_info tool call.
func (s *Service) wsInfo(ctx context.Context, req *mcp.CallToolRequest, input WSInfoInput) (*mcp.CallToolResult, WSInfoOutput, error) {
s.logger.Info("MCP tool execution", "tool", "ws_info", "user", log.Username())
stats := s.wsHub.Stats()
return nil, WSInfoOutput{
Clients: stats.Clients,
Channels: stats.Channels,
}, nil
}
// ProcessEventCallback is a callback function for process events.
// It can be registered with the process service to forward events to WebSocket.
type ProcessEventCallback struct {
hub *ws.Hub
}
// NewProcessEventCallback creates a callback that forwards process events to WebSocket.
func NewProcessEventCallback(hub *ws.Hub) *ProcessEventCallback {
return &ProcessEventCallback{hub: hub}
}
// OnProcessOutput forwards process output to WebSocket subscribers.
func (c *ProcessEventCallback) OnProcessOutput(processID string, line string) {
if c.hub != nil {
_ = c.hub.SendProcessOutput(processID, line)
}
}
// OnProcessStatus forwards process status changes to WebSocket subscribers.
func (c *ProcessEventCallback) OnProcessStatus(processID string, status string, exitCode int) {
if c.hub != nil {
_ = c.hub.SendProcessStatus(processID, status, exitCode)
}
}

View file

@ -1,174 +0,0 @@
package mcp
import (
"testing"
"forge.lthn.ai/core/go-ws"
)
// TestWSToolsRegistered_Good verifies that WebSocket tools are registered when hub is available.
func TestWSToolsRegistered_Good(t *testing.T) {
// Create a new MCP service without ws hub - tools should not be registered
s, err := New()
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
if s.wsHub != nil {
t.Error("WS hub should be nil by default")
}
if s.server == nil {
t.Fatal("Server should not be nil")
}
}
// TestWSStartInput_Good verifies the WSStartInput struct has expected fields.
func TestWSStartInput_Good(t *testing.T) {
input := WSStartInput{
Addr: ":9090",
}
if input.Addr != ":9090" {
t.Errorf("Expected addr ':9090', got %q", input.Addr)
}
}
// TestWSStartInput_Defaults verifies default values.
func TestWSStartInput_Defaults(t *testing.T) {
input := WSStartInput{}
if input.Addr != "" {
t.Errorf("Expected addr to default to empty, got %q", input.Addr)
}
}
// TestWSStartOutput_Good verifies the WSStartOutput struct has expected fields.
func TestWSStartOutput_Good(t *testing.T) {
output := WSStartOutput{
Success: true,
Addr: "127.0.0.1:8080",
Message: "WebSocket server started",
}
if !output.Success {
t.Error("Expected Success to be true")
}
if output.Addr != "127.0.0.1:8080" {
t.Errorf("Expected addr '127.0.0.1:8080', got %q", output.Addr)
}
if output.Message != "WebSocket server started" {
t.Errorf("Expected message 'WebSocket server started', got %q", output.Message)
}
}
// TestWSInfoInput_Good verifies the WSInfoInput struct exists (it's empty).
func TestWSInfoInput_Good(t *testing.T) {
input := WSInfoInput{}
_ = input // Just verify it compiles
}
// TestWSInfoOutput_Good verifies the WSInfoOutput struct has expected fields.
func TestWSInfoOutput_Good(t *testing.T) {
output := WSInfoOutput{
Clients: 5,
Channels: 3,
}
if output.Clients != 5 {
t.Errorf("Expected clients 5, got %d", output.Clients)
}
if output.Channels != 3 {
t.Errorf("Expected channels 3, got %d", output.Channels)
}
}
// TestWithWSHub_Good verifies the WithWSHub option.
func TestWithWSHub_Good(t *testing.T) {
hub := ws.NewHub()
s, err := New(WithWSHub(hub))
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
if s.wsHub != hub {
t.Error("Expected wsHub to be set")
}
}
// TestWithWSHub_Nil verifies the WithWSHub option with nil.
func TestWithWSHub_Nil(t *testing.T) {
s, err := New(WithWSHub(nil))
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
if s.wsHub != nil {
t.Error("Expected wsHub to be nil when passed nil")
}
}
// TestProcessEventCallback_Good verifies the ProcessEventCallback struct.
func TestProcessEventCallback_Good(t *testing.T) {
hub := ws.NewHub()
callback := NewProcessEventCallback(hub)
if callback.hub != hub {
t.Error("Expected callback hub to be set")
}
// Test that methods don't panic
callback.OnProcessOutput("proc-1", "test output")
callback.OnProcessStatus("proc-1", "exited", 0)
}
// TestProcessEventCallback_NilHub verifies the ProcessEventCallback with nil hub doesn't panic.
func TestProcessEventCallback_NilHub(t *testing.T) {
callback := NewProcessEventCallback(nil)
if callback.hub != nil {
t.Error("Expected callback hub to be nil")
}
// Test that methods don't panic with nil hub
callback.OnProcessOutput("proc-1", "test output")
callback.OnProcessStatus("proc-1", "exited", 0)
}
// TestServiceWSHub_Good verifies the WSHub getter method.
func TestServiceWSHub_Good(t *testing.T) {
hub := ws.NewHub()
s, err := New(WithWSHub(hub))
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
if s.WSHub() != hub {
t.Error("Expected WSHub() to return the hub")
}
}
// TestServiceWSHub_Nil verifies the WSHub getter returns nil when not configured.
func TestServiceWSHub_Nil(t *testing.T) {
s, err := New()
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
if s.WSHub() != nil {
t.Error("Expected WSHub() to return nil when not configured")
}
}
// TestServiceProcessService_Nil verifies the ProcessService getter returns nil when not configured.
func TestServiceProcessService_Nil(t *testing.T) {
s, err := New()
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
if s.ProcessService() != nil {
t.Error("Expected ProcessService() to return nil when not configured")
}
}

View file

@ -1,742 +0,0 @@
package mcp
import (
"bufio"
"encoding/json"
"fmt"
"net"
"os"
"path/filepath"
"strings"
"testing"
"time"
"context"
)
// jsonRPCRequest builds a raw JSON-RPC 2.0 request string with newline delimiter.
func jsonRPCRequest(id int, method string, params any) string {
msg := map[string]any{
"jsonrpc": "2.0",
"id": id,
"method": method,
}
if params != nil {
msg["params"] = params
}
data, _ := json.Marshal(msg)
return string(data) + "\n"
}
// jsonRPCNotification builds a raw JSON-RPC 2.0 notification (no id).
func jsonRPCNotification(method string) string {
msg := map[string]any{
"jsonrpc": "2.0",
"method": method,
}
data, _ := json.Marshal(msg)
return string(data) + "\n"
}
// readJSONRPCResponse reads a single line-delimited JSON-RPC response and
// returns the decoded map. It handles the case where the server sends a
// ping request interleaved with responses (responds to it and keeps reading).
func readJSONRPCResponse(t *testing.T, scanner *bufio.Scanner, conn net.Conn) map[string]any {
t.Helper()
for {
if !scanner.Scan() {
if err := scanner.Err(); err != nil {
t.Fatalf("scanner error: %v", err)
}
t.Fatal("unexpected EOF reading JSON-RPC response")
}
line := scanner.Text()
var msg map[string]any
if err := json.Unmarshal([]byte(line), &msg); err != nil {
t.Fatalf("failed to unmarshal response: %v\nraw: %s", err, line)
}
// If this is a server-initiated request (e.g. ping), respond and keep reading.
if method, ok := msg["method"]; ok {
if id, hasID := msg["id"]; hasID {
resp := map[string]any{
"jsonrpc": "2.0",
"id": id,
"result": map[string]any{},
}
data, _ := json.Marshal(resp)
_, _ = conn.Write(append(data, '\n'))
_ = method // consume
continue
}
// Notification from server — ignore and keep reading
continue
}
return msg
}
}
// --- TCP E2E Tests ---
func TestTCPTransport_E2E_FullRoundTrip(t *testing.T) {
// Create a temp workspace with a known file
tmpDir := t.TempDir()
testContent := "hello from tcp e2e test"
if err := os.WriteFile(filepath.Join(tmpDir, "test.txt"), []byte(testContent), 0644); err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
s, err := New(WithWorkspaceRoot(tmpDir))
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Start TCP server on a random port
errCh := make(chan error, 1)
go func() {
errCh <- s.ServeTCP(ctx, "127.0.0.1:0")
}()
// Wait for the server to start and get the actual address.
// ServeTCP creates its own listener internally, so we need to probe.
// We'll retry connecting for up to 2 seconds.
var conn net.Conn
deadline := time.Now().Add(2 * time.Second)
// Since ServeTCP binds :0, we can't predict the port. Instead, create
// our own listener to find a free port, close it, then pass that port
// to ServeTCP. This is a known race, but fine for tests.
cancel()
<-errCh
// Restart with a known port: find a free port first
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to find free port: %v", err)
}
addr := ln.Addr().String()
ln.Close()
ctx2, cancel2 := context.WithCancel(context.Background())
defer cancel2()
errCh2 := make(chan error, 1)
go func() {
errCh2 <- s.ServeTCP(ctx2, addr)
}()
// Wait for server to accept connections
deadline = time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
conn, err = net.DialTimeout("tcp", addr, 200*time.Millisecond)
if err == nil {
break
}
time.Sleep(50 * time.Millisecond)
}
if err != nil {
t.Fatalf("Failed to connect to TCP server at %s: %v", addr, err)
}
defer conn.Close()
// Set a read deadline to avoid hanging
conn.SetDeadline(time.Now().Add(10 * time.Second))
scanner := bufio.NewScanner(conn)
scanner.Buffer(make([]byte, 64*1024), 10*1024*1024)
// Step 1: Send initialize request
initReq := jsonRPCRequest(1, "initialize", map[string]any{
"protocolVersion": "2024-11-05",
"capabilities": map[string]any{},
"clientInfo": map[string]any{"name": "TestClient", "version": "1.0.0"},
})
if _, err := conn.Write([]byte(initReq)); err != nil {
t.Fatalf("Failed to send initialize: %v", err)
}
// Read initialize response
initResp := readJSONRPCResponse(t, scanner, conn)
if initResp["error"] != nil {
t.Fatalf("Initialize returned error: %v", initResp["error"])
}
result, ok := initResp["result"].(map[string]any)
if !ok {
t.Fatalf("Expected result object, got %T", initResp["result"])
}
serverInfo, _ := result["serverInfo"].(map[string]any)
if serverInfo["name"] != "core-cli" {
t.Errorf("Expected server name 'core-cli', got %v", serverInfo["name"])
}
// Step 2: Send notifications/initialized
if _, err := conn.Write([]byte(jsonRPCNotification("notifications/initialized"))); err != nil {
t.Fatalf("Failed to send initialized notification: %v", err)
}
// Step 3: Send tools/list
if _, err := conn.Write([]byte(jsonRPCRequest(2, "tools/list", nil))); err != nil {
t.Fatalf("Failed to send tools/list: %v", err)
}
toolsResp := readJSONRPCResponse(t, scanner, conn)
if toolsResp["error"] != nil {
t.Fatalf("tools/list returned error: %v", toolsResp["error"])
}
toolsResult, ok := toolsResp["result"].(map[string]any)
if !ok {
t.Fatalf("Expected result object for tools/list, got %T", toolsResp["result"])
}
tools, ok := toolsResult["tools"].([]any)
if !ok || len(tools) == 0 {
t.Fatal("Expected non-empty tools list")
}
// Verify file_read is among the tools
foundFileRead := false
for _, tool := range tools {
toolMap, _ := tool.(map[string]any)
if toolMap["name"] == "file_read" {
foundFileRead = true
break
}
}
if !foundFileRead {
t.Error("Expected file_read tool in tools/list response")
}
// Step 4: Call file_read
callReq := jsonRPCRequest(3, "tools/call", map[string]any{
"name": "file_read",
"arguments": map[string]any{"path": "test.txt"},
})
if _, err := conn.Write([]byte(callReq)); err != nil {
t.Fatalf("Failed to send tools/call: %v", err)
}
callResp := readJSONRPCResponse(t, scanner, conn)
if callResp["error"] != nil {
t.Fatalf("tools/call file_read returned error: %v", callResp["error"])
}
callResult, ok := callResp["result"].(map[string]any)
if !ok {
t.Fatalf("Expected result object for tools/call, got %T", callResp["result"])
}
// The MCP SDK wraps tool results in content array
content, ok := callResult["content"].([]any)
if !ok || len(content) == 0 {
t.Fatal("Expected non-empty content in tools/call response")
}
firstContent, _ := content[0].(map[string]any)
text, _ := firstContent["text"].(string)
if !strings.Contains(text, testContent) {
t.Errorf("Expected file content to contain %q, got %q", testContent, text)
}
// Graceful shutdown
cancel2()
err = <-errCh2
if err != nil {
t.Errorf("ServeTCP returned error: %v", err)
}
}
func TestTCPTransport_E2E_FileWrite(t *testing.T) {
tmpDir := t.TempDir()
s, err := New(WithWorkspaceRoot(tmpDir))
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Find free port
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to find free port: %v", err)
}
addr := ln.Addr().String()
ln.Close()
errCh := make(chan error, 1)
go func() {
errCh <- s.ServeTCP(ctx, addr)
}()
// Connect
var conn net.Conn
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
conn, err = net.DialTimeout("tcp", addr, 200*time.Millisecond)
if err == nil {
break
}
time.Sleep(50 * time.Millisecond)
}
if err != nil {
t.Fatalf("Failed to connect: %v", err)
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(10 * time.Second))
scanner := bufio.NewScanner(conn)
scanner.Buffer(make([]byte, 64*1024), 10*1024*1024)
// Initialize handshake
conn.Write([]byte(jsonRPCRequest(1, "initialize", map[string]any{
"protocolVersion": "2024-11-05",
"capabilities": map[string]any{},
"clientInfo": map[string]any{"name": "TestClient", "version": "1.0.0"},
})))
readJSONRPCResponse(t, scanner, conn)
conn.Write([]byte(jsonRPCNotification("notifications/initialized")))
// Write a file
writeContent := "written via tcp transport"
conn.Write([]byte(jsonRPCRequest(2, "tools/call", map[string]any{
"name": "file_write",
"arguments": map[string]any{"path": "tcp-written.txt", "content": writeContent},
})))
writeResp := readJSONRPCResponse(t, scanner, conn)
if writeResp["error"] != nil {
t.Fatalf("file_write returned error: %v", writeResp["error"])
}
// Verify file on disk
diskContent, err := os.ReadFile(filepath.Join(tmpDir, "tcp-written.txt"))
if err != nil {
t.Fatalf("Failed to read written file: %v", err)
}
if string(diskContent) != writeContent {
t.Errorf("Expected %q on disk, got %q", writeContent, string(diskContent))
}
cancel()
<-errCh
}
// --- Unix Socket E2E Tests ---
// shortSocketPath returns a Unix socket path under /tmp that fits within
// the macOS 104-byte sun_path limit. t.TempDir() paths on macOS are
// often too long (>104 bytes) for Unix sockets.
func shortSocketPath(t *testing.T, suffix string) string {
t.Helper()
path := fmt.Sprintf("/tmp/mcp-test-%s-%d.sock", suffix, os.Getpid())
t.Cleanup(func() { os.Remove(path) })
return path
}
func TestUnixTransport_E2E_FullRoundTrip(t *testing.T) {
// Create a temp workspace with a known file
tmpDir := t.TempDir()
testContent := "hello from unix e2e test"
if err := os.WriteFile(filepath.Join(tmpDir, "test.txt"), []byte(testContent), 0644); err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
s, err := New(WithWorkspaceRoot(tmpDir))
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Use a short socket path to avoid macOS 104-byte sun_path limit
socketPath := shortSocketPath(t, "full")
errCh := make(chan error, 1)
go func() {
errCh <- s.ServeUnix(ctx, socketPath)
}()
// Wait for socket to appear
var conn net.Conn
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
conn, err = net.DialTimeout("unix", socketPath, 200*time.Millisecond)
if err == nil {
break
}
time.Sleep(50 * time.Millisecond)
}
if err != nil {
t.Fatalf("Failed to connect to Unix socket at %s: %v", socketPath, err)
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(10 * time.Second))
scanner := bufio.NewScanner(conn)
scanner.Buffer(make([]byte, 64*1024), 10*1024*1024)
// Step 1: Initialize
conn.Write([]byte(jsonRPCRequest(1, "initialize", map[string]any{
"protocolVersion": "2024-11-05",
"capabilities": map[string]any{},
"clientInfo": map[string]any{"name": "TestClient", "version": "1.0.0"},
})))
initResp := readJSONRPCResponse(t, scanner, conn)
if initResp["error"] != nil {
t.Fatalf("Initialize returned error: %v", initResp["error"])
}
// Step 2: Send initialized notification
conn.Write([]byte(jsonRPCNotification("notifications/initialized")))
// Step 3: tools/list
conn.Write([]byte(jsonRPCRequest(2, "tools/list", nil)))
toolsResp := readJSONRPCResponse(t, scanner, conn)
if toolsResp["error"] != nil {
t.Fatalf("tools/list returned error: %v", toolsResp["error"])
}
toolsResult, ok := toolsResp["result"].(map[string]any)
if !ok {
t.Fatalf("Expected result object, got %T", toolsResp["result"])
}
tools, ok := toolsResult["tools"].([]any)
if !ok || len(tools) == 0 {
t.Fatal("Expected non-empty tools list")
}
// Step 4: Call file_read
conn.Write([]byte(jsonRPCRequest(3, "tools/call", map[string]any{
"name": "file_read",
"arguments": map[string]any{"path": "test.txt"},
})))
callResp := readJSONRPCResponse(t, scanner, conn)
if callResp["error"] != nil {
t.Fatalf("tools/call file_read returned error: %v", callResp["error"])
}
callResult, ok := callResp["result"].(map[string]any)
if !ok {
t.Fatalf("Expected result object, got %T", callResp["result"])
}
content, ok := callResult["content"].([]any)
if !ok || len(content) == 0 {
t.Fatal("Expected non-empty content")
}
firstContent, _ := content[0].(map[string]any)
text, _ := firstContent["text"].(string)
if !strings.Contains(text, testContent) {
t.Errorf("Expected content to contain %q, got %q", testContent, text)
}
// Graceful shutdown
cancel()
err = <-errCh
if err != nil {
t.Errorf("ServeUnix returned error: %v", err)
}
// Verify socket file is cleaned up
if _, err := os.Stat(socketPath); !os.IsNotExist(err) {
t.Error("Expected socket file to be cleaned up after shutdown")
}
}
func TestUnixTransport_E2E_DirList(t *testing.T) {
tmpDir := t.TempDir()
// Create some files and dirs
os.MkdirAll(filepath.Join(tmpDir, "subdir"), 0755)
os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("one"), 0644)
os.WriteFile(filepath.Join(tmpDir, "subdir", "file2.txt"), []byte("two"), 0644)
s, err := New(WithWorkspaceRoot(tmpDir))
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
socketPath := shortSocketPath(t, "dir")
errCh := make(chan error, 1)
go func() {
errCh <- s.ServeUnix(ctx, socketPath)
}()
var conn net.Conn
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
conn, err = net.DialTimeout("unix", socketPath, 200*time.Millisecond)
if err == nil {
break
}
time.Sleep(50 * time.Millisecond)
}
if err != nil {
t.Fatalf("Failed to connect: %v", err)
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(10 * time.Second))
scanner := bufio.NewScanner(conn)
scanner.Buffer(make([]byte, 64*1024), 10*1024*1024)
// Initialize
conn.Write([]byte(jsonRPCRequest(1, "initialize", map[string]any{
"protocolVersion": "2024-11-05",
"capabilities": map[string]any{},
"clientInfo": map[string]any{"name": "TestClient", "version": "1.0.0"},
})))
readJSONRPCResponse(t, scanner, conn)
conn.Write([]byte(jsonRPCNotification("notifications/initialized")))
// Call dir_list on root
conn.Write([]byte(jsonRPCRequest(2, "tools/call", map[string]any{
"name": "dir_list",
"arguments": map[string]any{"path": "."},
})))
dirResp := readJSONRPCResponse(t, scanner, conn)
if dirResp["error"] != nil {
t.Fatalf("dir_list returned error: %v", dirResp["error"])
}
dirResult, ok := dirResp["result"].(map[string]any)
if !ok {
t.Fatalf("Expected result object, got %T", dirResp["result"])
}
dirContent, ok := dirResult["content"].([]any)
if !ok || len(dirContent) == 0 {
t.Fatal("Expected non-empty content in dir_list response")
}
// The response content should mention our files
firstItem, _ := dirContent[0].(map[string]any)
text, _ := firstItem["text"].(string)
if !strings.Contains(text, "file1.txt") && !strings.Contains(text, "subdir") {
t.Errorf("Expected dir_list to contain file1.txt or subdir, got: %s", text)
}
cancel()
<-errCh
}
// --- Stdio Transport Tests ---
func TestStdioTransport_Documented_Skip(t *testing.T) {
// The MCP SDK's StdioTransport binds directly to os.Stdin and os.Stdout,
// with no way to inject custom io.Reader/io.Writer. Testing stdio transport
// would require spawning the binary as a subprocess and piping JSON-RPC
// messages through its stdin/stdout.
//
// Since the core MCP protocol handling is identical across all transports
// (the transport layer only handles framing), and we thoroughly test the
// protocol via TCP and Unix socket e2e tests, the stdio transport is
// effectively covered. The only untested code path is the StdioTransport
// adapter itself, which is a thin wrapper in the upstream SDK.
//
// If process-level testing is needed in the future, the approach would be:
// 1. Build the binary: `go build -o /tmp/mcp-test ./cmd/...`
// 2. Spawn it: exec.Command("/tmp/mcp-test", "mcp", "serve")
// 3. Pipe JSON-RPC to stdin, read from stdout
// 4. Verify responses match expected protocol
t.Skip("stdio transport requires process spawning; protocol is covered by TCP and Unix e2e tests")
}
// --- Helper: verify a specific tool exists in tools/list response ---
func assertToolExists(t *testing.T, tools []any, name string) {
t.Helper()
for _, tool := range tools {
toolMap, _ := tool.(map[string]any)
if toolMap["name"] == name {
return
}
}
toolNames := make([]string, 0, len(tools))
for _, tool := range tools {
toolMap, _ := tool.(map[string]any)
if n, ok := toolMap["name"].(string); ok {
toolNames = append(toolNames, n)
}
}
t.Errorf("Expected tool %q in list, got: %v", name, toolNames)
}
func TestTCPTransport_E2E_ToolsDiscovery(t *testing.T) {
tmpDir := t.TempDir()
s, err := New(WithWorkspaceRoot(tmpDir))
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to find free port: %v", err)
}
addr := ln.Addr().String()
ln.Close()
errCh := make(chan error, 1)
go func() {
errCh <- s.ServeTCP(ctx, addr)
}()
var conn net.Conn
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
conn, err = net.DialTimeout("tcp", addr, 200*time.Millisecond)
if err == nil {
break
}
time.Sleep(50 * time.Millisecond)
}
if err != nil {
t.Fatalf("Failed to connect: %v", err)
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(10 * time.Second))
scanner := bufio.NewScanner(conn)
scanner.Buffer(make([]byte, 64*1024), 10*1024*1024)
// Initialize
conn.Write([]byte(jsonRPCRequest(1, "initialize", map[string]any{
"protocolVersion": "2024-11-05",
"capabilities": map[string]any{},
"clientInfo": map[string]any{"name": "TestClient", "version": "1.0.0"},
})))
readJSONRPCResponse(t, scanner, conn)
conn.Write([]byte(jsonRPCNotification("notifications/initialized")))
// Get tools list
conn.Write([]byte(jsonRPCRequest(2, "tools/list", nil)))
toolsResp := readJSONRPCResponse(t, scanner, conn)
if toolsResp["error"] != nil {
t.Fatalf("tools/list error: %v", toolsResp["error"])
}
toolsResult, _ := toolsResp["result"].(map[string]any)
tools, _ := toolsResult["tools"].([]any)
// Verify all core tools are registered
expectedTools := []string{
"file_read", "file_write", "file_delete", "file_rename",
"file_exists", "file_edit", "dir_list", "dir_create",
"lang_detect", "lang_list",
}
for _, name := range expectedTools {
assertToolExists(t, tools, name)
}
// Log total tool count for visibility
t.Logf("Server registered %d tools", len(tools))
cancel()
<-errCh
}
func TestTCPTransport_E2E_ErrorHandling(t *testing.T) {
tmpDir := t.TempDir()
s, err := New(WithWorkspaceRoot(tmpDir))
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to find free port: %v", err)
}
addr := ln.Addr().String()
ln.Close()
errCh := make(chan error, 1)
go func() {
errCh <- s.ServeTCP(ctx, addr)
}()
var conn net.Conn
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
conn, err = net.DialTimeout("tcp", addr, 200*time.Millisecond)
if err == nil {
break
}
time.Sleep(50 * time.Millisecond)
}
if err != nil {
t.Fatalf("Failed to connect: %v", err)
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(10 * time.Second))
scanner := bufio.NewScanner(conn)
scanner.Buffer(make([]byte, 64*1024), 10*1024*1024)
// Initialize
conn.Write([]byte(jsonRPCRequest(1, "initialize", map[string]any{
"protocolVersion": "2024-11-05",
"capabilities": map[string]any{},
"clientInfo": map[string]any{"name": "TestClient", "version": "1.0.0"},
})))
readJSONRPCResponse(t, scanner, conn)
conn.Write([]byte(jsonRPCNotification("notifications/initialized")))
// Try to read a nonexistent file
conn.Write([]byte(jsonRPCRequest(2, "tools/call", map[string]any{
"name": "file_read",
"arguments": map[string]any{"path": "nonexistent.txt"},
})))
errResp := readJSONRPCResponse(t, scanner, conn)
// The MCP SDK wraps tool errors as isError content, not JSON-RPC errors.
// Check both possibilities.
if errResp["error"] != nil {
// JSON-RPC level error — this is acceptable
t.Logf("Got JSON-RPC error for nonexistent file: %v", errResp["error"])
} else {
errResult, _ := errResp["result"].(map[string]any)
isError, _ := errResult["isError"].(bool)
if !isError {
// Check content for error indicator
content, _ := errResult["content"].([]any)
if len(content) > 0 {
firstContent, _ := content[0].(map[string]any)
text, _ := firstContent["text"].(string)
t.Logf("Tool response for nonexistent file: %s", text)
}
}
}
// Verify tools/call without params returns an error
conn.Write([]byte(jsonRPCRequest(3, "tools/call", nil)))
noParamsResp := readJSONRPCResponse(t, scanner, conn)
if noParamsResp["error"] == nil {
t.Log("tools/call without params did not return JSON-RPC error (SDK may handle differently)")
} else {
errObj, _ := noParamsResp["error"].(map[string]any)
code, _ := errObj["code"].(float64)
if code != -32600 {
t.Logf("tools/call without params returned error code: %v", code)
}
}
cancel()
<-errCh
}
// Suppress "unused import" for fmt — used in helpers
var _ = fmt.Sprintf

View file

@ -1,15 +0,0 @@
package mcp
import (
"context"
"forge.lthn.ai/core/go-log"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// ServeStdio starts the MCP server over stdin/stdout.
// This is the default transport for CLI integrations.
func (s *Service) ServeStdio(ctx context.Context) error {
s.logger.Info("MCP Stdio server starting", "user", log.Username())
return s.server.Run(ctx, &mcp.StdioTransport{})
}

View file

@ -1,177 +0,0 @@
package mcp
import (
"bufio"
"context"
"fmt"
"io"
"net"
"os"
"sync"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
// DefaultTCPAddr is the default address for the MCP TCP server.
const DefaultTCPAddr = "127.0.0.1:9100"
// diagMu protects diagWriter from concurrent access across tests and goroutines.
var diagMu sync.Mutex
// diagWriter is the destination for warning and diagnostic messages.
// Use diagPrintf to write to it safely.
var diagWriter io.Writer = os.Stderr
// diagPrintf writes a formatted message to diagWriter under the mutex.
func diagPrintf(format string, args ...any) {
diagMu.Lock()
defer diagMu.Unlock()
fmt.Fprintf(diagWriter, format, args...)
}
// setDiagWriter swaps the diagnostic writer and returns the previous one.
// Used by tests to capture output without racing.
func setDiagWriter(w io.Writer) io.Writer {
diagMu.Lock()
defer diagMu.Unlock()
old := diagWriter
diagWriter = w
return old
}
// maxMCPMessageSize is the maximum size for MCP JSON-RPC messages (10 MB).
const maxMCPMessageSize = 10 * 1024 * 1024
// TCPTransport manages a TCP listener for MCP.
type TCPTransport struct {
addr string
listener net.Listener
}
// NewTCPTransport creates a new TCP transport listener.
// It listens on the provided address (e.g. "localhost:9100").
// Defaults to 127.0.0.1 when the host component is empty (e.g. ":9100").
// Emits a security warning when explicitly binding to 0.0.0.0 (all interfaces).
func NewTCPTransport(addr string) (*TCPTransport, error) {
host, port, _ := net.SplitHostPort(addr)
if host == "" {
addr = net.JoinHostPort("127.0.0.1", port)
} else if host == "0.0.0.0" {
diagPrintf("WARNING: MCP TCP server binding to all interfaces (%s). Use 127.0.0.1 for local-only access.\n", addr)
}
listener, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
return &TCPTransport{addr: addr, listener: listener}, nil
}
// ServeTCP starts a TCP server for the MCP service.
// It accepts connections and spawns a new MCP server session for each connection.
func (s *Service) ServeTCP(ctx context.Context, addr string) error {
t, err := NewTCPTransport(addr)
if err != nil {
return err
}
defer func() { _ = t.listener.Close() }()
// Close listener when context is cancelled to unblock Accept
go func() {
<-ctx.Done()
_ = t.listener.Close()
}()
if addr == "" {
addr = t.listener.Addr().String()
}
diagPrintf("MCP TCP server listening on %s\n", addr)
for {
conn, err := t.listener.Accept()
if err != nil {
select {
case <-ctx.Done():
return nil
default:
diagPrintf("Accept error: %v\n", err)
continue
}
}
go s.handleConnection(ctx, conn)
}
}
func (s *Service) handleConnection(ctx context.Context, conn net.Conn) {
// Note: We don't defer conn.Close() here because it's closed by the Server/Transport
// Create new server instance for this connection
impl := &mcp.Implementation{
Name: "core-cli",
Version: "0.1.0",
}
server := mcp.NewServer(impl, nil)
s.registerTools(server)
// Create transport for this connection
transport := &connTransport{conn: conn}
// Run server (blocks until connection closed)
// Server.Run calls Connect, then Read loop.
if err := server.Run(ctx, transport); err != nil {
diagPrintf("Connection error: %v\n", err)
}
}
// connTransport adapts net.Conn to mcp.Transport
type connTransport struct {
conn net.Conn
}
func (t *connTransport) Connect(ctx context.Context) (mcp.Connection, error) {
scanner := bufio.NewScanner(t.conn)
scanner.Buffer(make([]byte, 64*1024), maxMCPMessageSize)
return &connConnection{
conn: t.conn,
scanner: scanner,
}, nil
}
// connConnection implements mcp.Connection
type connConnection struct {
conn net.Conn
scanner *bufio.Scanner
}
func (c *connConnection) Read(ctx context.Context) (jsonrpc.Message, error) {
// Blocks until line is read
if !c.scanner.Scan() {
if err := c.scanner.Err(); err != nil {
return nil, err
}
// EOF - connection closed cleanly
return nil, io.EOF
}
line := c.scanner.Bytes()
return jsonrpc.DecodeMessage(line)
}
func (c *connConnection) Write(ctx context.Context, msg jsonrpc.Message) error {
data, err := jsonrpc.EncodeMessage(msg)
if err != nil {
return err
}
// Append newline for line-delimited JSON
data = append(data, '\n')
_, err = c.conn.Write(data)
return err
}
func (c *connConnection) Close() error {
return c.conn.Close()
}
func (c *connConnection) SessionID() string {
return "tcp-session" // Unique ID might be better, but optional
}

View file

@ -1,184 +0,0 @@
package mcp
import (
"bytes"
"context"
"net"
"os"
"strings"
"testing"
"time"
)
func TestNewTCPTransport_Defaults(t *testing.T) {
// Test that empty string gets replaced with default address constant
// Note: We can't actually bind to 9100 as it may be in use,
// so we verify the address is set correctly before Listen is called
if DefaultTCPAddr != "127.0.0.1:9100" {
t.Errorf("Expected default constant 127.0.0.1:9100, got %s", DefaultTCPAddr)
}
// Test with a dynamic port to verify transport creation works
tr, err := NewTCPTransport("127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create transport with dynamic port: %v", err)
}
defer tr.listener.Close()
// Verify we got a valid address
if tr.addr != "127.0.0.1:0" {
t.Errorf("Expected address to be set, got %s", tr.addr)
}
}
func TestNewTCPTransport_Warning(t *testing.T) {
// Capture warning output via setDiagWriter (mutex-protected, no race).
var buf bytes.Buffer
old := setDiagWriter(&buf)
defer setDiagWriter(old)
// Trigger warning
tr, err := NewTCPTransport("0.0.0.0:9101")
if err != nil {
t.Fatalf("Failed to create transport: %v", err)
}
defer tr.listener.Close()
output := buf.String()
if !strings.Contains(output, "WARNING") {
t.Error("Expected warning for binding to 0.0.0.0, but didn't find it in stderr")
}
}
func TestServeTCP_Connection(t *testing.T) {
s, err := New()
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Use a random port for testing to avoid collisions
addr := "127.0.0.1:0"
// Create transport first to get the actual address if we use :0
tr, err := NewTCPTransport(addr)
if err != nil {
t.Fatalf("Failed to create transport: %v", err)
}
actualAddr := tr.listener.Addr().String()
tr.listener.Close() // Close it so ServeTCP can re-open it or use the same address
// Start server in background
errCh := make(chan error, 1)
go func() {
errCh <- s.ServeTCP(ctx, actualAddr)
}()
// Give it a moment to start
time.Sleep(100 * time.Millisecond)
// Connect to the server
conn, err := net.Dial("tcp", actualAddr)
if err != nil {
t.Fatalf("Failed to connect to server: %v", err)
}
defer conn.Close()
// Verify we can write to it
_, err = conn.Write([]byte("{}\n"))
if err != nil {
t.Errorf("Failed to write to connection: %v", err)
}
// Shutdown server
cancel()
err = <-errCh
if err != nil {
t.Errorf("ServeTCP returned error: %v", err)
}
}
func TestRun_TCPTrigger(t *testing.T) {
s, err := New()
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Set MCP_ADDR to empty to trigger default TCP
os.Setenv("MCP_ADDR", "")
defer os.Unsetenv("MCP_ADDR")
// We use a random port for testing, but Run will try to use 127.0.0.1:9100 by default if we don't override.
// Since 9100 might be in use, we'll set MCP_ADDR to use :0 (random port)
os.Setenv("MCP_ADDR", "127.0.0.1:0")
errCh := make(chan error, 1)
go func() {
errCh <- s.Run(ctx)
}()
// Give it a moment to start
time.Sleep(100 * time.Millisecond)
// Since we can't easily get the actual port used by Run (it's internal),
// we just verify it didn't immediately fail.
select {
case err := <-errCh:
t.Fatalf("Run failed immediately: %v", err)
default:
// still running, which is good
}
cancel()
_ = <-errCh
}
func TestServeTCP_MultipleConnections(t *testing.T) {
s, err := New()
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
addr := "127.0.0.1:0"
tr, err := NewTCPTransport(addr)
if err != nil {
t.Fatalf("Failed to create transport: %v", err)
}
actualAddr := tr.listener.Addr().String()
tr.listener.Close()
errCh := make(chan error, 1)
go func() {
errCh <- s.ServeTCP(ctx, actualAddr)
}()
time.Sleep(100 * time.Millisecond)
// Connect multiple clients
const numClients = 3
for i := range numClients {
conn, err := net.Dial("tcp", actualAddr)
if err != nil {
t.Fatalf("Client %d failed to connect: %v", i, err)
}
defer conn.Close()
_, err = conn.Write([]byte("{}\n"))
if err != nil {
t.Errorf("Client %d failed to write: %v", i, err)
}
}
cancel()
err = <-errCh
if err != nil {
t.Errorf("ServeTCP returned error: %v", err)
}
}

View file

@ -1,52 +0,0 @@
package mcp
import (
"context"
"net"
"os"
"forge.lthn.ai/core/go-log"
)
// ServeUnix starts a Unix domain socket server for the MCP service.
// The socket file is created at the given path and removed on shutdown.
// It accepts connections and spawns a new MCP server session for each connection.
func (s *Service) ServeUnix(ctx context.Context, socketPath string) error {
// Clean up any stale socket file
if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
s.logger.Warn("Failed to remove stale socket", "path", socketPath, "err", err)
}
listener, err := net.Listen("unix", socketPath)
if err != nil {
return err
}
defer func() {
_ = listener.Close()
_ = os.Remove(socketPath)
}()
// Close listener when context is cancelled to unblock Accept
go func() {
<-ctx.Done()
_ = listener.Close()
}()
s.logger.Security("MCP Unix server listening", "path", socketPath, "user", log.Username())
for {
conn, err := listener.Accept()
if err != nil {
select {
case <-ctx.Done():
return nil
default:
s.logger.Error("MCP Unix accept error", "err", err, "user", log.Username())
continue
}
}
s.logger.Security("MCP Unix connection accepted", "user", log.Username())
go s.handleConnection(ctx, conn)
}
}