refactor: extract MCP server to core/mcp
Move mcp/, cmd/mcpcmd/, cmd/brain-seed/ to the new core/mcp repo. Update daemon import to use forge.lthn.ai/core/mcp/pkg/mcp. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
413c637d26
commit
0202bec84a
43 changed files with 1 additions and 10067 deletions
|
|
@ -1,502 +0,0 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
// brain-seed imports Claude Code MEMORY.md files into the OpenBrain knowledge
|
||||
// store via the MCP HTTP API (brain_remember tool). The Laravel app handles
|
||||
// embedding, Qdrant storage, and MariaDB dual-write internally.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// go run ./cmd/brain-seed -api-key YOUR_KEY
|
||||
// go run ./cmd/brain-seed -api-key YOUR_KEY -api https://lthn.sh/api/v1/mcp
|
||||
// go run ./cmd/brain-seed -api-key YOUR_KEY -dry-run
|
||||
// go run ./cmd/brain-seed -api-key YOUR_KEY -plans
|
||||
// go run ./cmd/brain-seed -api-key YOUR_KEY -claude-md # Also import CLAUDE.md files
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
apiURL = flag.String("api", "https://lthn.sh/api/v1/mcp", "MCP API base URL")
|
||||
apiKey = flag.String("api-key", "", "MCP API key (Bearer token)")
|
||||
server = flag.String("server", "hosthub-agent", "MCP server ID")
|
||||
agent = flag.String("agent", "charon", "Agent ID for attribution")
|
||||
dryRun = flag.Bool("dry-run", false, "Preview without storing")
|
||||
plans = flag.Bool("plans", false, "Also import plan documents")
|
||||
claudeMd = flag.Bool("claude-md", false, "Also import CLAUDE.md files")
|
||||
memoryPath = flag.String("memory-path", "", "Override memory scan path (default: ~/.claude/projects/*/memory/)")
|
||||
planPath = flag.String("plan-path", "", "Override plan scan path (default: ~/Code/*/docs/plans/)")
|
||||
codePath = flag.String("code-path", "", "Override code root for CLAUDE.md scan (default: ~/Code)")
|
||||
maxChars = flag.Int("max-chars", 3800, "Max chars per section (embeddinggemma limit ~4000)")
|
||||
)
|
||||
|
||||
// httpClient with TLS skip for non-public TLDs (.lthn.sh has real certs, but
|
||||
// allow .lan/.local if someone has legacy config).
|
||||
var httpClient = &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: false},
|
||||
},
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
fmt.Println("OpenBrain Seed — MCP API Client")
|
||||
fmt.Println(strings.Repeat("=", 55))
|
||||
|
||||
if *apiKey == "" && !*dryRun {
|
||||
fmt.Println("ERROR: -api-key is required (or use -dry-run)")
|
||||
fmt.Println(" Generate one at: https://lthn.sh/admin/mcp/api-keys")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if *dryRun {
|
||||
fmt.Println("[DRY RUN] — no data will be stored")
|
||||
}
|
||||
|
||||
fmt.Printf("API: %s\n", *apiURL)
|
||||
fmt.Printf("Server: %s | Agent: %s\n", *server, *agent)
|
||||
|
||||
// Discover memory files
|
||||
memPath := *memoryPath
|
||||
if memPath == "" {
|
||||
home, _ := os.UserHomeDir()
|
||||
memPath = filepath.Join(home, ".claude", "projects", "*", "memory")
|
||||
}
|
||||
memFiles, _ := filepath.Glob(filepath.Join(memPath, "*.md"))
|
||||
fmt.Printf("\nFound %d memory files\n", len(memFiles))
|
||||
|
||||
// Discover plan files
|
||||
var planFiles []string
|
||||
if *plans {
|
||||
pPath := *planPath
|
||||
if pPath == "" {
|
||||
home, _ := os.UserHomeDir()
|
||||
pPath = filepath.Join(home, "Code", "*", "docs", "plans")
|
||||
}
|
||||
planFiles, _ = filepath.Glob(filepath.Join(pPath, "*.md"))
|
||||
// Also check nested dirs (completed/, etc.)
|
||||
nested, _ := filepath.Glob(filepath.Join(pPath, "*", "*.md"))
|
||||
planFiles = append(planFiles, nested...)
|
||||
|
||||
// Also check host-uk nested repos
|
||||
home, _ := os.UserHomeDir()
|
||||
hostUkPath := filepath.Join(home, "Code", "host-uk", "*", "docs", "plans")
|
||||
hostUkFiles, _ := filepath.Glob(filepath.Join(hostUkPath, "*.md"))
|
||||
planFiles = append(planFiles, hostUkFiles...)
|
||||
hostUkNested, _ := filepath.Glob(filepath.Join(hostUkPath, "*", "*.md"))
|
||||
planFiles = append(planFiles, hostUkNested...)
|
||||
|
||||
fmt.Printf("Found %d plan files\n", len(planFiles))
|
||||
}
|
||||
|
||||
// Discover CLAUDE.md files
|
||||
var claudeFiles []string
|
||||
if *claudeMd {
|
||||
cPath := *codePath
|
||||
if cPath == "" {
|
||||
home, _ := os.UserHomeDir()
|
||||
cPath = filepath.Join(home, "Code")
|
||||
}
|
||||
claudeFiles = discoverClaudeMdFiles(cPath)
|
||||
fmt.Printf("Found %d CLAUDE.md files\n", len(claudeFiles))
|
||||
}
|
||||
|
||||
imported := 0
|
||||
skipped := 0
|
||||
errors := 0
|
||||
|
||||
// Process memory files
|
||||
fmt.Println("\n--- Memory Files ---")
|
||||
for _, f := range memFiles {
|
||||
project := extractProject(f)
|
||||
sections := parseMarkdownSections(f)
|
||||
filename := strings.TrimSuffix(filepath.Base(f), ".md")
|
||||
|
||||
if len(sections) == 0 {
|
||||
fmt.Printf(" skip %s/%s (no sections)\n", project, filename)
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
|
||||
for _, sec := range sections {
|
||||
content := sec.heading + "\n\n" + sec.content
|
||||
if strings.TrimSpace(sec.content) == "" {
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
|
||||
memType := inferType(sec.heading, sec.content, "memory")
|
||||
tags := buildTags(filename, "memory", project)
|
||||
confidence := confidenceForSource("memory")
|
||||
|
||||
// Truncate to embedding model limit
|
||||
content = truncate(content, *maxChars)
|
||||
|
||||
if *dryRun {
|
||||
fmt.Printf(" [DRY] %s/%s :: %s (%s) — %d chars\n",
|
||||
project, filename, sec.heading, memType, len(content))
|
||||
imported++
|
||||
continue
|
||||
}
|
||||
|
||||
if err := callBrainRemember(content, memType, tags, project, confidence); err != nil {
|
||||
fmt.Printf(" FAIL %s/%s :: %s — %v\n", project, filename, sec.heading, err)
|
||||
errors++
|
||||
continue
|
||||
}
|
||||
fmt.Printf(" ok %s/%s :: %s (%s)\n", project, filename, sec.heading, memType)
|
||||
imported++
|
||||
}
|
||||
}
|
||||
|
||||
// Process plan files
|
||||
if *plans && len(planFiles) > 0 {
|
||||
fmt.Println("\n--- Plan Documents ---")
|
||||
for _, f := range planFiles {
|
||||
project := extractProjectFromPlan(f)
|
||||
sections := parseMarkdownSections(f)
|
||||
filename := strings.TrimSuffix(filepath.Base(f), ".md")
|
||||
|
||||
if len(sections) == 0 {
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
|
||||
for _, sec := range sections {
|
||||
content := sec.heading + "\n\n" + sec.content
|
||||
if strings.TrimSpace(sec.content) == "" {
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
|
||||
tags := buildTags(filename, "plans", project)
|
||||
content = truncate(content, *maxChars)
|
||||
|
||||
if *dryRun {
|
||||
fmt.Printf(" [DRY] %s :: %s / %s (plan) — %d chars\n",
|
||||
project, filename, sec.heading, len(content))
|
||||
imported++
|
||||
continue
|
||||
}
|
||||
|
||||
if err := callBrainRemember(content, "plan", tags, project, 0.6); err != nil {
|
||||
fmt.Printf(" FAIL %s :: %s / %s — %v\n", project, filename, sec.heading, err)
|
||||
errors++
|
||||
continue
|
||||
}
|
||||
fmt.Printf(" ok %s :: %s / %s (plan)\n", project, filename, sec.heading)
|
||||
imported++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process CLAUDE.md files
|
||||
if *claudeMd && len(claudeFiles) > 0 {
|
||||
fmt.Println("\n--- CLAUDE.md Files ---")
|
||||
for _, f := range claudeFiles {
|
||||
project := extractProjectFromClaudeMd(f)
|
||||
sections := parseMarkdownSections(f)
|
||||
|
||||
if len(sections) == 0 {
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
|
||||
for _, sec := range sections {
|
||||
content := sec.heading + "\n\n" + sec.content
|
||||
if strings.TrimSpace(sec.content) == "" {
|
||||
skipped++
|
||||
continue
|
||||
}
|
||||
|
||||
tags := buildTags("CLAUDE", "claude-md", project)
|
||||
content = truncate(content, *maxChars)
|
||||
|
||||
if *dryRun {
|
||||
fmt.Printf(" [DRY] %s :: CLAUDE.md / %s (convention) — %d chars\n",
|
||||
project, sec.heading, len(content))
|
||||
imported++
|
||||
continue
|
||||
}
|
||||
|
||||
if err := callBrainRemember(content, "convention", tags, project, 0.9); err != nil {
|
||||
fmt.Printf(" FAIL %s :: CLAUDE.md / %s — %v\n", project, sec.heading, err)
|
||||
errors++
|
||||
continue
|
||||
}
|
||||
fmt.Printf(" ok %s :: CLAUDE.md / %s (convention)\n", project, sec.heading)
|
||||
imported++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("\n%s\n", strings.Repeat("=", 55))
|
||||
prefix := ""
|
||||
if *dryRun {
|
||||
prefix = "[DRY RUN] "
|
||||
}
|
||||
fmt.Printf("%sImported: %d | Skipped: %d | Errors: %d\n", prefix, imported, skipped, errors)
|
||||
}
|
||||
|
||||
// callBrainRemember sends a memory to the MCP API via brain_remember tool.
|
||||
func callBrainRemember(content, memType string, tags []string, project string, confidence float64) error {
|
||||
args := map[string]any{
|
||||
"content": content,
|
||||
"type": memType,
|
||||
"tags": tags,
|
||||
"confidence": confidence,
|
||||
}
|
||||
if project != "" && project != "unknown" {
|
||||
args["project"] = project
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"server": *server,
|
||||
"tool": "brain_remember",
|
||||
"arguments": args,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", *apiURL+"/tools/call", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+*apiKey)
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("http: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return fmt.Errorf("decode: %w", err)
|
||||
}
|
||||
if !result.Success {
|
||||
return fmt.Errorf("API: %s", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// truncate caps content to maxLen chars, appending an ellipsis if truncated.
|
||||
func truncate(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
// Find last space before limit to avoid splitting mid-word
|
||||
cut := maxLen
|
||||
if idx := strings.LastIndex(s[:maxLen], " "); idx > maxLen-200 {
|
||||
cut = idx
|
||||
}
|
||||
return s[:cut] + "…"
|
||||
}
|
||||
|
||||
// discoverClaudeMdFiles finds CLAUDE.md files across a code directory.
|
||||
func discoverClaudeMdFiles(codePath string) []string {
|
||||
var files []string
|
||||
|
||||
// Walk up to 4 levels deep, skip node_modules/vendor/.claude
|
||||
_ = filepath.WalkDir(codePath, func(path string, d os.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if d.IsDir() {
|
||||
name := d.Name()
|
||||
if name == "node_modules" || name == "vendor" || name == ".claude" {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
// Limit depth
|
||||
rel, _ := filepath.Rel(codePath, path)
|
||||
if strings.Count(rel, string(os.PathSeparator)) > 3 {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if d.Name() == "CLAUDE.md" {
|
||||
files = append(files, path)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return files
|
||||
}
|
||||
|
||||
// section is a parsed markdown section.
|
||||
type section struct {
|
||||
heading string
|
||||
content string
|
||||
}
|
||||
|
||||
var headingRe = regexp.MustCompile(`^#{1,3}\s+(.+)$`)
|
||||
|
||||
// parseMarkdownSections splits a markdown file by headings.
|
||||
func parseMarkdownSections(path string) []section {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil || len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var sections []section
|
||||
lines := strings.Split(string(data), "\n")
|
||||
var curHeading string
|
||||
var curContent []string
|
||||
|
||||
for _, line := range lines {
|
||||
if m := headingRe.FindStringSubmatch(line); m != nil {
|
||||
if curHeading != "" && len(curContent) > 0 {
|
||||
text := strings.TrimSpace(strings.Join(curContent, "\n"))
|
||||
if text != "" {
|
||||
sections = append(sections, section{curHeading, text})
|
||||
}
|
||||
}
|
||||
curHeading = strings.TrimSpace(m[1])
|
||||
curContent = nil
|
||||
} else {
|
||||
curContent = append(curContent, line)
|
||||
}
|
||||
}
|
||||
|
||||
// Flush last section
|
||||
if curHeading != "" && len(curContent) > 0 {
|
||||
text := strings.TrimSpace(strings.Join(curContent, "\n"))
|
||||
if text != "" {
|
||||
sections = append(sections, section{curHeading, text})
|
||||
}
|
||||
}
|
||||
|
||||
// If no headings found, treat entire file as one section
|
||||
if len(sections) == 0 && strings.TrimSpace(string(data)) != "" {
|
||||
sections = append(sections, section{
|
||||
heading: strings.TrimSuffix(filepath.Base(path), ".md"),
|
||||
content: strings.TrimSpace(string(data)),
|
||||
})
|
||||
}
|
||||
|
||||
return sections
|
||||
}
|
||||
|
||||
// extractProject derives a project name from a Claude memory path.
|
||||
// ~/.claude/projects/-Users-snider-Code-eaas/memory/MEMORY.md → "eaas"
|
||||
func extractProject(path string) string {
|
||||
re := regexp.MustCompile(`projects/[^/]*-([^-/]+)/memory/`)
|
||||
if m := re.FindStringSubmatch(path); m != nil {
|
||||
return m[1]
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// extractProjectFromPlan derives a project name from a plan path.
|
||||
// ~/Code/eaas/docs/plans/foo.md → "eaas"
|
||||
// ~/Code/host-uk/core/docs/plans/foo.md → "core"
|
||||
func extractProjectFromPlan(path string) string {
|
||||
// Check host-uk nested repos first
|
||||
re := regexp.MustCompile(`Code/host-uk/([^/]+)/docs/plans/`)
|
||||
if m := re.FindStringSubmatch(path); m != nil {
|
||||
return m[1]
|
||||
}
|
||||
re = regexp.MustCompile(`Code/([^/]+)/docs/plans/`)
|
||||
if m := re.FindStringSubmatch(path); m != nil {
|
||||
return m[1]
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// extractProjectFromClaudeMd derives a project name from a CLAUDE.md path.
|
||||
// ~/Code/host-uk/core/CLAUDE.md → "core"
|
||||
// ~/Code/eaas/CLAUDE.md → "eaas"
|
||||
func extractProjectFromClaudeMd(path string) string {
|
||||
re := regexp.MustCompile(`Code/host-uk/([^/]+)/`)
|
||||
if m := re.FindStringSubmatch(path); m != nil {
|
||||
return m[1]
|
||||
}
|
||||
re = regexp.MustCompile(`Code/([^/]+)/`)
|
||||
if m := re.FindStringSubmatch(path); m != nil {
|
||||
return m[1]
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// inferType guesses the memory type from heading + content keywords.
|
||||
func inferType(heading, content, source string) string {
|
||||
// Source-specific defaults (match PHP BrainIngestCommand behaviour)
|
||||
if source == "plans" {
|
||||
return "plan"
|
||||
}
|
||||
if source == "claude-md" {
|
||||
return "convention"
|
||||
}
|
||||
|
||||
lower := strings.ToLower(heading + " " + content)
|
||||
patterns := map[string][]string{
|
||||
"architecture": {"architecture", "stack", "infrastructure", "layer", "service mesh"},
|
||||
"convention": {"convention", "standard", "naming", "pattern", "rule", "coding"},
|
||||
"decision": {"decision", "chose", "strategy", "approach", "domain"},
|
||||
"bug": {"bug", "fix", "broken", "error", "issue", "lesson"},
|
||||
"plan": {"plan", "todo", "roadmap", "milestone", "phase", "task"},
|
||||
"research": {"research", "finding", "discovery", "analysis", "rfc"},
|
||||
}
|
||||
for t, keywords := range patterns {
|
||||
for _, kw := range keywords {
|
||||
if strings.Contains(lower, kw) {
|
||||
return t
|
||||
}
|
||||
}
|
||||
}
|
||||
return "observation"
|
||||
}
|
||||
|
||||
// buildTags creates the tag list for a memory.
|
||||
func buildTags(filename, source, project string) []string {
|
||||
tags := []string{"source:" + source}
|
||||
if project != "" && project != "unknown" {
|
||||
tags = append(tags, "project:"+project)
|
||||
}
|
||||
if filename != "MEMORY" && filename != "CLAUDE" {
|
||||
tags = append(tags, strings.ReplaceAll(strings.ReplaceAll(filename, "-", " "), "_", " "))
|
||||
}
|
||||
return tags
|
||||
}
|
||||
|
||||
// confidenceForSource returns a default confidence level matching the PHP ingest command.
|
||||
func confidenceForSource(source string) float64 {
|
||||
switch source {
|
||||
case "claude-md":
|
||||
return 0.9
|
||||
case "memory":
|
||||
return 0.8
|
||||
case "plans":
|
||||
return 0.6
|
||||
default:
|
||||
return 0.5
|
||||
}
|
||||
}
|
||||
|
|
@ -8,7 +8,7 @@ import (
|
|||
"path/filepath"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/go-ai/mcp"
|
||||
"forge.lthn.ai/core/mcp/pkg/mcp"
|
||||
"forge.lthn.ai/core/go-log"
|
||||
"forge.lthn.ai/core/go-process"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,92 +0,0 @@
|
|||
// Package mcpcmd provides the MCP server command.
|
||||
//
|
||||
// Commands:
|
||||
// - mcp serve: Start the MCP server for AI tool integration
|
||||
package mcpcmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/go-ai/mcp"
|
||||
)
|
||||
|
||||
var workspaceFlag string
|
||||
|
||||
var mcpCmd = &cli.Command{
|
||||
Use: "mcp",
|
||||
Short: "MCP server for AI tool integration",
|
||||
Long: "Model Context Protocol (MCP) server providing file operations, RAG, and metrics tools.",
|
||||
}
|
||||
|
||||
var serveCmd = &cli.Command{
|
||||
Use: "serve",
|
||||
Short: "Start the MCP server",
|
||||
Long: `Start the MCP server on stdio (default) or TCP.
|
||||
|
||||
The server provides file operations, RAG tools, and metrics tools for AI assistants.
|
||||
|
||||
Environment variables:
|
||||
MCP_ADDR TCP address to listen on (e.g., "localhost:9999")
|
||||
If not set, uses stdio transport.
|
||||
|
||||
Examples:
|
||||
# Start with stdio transport (for Claude Code integration)
|
||||
core mcp serve
|
||||
|
||||
# Start with workspace restriction
|
||||
core mcp serve --workspace /path/to/project
|
||||
|
||||
# Start TCP server
|
||||
MCP_ADDR=localhost:9999 core mcp serve`,
|
||||
RunE: func(cmd *cli.Command, args []string) error {
|
||||
return runServe()
|
||||
},
|
||||
}
|
||||
|
||||
func initFlags() {
|
||||
cli.StringFlag(serveCmd, &workspaceFlag, "workspace", "w", "", "Restrict file operations to this directory (empty = unrestricted)")
|
||||
}
|
||||
|
||||
// AddMCPCommands registers the 'mcp' command and all subcommands.
|
||||
func AddMCPCommands(root *cli.Command) {
|
||||
initFlags()
|
||||
mcpCmd.AddCommand(serveCmd)
|
||||
root.AddCommand(mcpCmd)
|
||||
}
|
||||
|
||||
func runServe() error {
|
||||
// Build MCP service options
|
||||
var opts []mcp.Option
|
||||
|
||||
if workspaceFlag != "" {
|
||||
opts = append(opts, mcp.WithWorkspaceRoot(workspaceFlag))
|
||||
} else {
|
||||
// Explicitly unrestricted when no workspace specified
|
||||
opts = append(opts, mcp.WithWorkspaceRoot(""))
|
||||
}
|
||||
|
||||
// Create the MCP service
|
||||
svc, err := mcp.New(opts...)
|
||||
if err != nil {
|
||||
return cli.Wrap(err, "create MCP service")
|
||||
}
|
||||
|
||||
// Set up signal handling for clean shutdown
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
<-sigCh
|
||||
cancel()
|
||||
}()
|
||||
|
||||
// Run the server (blocks until context cancelled or error)
|
||||
return svc.Run(ctx)
|
||||
}
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
// Package brain provides an MCP subsystem that proxies OpenBrain knowledge
|
||||
// store operations to the Laravel php-agentic backend via the IDE bridge.
|
||||
package brain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"forge.lthn.ai/core/go-ai/mcp/ide"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// errBridgeNotAvailable is returned when a tool requires the Laravel bridge
|
||||
// but it has not been initialised (headless mode).
|
||||
var errBridgeNotAvailable = errors.New("brain: bridge not available")
|
||||
|
||||
// Subsystem implements mcp.Subsystem for OpenBrain knowledge store operations.
|
||||
// It proxies brain_* tool calls to the Laravel backend via the shared IDE bridge.
|
||||
type Subsystem struct {
|
||||
bridge *ide.Bridge
|
||||
}
|
||||
|
||||
// New creates a brain subsystem that uses the given IDE bridge for Laravel communication.
|
||||
// Pass nil if headless (tools will return errBridgeNotAvailable).
|
||||
func New(bridge *ide.Bridge) *Subsystem {
|
||||
return &Subsystem{bridge: bridge}
|
||||
}
|
||||
|
||||
// Name implements mcp.Subsystem.
|
||||
func (s *Subsystem) Name() string { return "brain" }
|
||||
|
||||
// RegisterTools implements mcp.Subsystem.
|
||||
func (s *Subsystem) RegisterTools(server *mcp.Server) {
|
||||
s.registerBrainTools(server)
|
||||
}
|
||||
|
||||
// Shutdown implements mcp.SubsystemWithShutdown.
|
||||
func (s *Subsystem) Shutdown(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1,229 +0,0 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package brain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- Nil bridge tests (headless mode) ---
|
||||
|
||||
func TestBrainRemember_Bad_NilBridge(t *testing.T) {
|
||||
sub := New(nil)
|
||||
_, _, err := sub.brainRemember(context.Background(), nil, RememberInput{
|
||||
Content: "test memory",
|
||||
Type: "observation",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error when bridge is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrainRecall_Bad_NilBridge(t *testing.T) {
|
||||
sub := New(nil)
|
||||
_, _, err := sub.brainRecall(context.Background(), nil, RecallInput{
|
||||
Query: "how does scoring work?",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error when bridge is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrainForget_Bad_NilBridge(t *testing.T) {
|
||||
sub := New(nil)
|
||||
_, _, err := sub.brainForget(context.Background(), nil, ForgetInput{
|
||||
ID: "550e8400-e29b-41d4-a716-446655440000",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error when bridge is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBrainList_Bad_NilBridge(t *testing.T) {
|
||||
sub := New(nil)
|
||||
_, _, err := sub.brainList(context.Background(), nil, ListInput{
|
||||
Project: "eaas",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error when bridge is nil")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Subsystem interface tests ---
|
||||
|
||||
func TestSubsystem_Good_Name(t *testing.T) {
|
||||
sub := New(nil)
|
||||
if sub.Name() != "brain" {
|
||||
t.Errorf("expected Name() = 'brain', got %q", sub.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubsystem_Good_ShutdownNoop(t *testing.T) {
|
||||
sub := New(nil)
|
||||
if err := sub.Shutdown(context.Background()); err != nil {
|
||||
t.Errorf("Shutdown failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Struct round-trip tests ---
|
||||
|
||||
func TestRememberInput_Good_RoundTrip(t *testing.T) {
|
||||
in := RememberInput{
|
||||
Content: "LEM scoring was blind to negative emotions",
|
||||
Type: "bug",
|
||||
Tags: []string{"scoring", "lem"},
|
||||
Project: "eaas",
|
||||
Confidence: 0.95,
|
||||
Supersedes: "550e8400-e29b-41d4-a716-446655440000",
|
||||
ExpiresIn: 24,
|
||||
}
|
||||
data, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var out RememberInput
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
if out.Content != in.Content || out.Type != in.Type {
|
||||
t.Errorf("round-trip mismatch: content or type")
|
||||
}
|
||||
if len(out.Tags) != 2 || out.Tags[0] != "scoring" {
|
||||
t.Errorf("round-trip mismatch: tags")
|
||||
}
|
||||
if out.Confidence != 0.95 {
|
||||
t.Errorf("round-trip mismatch: confidence %f != 0.95", out.Confidence)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRememberOutput_Good_RoundTrip(t *testing.T) {
|
||||
in := RememberOutput{
|
||||
Success: true,
|
||||
MemoryID: "550e8400-e29b-41d4-a716-446655440000",
|
||||
Timestamp: time.Now().Truncate(time.Second),
|
||||
}
|
||||
data, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var out RememberOutput
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
if !out.Success || out.MemoryID != in.MemoryID {
|
||||
t.Errorf("round-trip mismatch: %+v != %+v", out, in)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecallInput_Good_RoundTrip(t *testing.T) {
|
||||
in := RecallInput{
|
||||
Query: "how does verdict classification work?",
|
||||
TopK: 5,
|
||||
Filter: RecallFilter{
|
||||
Project: "eaas",
|
||||
MinConfidence: 0.5,
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var out RecallInput
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
if out.Query != in.Query || out.TopK != 5 {
|
||||
t.Errorf("round-trip mismatch: query or topK")
|
||||
}
|
||||
if out.Filter.Project != "eaas" || out.Filter.MinConfidence != 0.5 {
|
||||
t.Errorf("round-trip mismatch: filter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemory_Good_RoundTrip(t *testing.T) {
|
||||
in := Memory{
|
||||
ID: "550e8400-e29b-41d4-a716-446655440000",
|
||||
AgentID: "virgil",
|
||||
Type: "decision",
|
||||
Content: "Use Qdrant for vector search",
|
||||
Tags: []string{"architecture", "openbrain"},
|
||||
Project: "php-agentic",
|
||||
Confidence: 0.9,
|
||||
CreatedAt: "2026-03-03T12:00:00+00:00",
|
||||
UpdatedAt: "2026-03-03T12:00:00+00:00",
|
||||
}
|
||||
data, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var out Memory
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
if out.ID != in.ID || out.AgentID != "virgil" || out.Type != "decision" {
|
||||
t.Errorf("round-trip mismatch: %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForgetInput_Good_RoundTrip(t *testing.T) {
|
||||
in := ForgetInput{
|
||||
ID: "550e8400-e29b-41d4-a716-446655440000",
|
||||
Reason: "Superseded by new approach",
|
||||
}
|
||||
data, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var out ForgetInput
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
if out.ID != in.ID || out.Reason != in.Reason {
|
||||
t.Errorf("round-trip mismatch: %+v != %+v", out, in)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListInput_Good_RoundTrip(t *testing.T) {
|
||||
in := ListInput{
|
||||
Project: "eaas",
|
||||
Type: "decision",
|
||||
AgentID: "charon",
|
||||
Limit: 20,
|
||||
}
|
||||
data, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var out ListInput
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
if out.Project != "eaas" || out.Type != "decision" || out.AgentID != "charon" || out.Limit != 20 {
|
||||
t.Errorf("round-trip mismatch: %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListOutput_Good_RoundTrip(t *testing.T) {
|
||||
in := ListOutput{
|
||||
Success: true,
|
||||
Count: 2,
|
||||
Memories: []Memory{
|
||||
{ID: "id-1", AgentID: "virgil", Type: "decision", Content: "memory 1", Confidence: 0.9, CreatedAt: "2026-03-03T12:00:00+00:00", UpdatedAt: "2026-03-03T12:00:00+00:00"},
|
||||
{ID: "id-2", AgentID: "charon", Type: "bug", Content: "memory 2", Confidence: 0.8, CreatedAt: "2026-03-03T13:00:00+00:00", UpdatedAt: "2026-03-03T13:00:00+00:00"},
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var out ListOutput
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
if !out.Success || out.Count != 2 || len(out.Memories) != 2 {
|
||||
t.Errorf("round-trip mismatch: %+v", out)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,220 +0,0 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package brain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/go-ai/mcp/ide"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// -- Input/Output types -------------------------------------------------------
|
||||
|
||||
// RememberInput is the input for brain_remember.
|
||||
type RememberInput struct {
|
||||
Content string `json:"content"`
|
||||
Type string `json:"type"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Project string `json:"project,omitempty"`
|
||||
Confidence float64 `json:"confidence,omitempty"`
|
||||
Supersedes string `json:"supersedes,omitempty"`
|
||||
ExpiresIn int `json:"expires_in,omitempty"`
|
||||
}
|
||||
|
||||
// RememberOutput is the output for brain_remember.
|
||||
type RememberOutput struct {
|
||||
Success bool `json:"success"`
|
||||
MemoryID string `json:"memoryId,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// RecallInput is the input for brain_recall.
|
||||
type RecallInput struct {
|
||||
Query string `json:"query"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Filter RecallFilter `json:"filter,omitempty"`
|
||||
}
|
||||
|
||||
// RecallFilter holds optional filter criteria for brain_recall.
|
||||
type RecallFilter struct {
|
||||
Project string `json:"project,omitempty"`
|
||||
Type any `json:"type,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
MinConfidence float64 `json:"min_confidence,omitempty"`
|
||||
}
|
||||
|
||||
// RecallOutput is the output for brain_recall.
|
||||
type RecallOutput struct {
|
||||
Success bool `json:"success"`
|
||||
Count int `json:"count"`
|
||||
Memories []Memory `json:"memories"`
|
||||
}
|
||||
|
||||
// Memory is a single memory entry returned by recall or list.
|
||||
type Memory struct {
|
||||
ID string `json:"id"`
|
||||
AgentID string `json:"agent_id"`
|
||||
Type string `json:"type"`
|
||||
Content string `json:"content"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Project string `json:"project,omitempty"`
|
||||
Confidence float64 `json:"confidence"`
|
||||
SupersedesID string `json:"supersedes_id,omitempty"`
|
||||
ExpiresAt string `json:"expires_at,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// ForgetInput is the input for brain_forget.
|
||||
type ForgetInput struct {
|
||||
ID string `json:"id"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
// ForgetOutput is the output for brain_forget.
|
||||
type ForgetOutput struct {
|
||||
Success bool `json:"success"`
|
||||
Forgotten string `json:"forgotten"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// ListInput is the input for brain_list.
|
||||
type ListInput struct {
|
||||
Project string `json:"project,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
}
|
||||
|
||||
// ListOutput is the output for brain_list.
|
||||
type ListOutput struct {
|
||||
Success bool `json:"success"`
|
||||
Count int `json:"count"`
|
||||
Memories []Memory `json:"memories"`
|
||||
}
|
||||
|
||||
// -- Tool registration --------------------------------------------------------
|
||||
|
||||
func (s *Subsystem) registerBrainTools(server *mcp.Server) {
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "brain_remember",
|
||||
Description: "Store a memory in the shared OpenBrain knowledge store. Persists decisions, observations, conventions, research, plans, bugs, or architecture knowledge for other agents.",
|
||||
}, s.brainRemember)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "brain_recall",
|
||||
Description: "Semantic search across the shared OpenBrain knowledge store. Returns memories ranked by similarity to your query, with optional filtering.",
|
||||
}, s.brainRecall)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "brain_forget",
|
||||
Description: "Remove a memory from the shared OpenBrain knowledge store. Permanently deletes from both database and vector index.",
|
||||
}, s.brainForget)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "brain_list",
|
||||
Description: "List memories in the shared OpenBrain knowledge store. Supports filtering by project, type, and agent. No vector search -- use brain_recall for semantic queries.",
|
||||
}, s.brainList)
|
||||
}
|
||||
|
||||
// -- Tool handlers ------------------------------------------------------------
|
||||
|
||||
func (s *Subsystem) brainRemember(_ context.Context, _ *mcp.CallToolRequest, input RememberInput) (*mcp.CallToolResult, RememberOutput, error) {
|
||||
if s.bridge == nil {
|
||||
return nil, RememberOutput{}, errBridgeNotAvailable
|
||||
}
|
||||
|
||||
err := s.bridge.Send(ide.BridgeMessage{
|
||||
Type: "brain_remember",
|
||||
Data: map[string]any{
|
||||
"content": input.Content,
|
||||
"type": input.Type,
|
||||
"tags": input.Tags,
|
||||
"project": input.Project,
|
||||
"confidence": input.Confidence,
|
||||
"supersedes": input.Supersedes,
|
||||
"expires_in": input.ExpiresIn,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, RememberOutput{}, fmt.Errorf("failed to send brain_remember: %w", err)
|
||||
}
|
||||
|
||||
return nil, RememberOutput{
|
||||
Success: true,
|
||||
Timestamp: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Subsystem) brainRecall(_ context.Context, _ *mcp.CallToolRequest, input RecallInput) (*mcp.CallToolResult, RecallOutput, error) {
|
||||
if s.bridge == nil {
|
||||
return nil, RecallOutput{}, errBridgeNotAvailable
|
||||
}
|
||||
|
||||
err := s.bridge.Send(ide.BridgeMessage{
|
||||
Type: "brain_recall",
|
||||
Data: map[string]any{
|
||||
"query": input.Query,
|
||||
"top_k": input.TopK,
|
||||
"filter": input.Filter,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, RecallOutput{}, fmt.Errorf("failed to send brain_recall: %w", err)
|
||||
}
|
||||
|
||||
return nil, RecallOutput{
|
||||
Success: true,
|
||||
Memories: []Memory{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Subsystem) brainForget(_ context.Context, _ *mcp.CallToolRequest, input ForgetInput) (*mcp.CallToolResult, ForgetOutput, error) {
|
||||
if s.bridge == nil {
|
||||
return nil, ForgetOutput{}, errBridgeNotAvailable
|
||||
}
|
||||
|
||||
err := s.bridge.Send(ide.BridgeMessage{
|
||||
Type: "brain_forget",
|
||||
Data: map[string]any{
|
||||
"id": input.ID,
|
||||
"reason": input.Reason,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, ForgetOutput{}, fmt.Errorf("failed to send brain_forget: %w", err)
|
||||
}
|
||||
|
||||
return nil, ForgetOutput{
|
||||
Success: true,
|
||||
Forgotten: input.ID,
|
||||
Timestamp: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Subsystem) brainList(_ context.Context, _ *mcp.CallToolRequest, input ListInput) (*mcp.CallToolResult, ListOutput, error) {
|
||||
if s.bridge == nil {
|
||||
return nil, ListOutput{}, errBridgeNotAvailable
|
||||
}
|
||||
|
||||
err := s.bridge.Send(ide.BridgeMessage{
|
||||
Type: "brain_list",
|
||||
Data: map[string]any{
|
||||
"project": input.Project,
|
||||
"type": input.Type,
|
||||
"agent_id": input.AgentID,
|
||||
"limit": input.Limit,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, ListOutput{}, fmt.Errorf("failed to send brain_list: %w", err)
|
||||
}
|
||||
|
||||
return nil, ListOutput{
|
||||
Success: true,
|
||||
Memories: []Memory{},
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -1,64 +0,0 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
api "forge.lthn.ai/core/go-api"
|
||||
)
|
||||
|
||||
// maxBodySize is the maximum request body size accepted by bridged tool endpoints.
|
||||
const maxBodySize = 10 << 20 // 10 MB
|
||||
|
||||
// BridgeToAPI populates a go-api ToolBridge from recorded MCP tools.
|
||||
// Each tool becomes a POST endpoint that reads a JSON body, dispatches
|
||||
// to the tool's RESTHandler (which knows the concrete input type), and
|
||||
// wraps the result in the standard api.Response envelope.
|
||||
func BridgeToAPI(svc *Service, bridge *api.ToolBridge) {
|
||||
for rec := range svc.ToolsSeq() {
|
||||
desc := api.ToolDescriptor{
|
||||
Name: rec.Name,
|
||||
Description: rec.Description,
|
||||
Group: rec.Group,
|
||||
InputSchema: rec.InputSchema,
|
||||
OutputSchema: rec.OutputSchema,
|
||||
}
|
||||
|
||||
// Capture the handler for the closure.
|
||||
handler := rec.RESTHandler
|
||||
|
||||
bridge.Add(desc, func(c *gin.Context) {
|
||||
var body []byte
|
||||
if c.Request.Body != nil {
|
||||
var err error
|
||||
body, err = io.ReadAll(io.LimitReader(c.Request.Body, maxBodySize))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, api.Fail("invalid_request", "Failed to read request body"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
result, err := handler(c.Request.Context(), body)
|
||||
if err != nil {
|
||||
// Classify JSON parse errors as client errors (400),
|
||||
// everything else as server errors (500).
|
||||
var syntaxErr *json.SyntaxError
|
||||
var typeErr *json.UnmarshalTypeError
|
||||
if errors.As(err, &syntaxErr) || errors.As(err, &typeErr) {
|
||||
c.JSON(http.StatusBadRequest, api.Fail("invalid_input", "Malformed JSON in request body"))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, api.Fail("tool_error", "Tool execution failed"))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, api.OK(result))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -1,250 +0,0 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
api "forge.lthn.ai/core/go-api"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
func TestBridgeToAPI_Good_AllTools(t *testing.T) {
|
||||
svc, err := New(WithWorkspaceRoot(t.TempDir()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
bridge := api.NewToolBridge("/tools")
|
||||
BridgeToAPI(svc, bridge)
|
||||
|
||||
svcCount := len(svc.Tools())
|
||||
bridgeCount := len(bridge.Tools())
|
||||
|
||||
if svcCount == 0 {
|
||||
t.Fatal("expected non-zero tool count from service")
|
||||
}
|
||||
if bridgeCount != svcCount {
|
||||
t.Fatalf("bridge tool count %d != service tool count %d", bridgeCount, svcCount)
|
||||
}
|
||||
|
||||
// Verify names match.
|
||||
svcNames := make(map[string]bool)
|
||||
for _, tr := range svc.Tools() {
|
||||
svcNames[tr.Name] = true
|
||||
}
|
||||
for _, td := range bridge.Tools() {
|
||||
if !svcNames[td.Name] {
|
||||
t.Errorf("bridge has tool %q not found in service", td.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridgeToAPI_Good_DescribableGroup(t *testing.T) {
|
||||
svc, err := New(WithWorkspaceRoot(t.TempDir()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
bridge := api.NewToolBridge("/tools")
|
||||
BridgeToAPI(svc, bridge)
|
||||
|
||||
// ToolBridge implements DescribableGroup.
|
||||
var dg api.DescribableGroup = bridge
|
||||
descs := dg.Describe()
|
||||
|
||||
if len(descs) != len(svc.Tools()) {
|
||||
t.Fatalf("expected %d descriptions, got %d", len(svc.Tools()), len(descs))
|
||||
}
|
||||
|
||||
for _, d := range descs {
|
||||
if d.Method != "POST" {
|
||||
t.Errorf("expected Method=POST for %s, got %q", d.Path, d.Method)
|
||||
}
|
||||
if d.Summary == "" {
|
||||
t.Errorf("expected non-empty Summary for %s", d.Path)
|
||||
}
|
||||
if len(d.Tags) == 0 {
|
||||
t.Errorf("expected non-empty Tags for %s", d.Path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridgeToAPI_Good_FileRead(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create a test file in the workspace.
|
||||
testContent := "hello from bridge test"
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "test.txt"), []byte(testContent), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
svc, err := New(WithWorkspaceRoot(tmpDir))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
bridge := api.NewToolBridge("/tools")
|
||||
BridgeToAPI(svc, bridge)
|
||||
|
||||
// Register with a Gin engine and make a request.
|
||||
engine := gin.New()
|
||||
rg := engine.Group(bridge.BasePath())
|
||||
bridge.RegisterRoutes(rg)
|
||||
|
||||
body := `{"path":"test.txt"}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodPost, "/tools/file_read", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Parse the response envelope.
|
||||
var resp api.Response[ReadFileOutput]
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
}
|
||||
if !resp.Success {
|
||||
t.Fatalf("expected Success=true, got error: %+v", resp.Error)
|
||||
}
|
||||
if resp.Data.Content != testContent {
|
||||
t.Fatalf("expected content %q, got %q", testContent, resp.Data.Content)
|
||||
}
|
||||
if resp.Data.Path != "test.txt" {
|
||||
t.Fatalf("expected path %q, got %q", "test.txt", resp.Data.Path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridgeToAPI_Bad_InvalidJSON(t *testing.T) {
|
||||
svc, err := New(WithWorkspaceRoot(t.TempDir()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
bridge := api.NewToolBridge("/tools")
|
||||
BridgeToAPI(svc, bridge)
|
||||
|
||||
engine := gin.New()
|
||||
rg := engine.Group(bridge.BasePath())
|
||||
bridge.RegisterRoutes(rg)
|
||||
|
||||
// Send malformed JSON.
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest(http.MethodPost, "/tools/file_read", strings.NewReader("{bad json"))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
// The handler unmarshals via RESTHandler which returns an error,
|
||||
// but since it's a JSON parse error it ends up as tool_error.
|
||||
// Check we get a non-200 with an error envelope.
|
||||
if w.Code == http.StatusOK {
|
||||
t.Fatalf("expected non-200 for invalid JSON, got 200")
|
||||
}
|
||||
}
|
||||
|
||||
var resp api.Response[any]
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
}
|
||||
if resp.Success {
|
||||
t.Fatal("expected Success=false for invalid JSON")
|
||||
}
|
||||
if resp.Error == nil {
|
||||
t.Fatal("expected error in response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridgeToAPI_Good_EndToEnd(t *testing.T) {
|
||||
svc, err := New(WithWorkspaceRoot(t.TempDir()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
bridge := api.NewToolBridge("/tools")
|
||||
BridgeToAPI(svc, bridge)
|
||||
|
||||
// Create an api.Engine with the bridge registered and Swagger enabled.
|
||||
e, err := api.New(
|
||||
api.WithSwagger("MCP Bridge Test", "Testing MCP-to-REST bridge", "0.1.0"),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
e.Register(bridge)
|
||||
|
||||
// Use a real test server because gin-swagger reads RequestURI
|
||||
// which is not populated by httptest.NewRecorder.
|
||||
srv := httptest.NewServer(e.Handler())
|
||||
defer srv.Close()
|
||||
|
||||
// Verify the health endpoint still works.
|
||||
resp, err := http.Get(srv.URL + "/health")
|
||||
if err != nil {
|
||||
t.Fatalf("health request failed: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 for /health, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Verify a tool endpoint is reachable through the engine.
|
||||
resp2, err := http.Post(srv.URL+"/tools/lang_list", "application/json", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("lang_list request failed: %v", err)
|
||||
}
|
||||
defer resp2.Body.Close()
|
||||
if resp2.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 for /tools/lang_list, got %d", resp2.StatusCode)
|
||||
}
|
||||
|
||||
var langResp api.Response[GetSupportedLanguagesOutput]
|
||||
if err := json.NewDecoder(resp2.Body).Decode(&langResp); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
}
|
||||
if !langResp.Success {
|
||||
t.Fatalf("expected Success=true, got error: %+v", langResp.Error)
|
||||
}
|
||||
if len(langResp.Data.Languages) == 0 {
|
||||
t.Fatal("expected non-empty languages list")
|
||||
}
|
||||
|
||||
// Verify Swagger endpoint contains tool paths.
|
||||
resp3, err := http.Get(srv.URL + "/swagger/doc.json")
|
||||
if err != nil {
|
||||
t.Fatalf("swagger request failed: %v", err)
|
||||
}
|
||||
defer resp3.Body.Close()
|
||||
if resp3.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected 200 for /swagger/doc.json, got %d", resp3.StatusCode)
|
||||
}
|
||||
|
||||
var specDoc map[string]any
|
||||
if err := json.NewDecoder(resp3.Body).Decode(&specDoc); err != nil {
|
||||
t.Fatalf("swagger unmarshal error: %v", err)
|
||||
}
|
||||
paths, ok := specDoc["paths"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("expected 'paths' in swagger spec")
|
||||
}
|
||||
if _, ok := paths["/tools/file_read"]; !ok {
|
||||
t.Error("expected /tools/file_read in swagger paths")
|
||||
}
|
||||
if _, ok := paths["/tools/lang_list"]; !ok {
|
||||
t.Error("expected /tools/lang_list in swagger paths")
|
||||
}
|
||||
}
|
||||
|
|
@ -1,191 +0,0 @@
|
|||
package ide
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/go-ws"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// BridgeMessage is the wire format between the IDE and Laravel.
|
||||
type BridgeMessage struct {
|
||||
Type string `json:"type"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
SessionID string `json:"sessionId,omitempty"`
|
||||
Data any `json:"data,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// Bridge maintains a WebSocket connection to the Laravel core-agentic
|
||||
// backend and forwards responses to a local ws.Hub.
|
||||
type Bridge struct {
|
||||
cfg Config
|
||||
hub *ws.Hub
|
||||
conn *websocket.Conn
|
||||
|
||||
mu sync.Mutex
|
||||
connected bool
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewBridge creates a bridge that will connect to the Laravel backend and
|
||||
// forward incoming messages to the provided ws.Hub channels.
|
||||
func NewBridge(hub *ws.Hub, cfg Config) *Bridge {
|
||||
return &Bridge{cfg: cfg, hub: hub}
|
||||
}
|
||||
|
||||
// Start begins the connection loop in a background goroutine.
|
||||
// Call Shutdown to stop it.
|
||||
func (b *Bridge) Start(ctx context.Context) {
|
||||
ctx, b.cancel = context.WithCancel(ctx)
|
||||
go b.connectLoop(ctx)
|
||||
}
|
||||
|
||||
// Shutdown cleanly closes the bridge.
|
||||
func (b *Bridge) Shutdown() {
|
||||
if b.cancel != nil {
|
||||
b.cancel()
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.conn != nil {
|
||||
b.conn.Close()
|
||||
b.conn = nil
|
||||
}
|
||||
b.connected = false
|
||||
}
|
||||
|
||||
// Connected reports whether the bridge has an active connection.
|
||||
func (b *Bridge) Connected() bool {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.connected
|
||||
}
|
||||
|
||||
// Send sends a message to the Laravel backend.
|
||||
func (b *Bridge) Send(msg BridgeMessage) error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.conn == nil {
|
||||
return errors.New("bridge: not connected")
|
||||
}
|
||||
msg.Timestamp = time.Now()
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bridge: marshal failed: %w", err)
|
||||
}
|
||||
return b.conn.WriteMessage(websocket.TextMessage, data)
|
||||
}
|
||||
|
||||
// connectLoop reconnects to Laravel with exponential backoff.
|
||||
func (b *Bridge) connectLoop(ctx context.Context) {
|
||||
delay := b.cfg.ReconnectInterval
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if err := b.dial(ctx); err != nil {
|
||||
log.Printf("ide bridge: connect failed: %v", err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(delay):
|
||||
}
|
||||
delay = min(delay*2, b.cfg.MaxReconnectInterval)
|
||||
continue
|
||||
}
|
||||
|
||||
// Reset backoff on successful connection
|
||||
delay = b.cfg.ReconnectInterval
|
||||
b.readLoop(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Bridge) dial(ctx context.Context) error {
|
||||
dialer := websocket.Dialer{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
var header http.Header
|
||||
if b.cfg.Token != "" {
|
||||
header = http.Header{}
|
||||
header.Set("Authorization", "Bearer "+b.cfg.Token)
|
||||
}
|
||||
|
||||
conn, _, err := dialer.DialContext(ctx, b.cfg.LaravelWSURL, header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
b.conn = conn
|
||||
b.connected = true
|
||||
b.mu.Unlock()
|
||||
|
||||
log.Printf("ide bridge: connected to %s", b.cfg.LaravelWSURL)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Bridge) readLoop(ctx context.Context) {
|
||||
defer func() {
|
||||
b.mu.Lock()
|
||||
if b.conn != nil {
|
||||
b.conn.Close()
|
||||
}
|
||||
b.connected = false
|
||||
b.mu.Unlock()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
_, data, err := b.conn.ReadMessage()
|
||||
if err != nil {
|
||||
log.Printf("ide bridge: read error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var msg BridgeMessage
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
log.Printf("ide bridge: unmarshal error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
b.dispatch(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// dispatch routes an incoming message to the appropriate ws.Hub channel.
|
||||
func (b *Bridge) dispatch(msg BridgeMessage) {
|
||||
if b.hub == nil {
|
||||
return
|
||||
}
|
||||
|
||||
wsMsg := ws.Message{
|
||||
Type: ws.TypeEvent,
|
||||
Data: msg.Data,
|
||||
}
|
||||
|
||||
channel := msg.Channel
|
||||
if channel == "" {
|
||||
channel = "ide:" + msg.Type
|
||||
}
|
||||
|
||||
if err := b.hub.SendToChannel(channel, wsMsg); err != nil {
|
||||
log.Printf("ide bridge: dispatch to %s failed: %v", channel, err)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,442 +0,0 @@
|
|||
package ide
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/go-ws"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var testUpgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
// echoServer creates a test WebSocket server that echoes messages back.
|
||||
func echoServer(t *testing.T) *httptest.Server {
|
||||
t.Helper()
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := testUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Logf("upgrade error: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
for {
|
||||
mt, data, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if err := conn.WriteMessage(mt, data); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func wsURL(ts *httptest.Server) string {
|
||||
return "ws" + strings.TrimPrefix(ts.URL, "http")
|
||||
}
|
||||
|
||||
// waitConnected polls bridge.Connected() until true or timeout.
|
||||
func waitConnected(t *testing.T, bridge *Bridge, timeout time.Duration) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(timeout)
|
||||
for !bridge.Connected() && time.Now().Before(deadline) {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
if !bridge.Connected() {
|
||||
t.Fatal("bridge did not connect within timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridge_Good_ConnectAndSend(t *testing.T) {
|
||||
ts := echoServer(t)
|
||||
defer ts.Close()
|
||||
|
||||
hub := ws.NewHub()
|
||||
ctx := t.Context()
|
||||
go hub.Run(ctx)
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.LaravelWSURL = wsURL(ts)
|
||||
cfg.ReconnectInterval = 100 * time.Millisecond
|
||||
|
||||
bridge := NewBridge(hub, cfg)
|
||||
bridge.Start(ctx)
|
||||
|
||||
waitConnected(t, bridge, 2*time.Second)
|
||||
|
||||
err := bridge.Send(BridgeMessage{
|
||||
Type: "test",
|
||||
Data: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Send() failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridge_Good_Shutdown(t *testing.T) {
|
||||
ts := echoServer(t)
|
||||
defer ts.Close()
|
||||
|
||||
hub := ws.NewHub()
|
||||
ctx := t.Context()
|
||||
go hub.Run(ctx)
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.LaravelWSURL = wsURL(ts)
|
||||
cfg.ReconnectInterval = 100 * time.Millisecond
|
||||
|
||||
bridge := NewBridge(hub, cfg)
|
||||
bridge.Start(ctx)
|
||||
|
||||
waitConnected(t, bridge, 2*time.Second)
|
||||
|
||||
bridge.Shutdown()
|
||||
if bridge.Connected() {
|
||||
t.Error("bridge should be disconnected after Shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridge_Bad_SendWithoutConnection(t *testing.T) {
|
||||
hub := ws.NewHub()
|
||||
cfg := DefaultConfig()
|
||||
bridge := NewBridge(hub, cfg)
|
||||
|
||||
err := bridge.Send(BridgeMessage{Type: "test"})
|
||||
if err == nil {
|
||||
t.Error("expected error when sending without connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridge_Good_MessageDispatch(t *testing.T) {
|
||||
// Server that sends a message to the bridge on connect.
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := testUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
msg := BridgeMessage{
|
||||
Type: "chat_response",
|
||||
Channel: "chat:session-1",
|
||||
Data: "hello from laravel",
|
||||
}
|
||||
data, _ := json.Marshal(msg)
|
||||
conn.WriteMessage(websocket.TextMessage, data)
|
||||
|
||||
// Keep connection open
|
||||
for {
|
||||
_, _, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
hub := ws.NewHub()
|
||||
ctx := t.Context()
|
||||
go hub.Run(ctx)
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.LaravelWSURL = wsURL(ts)
|
||||
cfg.ReconnectInterval = 100 * time.Millisecond
|
||||
|
||||
bridge := NewBridge(hub, cfg)
|
||||
bridge.Start(ctx)
|
||||
|
||||
waitConnected(t, bridge, 2*time.Second)
|
||||
|
||||
// Give time for the dispatched message to be processed.
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Verify hub stats — the message was dispatched (even without subscribers).
|
||||
// This confirms the dispatch path ran without error.
|
||||
}
|
||||
|
||||
func TestBridge_Good_Reconnect(t *testing.T) {
|
||||
// Use atomic counter to avoid data race between HTTP handler goroutine
|
||||
// and the test goroutine.
|
||||
var callCount atomic.Int32
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
n := callCount.Add(1)
|
||||
conn, err := testUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Close immediately on first connection to force reconnect
|
||||
if n == 1 {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
for {
|
||||
_, _, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
hub := ws.NewHub()
|
||||
ctx := t.Context()
|
||||
go hub.Run(ctx)
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.LaravelWSURL = wsURL(ts)
|
||||
cfg.ReconnectInterval = 100 * time.Millisecond
|
||||
cfg.MaxReconnectInterval = 200 * time.Millisecond
|
||||
|
||||
bridge := NewBridge(hub, cfg)
|
||||
bridge.Start(ctx)
|
||||
|
||||
waitConnected(t, bridge, 3*time.Second)
|
||||
|
||||
if callCount.Load() < 2 {
|
||||
t.Errorf("expected at least 2 connection attempts, got %d", callCount.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridge_Good_ExponentialBackoff(t *testing.T) {
|
||||
// Track timestamps of dial attempts to verify backoff behaviour.
|
||||
// The server rejects the WebSocket upgrade with HTTP 403, so dial()
|
||||
// returns an error and the exponential backoff path fires.
|
||||
var attempts []time.Time
|
||||
var mu sync.Mutex
|
||||
var attemptCount atomic.Int32
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
attempts = append(attempts, time.Now())
|
||||
mu.Unlock()
|
||||
attemptCount.Add(1)
|
||||
|
||||
// Reject the upgrade — this makes dial() fail, triggering backoff.
|
||||
http.Error(w, "forbidden", http.StatusForbidden)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
hub := ws.NewHub()
|
||||
ctx := t.Context()
|
||||
go hub.Run(ctx)
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.LaravelWSURL = wsURL(ts)
|
||||
cfg.ReconnectInterval = 100 * time.Millisecond
|
||||
cfg.MaxReconnectInterval = 400 * time.Millisecond
|
||||
|
||||
bridge := NewBridge(hub, cfg)
|
||||
bridge.Start(ctx)
|
||||
|
||||
// Wait for at least 4 dial attempts.
|
||||
deadline := time.Now().Add(5 * time.Second)
|
||||
for attemptCount.Load() < 4 && time.Now().Before(deadline) {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
bridge.Shutdown()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if len(attempts) < 4 {
|
||||
t.Fatalf("expected at least 4 connection attempts, got %d", len(attempts))
|
||||
}
|
||||
|
||||
// Verify exponential backoff: gap between attempts should increase.
|
||||
// Expected delays: ~100ms, ~200ms, ~400ms (capped).
|
||||
// Allow generous tolerance since timing is non-deterministic.
|
||||
for i := 1; i < len(attempts) && i <= 3; i++ {
|
||||
gap := attempts[i].Sub(attempts[i-1])
|
||||
// Minimum expected delay doubles each time: 100, 200, 400.
|
||||
// We check a lower bound (50% of expected) to be resilient.
|
||||
expectedMin := time.Duration(50*(1<<(i-1))) * time.Millisecond
|
||||
if gap < expectedMin {
|
||||
t.Errorf("attempt %d->%d gap %v < expected minimum %v", i-1, i, gap, expectedMin)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the backoff caps at MaxReconnectInterval.
|
||||
if len(attempts) >= 5 {
|
||||
gap := attempts[4].Sub(attempts[3])
|
||||
// After cap is hit, delay should not exceed MaxReconnectInterval + tolerance.
|
||||
maxExpected := cfg.MaxReconnectInterval + 200*time.Millisecond
|
||||
if gap > maxExpected {
|
||||
t.Errorf("attempt 3->4 gap %v exceeded max backoff %v", gap, maxExpected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridge_Good_ReconnectDetectsServerShutdown(t *testing.T) {
|
||||
// Start a server that closes the WS connection on demand, then close
|
||||
// the server entirely so the bridge cannot reconnect.
|
||||
closeConn := make(chan struct{}, 1)
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := testUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Wait for signal to close
|
||||
<-closeConn
|
||||
conn.WriteMessage(websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseGoingAway, "shutdown"))
|
||||
}))
|
||||
|
||||
hub := ws.NewHub()
|
||||
ctx := t.Context()
|
||||
go hub.Run(ctx)
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.LaravelWSURL = wsURL(ts)
|
||||
// Use long reconnect so bridge stays disconnected after server dies.
|
||||
cfg.ReconnectInterval = 5 * time.Second
|
||||
cfg.MaxReconnectInterval = 5 * time.Second
|
||||
|
||||
bridge := NewBridge(hub, cfg)
|
||||
bridge.Start(ctx)
|
||||
|
||||
waitConnected(t, bridge, 2*time.Second)
|
||||
|
||||
// Signal server handler to close the WS connection, then shut down
|
||||
// the server so the reconnect dial() also fails.
|
||||
closeConn <- struct{}{}
|
||||
ts.Close()
|
||||
|
||||
// Wait for disconnection.
|
||||
deadline := time.Now().Add(3 * time.Second)
|
||||
for bridge.Connected() && time.Now().Before(deadline) {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
if bridge.Connected() {
|
||||
t.Error("expected bridge to detect server-side connection close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridge_Good_AuthHeader(t *testing.T) {
|
||||
// Server that checks for the Authorization header on upgrade.
|
||||
var receivedAuth atomic.Value
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuth.Store(r.Header.Get("Authorization"))
|
||||
conn, err := testUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
for {
|
||||
_, _, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
hub := ws.NewHub()
|
||||
ctx := t.Context()
|
||||
go hub.Run(ctx)
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.LaravelWSURL = wsURL(ts)
|
||||
cfg.ReconnectInterval = 100 * time.Millisecond
|
||||
cfg.Token = "test-secret-token-42"
|
||||
|
||||
bridge := NewBridge(hub, cfg)
|
||||
bridge.Start(ctx)
|
||||
|
||||
waitConnected(t, bridge, 2*time.Second)
|
||||
|
||||
auth, ok := receivedAuth.Load().(string)
|
||||
if !ok || auth == "" {
|
||||
t.Fatal("server did not receive Authorization header")
|
||||
}
|
||||
|
||||
expected := "Bearer test-secret-token-42"
|
||||
if auth != expected {
|
||||
t.Errorf("expected auth header %q, got %q", expected, auth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridge_Good_NoAuthHeaderWhenTokenEmpty(t *testing.T) {
|
||||
// Verify that no Authorization header is sent when Token is empty.
|
||||
var receivedAuth atomic.Value
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedAuth.Store(r.Header.Get("Authorization"))
|
||||
conn, err := testUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
for {
|
||||
_, _, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
hub := ws.NewHub()
|
||||
ctx := t.Context()
|
||||
go hub.Run(ctx)
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.LaravelWSURL = wsURL(ts)
|
||||
cfg.ReconnectInterval = 100 * time.Millisecond
|
||||
// Token intentionally left empty
|
||||
|
||||
bridge := NewBridge(hub, cfg)
|
||||
bridge.Start(ctx)
|
||||
|
||||
waitConnected(t, bridge, 2*time.Second)
|
||||
|
||||
auth, _ := receivedAuth.Load().(string)
|
||||
if auth != "" {
|
||||
t.Errorf("expected no Authorization header when token is empty, got %q", auth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridge_Good_WithTokenOption(t *testing.T) {
|
||||
// Verify the WithToken option function works.
|
||||
cfg := DefaultConfig()
|
||||
opt := WithToken("my-token")
|
||||
opt(&cfg)
|
||||
|
||||
if cfg.Token != "my-token" {
|
||||
t.Errorf("expected token 'my-token', got %q", cfg.Token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubsystem_Good_Name(t *testing.T) {
|
||||
sub := New(nil)
|
||||
if sub.Name() != "ide" {
|
||||
t.Errorf("expected name 'ide', got %q", sub.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubsystem_Good_NilHub(t *testing.T) {
|
||||
sub := New(nil)
|
||||
if sub.Bridge() != nil {
|
||||
t.Error("expected nil bridge when hub is nil")
|
||||
}
|
||||
// Shutdown should not panic
|
||||
if err := sub.Shutdown(context.Background()); err != nil {
|
||||
t.Errorf("Shutdown with nil bridge failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,57 +0,0 @@
|
|||
// Package ide provides an MCP subsystem that bridges the desktop IDE to
|
||||
// a Laravel core-agentic backend over WebSocket.
|
||||
package ide
|
||||
|
||||
import "time"
|
||||
|
||||
// Config holds connection and workspace settings for the IDE subsystem.
|
||||
type Config struct {
|
||||
// LaravelWSURL is the WebSocket endpoint for the Laravel core-agentic backend.
|
||||
LaravelWSURL string
|
||||
|
||||
// WorkspaceRoot is the local path used as the default workspace context.
|
||||
WorkspaceRoot string
|
||||
|
||||
// Token is the Bearer token sent in the Authorization header during
|
||||
// WebSocket upgrade. When empty, no auth header is sent.
|
||||
Token string
|
||||
|
||||
// ReconnectInterval controls how long to wait between reconnect attempts.
|
||||
ReconnectInterval time.Duration
|
||||
|
||||
// MaxReconnectInterval caps exponential backoff for reconnection.
|
||||
MaxReconnectInterval time.Duration
|
||||
}
|
||||
|
||||
// DefaultConfig returns sensible defaults for local development.
|
||||
func DefaultConfig() Config {
|
||||
return Config{
|
||||
LaravelWSURL: "ws://localhost:9876/ws",
|
||||
WorkspaceRoot: ".",
|
||||
ReconnectInterval: 2 * time.Second,
|
||||
MaxReconnectInterval: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// Option configures the IDE subsystem.
|
||||
type Option func(*Config)
|
||||
|
||||
// WithLaravelURL sets the Laravel WebSocket endpoint.
|
||||
func WithLaravelURL(url string) Option {
|
||||
return func(c *Config) { c.LaravelWSURL = url }
|
||||
}
|
||||
|
||||
// WithWorkspaceRoot sets the workspace root directory.
|
||||
func WithWorkspaceRoot(root string) Option {
|
||||
return func(c *Config) { c.WorkspaceRoot = root }
|
||||
}
|
||||
|
||||
// WithReconnectInterval sets the base reconnect interval.
|
||||
func WithReconnectInterval(d time.Duration) Option {
|
||||
return func(c *Config) { c.ReconnectInterval = d }
|
||||
}
|
||||
|
||||
// WithToken sets the Bearer token for WebSocket authentication.
|
||||
func WithToken(token string) Option {
|
||||
return func(c *Config) { c.Token = token }
|
||||
}
|
||||
|
|
@ -1,62 +0,0 @@
|
|||
package ide
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"forge.lthn.ai/core/go-ws"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// errBridgeNotAvailable is returned when a tool requires the Laravel bridge
|
||||
// but it has not been initialised (headless mode).
|
||||
var errBridgeNotAvailable = errors.New("bridge not available")
|
||||
|
||||
// Subsystem implements mcp.Subsystem and mcp.SubsystemWithShutdown for the IDE.
|
||||
type Subsystem struct {
|
||||
cfg Config
|
||||
bridge *Bridge
|
||||
hub *ws.Hub
|
||||
}
|
||||
|
||||
// New creates an IDE subsystem. The ws.Hub is used for real-time forwarding;
|
||||
// pass nil if headless (tools still work but real-time streaming is disabled).
|
||||
func New(hub *ws.Hub, opts ...Option) *Subsystem {
|
||||
cfg := DefaultConfig()
|
||||
for _, opt := range opts {
|
||||
opt(&cfg)
|
||||
}
|
||||
var bridge *Bridge
|
||||
if hub != nil {
|
||||
bridge = NewBridge(hub, cfg)
|
||||
}
|
||||
return &Subsystem{cfg: cfg, bridge: bridge, hub: hub}
|
||||
}
|
||||
|
||||
// Name implements mcp.Subsystem.
|
||||
func (s *Subsystem) Name() string { return "ide" }
|
||||
|
||||
// RegisterTools implements mcp.Subsystem.
|
||||
func (s *Subsystem) RegisterTools(server *mcp.Server) {
|
||||
s.registerChatTools(server)
|
||||
s.registerBuildTools(server)
|
||||
s.registerDashboardTools(server)
|
||||
}
|
||||
|
||||
// Shutdown implements mcp.SubsystemWithShutdown.
|
||||
func (s *Subsystem) Shutdown(_ context.Context) error {
|
||||
if s.bridge != nil {
|
||||
s.bridge.Shutdown()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Bridge returns the Laravel WebSocket bridge (may be nil in headless mode).
|
||||
func (s *Subsystem) Bridge() *Bridge { return s.bridge }
|
||||
|
||||
// StartBridge begins the background connection to the Laravel backend.
|
||||
func (s *Subsystem) StartBridge(ctx context.Context) {
|
||||
if s.bridge != nil {
|
||||
s.bridge.Start(ctx)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,114 +0,0 @@
|
|||
package ide
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// Build tool input/output types.
|
||||
|
||||
// BuildStatusInput is the input for ide_build_status.
|
||||
type BuildStatusInput struct {
|
||||
BuildID string `json:"buildId"`
|
||||
}
|
||||
|
||||
// BuildInfo represents a single build.
|
||||
type BuildInfo struct {
|
||||
ID string `json:"id"`
|
||||
Repo string `json:"repo"`
|
||||
Branch string `json:"branch"`
|
||||
Status string `json:"status"`
|
||||
Duration string `json:"duration,omitempty"`
|
||||
StartedAt time.Time `json:"startedAt"`
|
||||
}
|
||||
|
||||
// BuildStatusOutput is the output for ide_build_status.
|
||||
type BuildStatusOutput struct {
|
||||
Build BuildInfo `json:"build"`
|
||||
}
|
||||
|
||||
// BuildListInput is the input for ide_build_list.
|
||||
type BuildListInput struct {
|
||||
Repo string `json:"repo,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
}
|
||||
|
||||
// BuildListOutput is the output for ide_build_list.
|
||||
type BuildListOutput struct {
|
||||
Builds []BuildInfo `json:"builds"`
|
||||
}
|
||||
|
||||
// BuildLogsInput is the input for ide_build_logs.
|
||||
type BuildLogsInput struct {
|
||||
BuildID string `json:"buildId"`
|
||||
Tail int `json:"tail,omitempty"`
|
||||
}
|
||||
|
||||
// BuildLogsOutput is the output for ide_build_logs.
|
||||
type BuildLogsOutput struct {
|
||||
BuildID string `json:"buildId"`
|
||||
Lines []string `json:"lines"`
|
||||
}
|
||||
|
||||
func (s *Subsystem) registerBuildTools(server *mcp.Server) {
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "ide_build_status",
|
||||
Description: "Get the status of a specific build",
|
||||
}, s.buildStatus)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "ide_build_list",
|
||||
Description: "List recent builds, optionally filtered by repository",
|
||||
}, s.buildList)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "ide_build_logs",
|
||||
Description: "Retrieve log output for a build",
|
||||
}, s.buildLogs)
|
||||
}
|
||||
|
||||
// buildStatus requests build status from the Laravel backend.
|
||||
// Stub implementation: sends request via bridge, returns "unknown" status. Awaiting Laravel backend.
|
||||
func (s *Subsystem) buildStatus(_ context.Context, _ *mcp.CallToolRequest, input BuildStatusInput) (*mcp.CallToolResult, BuildStatusOutput, error) {
|
||||
if s.bridge == nil {
|
||||
return nil, BuildStatusOutput{}, errBridgeNotAvailable
|
||||
}
|
||||
_ = s.bridge.Send(BridgeMessage{
|
||||
Type: "build_status",
|
||||
Data: map[string]any{"buildId": input.BuildID},
|
||||
})
|
||||
return nil, BuildStatusOutput{
|
||||
Build: BuildInfo{ID: input.BuildID, Status: "unknown"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildList requests a list of builds from the Laravel backend.
|
||||
// Stub implementation: sends request via bridge, returns empty list. Awaiting Laravel backend.
|
||||
func (s *Subsystem) buildList(_ context.Context, _ *mcp.CallToolRequest, input BuildListInput) (*mcp.CallToolResult, BuildListOutput, error) {
|
||||
if s.bridge == nil {
|
||||
return nil, BuildListOutput{}, errBridgeNotAvailable
|
||||
}
|
||||
_ = s.bridge.Send(BridgeMessage{
|
||||
Type: "build_list",
|
||||
Data: map[string]any{"repo": input.Repo, "limit": input.Limit},
|
||||
})
|
||||
return nil, BuildListOutput{Builds: []BuildInfo{}}, nil
|
||||
}
|
||||
|
||||
// buildLogs requests build log output from the Laravel backend.
|
||||
// Stub implementation: sends request via bridge, returns empty lines. Awaiting Laravel backend.
|
||||
func (s *Subsystem) buildLogs(_ context.Context, _ *mcp.CallToolRequest, input BuildLogsInput) (*mcp.CallToolResult, BuildLogsOutput, error) {
|
||||
if s.bridge == nil {
|
||||
return nil, BuildLogsOutput{}, errBridgeNotAvailable
|
||||
}
|
||||
_ = s.bridge.Send(BridgeMessage{
|
||||
Type: "build_logs",
|
||||
Data: map[string]any{"buildId": input.BuildID, "tail": input.Tail},
|
||||
})
|
||||
return nil, BuildLogsOutput{
|
||||
BuildID: input.BuildID,
|
||||
Lines: []string{},
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -1,201 +0,0 @@
|
|||
package ide
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// Chat tool input/output types.
|
||||
|
||||
// ChatSendInput is the input for ide_chat_send.
|
||||
type ChatSendInput struct {
|
||||
SessionID string `json:"sessionId"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ChatSendOutput is the output for ide_chat_send.
|
||||
type ChatSendOutput struct {
|
||||
Sent bool `json:"sent"`
|
||||
SessionID string `json:"sessionId"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// ChatHistoryInput is the input for ide_chat_history.
|
||||
type ChatHistoryInput struct {
|
||||
SessionID string `json:"sessionId"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
}
|
||||
|
||||
// ChatMessage represents a single message in history.
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// ChatHistoryOutput is the output for ide_chat_history.
|
||||
type ChatHistoryOutput struct {
|
||||
SessionID string `json:"sessionId"`
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
}
|
||||
|
||||
// SessionListInput is the input for ide_session_list.
|
||||
type SessionListInput struct{}
|
||||
|
||||
// Session represents an agent session.
|
||||
type Session struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// SessionListOutput is the output for ide_session_list.
|
||||
type SessionListOutput struct {
|
||||
Sessions []Session `json:"sessions"`
|
||||
}
|
||||
|
||||
// SessionCreateInput is the input for ide_session_create.
|
||||
type SessionCreateInput struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// SessionCreateOutput is the output for ide_session_create.
|
||||
type SessionCreateOutput struct {
|
||||
Session Session `json:"session"`
|
||||
}
|
||||
|
||||
// PlanStatusInput is the input for ide_plan_status.
|
||||
type PlanStatusInput struct {
|
||||
SessionID string `json:"sessionId"`
|
||||
}
|
||||
|
||||
// PlanStep is a single step in an agent plan.
|
||||
type PlanStep struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// PlanStatusOutput is the output for ide_plan_status.
|
||||
type PlanStatusOutput struct {
|
||||
SessionID string `json:"sessionId"`
|
||||
Status string `json:"status"`
|
||||
Steps []PlanStep `json:"steps"`
|
||||
}
|
||||
|
||||
func (s *Subsystem) registerChatTools(server *mcp.Server) {
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "ide_chat_send",
|
||||
Description: "Send a message to an agent chat session",
|
||||
}, s.chatSend)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "ide_chat_history",
|
||||
Description: "Retrieve message history for a chat session",
|
||||
}, s.chatHistory)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "ide_session_list",
|
||||
Description: "List active agent sessions",
|
||||
}, s.sessionList)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "ide_session_create",
|
||||
Description: "Create a new agent session",
|
||||
}, s.sessionCreate)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "ide_plan_status",
|
||||
Description: "Get the current plan status for a session",
|
||||
}, s.planStatus)
|
||||
}
|
||||
|
||||
// chatSend forwards a chat message to the Laravel backend via bridge.
|
||||
// Stub implementation: delegates to bridge, real response arrives via WebSocket subscription.
|
||||
func (s *Subsystem) chatSend(_ context.Context, _ *mcp.CallToolRequest, input ChatSendInput) (*mcp.CallToolResult, ChatSendOutput, error) {
|
||||
if s.bridge == nil {
|
||||
return nil, ChatSendOutput{}, errBridgeNotAvailable
|
||||
}
|
||||
err := s.bridge.Send(BridgeMessage{
|
||||
Type: "chat_send",
|
||||
Channel: "chat:" + input.SessionID,
|
||||
SessionID: input.SessionID,
|
||||
Data: input.Message,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, ChatSendOutput{}, fmt.Errorf("failed to send message: %w", err)
|
||||
}
|
||||
return nil, ChatSendOutput{
|
||||
Sent: true,
|
||||
SessionID: input.SessionID,
|
||||
Timestamp: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// chatHistory requests message history from the Laravel backend.
|
||||
// Stub implementation: sends request via bridge, returns empty messages. Real data arrives via WebSocket.
|
||||
func (s *Subsystem) chatHistory(_ context.Context, _ *mcp.CallToolRequest, input ChatHistoryInput) (*mcp.CallToolResult, ChatHistoryOutput, error) {
|
||||
if s.bridge == nil {
|
||||
return nil, ChatHistoryOutput{}, errBridgeNotAvailable
|
||||
}
|
||||
// Request history via bridge; for now return placeholder indicating the
|
||||
// request was forwarded. Real data arrives via WebSocket subscription.
|
||||
_ = s.bridge.Send(BridgeMessage{
|
||||
Type: "chat_history",
|
||||
SessionID: input.SessionID,
|
||||
Data: map[string]any{"limit": input.Limit},
|
||||
})
|
||||
return nil, ChatHistoryOutput{
|
||||
SessionID: input.SessionID,
|
||||
Messages: []ChatMessage{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// sessionList requests the session list from the Laravel backend.
|
||||
// Stub implementation: sends request via bridge, returns empty sessions. Awaiting Laravel backend.
|
||||
func (s *Subsystem) sessionList(_ context.Context, _ *mcp.CallToolRequest, _ SessionListInput) (*mcp.CallToolResult, SessionListOutput, error) {
|
||||
if s.bridge == nil {
|
||||
return nil, SessionListOutput{}, errBridgeNotAvailable
|
||||
}
|
||||
_ = s.bridge.Send(BridgeMessage{Type: "session_list"})
|
||||
return nil, SessionListOutput{Sessions: []Session{}}, nil
|
||||
}
|
||||
|
||||
// sessionCreate requests a new session from the Laravel backend.
|
||||
// Stub implementation: sends request via bridge, returns placeholder session. Awaiting Laravel backend.
|
||||
func (s *Subsystem) sessionCreate(_ context.Context, _ *mcp.CallToolRequest, input SessionCreateInput) (*mcp.CallToolResult, SessionCreateOutput, error) {
|
||||
if s.bridge == nil {
|
||||
return nil, SessionCreateOutput{}, errBridgeNotAvailable
|
||||
}
|
||||
_ = s.bridge.Send(BridgeMessage{
|
||||
Type: "session_create",
|
||||
Data: map[string]any{"name": input.Name},
|
||||
})
|
||||
return nil, SessionCreateOutput{
|
||||
Session: Session{
|
||||
Name: input.Name,
|
||||
Status: "creating",
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// planStatus requests plan status from the Laravel backend.
|
||||
// Stub implementation: sends request via bridge, returns "unknown" status. Awaiting Laravel backend.
|
||||
func (s *Subsystem) planStatus(_ context.Context, _ *mcp.CallToolRequest, input PlanStatusInput) (*mcp.CallToolResult, PlanStatusOutput, error) {
|
||||
if s.bridge == nil {
|
||||
return nil, PlanStatusOutput{}, errBridgeNotAvailable
|
||||
}
|
||||
_ = s.bridge.Send(BridgeMessage{
|
||||
Type: "plan_status",
|
||||
SessionID: input.SessionID,
|
||||
})
|
||||
return nil, PlanStatusOutput{
|
||||
SessionID: input.SessionID,
|
||||
Status: "unknown",
|
||||
Steps: []PlanStep{},
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -1,132 +0,0 @@
|
|||
package ide
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// Dashboard tool input/output types.
|
||||
|
||||
// DashboardOverviewInput is the input for ide_dashboard_overview.
|
||||
type DashboardOverviewInput struct{}
|
||||
|
||||
// DashboardOverview contains high-level platform stats.
|
||||
type DashboardOverview struct {
|
||||
Repos int `json:"repos"`
|
||||
Services int `json:"services"`
|
||||
ActiveSessions int `json:"activeSessions"`
|
||||
RecentBuilds int `json:"recentBuilds"`
|
||||
BridgeOnline bool `json:"bridgeOnline"`
|
||||
}
|
||||
|
||||
// DashboardOverviewOutput is the output for ide_dashboard_overview.
|
||||
type DashboardOverviewOutput struct {
|
||||
Overview DashboardOverview `json:"overview"`
|
||||
}
|
||||
|
||||
// DashboardActivityInput is the input for ide_dashboard_activity.
|
||||
type DashboardActivityInput struct {
|
||||
Limit int `json:"limit,omitempty"`
|
||||
}
|
||||
|
||||
// ActivityEvent represents a single activity feed item.
|
||||
type ActivityEvent struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// DashboardActivityOutput is the output for ide_dashboard_activity.
|
||||
type DashboardActivityOutput struct {
|
||||
Events []ActivityEvent `json:"events"`
|
||||
}
|
||||
|
||||
// DashboardMetricsInput is the input for ide_dashboard_metrics.
|
||||
type DashboardMetricsInput struct {
|
||||
Period string `json:"period,omitempty"` // "1h", "24h", "7d"
|
||||
}
|
||||
|
||||
// DashboardMetrics contains aggregate metrics.
|
||||
type DashboardMetrics struct {
|
||||
BuildsTotal int `json:"buildsTotal"`
|
||||
BuildsSuccess int `json:"buildsSuccess"`
|
||||
BuildsFailed int `json:"buildsFailed"`
|
||||
AvgBuildTime string `json:"avgBuildTime"`
|
||||
AgentSessions int `json:"agentSessions"`
|
||||
MessagesTotal int `json:"messagesTotal"`
|
||||
SuccessRate float64 `json:"successRate"`
|
||||
}
|
||||
|
||||
// DashboardMetricsOutput is the output for ide_dashboard_metrics.
|
||||
type DashboardMetricsOutput struct {
|
||||
Period string `json:"period"`
|
||||
Metrics DashboardMetrics `json:"metrics"`
|
||||
}
|
||||
|
||||
func (s *Subsystem) registerDashboardTools(server *mcp.Server) {
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "ide_dashboard_overview",
|
||||
Description: "Get a high-level overview of the platform (repos, services, sessions, builds)",
|
||||
}, s.dashboardOverview)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "ide_dashboard_activity",
|
||||
Description: "Get the recent activity feed",
|
||||
}, s.dashboardActivity)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "ide_dashboard_metrics",
|
||||
Description: "Get aggregate build and agent metrics for a time period",
|
||||
}, s.dashboardMetrics)
|
||||
}
|
||||
|
||||
// dashboardOverview returns a platform overview with bridge status.
|
||||
// Stub implementation: only BridgeOnline is live; other fields return zero values. Awaiting Laravel backend.
|
||||
func (s *Subsystem) dashboardOverview(_ context.Context, _ *mcp.CallToolRequest, _ DashboardOverviewInput) (*mcp.CallToolResult, DashboardOverviewOutput, error) {
|
||||
connected := s.bridge != nil && s.bridge.Connected()
|
||||
|
||||
if s.bridge != nil {
|
||||
_ = s.bridge.Send(BridgeMessage{Type: "dashboard_overview"})
|
||||
}
|
||||
|
||||
return nil, DashboardOverviewOutput{
|
||||
Overview: DashboardOverview{
|
||||
BridgeOnline: connected,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// dashboardActivity requests the activity feed from the Laravel backend.
|
||||
// Stub implementation: sends request via bridge, returns empty events. Awaiting Laravel backend.
|
||||
func (s *Subsystem) dashboardActivity(_ context.Context, _ *mcp.CallToolRequest, input DashboardActivityInput) (*mcp.CallToolResult, DashboardActivityOutput, error) {
|
||||
if s.bridge == nil {
|
||||
return nil, DashboardActivityOutput{}, errBridgeNotAvailable
|
||||
}
|
||||
_ = s.bridge.Send(BridgeMessage{
|
||||
Type: "dashboard_activity",
|
||||
Data: map[string]any{"limit": input.Limit},
|
||||
})
|
||||
return nil, DashboardActivityOutput{Events: []ActivityEvent{}}, nil
|
||||
}
|
||||
|
||||
// dashboardMetrics requests aggregate metrics from the Laravel backend.
|
||||
// Stub implementation: sends request via bridge, returns zero metrics. Awaiting Laravel backend.
|
||||
func (s *Subsystem) dashboardMetrics(_ context.Context, _ *mcp.CallToolRequest, input DashboardMetricsInput) (*mcp.CallToolResult, DashboardMetricsOutput, error) {
|
||||
if s.bridge == nil {
|
||||
return nil, DashboardMetricsOutput{}, errBridgeNotAvailable
|
||||
}
|
||||
period := input.Period
|
||||
if period == "" {
|
||||
period = "24h"
|
||||
}
|
||||
_ = s.bridge.Send(BridgeMessage{
|
||||
Type: "dashboard_metrics",
|
||||
Data: map[string]any{"period": period},
|
||||
})
|
||||
return nil, DashboardMetricsOutput{
|
||||
Period: period,
|
||||
Metrics: DashboardMetrics{},
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -1,781 +0,0 @@
|
|||
package ide
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/go-ws"
|
||||
)
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
// newNilBridgeSubsystem returns a Subsystem with no hub/bridge (headless mode).
|
||||
func newNilBridgeSubsystem() *Subsystem {
|
||||
return New(nil)
|
||||
}
|
||||
|
||||
// newConnectedSubsystem returns a Subsystem with a connected bridge and a
|
||||
// running echo WS server. Caller must cancel ctx and close server when done.
|
||||
func newConnectedSubsystem(t *testing.T) (*Subsystem, context.CancelFunc, *httptest.Server) {
|
||||
t.Helper()
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := testUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
for {
|
||||
mt, data, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
_ = conn.WriteMessage(mt, data)
|
||||
}
|
||||
}))
|
||||
|
||||
hub := ws.NewHub()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go hub.Run(ctx)
|
||||
|
||||
sub := New(hub,
|
||||
WithLaravelURL(wsURL(ts)),
|
||||
WithReconnectInterval(50*time.Millisecond),
|
||||
)
|
||||
sub.StartBridge(ctx)
|
||||
|
||||
waitConnected(t, sub.Bridge(), 2*time.Second)
|
||||
return sub, cancel, ts
|
||||
}
|
||||
|
||||
// --- 4.3: Chat tool tests ---
|
||||
|
||||
// TestChatSend_Bad_NilBridge verifies chatSend returns error without a bridge.
|
||||
func TestChatSend_Bad_NilBridge(t *testing.T) {
|
||||
sub := newNilBridgeSubsystem()
|
||||
_, _, err := sub.chatSend(context.Background(), nil, ChatSendInput{
|
||||
SessionID: "s1",
|
||||
Message: "hello",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error when bridge is nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestChatSend_Good_Connected verifies chatSend succeeds with a connected bridge.
|
||||
func TestChatSend_Good_Connected(t *testing.T) {
|
||||
sub, cancel, ts := newConnectedSubsystem(t)
|
||||
defer cancel()
|
||||
defer ts.Close()
|
||||
|
||||
_, out, err := sub.chatSend(context.Background(), nil, ChatSendInput{
|
||||
SessionID: "sess-42",
|
||||
Message: "hello",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("chatSend failed: %v", err)
|
||||
}
|
||||
if !out.Sent {
|
||||
t.Error("expected Sent=true")
|
||||
}
|
||||
if out.SessionID != "sess-42" {
|
||||
t.Errorf("expected sessionId 'sess-42', got %q", out.SessionID)
|
||||
}
|
||||
if out.Timestamp.IsZero() {
|
||||
t.Error("expected non-zero timestamp")
|
||||
}
|
||||
}
|
||||
|
||||
// TestChatHistory_Bad_NilBridge verifies chatHistory returns error without a bridge.
|
||||
func TestChatHistory_Bad_NilBridge(t *testing.T) {
|
||||
sub := newNilBridgeSubsystem()
|
||||
_, _, err := sub.chatHistory(context.Background(), nil, ChatHistoryInput{
|
||||
SessionID: "s1",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error when bridge is nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestChatHistory_Good_Connected verifies chatHistory succeeds and returns empty messages.
|
||||
func TestChatHistory_Good_Connected(t *testing.T) {
|
||||
sub, cancel, ts := newConnectedSubsystem(t)
|
||||
defer cancel()
|
||||
defer ts.Close()
|
||||
|
||||
_, out, err := sub.chatHistory(context.Background(), nil, ChatHistoryInput{
|
||||
SessionID: "sess-1",
|
||||
Limit: 50,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("chatHistory failed: %v", err)
|
||||
}
|
||||
if out.SessionID != "sess-1" {
|
||||
t.Errorf("expected sessionId 'sess-1', got %q", out.SessionID)
|
||||
}
|
||||
if out.Messages == nil {
|
||||
t.Error("expected non-nil messages slice")
|
||||
}
|
||||
if len(out.Messages) != 0 {
|
||||
t.Errorf("expected 0 messages (stub), got %d", len(out.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionList_Bad_NilBridge verifies sessionList returns error without a bridge.
|
||||
func TestSessionList_Bad_NilBridge(t *testing.T) {
|
||||
sub := newNilBridgeSubsystem()
|
||||
_, _, err := sub.sessionList(context.Background(), nil, SessionListInput{})
|
||||
if err == nil {
|
||||
t.Error("expected error when bridge is nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionList_Good_Connected verifies sessionList returns empty sessions.
|
||||
func TestSessionList_Good_Connected(t *testing.T) {
|
||||
sub, cancel, ts := newConnectedSubsystem(t)
|
||||
defer cancel()
|
||||
defer ts.Close()
|
||||
|
||||
_, out, err := sub.sessionList(context.Background(), nil, SessionListInput{})
|
||||
if err != nil {
|
||||
t.Fatalf("sessionList failed: %v", err)
|
||||
}
|
||||
if out.Sessions == nil {
|
||||
t.Error("expected non-nil sessions slice")
|
||||
}
|
||||
if len(out.Sessions) != 0 {
|
||||
t.Errorf("expected 0 sessions (stub), got %d", len(out.Sessions))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionCreate_Bad_NilBridge verifies sessionCreate returns error without a bridge.
|
||||
func TestSessionCreate_Bad_NilBridge(t *testing.T) {
|
||||
sub := newNilBridgeSubsystem()
|
||||
_, _, err := sub.sessionCreate(context.Background(), nil, SessionCreateInput{
|
||||
Name: "test",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error when bridge is nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionCreate_Good_Connected verifies sessionCreate returns a session stub.
|
||||
func TestSessionCreate_Good_Connected(t *testing.T) {
|
||||
sub, cancel, ts := newConnectedSubsystem(t)
|
||||
defer cancel()
|
||||
defer ts.Close()
|
||||
|
||||
_, out, err := sub.sessionCreate(context.Background(), nil, SessionCreateInput{
|
||||
Name: "my-session",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("sessionCreate failed: %v", err)
|
||||
}
|
||||
if out.Session.Name != "my-session" {
|
||||
t.Errorf("expected name 'my-session', got %q", out.Session.Name)
|
||||
}
|
||||
if out.Session.Status != "creating" {
|
||||
t.Errorf("expected status 'creating', got %q", out.Session.Status)
|
||||
}
|
||||
if out.Session.CreatedAt.IsZero() {
|
||||
t.Error("expected non-zero CreatedAt")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPlanStatus_Bad_NilBridge verifies planStatus returns error without a bridge.
|
||||
func TestPlanStatus_Bad_NilBridge(t *testing.T) {
|
||||
sub := newNilBridgeSubsystem()
|
||||
_, _, err := sub.planStatus(context.Background(), nil, PlanStatusInput{
|
||||
SessionID: "s1",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error when bridge is nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPlanStatus_Good_Connected verifies planStatus returns a stub status.
|
||||
func TestPlanStatus_Good_Connected(t *testing.T) {
|
||||
sub, cancel, ts := newConnectedSubsystem(t)
|
||||
defer cancel()
|
||||
defer ts.Close()
|
||||
|
||||
_, out, err := sub.planStatus(context.Background(), nil, PlanStatusInput{
|
||||
SessionID: "sess-7",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("planStatus failed: %v", err)
|
||||
}
|
||||
if out.SessionID != "sess-7" {
|
||||
t.Errorf("expected sessionId 'sess-7', got %q", out.SessionID)
|
||||
}
|
||||
if out.Status != "unknown" {
|
||||
t.Errorf("expected status 'unknown', got %q", out.Status)
|
||||
}
|
||||
if out.Steps == nil {
|
||||
t.Error("expected non-nil steps slice")
|
||||
}
|
||||
}
|
||||
|
||||
// --- 4.3: Build tool tests ---
|
||||
|
||||
// TestBuildStatus_Bad_NilBridge verifies buildStatus returns error without a bridge.
|
||||
func TestBuildStatus_Bad_NilBridge(t *testing.T) {
|
||||
sub := newNilBridgeSubsystem()
|
||||
_, _, err := sub.buildStatus(context.Background(), nil, BuildStatusInput{
|
||||
BuildID: "b1",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error when bridge is nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildStatus_Good_Connected verifies buildStatus returns a stub.
|
||||
func TestBuildStatus_Good_Connected(t *testing.T) {
|
||||
sub, cancel, ts := newConnectedSubsystem(t)
|
||||
defer cancel()
|
||||
defer ts.Close()
|
||||
|
||||
_, out, err := sub.buildStatus(context.Background(), nil, BuildStatusInput{
|
||||
BuildID: "build-99",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("buildStatus failed: %v", err)
|
||||
}
|
||||
if out.Build.ID != "build-99" {
|
||||
t.Errorf("expected build ID 'build-99', got %q", out.Build.ID)
|
||||
}
|
||||
if out.Build.Status != "unknown" {
|
||||
t.Errorf("expected status 'unknown', got %q", out.Build.Status)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildList_Bad_NilBridge verifies buildList returns error without a bridge.
|
||||
func TestBuildList_Bad_NilBridge(t *testing.T) {
|
||||
sub := newNilBridgeSubsystem()
|
||||
_, _, err := sub.buildList(context.Background(), nil, BuildListInput{
|
||||
Repo: "core-php",
|
||||
Limit: 10,
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error when bridge is nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildList_Good_Connected verifies buildList returns an empty list.
|
||||
func TestBuildList_Good_Connected(t *testing.T) {
|
||||
sub, cancel, ts := newConnectedSubsystem(t)
|
||||
defer cancel()
|
||||
defer ts.Close()
|
||||
|
||||
_, out, err := sub.buildList(context.Background(), nil, BuildListInput{
|
||||
Repo: "core-php",
|
||||
Limit: 10,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("buildList failed: %v", err)
|
||||
}
|
||||
if out.Builds == nil {
|
||||
t.Error("expected non-nil builds slice")
|
||||
}
|
||||
if len(out.Builds) != 0 {
|
||||
t.Errorf("expected 0 builds (stub), got %d", len(out.Builds))
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildLogs_Bad_NilBridge verifies buildLogs returns error without a bridge.
|
||||
func TestBuildLogs_Bad_NilBridge(t *testing.T) {
|
||||
sub := newNilBridgeSubsystem()
|
||||
_, _, err := sub.buildLogs(context.Background(), nil, BuildLogsInput{
|
||||
BuildID: "b1",
|
||||
Tail: 100,
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error when bridge is nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildLogs_Good_Connected verifies buildLogs returns empty lines.
|
||||
func TestBuildLogs_Good_Connected(t *testing.T) {
|
||||
sub, cancel, ts := newConnectedSubsystem(t)
|
||||
defer cancel()
|
||||
defer ts.Close()
|
||||
|
||||
_, out, err := sub.buildLogs(context.Background(), nil, BuildLogsInput{
|
||||
BuildID: "build-55",
|
||||
Tail: 50,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("buildLogs failed: %v", err)
|
||||
}
|
||||
if out.BuildID != "build-55" {
|
||||
t.Errorf("expected buildId 'build-55', got %q", out.BuildID)
|
||||
}
|
||||
if out.Lines == nil {
|
||||
t.Error("expected non-nil lines slice")
|
||||
}
|
||||
if len(out.Lines) != 0 {
|
||||
t.Errorf("expected 0 lines (stub), got %d", len(out.Lines))
|
||||
}
|
||||
}
|
||||
|
||||
// --- 4.3: Dashboard tool tests ---
|
||||
|
||||
// TestDashboardOverview_Good_NilBridge verifies dashboardOverview works without bridge
|
||||
// (it does not return error — it reports BridgeOnline=false).
|
||||
func TestDashboardOverview_Good_NilBridge(t *testing.T) {
|
||||
sub := newNilBridgeSubsystem()
|
||||
_, out, err := sub.dashboardOverview(context.Background(), nil, DashboardOverviewInput{})
|
||||
if err != nil {
|
||||
t.Fatalf("dashboardOverview failed: %v", err)
|
||||
}
|
||||
if out.Overview.BridgeOnline {
|
||||
t.Error("expected BridgeOnline=false when bridge is nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDashboardOverview_Good_Connected verifies dashboardOverview reports bridge online.
|
||||
func TestDashboardOverview_Good_Connected(t *testing.T) {
|
||||
sub, cancel, ts := newConnectedSubsystem(t)
|
||||
defer cancel()
|
||||
defer ts.Close()
|
||||
|
||||
_, out, err := sub.dashboardOverview(context.Background(), nil, DashboardOverviewInput{})
|
||||
if err != nil {
|
||||
t.Fatalf("dashboardOverview failed: %v", err)
|
||||
}
|
||||
if !out.Overview.BridgeOnline {
|
||||
t.Error("expected BridgeOnline=true when bridge is connected")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDashboardActivity_Bad_NilBridge verifies dashboardActivity returns error without bridge.
|
||||
func TestDashboardActivity_Bad_NilBridge(t *testing.T) {
|
||||
sub := newNilBridgeSubsystem()
|
||||
_, _, err := sub.dashboardActivity(context.Background(), nil, DashboardActivityInput{
|
||||
Limit: 10,
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error when bridge is nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDashboardActivity_Good_Connected verifies dashboardActivity returns empty events.
|
||||
func TestDashboardActivity_Good_Connected(t *testing.T) {
|
||||
sub, cancel, ts := newConnectedSubsystem(t)
|
||||
defer cancel()
|
||||
defer ts.Close()
|
||||
|
||||
_, out, err := sub.dashboardActivity(context.Background(), nil, DashboardActivityInput{
|
||||
Limit: 20,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("dashboardActivity failed: %v", err)
|
||||
}
|
||||
if out.Events == nil {
|
||||
t.Error("expected non-nil events slice")
|
||||
}
|
||||
if len(out.Events) != 0 {
|
||||
t.Errorf("expected 0 events (stub), got %d", len(out.Events))
|
||||
}
|
||||
}
|
||||
|
||||
// TestDashboardMetrics_Bad_NilBridge verifies dashboardMetrics returns error without bridge.
|
||||
func TestDashboardMetrics_Bad_NilBridge(t *testing.T) {
|
||||
sub := newNilBridgeSubsystem()
|
||||
_, _, err := sub.dashboardMetrics(context.Background(), nil, DashboardMetricsInput{
|
||||
Period: "1h",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error when bridge is nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDashboardMetrics_Good_Connected verifies dashboardMetrics returns empty metrics.
|
||||
func TestDashboardMetrics_Good_Connected(t *testing.T) {
|
||||
sub, cancel, ts := newConnectedSubsystem(t)
|
||||
defer cancel()
|
||||
defer ts.Close()
|
||||
|
||||
_, out, err := sub.dashboardMetrics(context.Background(), nil, DashboardMetricsInput{
|
||||
Period: "7d",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("dashboardMetrics failed: %v", err)
|
||||
}
|
||||
if out.Period != "7d" {
|
||||
t.Errorf("expected period '7d', got %q", out.Period)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDashboardMetrics_Good_DefaultPeriod verifies the default period is "24h".
|
||||
func TestDashboardMetrics_Good_DefaultPeriod(t *testing.T) {
|
||||
sub, cancel, ts := newConnectedSubsystem(t)
|
||||
defer cancel()
|
||||
defer ts.Close()
|
||||
|
||||
_, out, err := sub.dashboardMetrics(context.Background(), nil, DashboardMetricsInput{})
|
||||
if err != nil {
|
||||
t.Fatalf("dashboardMetrics failed: %v", err)
|
||||
}
|
||||
if out.Period != "24h" {
|
||||
t.Errorf("expected default period '24h', got %q", out.Period)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Struct serialisation round-trip tests ---
|
||||
|
||||
// TestChatSendInput_Good_RoundTrip verifies JSON serialisation of ChatSendInput.
|
||||
func TestChatSendInput_Good_RoundTrip(t *testing.T) {
|
||||
in := ChatSendInput{SessionID: "s1", Message: "hello"}
|
||||
data, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var out ChatSendInput
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
if out != in {
|
||||
t.Errorf("round-trip mismatch: %+v != %+v", out, in)
|
||||
}
|
||||
}
|
||||
|
||||
// TestChatSendOutput_Good_RoundTrip verifies JSON serialisation of ChatSendOutput.
|
||||
func TestChatSendOutput_Good_RoundTrip(t *testing.T) {
|
||||
in := ChatSendOutput{Sent: true, SessionID: "s1", Timestamp: time.Now().Truncate(time.Second)}
|
||||
data, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var out ChatSendOutput
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
if out.Sent != in.Sent || out.SessionID != in.SessionID {
|
||||
t.Errorf("round-trip mismatch: %+v != %+v", out, in)
|
||||
}
|
||||
}
|
||||
|
||||
// TestChatHistoryOutput_Good_RoundTrip verifies ChatHistoryOutput JSON round-trip.
|
||||
func TestChatHistoryOutput_Good_RoundTrip(t *testing.T) {
|
||||
in := ChatHistoryOutput{
|
||||
SessionID: "s1",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: "hello", Timestamp: time.Now().Truncate(time.Second)},
|
||||
{Role: "assistant", Content: "hi", Timestamp: time.Now().Truncate(time.Second)},
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var out ChatHistoryOutput
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
if out.SessionID != in.SessionID {
|
||||
t.Errorf("sessionId mismatch: %q != %q", out.SessionID, in.SessionID)
|
||||
}
|
||||
if len(out.Messages) != 2 {
|
||||
t.Errorf("expected 2 messages, got %d", len(out.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionListOutput_Good_RoundTrip verifies SessionListOutput JSON round-trip.
|
||||
func TestSessionListOutput_Good_RoundTrip(t *testing.T) {
|
||||
in := SessionListOutput{
|
||||
Sessions: []Session{
|
||||
{ID: "s1", Name: "test", Status: "active", CreatedAt: time.Now().Truncate(time.Second)},
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var out SessionListOutput
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
if len(out.Sessions) != 1 || out.Sessions[0].ID != "s1" {
|
||||
t.Errorf("round-trip mismatch: %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPlanStatusOutput_Good_RoundTrip verifies PlanStatusOutput JSON round-trip.
|
||||
func TestPlanStatusOutput_Good_RoundTrip(t *testing.T) {
|
||||
in := PlanStatusOutput{
|
||||
SessionID: "s1",
|
||||
Status: "running",
|
||||
Steps: []PlanStep{{Name: "step1", Status: "done"}, {Name: "step2", Status: "pending"}},
|
||||
}
|
||||
data, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var out PlanStatusOutput
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
if out.SessionID != "s1" || len(out.Steps) != 2 {
|
||||
t.Errorf("round-trip mismatch: %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildStatusOutput_Good_RoundTrip verifies BuildStatusOutput JSON round-trip.
|
||||
func TestBuildStatusOutput_Good_RoundTrip(t *testing.T) {
|
||||
in := BuildStatusOutput{
|
||||
Build: BuildInfo{
|
||||
ID: "b1",
|
||||
Repo: "core-php",
|
||||
Branch: "main",
|
||||
Status: "success",
|
||||
Duration: "2m30s",
|
||||
StartedAt: time.Now().Truncate(time.Second),
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var out BuildStatusOutput
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
if out.Build.ID != "b1" || out.Build.Status != "success" {
|
||||
t.Errorf("round-trip mismatch: %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildListOutput_Good_RoundTrip verifies BuildListOutput JSON round-trip.
|
||||
func TestBuildListOutput_Good_RoundTrip(t *testing.T) {
|
||||
in := BuildListOutput{
|
||||
Builds: []BuildInfo{
|
||||
{ID: "b1", Repo: "core-php", Branch: "main", Status: "success"},
|
||||
{ID: "b2", Repo: "core-admin", Branch: "dev", Status: "failed"},
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var out BuildListOutput
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
if len(out.Builds) != 2 {
|
||||
t.Errorf("expected 2 builds, got %d", len(out.Builds))
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildLogsOutput_Good_RoundTrip verifies BuildLogsOutput JSON round-trip.
|
||||
func TestBuildLogsOutput_Good_RoundTrip(t *testing.T) {
|
||||
in := BuildLogsOutput{
|
||||
BuildID: "b1",
|
||||
Lines: []string{"line1", "line2", "line3"},
|
||||
}
|
||||
data, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var out BuildLogsOutput
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
if out.BuildID != "b1" || len(out.Lines) != 3 {
|
||||
t.Errorf("round-trip mismatch: %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDashboardOverviewOutput_Good_RoundTrip verifies DashboardOverviewOutput JSON round-trip.
|
||||
func TestDashboardOverviewOutput_Good_RoundTrip(t *testing.T) {
|
||||
in := DashboardOverviewOutput{
|
||||
Overview: DashboardOverview{
|
||||
Repos: 18,
|
||||
Services: 5,
|
||||
ActiveSessions: 3,
|
||||
RecentBuilds: 12,
|
||||
BridgeOnline: true,
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var out DashboardOverviewOutput
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
if out.Overview.Repos != 18 || !out.Overview.BridgeOnline {
|
||||
t.Errorf("round-trip mismatch: %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDashboardActivityOutput_Good_RoundTrip verifies DashboardActivityOutput JSON round-trip.
|
||||
func TestDashboardActivityOutput_Good_RoundTrip(t *testing.T) {
|
||||
in := DashboardActivityOutput{
|
||||
Events: []ActivityEvent{
|
||||
{Type: "deploy", Message: "deployed v1.2", Timestamp: time.Now().Truncate(time.Second)},
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var out DashboardActivityOutput
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
if len(out.Events) != 1 || out.Events[0].Type != "deploy" {
|
||||
t.Errorf("round-trip mismatch: %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDashboardMetricsOutput_Good_RoundTrip verifies DashboardMetricsOutput JSON round-trip.
|
||||
func TestDashboardMetricsOutput_Good_RoundTrip(t *testing.T) {
|
||||
in := DashboardMetricsOutput{
|
||||
Period: "24h",
|
||||
Metrics: DashboardMetrics{
|
||||
BuildsTotal: 100,
|
||||
BuildsSuccess: 90,
|
||||
BuildsFailed: 10,
|
||||
AvgBuildTime: "3m",
|
||||
AgentSessions: 5,
|
||||
MessagesTotal: 500,
|
||||
SuccessRate: 0.9,
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var out DashboardMetricsOutput
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
if out.Period != "24h" || out.Metrics.BuildsTotal != 100 || out.Metrics.SuccessRate != 0.9 {
|
||||
t.Errorf("round-trip mismatch: %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBridgeMessage_Good_RoundTrip verifies BridgeMessage JSON round-trip.
|
||||
func TestBridgeMessage_Good_RoundTrip(t *testing.T) {
|
||||
in := BridgeMessage{
|
||||
Type: "test",
|
||||
Channel: "ch1",
|
||||
SessionID: "s1",
|
||||
Data: "payload",
|
||||
Timestamp: time.Now().Truncate(time.Second),
|
||||
}
|
||||
data, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var out BridgeMessage
|
||||
if err := json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
if out.Type != "test" || out.Channel != "ch1" || out.SessionID != "s1" {
|
||||
t.Errorf("round-trip mismatch: %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Subsystem integration tests ---
|
||||
|
||||
// TestSubsystem_Good_RegisterTools verifies RegisterTools does not panic.
|
||||
func TestSubsystem_Good_RegisterTools(t *testing.T) {
|
||||
// RegisterTools requires a real mcp.Server which is complex to construct
|
||||
// in isolation. This test verifies the Subsystem can be created and
|
||||
// the Bridge/Shutdown path works end-to-end.
|
||||
sub := New(nil)
|
||||
if sub.Bridge() != nil {
|
||||
t.Error("expected nil bridge with nil hub")
|
||||
}
|
||||
if err := sub.Shutdown(context.Background()); err != nil {
|
||||
t.Errorf("Shutdown failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubsystem_Good_StartBridgeNilHub verifies StartBridge is a no-op with nil hub.
|
||||
func TestSubsystem_Good_StartBridgeNilHub(t *testing.T) {
|
||||
sub := New(nil)
|
||||
// Should not panic
|
||||
sub.StartBridge(context.Background())
|
||||
}
|
||||
|
||||
// TestSubsystem_Good_WithOptions verifies all config options apply correctly.
|
||||
func TestSubsystem_Good_WithOptions(t *testing.T) {
|
||||
hub := ws.NewHub()
|
||||
sub := New(hub,
|
||||
WithLaravelURL("ws://custom:1234/ws"),
|
||||
WithWorkspaceRoot("/tmp/test"),
|
||||
WithReconnectInterval(5*time.Second),
|
||||
WithToken("secret-123"),
|
||||
)
|
||||
|
||||
if sub.cfg.LaravelWSURL != "ws://custom:1234/ws" {
|
||||
t.Errorf("expected custom URL, got %q", sub.cfg.LaravelWSURL)
|
||||
}
|
||||
if sub.cfg.WorkspaceRoot != "/tmp/test" {
|
||||
t.Errorf("expected workspace '/tmp/test', got %q", sub.cfg.WorkspaceRoot)
|
||||
}
|
||||
if sub.cfg.ReconnectInterval != 5*time.Second {
|
||||
t.Errorf("expected 5s reconnect interval, got %v", sub.cfg.ReconnectInterval)
|
||||
}
|
||||
if sub.cfg.Token != "secret-123" {
|
||||
t.Errorf("expected token 'secret-123', got %q", sub.cfg.Token)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tool sends correct bridge message type ---
|
||||
|
||||
// TestChatSend_Good_BridgeMessageType verifies the bridge receives the correct message type.
|
||||
func TestChatSend_Good_BridgeMessageType(t *testing.T) {
|
||||
msgCh := make(chan BridgeMessage, 1)
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := testUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
_, data, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var msg BridgeMessage
|
||||
json.Unmarshal(data, &msg)
|
||||
msgCh <- msg
|
||||
// Keep alive
|
||||
for {
|
||||
if _, _, err := conn.ReadMessage(); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
hub := ws.NewHub()
|
||||
ctx := t.Context()
|
||||
go hub.Run(ctx)
|
||||
|
||||
sub := New(hub, WithLaravelURL(wsURL(ts)), WithReconnectInterval(50*time.Millisecond))
|
||||
sub.StartBridge(ctx)
|
||||
waitConnected(t, sub.Bridge(), 2*time.Second)
|
||||
|
||||
sub.chatSend(ctx, nil, ChatSendInput{SessionID: "s1", Message: "test"})
|
||||
|
||||
select {
|
||||
case received := <-msgCh:
|
||||
if received.Type != "chat_send" {
|
||||
t.Errorf("expected bridge message type 'chat_send', got %q", received.Type)
|
||||
}
|
||||
if received.Channel != "chat:s1" {
|
||||
t.Errorf("expected channel 'chat:s1', got %q", received.Channel)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for bridge message")
|
||||
}
|
||||
}
|
||||
|
|
@ -1,121 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIntegration_FileTools(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
s, err := New(WithWorkspaceRoot(tmpDir))
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// 1. Test file_write
|
||||
writeInput := WriteFileInput{
|
||||
Path: "test.txt",
|
||||
Content: "hello world",
|
||||
}
|
||||
_, writeOutput, err := s.writeFile(ctx, nil, writeInput)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, writeOutput.Success)
|
||||
assert.Equal(t, "test.txt", writeOutput.Path)
|
||||
|
||||
// Verify on disk
|
||||
content, _ := os.ReadFile(filepath.Join(tmpDir, "test.txt"))
|
||||
assert.Equal(t, "hello world", string(content))
|
||||
|
||||
// 2. Test file_read
|
||||
readInput := ReadFileInput{
|
||||
Path: "test.txt",
|
||||
}
|
||||
_, readOutput, err := s.readFile(ctx, nil, readInput)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "hello world", readOutput.Content)
|
||||
assert.Equal(t, "plaintext", readOutput.Language)
|
||||
|
||||
// 3. Test file_edit (replace_all=false)
|
||||
editInput := EditDiffInput{
|
||||
Path: "test.txt",
|
||||
OldString: "world",
|
||||
NewString: "mcp",
|
||||
}
|
||||
_, editOutput, err := s.editDiff(ctx, nil, editInput)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, editOutput.Success)
|
||||
assert.Equal(t, 1, editOutput.Replacements)
|
||||
|
||||
// Verify change
|
||||
_, readOutput, _ = s.readFile(ctx, nil, readInput)
|
||||
assert.Equal(t, "hello mcp", readOutput.Content)
|
||||
|
||||
// 4. Test file_edit (replace_all=true)
|
||||
_ = s.medium.Write("multi.txt", "abc abc abc")
|
||||
editInputMulti := EditDiffInput{
|
||||
Path: "multi.txt",
|
||||
OldString: "abc",
|
||||
NewString: "xyz",
|
||||
ReplaceAll: true,
|
||||
}
|
||||
_, editOutput, err = s.editDiff(ctx, nil, editInputMulti)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, editOutput.Replacements)
|
||||
|
||||
content, _ = os.ReadFile(filepath.Join(tmpDir, "multi.txt"))
|
||||
assert.Equal(t, "xyz xyz xyz", string(content))
|
||||
|
||||
// 5. Test dir_list
|
||||
_ = s.medium.EnsureDir("subdir")
|
||||
_ = s.medium.Write("subdir/file1.txt", "content1")
|
||||
|
||||
listInput := ListDirectoryInput{
|
||||
Path: "subdir",
|
||||
}
|
||||
_, listOutput, err := s.listDirectory(ctx, nil, listInput)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, listOutput.Entries, 1)
|
||||
assert.Equal(t, "file1.txt", listOutput.Entries[0].Name)
|
||||
assert.False(t, listOutput.Entries[0].IsDir)
|
||||
}
|
||||
|
||||
func TestIntegration_ErrorPaths(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
s, err := New(WithWorkspaceRoot(tmpDir))
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Read nonexistent file
|
||||
_, _, err = s.readFile(ctx, nil, ReadFileInput{Path: "nonexistent.txt"})
|
||||
assert.Error(t, err)
|
||||
|
||||
// Edit nonexistent file
|
||||
_, _, err = s.editDiff(ctx, nil, EditDiffInput{
|
||||
Path: "nonexistent.txt",
|
||||
OldString: "foo",
|
||||
NewString: "bar",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
|
||||
// Edit with empty old_string
|
||||
_, _, err = s.editDiff(ctx, nil, EditDiffInput{
|
||||
Path: "test.txt",
|
||||
OldString: "",
|
||||
NewString: "bar",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
|
||||
// Edit with old_string not found
|
||||
_ = s.medium.Write("test.txt", "hello")
|
||||
_, _, err = s.editDiff(ctx, nil, EditDiffInput{
|
||||
Path: "test.txt",
|
||||
OldString: "missing",
|
||||
NewString: "bar",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
|
@ -1,40 +0,0 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestService_Iterators(t *testing.T) {
|
||||
svc, err := New(WithWorkspaceRoot(t.TempDir()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Test ToolsSeq
|
||||
tools := slices.Collect(svc.ToolsSeq())
|
||||
if len(tools) == 0 {
|
||||
t.Error("expected non-empty ToolsSeq")
|
||||
}
|
||||
if len(tools) != len(svc.Tools()) {
|
||||
t.Errorf("ToolsSeq length %d != Tools() length %d", len(tools), len(svc.Tools()))
|
||||
}
|
||||
|
||||
// Test SubsystemsSeq
|
||||
subsystems := slices.Collect(svc.SubsystemsSeq())
|
||||
if len(subsystems) != len(svc.Subsystems()) {
|
||||
t.Errorf("SubsystemsSeq length %d != Subsystems() length %d", len(subsystems), len(svc.Subsystems()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_SplitTagSeq(t *testing.T) {
|
||||
tag := "name,omitempty,json"
|
||||
parts := slices.Collect(splitTagSeq(tag))
|
||||
expected := []string{"name", "omitempty", "json"}
|
||||
|
||||
if !slices.Equal(parts, expected) {
|
||||
t.Errorf("expected %v, got %v", expected, parts)
|
||||
}
|
||||
}
|
||||
580
mcp/mcp.go
580
mcp/mcp.go
|
|
@ -1,580 +0,0 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
// Package mcp provides a lightweight MCP (Model Context Protocol) server for CLI use.
|
||||
// For full GUI integration (display, webview, process management), see core-gui/pkg/mcp.
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"iter"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"forge.lthn.ai/core/go-io"
|
||||
"forge.lthn.ai/core/go-log"
|
||||
"forge.lthn.ai/core/go-process"
|
||||
"forge.lthn.ai/core/go-ws"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// Service provides a lightweight MCP server with file operations only.
|
||||
// For full GUI features, use the core-gui package.
|
||||
type Service struct {
|
||||
server *mcp.Server
|
||||
workspaceRoot string // Root directory for file operations (empty = unrestricted)
|
||||
medium io.Medium // Filesystem medium for sandboxed operations
|
||||
subsystems []Subsystem // Additional subsystems registered via WithSubsystem
|
||||
logger *log.Logger // Logger for tool execution auditing
|
||||
processService *process.Service // Process management service (optional)
|
||||
wsHub *ws.Hub // WebSocket hub for real-time streaming (optional)
|
||||
wsServer *http.Server // WebSocket HTTP server (optional)
|
||||
wsAddr string // WebSocket server address
|
||||
tools []ToolRecord // Parallel tool registry for REST bridge
|
||||
}
|
||||
|
||||
// Option configures a Service.
|
||||
type Option func(*Service) error
|
||||
|
||||
// WithWorkspaceRoot restricts file operations to the given directory.
|
||||
// All paths are validated to be within this directory.
|
||||
// An empty string disables the restriction (not recommended).
|
||||
func WithWorkspaceRoot(root string) Option {
|
||||
return func(s *Service) error {
|
||||
if root == "" {
|
||||
// Explicitly disable restriction - use unsandboxed global
|
||||
s.workspaceRoot = ""
|
||||
s.medium = io.Local
|
||||
return nil
|
||||
}
|
||||
// Create sandboxed medium for this workspace
|
||||
abs, err := filepath.Abs(root)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid workspace root: %w", err)
|
||||
}
|
||||
m, err := io.NewSandboxed(abs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create workspace medium: %w", err)
|
||||
}
|
||||
s.workspaceRoot = abs
|
||||
s.medium = m
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a new MCP service with file operations.
|
||||
// By default, restricts file access to the current working directory.
|
||||
// Use WithWorkspaceRoot("") to disable restrictions (not recommended).
|
||||
// Returns an error if initialization fails.
|
||||
func New(opts ...Option) (*Service, error) {
|
||||
impl := &mcp.Implementation{
|
||||
Name: "core-cli",
|
||||
Version: "0.1.0",
|
||||
}
|
||||
|
||||
server := mcp.NewServer(impl, nil)
|
||||
s := &Service{
|
||||
server: server,
|
||||
logger: log.Default(),
|
||||
}
|
||||
|
||||
// Default to current working directory with sandboxed medium
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get working directory: %w", err)
|
||||
}
|
||||
s.workspaceRoot = cwd
|
||||
m, err := io.NewSandboxed(cwd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create sandboxed medium: %w", err)
|
||||
}
|
||||
s.medium = m
|
||||
|
||||
// Apply options
|
||||
for _, opt := range opts {
|
||||
if err := opt(s); err != nil {
|
||||
return nil, fmt.Errorf("failed to apply option: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.registerTools(s.server)
|
||||
|
||||
// Register subsystem tools.
|
||||
for _, sub := range s.subsystems {
|
||||
sub.RegisterTools(s.server)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Subsystems returns the registered subsystems.
|
||||
func (s *Service) Subsystems() []Subsystem {
|
||||
return s.subsystems
|
||||
}
|
||||
|
||||
// SubsystemsSeq returns an iterator over the registered subsystems.
|
||||
func (s *Service) SubsystemsSeq() iter.Seq[Subsystem] {
|
||||
return slices.Values(s.subsystems)
|
||||
}
|
||||
|
||||
// Tools returns all recorded tool metadata.
|
||||
func (s *Service) Tools() []ToolRecord {
|
||||
return s.tools
|
||||
}
|
||||
|
||||
// ToolsSeq returns an iterator over all recorded tool metadata.
|
||||
func (s *Service) ToolsSeq() iter.Seq[ToolRecord] {
|
||||
return slices.Values(s.tools)
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down all subsystems that support it.
|
||||
func (s *Service) Shutdown(ctx context.Context) error {
|
||||
for _, sub := range s.subsystems {
|
||||
if sh, ok := sub.(SubsystemWithShutdown); ok {
|
||||
if err := sh.Shutdown(ctx); err != nil {
|
||||
return fmt.Errorf("shutdown %s: %w", sub.Name(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WithProcessService configures the process management service.
|
||||
func WithProcessService(ps *process.Service) Option {
|
||||
return func(s *Service) error {
|
||||
s.processService = ps
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithWSHub configures the WebSocket hub for real-time streaming.
|
||||
func WithWSHub(hub *ws.Hub) Option {
|
||||
return func(s *Service) error {
|
||||
s.wsHub = hub
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WSHub returns the WebSocket hub.
|
||||
func (s *Service) WSHub() *ws.Hub {
|
||||
return s.wsHub
|
||||
}
|
||||
|
||||
// ProcessService returns the process service.
|
||||
func (s *Service) ProcessService() *process.Service {
|
||||
return s.processService
|
||||
}
|
||||
|
||||
// registerTools adds file operation tools to the MCP server.
|
||||
func (s *Service) registerTools(server *mcp.Server) {
|
||||
// File operations
|
||||
addToolRecorded(s, server, "files", &mcp.Tool{
|
||||
Name: "file_read",
|
||||
Description: "Read the contents of a file",
|
||||
}, s.readFile)
|
||||
|
||||
addToolRecorded(s, server, "files", &mcp.Tool{
|
||||
Name: "file_write",
|
||||
Description: "Write content to a file",
|
||||
}, s.writeFile)
|
||||
|
||||
addToolRecorded(s, server, "files", &mcp.Tool{
|
||||
Name: "file_delete",
|
||||
Description: "Delete a file or empty directory",
|
||||
}, s.deleteFile)
|
||||
|
||||
addToolRecorded(s, server, "files", &mcp.Tool{
|
||||
Name: "file_rename",
|
||||
Description: "Rename or move a file",
|
||||
}, s.renameFile)
|
||||
|
||||
addToolRecorded(s, server, "files", &mcp.Tool{
|
||||
Name: "file_exists",
|
||||
Description: "Check if a file or directory exists",
|
||||
}, s.fileExists)
|
||||
|
||||
addToolRecorded(s, server, "files", &mcp.Tool{
|
||||
Name: "file_edit",
|
||||
Description: "Edit a file by replacing old_string with new_string. Use replace_all=true to replace all occurrences.",
|
||||
}, s.editDiff)
|
||||
|
||||
// Directory operations
|
||||
addToolRecorded(s, server, "files", &mcp.Tool{
|
||||
Name: "dir_list",
|
||||
Description: "List contents of a directory",
|
||||
}, s.listDirectory)
|
||||
|
||||
addToolRecorded(s, server, "files", &mcp.Tool{
|
||||
Name: "dir_create",
|
||||
Description: "Create a new directory",
|
||||
}, s.createDirectory)
|
||||
|
||||
// Language detection
|
||||
addToolRecorded(s, server, "language", &mcp.Tool{
|
||||
Name: "lang_detect",
|
||||
Description: "Detect the programming language of a file",
|
||||
}, s.detectLanguage)
|
||||
|
||||
addToolRecorded(s, server, "language", &mcp.Tool{
|
||||
Name: "lang_list",
|
||||
Description: "Get list of supported programming languages",
|
||||
}, s.getSupportedLanguages)
|
||||
}
|
||||
|
||||
// Tool input/output types for MCP file operations.
|
||||
|
||||
// ReadFileInput contains parameters for reading a file.
|
||||
type ReadFileInput struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// ReadFileOutput contains the result of reading a file.
|
||||
type ReadFileOutput struct {
|
||||
Content string `json:"content"`
|
||||
Language string `json:"language"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// WriteFileInput contains parameters for writing a file.
|
||||
type WriteFileInput struct {
|
||||
Path string `json:"path"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// WriteFileOutput contains the result of writing a file.
|
||||
type WriteFileOutput struct {
|
||||
Success bool `json:"success"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// ListDirectoryInput contains parameters for listing a directory.
|
||||
type ListDirectoryInput struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// ListDirectoryOutput contains the result of listing a directory.
|
||||
type ListDirectoryOutput struct {
|
||||
Entries []DirectoryEntry `json:"entries"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// DirectoryEntry represents a single entry in a directory listing.
|
||||
type DirectoryEntry struct {
|
||||
Name string `json:"name"`
|
||||
Path string `json:"path"`
|
||||
IsDir bool `json:"isDir"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
// CreateDirectoryInput contains parameters for creating a directory.
|
||||
type CreateDirectoryInput struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// CreateDirectoryOutput contains the result of creating a directory.
|
||||
type CreateDirectoryOutput struct {
|
||||
Success bool `json:"success"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// DeleteFileInput contains parameters for deleting a file.
|
||||
type DeleteFileInput struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// DeleteFileOutput contains the result of deleting a file.
|
||||
type DeleteFileOutput struct {
|
||||
Success bool `json:"success"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// RenameFileInput contains parameters for renaming a file.
|
||||
type RenameFileInput struct {
|
||||
OldPath string `json:"oldPath"`
|
||||
NewPath string `json:"newPath"`
|
||||
}
|
||||
|
||||
// RenameFileOutput contains the result of renaming a file.
|
||||
type RenameFileOutput struct {
|
||||
Success bool `json:"success"`
|
||||
OldPath string `json:"oldPath"`
|
||||
NewPath string `json:"newPath"`
|
||||
}
|
||||
|
||||
// FileExistsInput contains parameters for checking file existence.
|
||||
type FileExistsInput struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// FileExistsOutput contains the result of checking file existence.
|
||||
type FileExistsOutput struct {
|
||||
Exists bool `json:"exists"`
|
||||
IsDir bool `json:"isDir"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// DetectLanguageInput contains parameters for detecting file language.
|
||||
type DetectLanguageInput struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// DetectLanguageOutput contains the detected programming language.
|
||||
type DetectLanguageOutput struct {
|
||||
Language string `json:"language"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// GetSupportedLanguagesInput is an empty struct for the languages query.
|
||||
type GetSupportedLanguagesInput struct{}
|
||||
|
||||
// GetSupportedLanguagesOutput contains the list of supported languages.
|
||||
type GetSupportedLanguagesOutput struct {
|
||||
Languages []LanguageInfo `json:"languages"`
|
||||
}
|
||||
|
||||
// LanguageInfo describes a supported programming language.
|
||||
type LanguageInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Extensions []string `json:"extensions"`
|
||||
}
|
||||
|
||||
// EditDiffInput contains parameters for editing a file via diff.
|
||||
type EditDiffInput struct {
|
||||
Path string `json:"path"`
|
||||
OldString string `json:"old_string"`
|
||||
NewString string `json:"new_string"`
|
||||
ReplaceAll bool `json:"replace_all,omitempty"`
|
||||
}
|
||||
|
||||
// EditDiffOutput contains the result of a diff-based edit operation.
|
||||
type EditDiffOutput struct {
|
||||
Path string `json:"path"`
|
||||
Success bool `json:"success"`
|
||||
Replacements int `json:"replacements"`
|
||||
}
|
||||
|
||||
// Tool handlers
|
||||
|
||||
func (s *Service) readFile(ctx context.Context, req *mcp.CallToolRequest, input ReadFileInput) (*mcp.CallToolResult, ReadFileOutput, error) {
|
||||
content, err := s.medium.Read(input.Path)
|
||||
if err != nil {
|
||||
return nil, ReadFileOutput{}, fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
return nil, ReadFileOutput{
|
||||
Content: content,
|
||||
Language: detectLanguageFromPath(input.Path),
|
||||
Path: input.Path,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) writeFile(ctx context.Context, req *mcp.CallToolRequest, input WriteFileInput) (*mcp.CallToolResult, WriteFileOutput, error) {
|
||||
// Medium.Write creates parent directories automatically
|
||||
if err := s.medium.Write(input.Path, input.Content); err != nil {
|
||||
return nil, WriteFileOutput{}, fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
return nil, WriteFileOutput{Success: true, Path: input.Path}, nil
|
||||
}
|
||||
|
||||
func (s *Service) listDirectory(ctx context.Context, req *mcp.CallToolRequest, input ListDirectoryInput) (*mcp.CallToolResult, ListDirectoryOutput, error) {
|
||||
entries, err := s.medium.List(input.Path)
|
||||
if err != nil {
|
||||
return nil, ListDirectoryOutput{}, fmt.Errorf("failed to list directory: %w", err)
|
||||
}
|
||||
result := make([]DirectoryEntry, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
info, _ := e.Info()
|
||||
var size int64
|
||||
if info != nil {
|
||||
size = info.Size()
|
||||
}
|
||||
result = append(result, DirectoryEntry{
|
||||
Name: e.Name(),
|
||||
Path: filepath.Join(input.Path, e.Name()), // Note: This might be relative path, client might expect absolute?
|
||||
// Issue 103 says "Replace ... with local.Medium sandboxing".
|
||||
// Previous code returned `filepath.Join(input.Path, e.Name())`.
|
||||
// If input.Path is relative, this preserves it.
|
||||
IsDir: e.IsDir(),
|
||||
Size: size,
|
||||
})
|
||||
}
|
||||
return nil, ListDirectoryOutput{Entries: result, Path: input.Path}, nil
|
||||
}
|
||||
|
||||
func (s *Service) createDirectory(ctx context.Context, req *mcp.CallToolRequest, input CreateDirectoryInput) (*mcp.CallToolResult, CreateDirectoryOutput, error) {
|
||||
if err := s.medium.EnsureDir(input.Path); err != nil {
|
||||
return nil, CreateDirectoryOutput{}, fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
return nil, CreateDirectoryOutput{Success: true, Path: input.Path}, nil
|
||||
}
|
||||
|
||||
func (s *Service) deleteFile(ctx context.Context, req *mcp.CallToolRequest, input DeleteFileInput) (*mcp.CallToolResult, DeleteFileOutput, error) {
|
||||
if err := s.medium.Delete(input.Path); err != nil {
|
||||
return nil, DeleteFileOutput{}, fmt.Errorf("failed to delete file: %w", err)
|
||||
}
|
||||
return nil, DeleteFileOutput{Success: true, Path: input.Path}, nil
|
||||
}
|
||||
|
||||
func (s *Service) renameFile(ctx context.Context, req *mcp.CallToolRequest, input RenameFileInput) (*mcp.CallToolResult, RenameFileOutput, error) {
|
||||
if err := s.medium.Rename(input.OldPath, input.NewPath); err != nil {
|
||||
return nil, RenameFileOutput{}, fmt.Errorf("failed to rename file: %w", err)
|
||||
}
|
||||
return nil, RenameFileOutput{Success: true, OldPath: input.OldPath, NewPath: input.NewPath}, nil
|
||||
}
|
||||
|
||||
func (s *Service) fileExists(ctx context.Context, req *mcp.CallToolRequest, input FileExistsInput) (*mcp.CallToolResult, FileExistsOutput, error) {
|
||||
exists := s.medium.IsFile(input.Path)
|
||||
if exists {
|
||||
return nil, FileExistsOutput{Exists: true, IsDir: false, Path: input.Path}, nil
|
||||
}
|
||||
// Check if it's a directory by attempting to list it
|
||||
// List might fail if it's a file too (but we checked IsFile) or if doesn't exist.
|
||||
_, err := s.medium.List(input.Path)
|
||||
isDir := err == nil
|
||||
|
||||
// If List failed, it might mean it doesn't exist OR it's a special file or permissions.
|
||||
// Assuming if List works, it's a directory.
|
||||
|
||||
// Refinement: If it doesn't exist, List returns error.
|
||||
|
||||
return nil, FileExistsOutput{Exists: isDir, IsDir: isDir, Path: input.Path}, nil
|
||||
}
|
||||
|
||||
func (s *Service) detectLanguage(ctx context.Context, req *mcp.CallToolRequest, input DetectLanguageInput) (*mcp.CallToolResult, DetectLanguageOutput, error) {
|
||||
lang := detectLanguageFromPath(input.Path)
|
||||
return nil, DetectLanguageOutput{Language: lang, Path: input.Path}, nil
|
||||
}
|
||||
|
||||
func (s *Service) getSupportedLanguages(ctx context.Context, req *mcp.CallToolRequest, input GetSupportedLanguagesInput) (*mcp.CallToolResult, GetSupportedLanguagesOutput, error) {
|
||||
languages := []LanguageInfo{
|
||||
{ID: "typescript", Name: "TypeScript", Extensions: []string{".ts", ".tsx"}},
|
||||
{ID: "javascript", Name: "JavaScript", Extensions: []string{".js", ".jsx"}},
|
||||
{ID: "go", Name: "Go", Extensions: []string{".go"}},
|
||||
{ID: "python", Name: "Python", Extensions: []string{".py"}},
|
||||
{ID: "rust", Name: "Rust", Extensions: []string{".rs"}},
|
||||
{ID: "java", Name: "Java", Extensions: []string{".java"}},
|
||||
{ID: "php", Name: "PHP", Extensions: []string{".php"}},
|
||||
{ID: "ruby", Name: "Ruby", Extensions: []string{".rb"}},
|
||||
{ID: "html", Name: "HTML", Extensions: []string{".html", ".htm"}},
|
||||
{ID: "css", Name: "CSS", Extensions: []string{".css"}},
|
||||
{ID: "json", Name: "JSON", Extensions: []string{".json"}},
|
||||
{ID: "yaml", Name: "YAML", Extensions: []string{".yaml", ".yml"}},
|
||||
{ID: "markdown", Name: "Markdown", Extensions: []string{".md", ".markdown"}},
|
||||
{ID: "sql", Name: "SQL", Extensions: []string{".sql"}},
|
||||
{ID: "shell", Name: "Shell", Extensions: []string{".sh", ".bash"}},
|
||||
}
|
||||
return nil, GetSupportedLanguagesOutput{Languages: languages}, nil
|
||||
}
|
||||
|
||||
func (s *Service) editDiff(ctx context.Context, req *mcp.CallToolRequest, input EditDiffInput) (*mcp.CallToolResult, EditDiffOutput, error) {
|
||||
if input.OldString == "" {
|
||||
return nil, EditDiffOutput{}, errors.New("old_string cannot be empty")
|
||||
}
|
||||
|
||||
content, err := s.medium.Read(input.Path)
|
||||
if err != nil {
|
||||
return nil, EditDiffOutput{}, fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
|
||||
count := 0
|
||||
|
||||
if input.ReplaceAll {
|
||||
count = strings.Count(content, input.OldString)
|
||||
if count == 0 {
|
||||
return nil, EditDiffOutput{}, errors.New("old_string not found in file")
|
||||
}
|
||||
content = strings.ReplaceAll(content, input.OldString, input.NewString)
|
||||
} else {
|
||||
if !strings.Contains(content, input.OldString) {
|
||||
return nil, EditDiffOutput{}, errors.New("old_string not found in file")
|
||||
}
|
||||
content = strings.Replace(content, input.OldString, input.NewString, 1)
|
||||
count = 1
|
||||
}
|
||||
|
||||
if err := s.medium.Write(input.Path, content); err != nil {
|
||||
return nil, EditDiffOutput{}, fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
|
||||
return nil, EditDiffOutput{
|
||||
Path: input.Path,
|
||||
Success: true,
|
||||
Replacements: count,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// detectLanguageFromPath maps file extensions to language IDs.
|
||||
func detectLanguageFromPath(path string) string {
|
||||
ext := filepath.Ext(path)
|
||||
switch ext {
|
||||
case ".ts", ".tsx":
|
||||
return "typescript"
|
||||
case ".js", ".jsx":
|
||||
return "javascript"
|
||||
case ".go":
|
||||
return "go"
|
||||
case ".py":
|
||||
return "python"
|
||||
case ".rs":
|
||||
return "rust"
|
||||
case ".rb":
|
||||
return "ruby"
|
||||
case ".java":
|
||||
return "java"
|
||||
case ".php":
|
||||
return "php"
|
||||
case ".c", ".h":
|
||||
return "c"
|
||||
case ".cpp", ".hpp", ".cc", ".cxx":
|
||||
return "cpp"
|
||||
case ".cs":
|
||||
return "csharp"
|
||||
case ".html", ".htm":
|
||||
return "html"
|
||||
case ".css":
|
||||
return "css"
|
||||
case ".scss":
|
||||
return "scss"
|
||||
case ".json":
|
||||
return "json"
|
||||
case ".yaml", ".yml":
|
||||
return "yaml"
|
||||
case ".xml":
|
||||
return "xml"
|
||||
case ".md", ".markdown":
|
||||
return "markdown"
|
||||
case ".sql":
|
||||
return "sql"
|
||||
case ".sh", ".bash":
|
||||
return "shell"
|
||||
case ".swift":
|
||||
return "swift"
|
||||
case ".kt", ".kts":
|
||||
return "kotlin"
|
||||
default:
|
||||
if filepath.Base(path) == "Dockerfile" {
|
||||
return "dockerfile"
|
||||
}
|
||||
return "plaintext"
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the MCP server.
|
||||
// If MCP_ADDR is set, it starts a TCP server.
|
||||
// Otherwise, it starts a Stdio server.
|
||||
func (s *Service) Run(ctx context.Context) error {
|
||||
addr := os.Getenv("MCP_ADDR")
|
||||
if addr != "" {
|
||||
return s.ServeTCP(ctx, addr)
|
||||
}
|
||||
return s.server.Run(ctx, &mcp.StdioTransport{})
|
||||
}
|
||||
|
||||
// Server returns the underlying MCP server for advanced configuration.
|
||||
func (s *Service) Server() *mcp.Server {
|
||||
return s.server
|
||||
}
|
||||
180
mcp/mcp_test.go
180
mcp/mcp_test.go
|
|
@ -1,180 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNew_Good_DefaultWorkspace(t *testing.T) {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get working directory: %v", err)
|
||||
}
|
||||
|
||||
s, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
if s.workspaceRoot != cwd {
|
||||
t.Errorf("Expected default workspace root %s, got %s", cwd, s.workspaceRoot)
|
||||
}
|
||||
if s.medium == nil {
|
||||
t.Error("Expected medium to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_Good_CustomWorkspace(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
s, err := New(WithWorkspaceRoot(tmpDir))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
if s.workspaceRoot != tmpDir {
|
||||
t.Errorf("Expected workspace root %s, got %s", tmpDir, s.workspaceRoot)
|
||||
}
|
||||
if s.medium == nil {
|
||||
t.Error("Expected medium to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew_Good_NoRestriction(t *testing.T) {
|
||||
s, err := New(WithWorkspaceRoot(""))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
if s.workspaceRoot != "" {
|
||||
t.Errorf("Expected empty workspace root, got %s", s.workspaceRoot)
|
||||
}
|
||||
if s.medium == nil {
|
||||
t.Error("Expected medium to be set (unsandboxed)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMedium_Good_ReadWrite(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
s, err := New(WithWorkspaceRoot(tmpDir))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
// Write a file
|
||||
testContent := "hello world"
|
||||
err = s.medium.Write("test.txt", testContent)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write file: %v", err)
|
||||
}
|
||||
|
||||
// Read it back
|
||||
content, err := s.medium.Read("test.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read file: %v", err)
|
||||
}
|
||||
if content != testContent {
|
||||
t.Errorf("Expected content %q, got %q", testContent, content)
|
||||
}
|
||||
|
||||
// Verify file exists on disk
|
||||
diskPath := filepath.Join(tmpDir, "test.txt")
|
||||
if _, err := os.Stat(diskPath); os.IsNotExist(err) {
|
||||
t.Error("File should exist on disk")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMedium_Good_EnsureDir(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
s, err := New(WithWorkspaceRoot(tmpDir))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
err = s.medium.EnsureDir("subdir/nested")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
// Verify directory exists
|
||||
diskPath := filepath.Join(tmpDir, "subdir", "nested")
|
||||
info, err := os.Stat(diskPath)
|
||||
if os.IsNotExist(err) {
|
||||
t.Error("Directory should exist on disk")
|
||||
}
|
||||
if err == nil && !info.IsDir() {
|
||||
t.Error("Path should be a directory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMedium_Good_IsFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
s, err := New(WithWorkspaceRoot(tmpDir))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
// File doesn't exist yet
|
||||
if s.medium.IsFile("test.txt") {
|
||||
t.Error("File should not exist yet")
|
||||
}
|
||||
|
||||
// Create the file
|
||||
_ = s.medium.Write("test.txt", "content")
|
||||
|
||||
// Now it should exist
|
||||
if !s.medium.IsFile("test.txt") {
|
||||
t.Error("File should exist after write")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSandboxing_Traversal_Sanitized(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
s, err := New(WithWorkspaceRoot(tmpDir))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
// Path traversal is sanitized (.. becomes .), so ../secret.txt becomes
|
||||
// ./secret.txt in the workspace. Since that file doesn't exist, we get
|
||||
// a file not found error (not a traversal error).
|
||||
_, err = s.medium.Read("../secret.txt")
|
||||
if err == nil {
|
||||
t.Error("Expected error (file not found)")
|
||||
}
|
||||
|
||||
// Absolute paths are allowed through - they access the real filesystem.
|
||||
// This is intentional for full filesystem access. Callers wanting sandboxing
|
||||
// should validate inputs before calling Medium.
|
||||
}
|
||||
|
||||
func TestSandboxing_Symlinks_Blocked(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
outsideDir := t.TempDir()
|
||||
|
||||
// Create a target file outside workspace
|
||||
targetFile := filepath.Join(outsideDir, "secret.txt")
|
||||
if err := os.WriteFile(targetFile, []byte("secret"), 0644); err != nil {
|
||||
t.Fatalf("Failed to create target file: %v", err)
|
||||
}
|
||||
|
||||
// Create symlink inside workspace pointing outside
|
||||
symlinkPath := filepath.Join(tmpDir, "link")
|
||||
if err := os.Symlink(targetFile, symlinkPath); err != nil {
|
||||
t.Skipf("Symlinks not supported: %v", err)
|
||||
}
|
||||
|
||||
s, err := New(WithWorkspaceRoot(tmpDir))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
// Symlinks pointing outside the sandbox root are blocked (security feature).
|
||||
// The sandbox resolves the symlink target and rejects it because it escapes
|
||||
// the workspace boundary.
|
||||
_, err = s.medium.Read("link")
|
||||
if err == nil {
|
||||
t.Error("Expected permission denied for symlink escaping sandbox, but read succeeded")
|
||||
}
|
||||
}
|
||||
149
mcp/registry.go
149
mcp/registry.go
|
|
@ -1,149 +0,0 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"iter"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// RESTHandler handles a tool call from a REST endpoint.
|
||||
// It receives raw JSON input and returns the typed output or an error.
|
||||
type RESTHandler func(ctx context.Context, body []byte) (any, error)
|
||||
|
||||
// ToolRecord captures metadata about a registered MCP tool.
|
||||
type ToolRecord struct {
|
||||
Name string // Tool name, e.g. "file_read"
|
||||
Description string // Human-readable description
|
||||
Group string // Subsystem group name, e.g. "files", "rag"
|
||||
InputSchema map[string]any // JSON Schema from Go struct reflection
|
||||
OutputSchema map[string]any // JSON Schema from Go struct reflection
|
||||
RESTHandler RESTHandler // REST-callable handler created at registration time
|
||||
}
|
||||
|
||||
// addToolRecorded registers a tool with the MCP server AND records its metadata.
|
||||
// This is a generic function that captures the In/Out types for schema extraction.
|
||||
// It also creates a RESTHandler closure that can unmarshal JSON to the correct
|
||||
// input type and call the handler directly, enabling the MCP-to-REST bridge.
|
||||
func addToolRecorded[In, Out any](s *Service, server *mcp.Server, group string, t *mcp.Tool, h mcp.ToolHandlerFor[In, Out]) {
|
||||
mcp.AddTool(server, t, h)
|
||||
|
||||
restHandler := func(ctx context.Context, body []byte) (any, error) {
|
||||
var input In
|
||||
if len(body) > 0 {
|
||||
if err := json.Unmarshal(body, &input); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// nil: REST callers have no MCP request context.
|
||||
// Tool handlers called via REST must not dereference CallToolRequest.
|
||||
_, output, err := h(ctx, nil, input)
|
||||
return output, err
|
||||
}
|
||||
|
||||
s.tools = append(s.tools, ToolRecord{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Group: group,
|
||||
InputSchema: structSchema(new(In)),
|
||||
OutputSchema: structSchema(new(Out)),
|
||||
RESTHandler: restHandler,
|
||||
})
|
||||
}
|
||||
|
||||
// structSchema builds a simple JSON Schema from a struct's json tags via reflection.
|
||||
// Returns nil for non-struct types or empty structs.
|
||||
func structSchema(v any) map[string]any {
|
||||
t := reflect.TypeOf(v)
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
if t.Kind() == reflect.Pointer {
|
||||
t = t.Elem()
|
||||
}
|
||||
if t.Kind() != reflect.Struct {
|
||||
return nil
|
||||
}
|
||||
if t.NumField() == 0 {
|
||||
return map[string]any{"type": "object", "properties": map[string]any{}}
|
||||
}
|
||||
|
||||
properties := make(map[string]any)
|
||||
required := make([]string, 0)
|
||||
|
||||
for f := range t.Fields() {
|
||||
f := f
|
||||
if !f.IsExported() {
|
||||
continue
|
||||
}
|
||||
jsonTag := f.Tag.Get("json")
|
||||
if jsonTag == "-" {
|
||||
continue
|
||||
}
|
||||
name := f.Name
|
||||
isOptional := false
|
||||
if jsonTag != "" {
|
||||
parts := splitTag(jsonTag)
|
||||
name = parts[0]
|
||||
for _, p := range parts[1:] {
|
||||
if p == "omitempty" {
|
||||
isOptional = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prop := map[string]any{
|
||||
"type": goTypeToJSONType(f.Type),
|
||||
}
|
||||
properties[name] = prop
|
||||
|
||||
if !isOptional {
|
||||
required = append(required, name)
|
||||
}
|
||||
}
|
||||
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
}
|
||||
if len(required) > 0 {
|
||||
schema["required"] = required
|
||||
}
|
||||
return schema
|
||||
}
|
||||
|
||||
// splitTag splits a struct tag value by commas.
|
||||
func splitTag(tag string) []string {
|
||||
return strings.Split(tag, ",")
|
||||
}
|
||||
|
||||
// splitTagSeq returns an iterator over the tag parts.
|
||||
func splitTagSeq(tag string) iter.Seq[string] {
|
||||
return strings.SplitSeq(tag, ",")
|
||||
}
|
||||
|
||||
// goTypeToJSONType maps Go types to JSON Schema types.
|
||||
func goTypeToJSONType(t reflect.Type) string {
|
||||
switch t.Kind() {
|
||||
case reflect.String:
|
||||
return "string"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return "integer"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "number"
|
||||
case reflect.Bool:
|
||||
return "boolean"
|
||||
case reflect.Slice, reflect.Array:
|
||||
return "array"
|
||||
case reflect.Map, reflect.Struct:
|
||||
return "object"
|
||||
default:
|
||||
return "string"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,150 +0,0 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestToolRegistry_Good_RecordsTools(t *testing.T) {
|
||||
svc, err := New(WithWorkspaceRoot(t.TempDir()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tools := svc.Tools()
|
||||
if len(tools) == 0 {
|
||||
t.Fatal("expected non-empty tool registry")
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, tr := range tools {
|
||||
if tr.Name == "file_read" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected file_read in tool registry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Good_SchemaExtraction(t *testing.T) {
|
||||
svc, err := New(WithWorkspaceRoot(t.TempDir()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var record ToolRecord
|
||||
for _, tr := range svc.Tools() {
|
||||
if tr.Name == "file_read" {
|
||||
record = tr
|
||||
break
|
||||
}
|
||||
}
|
||||
if record.Name == "" {
|
||||
t.Fatal("file_read not found in registry")
|
||||
}
|
||||
|
||||
if record.InputSchema == nil {
|
||||
t.Fatal("expected non-nil InputSchema for file_read")
|
||||
}
|
||||
|
||||
props, ok := record.InputSchema["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("expected properties map in InputSchema")
|
||||
}
|
||||
|
||||
if _, ok := props["path"]; !ok {
|
||||
t.Error("expected 'path' property in file_read InputSchema")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Good_ToolCount(t *testing.T) {
|
||||
svc, err := New(WithWorkspaceRoot(t.TempDir()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tools := svc.Tools()
|
||||
// Built-in tools: file_read, file_write, file_delete, file_rename,
|
||||
// file_exists, file_edit, dir_list, dir_create, lang_detect, lang_list
|
||||
const expectedCount = 10
|
||||
if len(tools) != expectedCount {
|
||||
t.Errorf("expected %d tools, got %d", expectedCount, len(tools))
|
||||
for _, tr := range tools {
|
||||
t.Logf(" - %s (%s)", tr.Name, tr.Group)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Good_GroupAssignment(t *testing.T) {
|
||||
svc, err := New(WithWorkspaceRoot(t.TempDir()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
fileTools := []string{"file_read", "file_write", "file_delete", "file_rename", "file_exists", "file_edit", "dir_list", "dir_create"}
|
||||
langTools := []string{"lang_detect", "lang_list"}
|
||||
|
||||
byName := make(map[string]ToolRecord)
|
||||
for _, tr := range svc.Tools() {
|
||||
byName[tr.Name] = tr
|
||||
}
|
||||
|
||||
for _, name := range fileTools {
|
||||
tr, ok := byName[name]
|
||||
if !ok {
|
||||
t.Errorf("tool %s not found in registry", name)
|
||||
continue
|
||||
}
|
||||
if tr.Group != "files" {
|
||||
t.Errorf("tool %s: expected group 'files', got %q", name, tr.Group)
|
||||
}
|
||||
}
|
||||
|
||||
for _, name := range langTools {
|
||||
tr, ok := byName[name]
|
||||
if !ok {
|
||||
t.Errorf("tool %s not found in registry", name)
|
||||
continue
|
||||
}
|
||||
if tr.Group != "language" {
|
||||
t.Errorf("tool %s: expected group 'language', got %q", name, tr.Group)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolRegistry_Good_ToolRecordFields(t *testing.T) {
|
||||
svc, err := New(WithWorkspaceRoot(t.TempDir()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var record ToolRecord
|
||||
for _, tr := range svc.Tools() {
|
||||
if tr.Name == "file_write" {
|
||||
record = tr
|
||||
break
|
||||
}
|
||||
}
|
||||
if record.Name == "" {
|
||||
t.Fatal("file_write not found in registry")
|
||||
}
|
||||
|
||||
if record.Name != "file_write" {
|
||||
t.Errorf("expected Name 'file_write', got %q", record.Name)
|
||||
}
|
||||
if record.Description == "" {
|
||||
t.Error("expected non-empty Description")
|
||||
}
|
||||
if record.Group == "" {
|
||||
t.Error("expected non-empty Group")
|
||||
}
|
||||
if record.InputSchema == nil {
|
||||
t.Error("expected non-nil InputSchema")
|
||||
}
|
||||
if record.OutputSchema == nil {
|
||||
t.Error("expected non-nil OutputSchema")
|
||||
}
|
||||
}
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// Subsystem registers additional MCP tools at startup.
|
||||
// Implementations should be safe to call concurrently.
|
||||
type Subsystem interface {
|
||||
// Name returns a human-readable identifier for logging.
|
||||
Name() string
|
||||
|
||||
// RegisterTools adds tools to the MCP server during initialisation.
|
||||
RegisterTools(server *mcp.Server)
|
||||
}
|
||||
|
||||
// SubsystemWithShutdown extends Subsystem with graceful cleanup.
|
||||
type SubsystemWithShutdown interface {
|
||||
Subsystem
|
||||
Shutdown(ctx context.Context) error
|
||||
}
|
||||
|
||||
// WithSubsystem registers a subsystem whose tools will be added
|
||||
// after the built-in tools during New().
|
||||
func WithSubsystem(sub Subsystem) Option {
|
||||
return func(s *Service) error {
|
||||
s.subsystems = append(s.subsystems, sub)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
@ -1,114 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// stubSubsystem is a minimal Subsystem for testing.
|
||||
type stubSubsystem struct {
|
||||
name string
|
||||
toolsRegistered bool
|
||||
}
|
||||
|
||||
func (s *stubSubsystem) Name() string { return s.name }
|
||||
|
||||
func (s *stubSubsystem) RegisterTools(server *mcp.Server) {
|
||||
s.toolsRegistered = true
|
||||
}
|
||||
|
||||
// shutdownSubsystem tracks Shutdown calls.
|
||||
type shutdownSubsystem struct {
|
||||
stubSubsystem
|
||||
shutdownCalled bool
|
||||
shutdownErr error
|
||||
}
|
||||
|
||||
func (s *shutdownSubsystem) Shutdown(_ context.Context) error {
|
||||
s.shutdownCalled = true
|
||||
return s.shutdownErr
|
||||
}
|
||||
|
||||
func TestWithSubsystem_Good_Registration(t *testing.T) {
|
||||
sub := &stubSubsystem{name: "test-sub"}
|
||||
svc, err := New(WithSubsystem(sub))
|
||||
if err != nil {
|
||||
t.Fatalf("New() failed: %v", err)
|
||||
}
|
||||
|
||||
if len(svc.Subsystems()) != 1 {
|
||||
t.Fatalf("expected 1 subsystem, got %d", len(svc.Subsystems()))
|
||||
}
|
||||
if svc.Subsystems()[0].Name() != "test-sub" {
|
||||
t.Errorf("expected name 'test-sub', got %q", svc.Subsystems()[0].Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithSubsystem_Good_ToolsRegistered(t *testing.T) {
|
||||
sub := &stubSubsystem{name: "tools-sub"}
|
||||
_, err := New(WithSubsystem(sub))
|
||||
if err != nil {
|
||||
t.Fatalf("New() failed: %v", err)
|
||||
}
|
||||
if !sub.toolsRegistered {
|
||||
t.Error("expected RegisterTools to have been called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithSubsystem_Good_MultipleSubsystems(t *testing.T) {
|
||||
sub1 := &stubSubsystem{name: "sub-1"}
|
||||
sub2 := &stubSubsystem{name: "sub-2"}
|
||||
svc, err := New(WithSubsystem(sub1), WithSubsystem(sub2))
|
||||
if err != nil {
|
||||
t.Fatalf("New() failed: %v", err)
|
||||
}
|
||||
if len(svc.Subsystems()) != 2 {
|
||||
t.Fatalf("expected 2 subsystems, got %d", len(svc.Subsystems()))
|
||||
}
|
||||
if !sub1.toolsRegistered || !sub2.toolsRegistered {
|
||||
t.Error("expected all subsystems to have RegisterTools called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubsystemShutdown_Good(t *testing.T) {
|
||||
sub := &shutdownSubsystem{stubSubsystem: stubSubsystem{name: "shutdown-sub"}}
|
||||
svc, err := New(WithSubsystem(sub))
|
||||
if err != nil {
|
||||
t.Fatalf("New() failed: %v", err)
|
||||
}
|
||||
if err := svc.Shutdown(context.Background()); err != nil {
|
||||
t.Fatalf("Shutdown() failed: %v", err)
|
||||
}
|
||||
if !sub.shutdownCalled {
|
||||
t.Error("expected Shutdown to have been called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubsystemShutdown_Bad_Error(t *testing.T) {
|
||||
sub := &shutdownSubsystem{
|
||||
stubSubsystem: stubSubsystem{name: "fail-sub"},
|
||||
shutdownErr: context.DeadlineExceeded,
|
||||
}
|
||||
svc, err := New(WithSubsystem(sub))
|
||||
if err != nil {
|
||||
t.Fatalf("New() failed: %v", err)
|
||||
}
|
||||
err = svc.Shutdown(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected error from Shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubsystemShutdown_Good_NoShutdownInterface(t *testing.T) {
|
||||
// A plain Subsystem (without Shutdown) should not cause errors.
|
||||
sub := &stubSubsystem{name: "plain-sub"}
|
||||
svc, err := New(WithSubsystem(sub))
|
||||
if err != nil {
|
||||
t.Fatalf("New() failed: %v", err)
|
||||
}
|
||||
if err := svc.Shutdown(context.Background()); err != nil {
|
||||
t.Fatalf("Shutdown() should succeed for non-shutdown subsystem: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,213 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/go-ai/ai"
|
||||
"forge.lthn.ai/core/go-log"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// Default values for metrics operations.
|
||||
const (
|
||||
DefaultMetricsSince = "7d"
|
||||
DefaultMetricsLimit = 10
|
||||
)
|
||||
|
||||
// MetricsRecordInput contains parameters for recording a metrics event.
|
||||
type MetricsRecordInput struct {
|
||||
Type string `json:"type"` // Event type (required)
|
||||
AgentID string `json:"agent_id,omitempty"` // Agent identifier
|
||||
Repo string `json:"repo,omitempty"` // Repository name
|
||||
Data map[string]any `json:"data,omitempty"` // Additional event data
|
||||
}
|
||||
|
||||
// MetricsRecordOutput contains the result of recording a metrics event.
|
||||
type MetricsRecordOutput struct {
|
||||
Success bool `json:"success"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// MetricsQueryInput contains parameters for querying metrics.
|
||||
type MetricsQueryInput struct {
|
||||
Since string `json:"since,omitempty"` // Time range like "7d", "24h", "30m" (default: "7d")
|
||||
}
|
||||
|
||||
// MetricsQueryOutput contains the results of a metrics query.
|
||||
type MetricsQueryOutput struct {
|
||||
Total int `json:"total"`
|
||||
ByType []MetricCount `json:"by_type"`
|
||||
ByRepo []MetricCount `json:"by_repo"`
|
||||
ByAgent []MetricCount `json:"by_agent"`
|
||||
Events []MetricEventBrief `json:"events"` // Most recent 10 events
|
||||
}
|
||||
|
||||
// MetricCount represents a count for a specific key.
|
||||
type MetricCount struct {
|
||||
Key string `json:"key"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
// MetricEventBrief represents a brief summary of an event.
|
||||
type MetricEventBrief struct {
|
||||
Type string `json:"type"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
Repo string `json:"repo,omitempty"`
|
||||
}
|
||||
|
||||
// registerMetricsTools adds metrics tools to the MCP server.
|
||||
func (s *Service) registerMetricsTools(server *mcp.Server) {
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "metrics_record",
|
||||
Description: "Record a metrics event for AI/security tracking. Events are stored in daily JSONL files.",
|
||||
}, s.metricsRecord)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "metrics_query",
|
||||
Description: "Query metrics events and get aggregated statistics by type, repo, and agent.",
|
||||
}, s.metricsQuery)
|
||||
}
|
||||
|
||||
// metricsRecord handles the metrics_record tool call.
|
||||
func (s *Service) metricsRecord(ctx context.Context, req *mcp.CallToolRequest, input MetricsRecordInput) (*mcp.CallToolResult, MetricsRecordOutput, error) {
|
||||
s.logger.Info("MCP tool execution", "tool", "metrics_record", "type", input.Type, "agent_id", input.AgentID, "repo", input.Repo, "user", log.Username())
|
||||
|
||||
// Validate input
|
||||
if input.Type == "" {
|
||||
return nil, MetricsRecordOutput{}, errors.New("type cannot be empty")
|
||||
}
|
||||
|
||||
// Create the event
|
||||
event := ai.Event{
|
||||
Type: input.Type,
|
||||
Timestamp: time.Now(),
|
||||
AgentID: input.AgentID,
|
||||
Repo: input.Repo,
|
||||
Data: input.Data,
|
||||
}
|
||||
|
||||
// Record the event
|
||||
if err := ai.Record(event); err != nil {
|
||||
log.Error("mcp: metrics record failed", "type", input.Type, "err", err)
|
||||
return nil, MetricsRecordOutput{}, fmt.Errorf("failed to record metrics: %w", err)
|
||||
}
|
||||
|
||||
return nil, MetricsRecordOutput{
|
||||
Success: true,
|
||||
Timestamp: event.Timestamp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// metricsQuery handles the metrics_query tool call.
|
||||
func (s *Service) metricsQuery(ctx context.Context, req *mcp.CallToolRequest, input MetricsQueryInput) (*mcp.CallToolResult, MetricsQueryOutput, error) {
|
||||
// Apply defaults
|
||||
since := input.Since
|
||||
if since == "" {
|
||||
since = DefaultMetricsSince
|
||||
}
|
||||
|
||||
s.logger.Info("MCP tool execution", "tool", "metrics_query", "since", since, "user", log.Username())
|
||||
|
||||
// Parse the duration
|
||||
duration, err := parseDuration(since)
|
||||
if err != nil {
|
||||
return nil, MetricsQueryOutput{}, fmt.Errorf("invalid since value: %w", err)
|
||||
}
|
||||
|
||||
sinceTime := time.Now().Add(-duration)
|
||||
|
||||
// Read events
|
||||
events, err := ai.ReadEvents(sinceTime)
|
||||
if err != nil {
|
||||
log.Error("mcp: metrics query failed", "since", since, "err", err)
|
||||
return nil, MetricsQueryOutput{}, fmt.Errorf("failed to read metrics: %w", err)
|
||||
}
|
||||
|
||||
// Get summary
|
||||
summary := ai.Summary(events)
|
||||
|
||||
// Build output
|
||||
output := MetricsQueryOutput{
|
||||
Total: summary["total"].(int),
|
||||
ByType: convertMetricCounts(summary["by_type"]),
|
||||
ByRepo: convertMetricCounts(summary["by_repo"]),
|
||||
ByAgent: convertMetricCounts(summary["by_agent"]),
|
||||
Events: make([]MetricEventBrief, 0, DefaultMetricsLimit),
|
||||
}
|
||||
|
||||
// Get recent events (last 10, most recent first)
|
||||
startIdx := max(len(events)-DefaultMetricsLimit, 0)
|
||||
for i := len(events) - 1; i >= startIdx; i-- {
|
||||
ev := events[i]
|
||||
output.Events = append(output.Events, MetricEventBrief{
|
||||
Type: ev.Type,
|
||||
Timestamp: ev.Timestamp,
|
||||
AgentID: ev.AgentID,
|
||||
Repo: ev.Repo,
|
||||
})
|
||||
}
|
||||
|
||||
return nil, output, nil
|
||||
}
|
||||
|
||||
// convertMetricCounts converts the summary map format to MetricCount slice.
|
||||
func convertMetricCounts(data any) []MetricCount {
|
||||
if data == nil {
|
||||
return []MetricCount{}
|
||||
}
|
||||
|
||||
items, ok := data.([]map[string]any)
|
||||
if !ok {
|
||||
return []MetricCount{}
|
||||
}
|
||||
|
||||
result := make([]MetricCount, len(items))
|
||||
for i, item := range items {
|
||||
key, _ := item["key"].(string)
|
||||
count, _ := item["count"].(int)
|
||||
result[i] = MetricCount{Key: key, Count: count}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// parseDuration parses a duration string like "7d", "24h", "30m".
|
||||
func parseDuration(s string) (time.Duration, error) {
|
||||
if s == "" {
|
||||
return 0, errors.New("duration cannot be empty")
|
||||
}
|
||||
|
||||
s = strings.TrimSpace(s)
|
||||
if len(s) < 2 {
|
||||
return 0, fmt.Errorf("invalid duration format: %q", s)
|
||||
}
|
||||
|
||||
// Get the numeric part and unit
|
||||
unit := s[len(s)-1]
|
||||
numStr := s[:len(s)-1]
|
||||
|
||||
num, err := strconv.Atoi(numStr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid duration number: %q", numStr)
|
||||
}
|
||||
|
||||
if num <= 0 {
|
||||
return 0, fmt.Errorf("duration must be positive: %d", num)
|
||||
}
|
||||
|
||||
switch unit {
|
||||
case 'd':
|
||||
return time.Duration(num) * 24 * time.Hour, nil
|
||||
case 'h':
|
||||
return time.Duration(num) * time.Hour, nil
|
||||
case 'm':
|
||||
return time.Duration(num) * time.Minute, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid duration unit: %q (expected d, h, or m)", string(unit))
|
||||
}
|
||||
}
|
||||
|
|
@ -1,207 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestMetricsToolsRegistered_Good verifies that metrics tools are registered with the MCP server.
|
||||
func TestMetricsToolsRegistered_Good(t *testing.T) {
|
||||
// Create a new MCP service - this should register all tools including metrics
|
||||
s, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
// The server should have registered the metrics tools
|
||||
// We verify by checking that the server and logger exist
|
||||
if s.server == nil {
|
||||
t.Fatal("Server should not be nil")
|
||||
}
|
||||
|
||||
if s.logger == nil {
|
||||
t.Error("Logger should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMetricsRecordInput_Good verifies the MetricsRecordInput struct has expected fields.
|
||||
func TestMetricsRecordInput_Good(t *testing.T) {
|
||||
input := MetricsRecordInput{
|
||||
Type: "tool_call",
|
||||
AgentID: "agent-123",
|
||||
Repo: "host-uk/core",
|
||||
Data: map[string]any{"tool": "file_read", "duration_ms": 150},
|
||||
}
|
||||
|
||||
if input.Type != "tool_call" {
|
||||
t.Errorf("Expected type 'tool_call', got %q", input.Type)
|
||||
}
|
||||
if input.AgentID != "agent-123" {
|
||||
t.Errorf("Expected agent_id 'agent-123', got %q", input.AgentID)
|
||||
}
|
||||
if input.Repo != "host-uk/core" {
|
||||
t.Errorf("Expected repo 'host-uk/core', got %q", input.Repo)
|
||||
}
|
||||
if input.Data["tool"] != "file_read" {
|
||||
t.Errorf("Expected data[tool] 'file_read', got %v", input.Data["tool"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestMetricsRecordOutput_Good verifies the MetricsRecordOutput struct has expected fields.
|
||||
func TestMetricsRecordOutput_Good(t *testing.T) {
|
||||
ts := time.Now()
|
||||
output := MetricsRecordOutput{
|
||||
Success: true,
|
||||
Timestamp: ts,
|
||||
}
|
||||
|
||||
if !output.Success {
|
||||
t.Error("Expected success to be true")
|
||||
}
|
||||
if output.Timestamp != ts {
|
||||
t.Errorf("Expected timestamp %v, got %v", ts, output.Timestamp)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMetricsQueryInput_Good verifies the MetricsQueryInput struct has expected fields.
|
||||
func TestMetricsQueryInput_Good(t *testing.T) {
|
||||
input := MetricsQueryInput{
|
||||
Since: "7d",
|
||||
}
|
||||
|
||||
if input.Since != "7d" {
|
||||
t.Errorf("Expected since '7d', got %q", input.Since)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMetricsQueryInput_Defaults verifies default values are handled correctly.
|
||||
func TestMetricsQueryInput_Defaults(t *testing.T) {
|
||||
input := MetricsQueryInput{}
|
||||
|
||||
// Empty since should use default when processed
|
||||
if input.Since != "" {
|
||||
t.Errorf("Expected empty since before defaults, got %q", input.Since)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMetricsQueryOutput_Good verifies the MetricsQueryOutput struct has expected fields.
|
||||
func TestMetricsQueryOutput_Good(t *testing.T) {
|
||||
output := MetricsQueryOutput{
|
||||
Total: 100,
|
||||
ByType: []MetricCount{
|
||||
{Key: "tool_call", Count: 50},
|
||||
{Key: "query", Count: 30},
|
||||
},
|
||||
ByRepo: []MetricCount{
|
||||
{Key: "host-uk/core", Count: 40},
|
||||
},
|
||||
ByAgent: []MetricCount{
|
||||
{Key: "agent-123", Count: 25},
|
||||
},
|
||||
Events: []MetricEventBrief{
|
||||
{Type: "tool_call", Timestamp: time.Now(), AgentID: "agent-1", Repo: "host-uk/core"},
|
||||
},
|
||||
}
|
||||
|
||||
if output.Total != 100 {
|
||||
t.Errorf("Expected total 100, got %d", output.Total)
|
||||
}
|
||||
if len(output.ByType) != 2 {
|
||||
t.Errorf("Expected 2 ByType entries, got %d", len(output.ByType))
|
||||
}
|
||||
if output.ByType[0].Key != "tool_call" {
|
||||
t.Errorf("Expected ByType[0].Key 'tool_call', got %q", output.ByType[0].Key)
|
||||
}
|
||||
if output.ByType[0].Count != 50 {
|
||||
t.Errorf("Expected ByType[0].Count 50, got %d", output.ByType[0].Count)
|
||||
}
|
||||
if len(output.Events) != 1 {
|
||||
t.Errorf("Expected 1 event, got %d", len(output.Events))
|
||||
}
|
||||
}
|
||||
|
||||
// TestMetricCount_Good verifies the MetricCount struct has expected fields.
|
||||
func TestMetricCount_Good(t *testing.T) {
|
||||
mc := MetricCount{
|
||||
Key: "tool_call",
|
||||
Count: 42,
|
||||
}
|
||||
|
||||
if mc.Key != "tool_call" {
|
||||
t.Errorf("Expected key 'tool_call', got %q", mc.Key)
|
||||
}
|
||||
if mc.Count != 42 {
|
||||
t.Errorf("Expected count 42, got %d", mc.Count)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMetricEventBrief_Good verifies the MetricEventBrief struct has expected fields.
|
||||
func TestMetricEventBrief_Good(t *testing.T) {
|
||||
ts := time.Now()
|
||||
ev := MetricEventBrief{
|
||||
Type: "tool_call",
|
||||
Timestamp: ts,
|
||||
AgentID: "agent-123",
|
||||
Repo: "host-uk/core",
|
||||
}
|
||||
|
||||
if ev.Type != "tool_call" {
|
||||
t.Errorf("Expected type 'tool_call', got %q", ev.Type)
|
||||
}
|
||||
if ev.Timestamp != ts {
|
||||
t.Errorf("Expected timestamp %v, got %v", ts, ev.Timestamp)
|
||||
}
|
||||
if ev.AgentID != "agent-123" {
|
||||
t.Errorf("Expected agent_id 'agent-123', got %q", ev.AgentID)
|
||||
}
|
||||
if ev.Repo != "host-uk/core" {
|
||||
t.Errorf("Expected repo 'host-uk/core', got %q", ev.Repo)
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseDuration_Good verifies the parseDuration helper handles various formats.
|
||||
func TestParseDuration_Good(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected time.Duration
|
||||
}{
|
||||
{"7d", 7 * 24 * time.Hour},
|
||||
{"24h", 24 * time.Hour},
|
||||
{"30m", 30 * time.Minute},
|
||||
{"1d", 24 * time.Hour},
|
||||
{"14d", 14 * 24 * time.Hour},
|
||||
{"1h", time.Hour},
|
||||
{"10m", 10 * time.Minute},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.input, func(t *testing.T) {
|
||||
d, err := parseDuration(tc.input)
|
||||
if err != nil {
|
||||
t.Fatalf("parseDuration(%q) returned error: %v", tc.input, err)
|
||||
}
|
||||
if d != tc.expected {
|
||||
t.Errorf("parseDuration(%q) = %v, want %v", tc.input, d, tc.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseDuration_Bad verifies parseDuration returns errors for invalid input.
|
||||
func TestParseDuration_Bad(t *testing.T) {
|
||||
tests := []string{
|
||||
"",
|
||||
"abc",
|
||||
"7x",
|
||||
"-7d",
|
||||
}
|
||||
|
||||
for _, input := range tests {
|
||||
t.Run(input, func(t *testing.T) {
|
||||
_, err := parseDuration(input)
|
||||
if err == nil {
|
||||
t.Errorf("parseDuration(%q) should return error", input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
290
mcp/tools_ml.go
290
mcp/tools_ml.go
|
|
@ -1,290 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
"forge.lthn.ai/core/go-ml"
|
||||
"forge.lthn.ai/core/go-log"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// MLSubsystem exposes ML inference and scoring tools via MCP.
|
||||
type MLSubsystem struct {
|
||||
service *ml.Service
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// NewMLSubsystem creates an MCP subsystem for ML tools.
|
||||
func NewMLSubsystem(svc *ml.Service) *MLSubsystem {
|
||||
return &MLSubsystem{
|
||||
service: svc,
|
||||
logger: log.Default(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MLSubsystem) Name() string { return "ml" }
|
||||
|
||||
// RegisterTools adds ML tools to the MCP server.
|
||||
func (m *MLSubsystem) RegisterTools(server *mcp.Server) {
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "ml_generate",
|
||||
Description: "Generate text via a configured ML inference backend.",
|
||||
}, m.mlGenerate)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "ml_score",
|
||||
Description: "Score a prompt/response pair using heuristic and LLM judge suites.",
|
||||
}, m.mlScore)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "ml_probe",
|
||||
Description: "Run capability probes against an inference backend.",
|
||||
}, m.mlProbe)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "ml_status",
|
||||
Description: "Show training and generation progress from InfluxDB.",
|
||||
}, m.mlStatus)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "ml_backends",
|
||||
Description: "List available inference backends and their status.",
|
||||
}, m.mlBackends)
|
||||
}
|
||||
|
||||
// --- Input/Output types ---
|
||||
|
||||
// MLGenerateInput contains parameters for text generation.
|
||||
type MLGenerateInput struct {
|
||||
Prompt string `json:"prompt"` // The prompt to generate from
|
||||
Backend string `json:"backend,omitempty"` // Backend name (default: service default)
|
||||
Model string `json:"model,omitempty"` // Model override
|
||||
Temperature float64 `json:"temperature,omitempty"` // Sampling temperature
|
||||
MaxTokens int `json:"max_tokens,omitempty"` // Maximum tokens to generate
|
||||
}
|
||||
|
||||
// MLGenerateOutput contains the generation result.
|
||||
type MLGenerateOutput struct {
|
||||
Response string `json:"response"`
|
||||
Backend string `json:"backend"`
|
||||
Model string `json:"model,omitempty"`
|
||||
}
|
||||
|
||||
// MLScoreInput contains parameters for scoring a response.
|
||||
type MLScoreInput struct {
|
||||
Prompt string `json:"prompt"` // The original prompt
|
||||
Response string `json:"response"` // The model response to score
|
||||
Suites string `json:"suites,omitempty"` // Comma-separated suites (default: heuristic)
|
||||
}
|
||||
|
||||
// MLScoreOutput contains the scoring result.
|
||||
type MLScoreOutput struct {
|
||||
Heuristic *ml.HeuristicScores `json:"heuristic,omitempty"`
|
||||
Semantic *ml.SemanticScores `json:"semantic,omitempty"`
|
||||
Content *ml.ContentScores `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
// MLProbeInput contains parameters for running probes.
|
||||
type MLProbeInput struct {
|
||||
Backend string `json:"backend,omitempty"` // Backend name
|
||||
Categories string `json:"categories,omitempty"` // Comma-separated categories to run
|
||||
}
|
||||
|
||||
// MLProbeOutput contains probe results.
|
||||
type MLProbeOutput struct {
|
||||
Total int `json:"total"`
|
||||
Results []MLProbeResultItem `json:"results"`
|
||||
}
|
||||
|
||||
// MLProbeResultItem is a single probe result.
|
||||
type MLProbeResultItem struct {
|
||||
ID string `json:"id"`
|
||||
Category string `json:"category"`
|
||||
Response string `json:"response"`
|
||||
}
|
||||
|
||||
// MLStatusInput contains parameters for the status query.
|
||||
type MLStatusInput struct {
|
||||
InfluxURL string `json:"influx_url,omitempty"` // InfluxDB URL override
|
||||
InfluxDB string `json:"influx_db,omitempty"` // InfluxDB database override
|
||||
}
|
||||
|
||||
// MLStatusOutput contains pipeline status.
|
||||
type MLStatusOutput struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// MLBackendsInput is empty — lists all backends.
|
||||
type MLBackendsInput struct{}
|
||||
|
||||
// MLBackendsOutput lists available backends.
|
||||
type MLBackendsOutput struct {
|
||||
Backends []MLBackendInfo `json:"backends"`
|
||||
Default string `json:"default"`
|
||||
}
|
||||
|
||||
// MLBackendInfo describes a single backend.
|
||||
type MLBackendInfo struct {
|
||||
Name string `json:"name"`
|
||||
Available bool `json:"available"`
|
||||
}
|
||||
|
||||
// --- Tool handlers ---
|
||||
|
||||
// mlGenerate delegates to go-ml.Service.Generate, which internally uses
|
||||
// InferenceAdapter to route generation through an inference.TextModel.
|
||||
// Flow: go-ai → go-ml.Service.Generate → InferenceAdapter → inference.TextModel.
|
||||
func (m *MLSubsystem) mlGenerate(ctx context.Context, req *mcp.CallToolRequest, input MLGenerateInput) (*mcp.CallToolResult, MLGenerateOutput, error) {
|
||||
m.logger.Info("MCP tool execution", "tool", "ml_generate", "backend", input.Backend, "user", log.Username())
|
||||
|
||||
if input.Prompt == "" {
|
||||
return nil, MLGenerateOutput{}, errors.New("prompt cannot be empty")
|
||||
}
|
||||
|
||||
opts := ml.GenOpts{
|
||||
Temperature: input.Temperature,
|
||||
MaxTokens: input.MaxTokens,
|
||||
Model: input.Model,
|
||||
}
|
||||
|
||||
result, err := m.service.Generate(ctx, input.Backend, input.Prompt, opts)
|
||||
if err != nil {
|
||||
return nil, MLGenerateOutput{}, fmt.Errorf("generate: %w", err)
|
||||
}
|
||||
|
||||
return nil, MLGenerateOutput{
|
||||
Response: result.Text,
|
||||
Backend: input.Backend,
|
||||
Model: input.Model,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MLSubsystem) mlScore(ctx context.Context, req *mcp.CallToolRequest, input MLScoreInput) (*mcp.CallToolResult, MLScoreOutput, error) {
|
||||
m.logger.Info("MCP tool execution", "tool", "ml_score", "suites", input.Suites, "user", log.Username())
|
||||
|
||||
if input.Prompt == "" || input.Response == "" {
|
||||
return nil, MLScoreOutput{}, errors.New("prompt and response cannot be empty")
|
||||
}
|
||||
|
||||
suites := input.Suites
|
||||
if suites == "" {
|
||||
suites = "heuristic"
|
||||
}
|
||||
|
||||
output := MLScoreOutput{}
|
||||
|
||||
for suite := range strings.SplitSeq(suites, ",") {
|
||||
suite = strings.TrimSpace(suite)
|
||||
switch suite {
|
||||
case "heuristic":
|
||||
output.Heuristic = ml.ScoreHeuristic(input.Response)
|
||||
case "semantic":
|
||||
judge := m.service.Judge()
|
||||
if judge == nil {
|
||||
return nil, MLScoreOutput{}, errors.New("semantic scoring requires a judge backend")
|
||||
}
|
||||
s, err := judge.ScoreSemantic(ctx, input.Prompt, input.Response)
|
||||
if err != nil {
|
||||
return nil, MLScoreOutput{}, fmt.Errorf("semantic score: %w", err)
|
||||
}
|
||||
output.Semantic = s
|
||||
case "content":
|
||||
return nil, MLScoreOutput{}, errors.New("content scoring requires a ContentProbe — use ml_probe instead")
|
||||
}
|
||||
}
|
||||
|
||||
return nil, output, nil
|
||||
}
|
||||
|
||||
// mlProbe runs capability probes by generating responses via go-ml.Service.
|
||||
// Flow: go-ai → go-ml.Service.Generate → InferenceAdapter → inference.TextModel.
|
||||
func (m *MLSubsystem) mlProbe(ctx context.Context, req *mcp.CallToolRequest, input MLProbeInput) (*mcp.CallToolResult, MLProbeOutput, error) {
|
||||
m.logger.Info("MCP tool execution", "tool", "ml_probe", "backend", input.Backend, "user", log.Username())
|
||||
|
||||
// Filter probes by category if specified.
|
||||
probes := ml.CapabilityProbes
|
||||
if input.Categories != "" {
|
||||
cats := make(map[string]bool)
|
||||
for c := range strings.SplitSeq(input.Categories, ",") {
|
||||
cats[strings.TrimSpace(c)] = true
|
||||
}
|
||||
var filtered []ml.Probe
|
||||
for _, p := range probes {
|
||||
if cats[p.Category] {
|
||||
filtered = append(filtered, p)
|
||||
}
|
||||
}
|
||||
probes = filtered
|
||||
}
|
||||
|
||||
var results []MLProbeResultItem
|
||||
for _, probe := range probes {
|
||||
result, err := m.service.Generate(ctx, input.Backend, probe.Prompt, ml.GenOpts{Temperature: 0.7, MaxTokens: 2048})
|
||||
respText := result.Text
|
||||
if err != nil {
|
||||
respText = fmt.Sprintf("error: %v", err)
|
||||
}
|
||||
results = append(results, MLProbeResultItem{
|
||||
ID: probe.ID,
|
||||
Category: probe.Category,
|
||||
Response: respText,
|
||||
})
|
||||
}
|
||||
|
||||
return nil, MLProbeOutput{
|
||||
Total: len(results),
|
||||
Results: results,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MLSubsystem) mlStatus(ctx context.Context, req *mcp.CallToolRequest, input MLStatusInput) (*mcp.CallToolResult, MLStatusOutput, error) {
|
||||
m.logger.Info("MCP tool execution", "tool", "ml_status", "user", log.Username())
|
||||
|
||||
url := input.InfluxURL
|
||||
db := input.InfluxDB
|
||||
if url == "" {
|
||||
url = "http://localhost:8086"
|
||||
}
|
||||
if db == "" {
|
||||
db = "lem"
|
||||
}
|
||||
|
||||
influx := ml.NewInfluxClient(url, db)
|
||||
var buf strings.Builder
|
||||
if err := ml.PrintStatus(influx, &buf); err != nil {
|
||||
return nil, MLStatusOutput{}, fmt.Errorf("status: %w", err)
|
||||
}
|
||||
|
||||
return nil, MLStatusOutput{Status: buf.String()}, nil
|
||||
}
|
||||
|
||||
// mlBackends enumerates registered backends via the go-inference registry,
|
||||
// bypassing go-ml.Service entirely. This is the canonical source of truth
|
||||
// for backend availability since all backends register with inference.Register().
|
||||
func (m *MLSubsystem) mlBackends(ctx context.Context, req *mcp.CallToolRequest, input MLBackendsInput) (*mcp.CallToolResult, MLBackendsOutput, error) {
|
||||
m.logger.Info("MCP tool execution", "tool", "ml_backends", "user", log.Username())
|
||||
|
||||
names := inference.List()
|
||||
backends := make([]MLBackendInfo, 0, len(names))
|
||||
for _, name := range names {
|
||||
b, ok := inference.Get(name)
|
||||
backends = append(backends, MLBackendInfo{
|
||||
Name: name,
|
||||
Available: ok && b.Available(),
|
||||
})
|
||||
}
|
||||
|
||||
defaultName := ""
|
||||
if db, err := inference.Default(); err == nil {
|
||||
defaultName = db.Name()
|
||||
}
|
||||
|
||||
return nil, MLBackendsOutput{
|
||||
Backends: backends,
|
||||
Default: defaultName,
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -1,479 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
"forge.lthn.ai/core/go-ml"
|
||||
"forge.lthn.ai/core/go/pkg/core"
|
||||
"forge.lthn.ai/core/go-log"
|
||||
)
|
||||
|
||||
// --- Mock backend for inference registry ---
|
||||
|
||||
// mockInferenceBackend implements inference.Backend for CI testing of ml_backends.
|
||||
type mockInferenceBackend struct {
|
||||
name string
|
||||
available bool
|
||||
}
|
||||
|
||||
func (m *mockInferenceBackend) Name() string { return m.name }
|
||||
func (m *mockInferenceBackend) Available() bool { return m.available }
|
||||
func (m *mockInferenceBackend) LoadModel(_ string, _ ...inference.LoadOption) (inference.TextModel, error) {
|
||||
return nil, fmt.Errorf("mock backend: LoadModel not implemented")
|
||||
}
|
||||
|
||||
// --- Mock ml.Backend for Generate ---
|
||||
|
||||
// mockMLBackend implements ml.Backend for CI testing.
|
||||
type mockMLBackend struct {
|
||||
name string
|
||||
available bool
|
||||
generateResp string
|
||||
generateErr error
|
||||
}
|
||||
|
||||
func (m *mockMLBackend) Name() string { return m.name }
|
||||
func (m *mockMLBackend) Available() bool { return m.available }
|
||||
|
||||
func (m *mockMLBackend) Generate(_ context.Context, _ string, _ ml.GenOpts) (ml.Result, error) {
|
||||
return ml.Result{Text: m.generateResp}, m.generateErr
|
||||
}
|
||||
|
||||
func (m *mockMLBackend) Chat(_ context.Context, _ []ml.Message, _ ml.GenOpts) (ml.Result, error) {
|
||||
return ml.Result{Text: m.generateResp}, m.generateErr
|
||||
}
|
||||
|
||||
// newTestMLSubsystem creates an MLSubsystem with a real ml.Service for testing.
|
||||
func newTestMLSubsystem(t *testing.T, backends ...ml.Backend) *MLSubsystem {
|
||||
t.Helper()
|
||||
c, err := core.New(
|
||||
core.WithName("ml", ml.NewService(ml.Options{})),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create framework core: %v", err)
|
||||
}
|
||||
svc, err := core.ServiceFor[*ml.Service](c, "ml")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get ML service: %v", err)
|
||||
}
|
||||
// Register mock backends
|
||||
for _, b := range backends {
|
||||
svc.RegisterBackend(b.Name(), b)
|
||||
}
|
||||
return &MLSubsystem{
|
||||
service: svc,
|
||||
logger: log.Default(),
|
||||
}
|
||||
}
|
||||
|
||||
// --- Input/Output struct tests ---
|
||||
|
||||
// TestMLGenerateInput_Good verifies all fields can be set.
|
||||
func TestMLGenerateInput_Good(t *testing.T) {
|
||||
input := MLGenerateInput{
|
||||
Prompt: "Hello world",
|
||||
Backend: "test",
|
||||
Model: "test-model",
|
||||
Temperature: 0.7,
|
||||
MaxTokens: 100,
|
||||
}
|
||||
if input.Prompt != "Hello world" {
|
||||
t.Errorf("Expected prompt 'Hello world', got %q", input.Prompt)
|
||||
}
|
||||
if input.Temperature != 0.7 {
|
||||
t.Errorf("Expected temperature 0.7, got %f", input.Temperature)
|
||||
}
|
||||
if input.MaxTokens != 100 {
|
||||
t.Errorf("Expected max_tokens 100, got %d", input.MaxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMLScoreInput_Good verifies all fields can be set.
|
||||
func TestMLScoreInput_Good(t *testing.T) {
|
||||
input := MLScoreInput{
|
||||
Prompt: "test prompt",
|
||||
Response: "test response",
|
||||
Suites: "heuristic,semantic",
|
||||
}
|
||||
if input.Prompt != "test prompt" {
|
||||
t.Errorf("Expected prompt 'test prompt', got %q", input.Prompt)
|
||||
}
|
||||
if input.Response != "test response" {
|
||||
t.Errorf("Expected response 'test response', got %q", input.Response)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMLProbeInput_Good verifies all fields can be set.
|
||||
func TestMLProbeInput_Good(t *testing.T) {
|
||||
input := MLProbeInput{
|
||||
Backend: "test",
|
||||
Categories: "reasoning,code",
|
||||
}
|
||||
if input.Backend != "test" {
|
||||
t.Errorf("Expected backend 'test', got %q", input.Backend)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMLStatusInput_Good verifies all fields can be set.
|
||||
func TestMLStatusInput_Good(t *testing.T) {
|
||||
input := MLStatusInput{
|
||||
InfluxURL: "http://localhost:8086",
|
||||
InfluxDB: "lem",
|
||||
}
|
||||
if input.InfluxURL != "http://localhost:8086" {
|
||||
t.Errorf("Expected InfluxURL, got %q", input.InfluxURL)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMLBackendsInput_Good verifies empty struct.
|
||||
func TestMLBackendsInput_Good(t *testing.T) {
|
||||
_ = MLBackendsInput{}
|
||||
}
|
||||
|
||||
// TestMLBackendsOutput_Good verifies struct fields.
|
||||
func TestMLBackendsOutput_Good(t *testing.T) {
|
||||
output := MLBackendsOutput{
|
||||
Backends: []MLBackendInfo{
|
||||
{Name: "ollama", Available: true},
|
||||
{Name: "llama", Available: false},
|
||||
},
|
||||
Default: "ollama",
|
||||
}
|
||||
if len(output.Backends) != 2 {
|
||||
t.Fatalf("Expected 2 backends, got %d", len(output.Backends))
|
||||
}
|
||||
if output.Default != "ollama" {
|
||||
t.Errorf("Expected default 'ollama', got %q", output.Default)
|
||||
}
|
||||
if !output.Backends[0].Available {
|
||||
t.Error("Expected first backend to be available")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMLProbeOutput_Good verifies struct fields.
|
||||
func TestMLProbeOutput_Good(t *testing.T) {
|
||||
output := MLProbeOutput{
|
||||
Total: 2,
|
||||
Results: []MLProbeResultItem{
|
||||
{ID: "probe-1", Category: "reasoning", Response: "test"},
|
||||
{ID: "probe-2", Category: "code", Response: "test2"},
|
||||
},
|
||||
}
|
||||
if output.Total != 2 {
|
||||
t.Errorf("Expected total 2, got %d", output.Total)
|
||||
}
|
||||
if output.Results[0].ID != "probe-1" {
|
||||
t.Errorf("Expected ID 'probe-1', got %q", output.Results[0].ID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMLStatusOutput_Good verifies struct fields.
|
||||
func TestMLStatusOutput_Good(t *testing.T) {
|
||||
output := MLStatusOutput{Status: "OK: 5 training runs"}
|
||||
if output.Status != "OK: 5 training runs" {
|
||||
t.Errorf("Unexpected status: %q", output.Status)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMLGenerateOutput_Good verifies struct fields.
|
||||
func TestMLGenerateOutput_Good(t *testing.T) {
|
||||
output := MLGenerateOutput{
|
||||
Response: "Generated text here",
|
||||
Backend: "ollama",
|
||||
Model: "qwen3:8b",
|
||||
}
|
||||
if output.Response != "Generated text here" {
|
||||
t.Errorf("Unexpected response: %q", output.Response)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMLScoreOutput_Good verifies struct fields.
|
||||
func TestMLScoreOutput_Good(t *testing.T) {
|
||||
output := MLScoreOutput{
|
||||
Heuristic: &ml.HeuristicScores{},
|
||||
}
|
||||
if output.Heuristic == nil {
|
||||
t.Error("Expected Heuristic to be set")
|
||||
}
|
||||
if output.Semantic != nil {
|
||||
t.Error("Expected Semantic to be nil")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Handler validation tests ---
|
||||
|
||||
// TestMLGenerate_Bad_EmptyPrompt verifies empty prompt returns error.
|
||||
func TestMLGenerate_Bad_EmptyPrompt(t *testing.T) {
|
||||
m := newTestMLSubsystem(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err := m.mlGenerate(ctx, nil, MLGenerateInput{})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty prompt")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "prompt cannot be empty") {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMLGenerate_Good_WithMockBackend verifies generate works with a mock backend.
|
||||
func TestMLGenerate_Good_WithMockBackend(t *testing.T) {
|
||||
mock := &mockMLBackend{
|
||||
name: "test-mock",
|
||||
available: true,
|
||||
generateResp: "mock response",
|
||||
}
|
||||
m := newTestMLSubsystem(t, mock)
|
||||
ctx := context.Background()
|
||||
|
||||
_, out, err := m.mlGenerate(ctx, nil, MLGenerateInput{
|
||||
Prompt: "test",
|
||||
Backend: "test-mock",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("mlGenerate failed: %v", err)
|
||||
}
|
||||
if out.Response != "mock response" {
|
||||
t.Errorf("Expected 'mock response', got %q", out.Response)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMLGenerate_Bad_NoBackend verifies generate fails gracefully without a backend.
|
||||
func TestMLGenerate_Bad_NoBackend(t *testing.T) {
|
||||
m := newTestMLSubsystem(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err := m.mlGenerate(ctx, nil, MLGenerateInput{
|
||||
Prompt: "test",
|
||||
Backend: "nonexistent",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for missing backend")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no backend available") {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMLScore_Bad_EmptyPrompt verifies empty prompt returns error.
|
||||
func TestMLScore_Bad_EmptyPrompt(t *testing.T) {
|
||||
m := newTestMLSubsystem(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err := m.mlScore(ctx, nil, MLScoreInput{Response: "some"})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty prompt")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMLScore_Bad_EmptyResponse verifies empty response returns error.
|
||||
func TestMLScore_Bad_EmptyResponse(t *testing.T) {
|
||||
m := newTestMLSubsystem(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err := m.mlScore(ctx, nil, MLScoreInput{Prompt: "some"})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty response")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMLScore_Good_Heuristic verifies heuristic scoring without live services.
|
||||
func TestMLScore_Good_Heuristic(t *testing.T) {
|
||||
m := newTestMLSubsystem(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, out, err := m.mlScore(ctx, nil, MLScoreInput{
|
||||
Prompt: "What is Go?",
|
||||
Response: "Go is a statically typed, compiled programming language designed at Google.",
|
||||
Suites: "heuristic",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("mlScore failed: %v", err)
|
||||
}
|
||||
if out.Heuristic == nil {
|
||||
t.Fatal("Expected heuristic scores to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMLScore_Good_DefaultSuite verifies default suite is heuristic.
|
||||
func TestMLScore_Good_DefaultSuite(t *testing.T) {
|
||||
m := newTestMLSubsystem(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, out, err := m.mlScore(ctx, nil, MLScoreInput{
|
||||
Prompt: "What is Go?",
|
||||
Response: "Go is a statically typed, compiled programming language designed at Google.",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("mlScore failed: %v", err)
|
||||
}
|
||||
if out.Heuristic == nil {
|
||||
t.Fatal("Expected heuristic scores (default suite)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMLScore_Bad_SemanticNoJudge verifies semantic scoring fails without a judge.
|
||||
func TestMLScore_Bad_SemanticNoJudge(t *testing.T) {
|
||||
m := newTestMLSubsystem(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err := m.mlScore(ctx, nil, MLScoreInput{
|
||||
Prompt: "test",
|
||||
Response: "test",
|
||||
Suites: "semantic",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for semantic scoring without judge")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "requires a judge") {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMLScore_Bad_ContentSuite verifies content suite redirects to ml_probe.
|
||||
func TestMLScore_Bad_ContentSuite(t *testing.T) {
|
||||
m := newTestMLSubsystem(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err := m.mlScore(ctx, nil, MLScoreInput{
|
||||
Prompt: "test",
|
||||
Response: "test",
|
||||
Suites: "content",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for content suite")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "ContentProbe") {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMLProbe_Good_WithMockBackend verifies probes run with mock backend.
|
||||
func TestMLProbe_Good_WithMockBackend(t *testing.T) {
|
||||
mock := &mockMLBackend{
|
||||
name: "probe-mock",
|
||||
available: true,
|
||||
generateResp: "probe response",
|
||||
}
|
||||
m := newTestMLSubsystem(t, mock)
|
||||
ctx := context.Background()
|
||||
|
||||
_, out, err := m.mlProbe(ctx, nil, MLProbeInput{
|
||||
Backend: "probe-mock",
|
||||
Categories: "reasoning",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("mlProbe failed: %v", err)
|
||||
}
|
||||
// Should have run probes in the "reasoning" category
|
||||
for _, r := range out.Results {
|
||||
if r.Category != "reasoning" {
|
||||
t.Errorf("Expected category 'reasoning', got %q", r.Category)
|
||||
}
|
||||
if r.Response != "probe response" {
|
||||
t.Errorf("Expected 'probe response', got %q", r.Response)
|
||||
}
|
||||
}
|
||||
if out.Total != len(out.Results) {
|
||||
t.Errorf("Expected total %d, got %d", len(out.Results), out.Total)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMLProbe_Good_NoCategory verifies all probes run without category filter.
|
||||
func TestMLProbe_Good_NoCategory(t *testing.T) {
|
||||
mock := &mockMLBackend{
|
||||
name: "all-probe-mock",
|
||||
available: true,
|
||||
generateResp: "ok",
|
||||
}
|
||||
m := newTestMLSubsystem(t, mock)
|
||||
ctx := context.Background()
|
||||
|
||||
_, out, err := m.mlProbe(ctx, nil, MLProbeInput{Backend: "all-probe-mock"})
|
||||
if err != nil {
|
||||
t.Fatalf("mlProbe failed: %v", err)
|
||||
}
|
||||
// Should run all 23 probes
|
||||
if out.Total != len(ml.CapabilityProbes) {
|
||||
t.Errorf("Expected %d probes, got %d", len(ml.CapabilityProbes), out.Total)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMLBackends_Good_EmptyRegistry verifies empty result when no backends registered.
|
||||
func TestMLBackends_Good_EmptyRegistry(t *testing.T) {
|
||||
m := newTestMLSubsystem(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Note: inference.List() returns global registry state.
|
||||
// This test verifies the handler runs without panic.
|
||||
_, out, err := m.mlBackends(ctx, nil, MLBackendsInput{})
|
||||
if err != nil {
|
||||
t.Fatalf("mlBackends failed: %v", err)
|
||||
}
|
||||
// We can't guarantee what's in the global registry, but it should not panic
|
||||
_ = out
|
||||
}
|
||||
|
||||
// TestMLBackends_Good_WithMockInferenceBackend verifies registered backend appears.
|
||||
func TestMLBackends_Good_WithMockInferenceBackend(t *testing.T) {
|
||||
// Register a mock backend in the global inference registry
|
||||
mock := &mockInferenceBackend{name: "test-ci-mock", available: true}
|
||||
inference.Register(mock)
|
||||
|
||||
m := newTestMLSubsystem(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, out, err := m.mlBackends(ctx, nil, MLBackendsInput{})
|
||||
if err != nil {
|
||||
t.Fatalf("mlBackends failed: %v", err)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, b := range out.Backends {
|
||||
if b.Name == "test-ci-mock" {
|
||||
found = true
|
||||
if !b.Available {
|
||||
t.Error("Expected mock backend to be available")
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected to find 'test-ci-mock' in backends list")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMLSubsystem_Good_Name verifies subsystem name.
|
||||
func TestMLSubsystem_Good_Name(t *testing.T) {
|
||||
m := newTestMLSubsystem(t)
|
||||
if m.Name() != "ml" {
|
||||
t.Errorf("Expected name 'ml', got %q", m.Name())
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewMLSubsystem_Good verifies constructor.
|
||||
func TestNewMLSubsystem_Good(t *testing.T) {
|
||||
c, err := core.New(
|
||||
core.WithName("ml", ml.NewService(ml.Options{})),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create core: %v", err)
|
||||
}
|
||||
svc, err := core.ServiceFor[*ml.Service](c, "ml")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get service: %v", err)
|
||||
}
|
||||
sub := NewMLSubsystem(svc)
|
||||
if sub == nil {
|
||||
t.Fatal("Expected non-nil subsystem")
|
||||
}
|
||||
if sub.service != svc {
|
||||
t.Error("Expected service to be set")
|
||||
}
|
||||
if sub.logger == nil {
|
||||
t.Error("Expected logger to be set")
|
||||
}
|
||||
}
|
||||
|
|
@ -1,305 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/go-log"
|
||||
"forge.lthn.ai/core/go-process"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// errIDEmpty is returned when a process tool call omits the required ID.
|
||||
var errIDEmpty = errors.New("id cannot be empty")
|
||||
|
||||
// ProcessStartInput contains parameters for starting a new process.
|
||||
type ProcessStartInput struct {
|
||||
Command string `json:"command"` // The command to run
|
||||
Args []string `json:"args,omitempty"` // Command arguments
|
||||
Dir string `json:"dir,omitempty"` // Working directory
|
||||
Env []string `json:"env,omitempty"` // Environment variables (KEY=VALUE format)
|
||||
}
|
||||
|
||||
// ProcessStartOutput contains the result of starting a process.
|
||||
type ProcessStartOutput struct {
|
||||
ID string `json:"id"`
|
||||
PID int `json:"pid"`
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args"`
|
||||
StartedAt time.Time `json:"startedAt"`
|
||||
}
|
||||
|
||||
// ProcessStopInput contains parameters for gracefully stopping a process.
|
||||
type ProcessStopInput struct {
|
||||
ID string `json:"id"` // Process ID to stop
|
||||
}
|
||||
|
||||
// ProcessStopOutput contains the result of stopping a process.
|
||||
type ProcessStopOutput struct {
|
||||
ID string `json:"id"`
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// ProcessKillInput contains parameters for force killing a process.
|
||||
type ProcessKillInput struct {
|
||||
ID string `json:"id"` // Process ID to kill
|
||||
}
|
||||
|
||||
// ProcessKillOutput contains the result of killing a process.
|
||||
type ProcessKillOutput struct {
|
||||
ID string `json:"id"`
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// ProcessListInput contains parameters for listing processes.
|
||||
type ProcessListInput struct {
|
||||
RunningOnly bool `json:"running_only,omitempty"` // If true, only return running processes
|
||||
}
|
||||
|
||||
// ProcessListOutput contains the list of processes.
|
||||
type ProcessListOutput struct {
|
||||
Processes []ProcessInfo `json:"processes"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
|
||||
// ProcessInfo represents information about a process.
|
||||
type ProcessInfo struct {
|
||||
ID string `json:"id"`
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args"`
|
||||
Dir string `json:"dir"`
|
||||
Status string `json:"status"`
|
||||
PID int `json:"pid"`
|
||||
ExitCode int `json:"exitCode"`
|
||||
StartedAt time.Time `json:"startedAt"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
}
|
||||
|
||||
// ProcessOutputInput contains parameters for getting process output.
|
||||
type ProcessOutputInput struct {
|
||||
ID string `json:"id"` // Process ID
|
||||
}
|
||||
|
||||
// ProcessOutputOutput contains the captured output of a process.
|
||||
type ProcessOutputOutput struct {
|
||||
ID string `json:"id"`
|
||||
Output string `json:"output"`
|
||||
}
|
||||
|
||||
// ProcessInputInput contains parameters for sending input to a process.
|
||||
type ProcessInputInput struct {
|
||||
ID string `json:"id"` // Process ID
|
||||
Input string `json:"input"` // Input to send to stdin
|
||||
}
|
||||
|
||||
// ProcessInputOutput contains the result of sending input to a process.
|
||||
type ProcessInputOutput struct {
|
||||
ID string `json:"id"`
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// registerProcessTools adds process management tools to the MCP server.
|
||||
// Returns false if process service is not available.
|
||||
func (s *Service) registerProcessTools(server *mcp.Server) bool {
|
||||
if s.processService == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "process_start",
|
||||
Description: "Start a new external process. Returns process ID for tracking.",
|
||||
}, s.processStart)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "process_stop",
|
||||
Description: "Gracefully stop a running process by ID.",
|
||||
}, s.processStop)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "process_kill",
|
||||
Description: "Force kill a process by ID. Use when process_stop doesn't work.",
|
||||
}, s.processKill)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "process_list",
|
||||
Description: "List all managed processes. Use running_only=true for only active processes.",
|
||||
}, s.processList)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "process_output",
|
||||
Description: "Get the captured output of a process by ID.",
|
||||
}, s.processOutput)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "process_input",
|
||||
Description: "Send input to a running process stdin.",
|
||||
}, s.processInput)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// processStart handles the process_start tool call.
|
||||
func (s *Service) processStart(ctx context.Context, req *mcp.CallToolRequest, input ProcessStartInput) (*mcp.CallToolResult, ProcessStartOutput, error) {
|
||||
s.logger.Security("MCP tool execution", "tool", "process_start", "command", input.Command, "args", input.Args, "dir", input.Dir, "user", log.Username())
|
||||
|
||||
if input.Command == "" {
|
||||
return nil, ProcessStartOutput{}, errors.New("command cannot be empty")
|
||||
}
|
||||
|
||||
opts := process.RunOptions{
|
||||
Command: input.Command,
|
||||
Args: input.Args,
|
||||
Dir: input.Dir,
|
||||
Env: input.Env,
|
||||
}
|
||||
|
||||
proc, err := s.processService.StartWithOptions(ctx, opts)
|
||||
if err != nil {
|
||||
log.Error("mcp: process start failed", "command", input.Command, "err", err)
|
||||
return nil, ProcessStartOutput{}, fmt.Errorf("failed to start process: %w", err)
|
||||
}
|
||||
|
||||
info := proc.Info()
|
||||
return nil, ProcessStartOutput{
|
||||
ID: proc.ID,
|
||||
PID: info.PID,
|
||||
Command: proc.Command,
|
||||
Args: proc.Args,
|
||||
StartedAt: proc.StartedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// processStop handles the process_stop tool call.
|
||||
func (s *Service) processStop(ctx context.Context, req *mcp.CallToolRequest, input ProcessStopInput) (*mcp.CallToolResult, ProcessStopOutput, error) {
|
||||
s.logger.Security("MCP tool execution", "tool", "process_stop", "id", input.ID, "user", log.Username())
|
||||
|
||||
if input.ID == "" {
|
||||
return nil, ProcessStopOutput{}, errIDEmpty
|
||||
}
|
||||
|
||||
proc, err := s.processService.Get(input.ID)
|
||||
if err != nil {
|
||||
log.Error("mcp: process stop failed", "id", input.ID, "err", err)
|
||||
return nil, ProcessStopOutput{}, fmt.Errorf("process not found: %w", err)
|
||||
}
|
||||
|
||||
// For graceful stop, we use Kill() which sends SIGKILL
|
||||
// A more sophisticated implementation could use SIGTERM first
|
||||
if err := proc.Kill(); err != nil {
|
||||
log.Error("mcp: process stop kill failed", "id", input.ID, "err", err)
|
||||
return nil, ProcessStopOutput{}, fmt.Errorf("failed to stop process: %w", err)
|
||||
}
|
||||
|
||||
return nil, ProcessStopOutput{
|
||||
ID: input.ID,
|
||||
Success: true,
|
||||
Message: "Process stop signal sent",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// processKill handles the process_kill tool call.
|
||||
func (s *Service) processKill(ctx context.Context, req *mcp.CallToolRequest, input ProcessKillInput) (*mcp.CallToolResult, ProcessKillOutput, error) {
|
||||
s.logger.Security("MCP tool execution", "tool", "process_kill", "id", input.ID, "user", log.Username())
|
||||
|
||||
if input.ID == "" {
|
||||
return nil, ProcessKillOutput{}, errIDEmpty
|
||||
}
|
||||
|
||||
if err := s.processService.Kill(input.ID); err != nil {
|
||||
log.Error("mcp: process kill failed", "id", input.ID, "err", err)
|
||||
return nil, ProcessKillOutput{}, fmt.Errorf("failed to kill process: %w", err)
|
||||
}
|
||||
|
||||
return nil, ProcessKillOutput{
|
||||
ID: input.ID,
|
||||
Success: true,
|
||||
Message: "Process killed",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// processList handles the process_list tool call.
|
||||
func (s *Service) processList(ctx context.Context, req *mcp.CallToolRequest, input ProcessListInput) (*mcp.CallToolResult, ProcessListOutput, error) {
|
||||
s.logger.Info("MCP tool execution", "tool", "process_list", "running_only", input.RunningOnly, "user", log.Username())
|
||||
|
||||
var procs []*process.Process
|
||||
if input.RunningOnly {
|
||||
procs = s.processService.Running()
|
||||
} else {
|
||||
procs = s.processService.List()
|
||||
}
|
||||
|
||||
result := make([]ProcessInfo, len(procs))
|
||||
for i, p := range procs {
|
||||
info := p.Info()
|
||||
result[i] = ProcessInfo{
|
||||
ID: info.ID,
|
||||
Command: info.Command,
|
||||
Args: info.Args,
|
||||
Dir: info.Dir,
|
||||
Status: string(info.Status),
|
||||
PID: info.PID,
|
||||
ExitCode: info.ExitCode,
|
||||
StartedAt: info.StartedAt,
|
||||
Duration: info.Duration,
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ProcessListOutput{
|
||||
Processes: result,
|
||||
Total: len(result),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// processOutput handles the process_output tool call.
|
||||
func (s *Service) processOutput(ctx context.Context, req *mcp.CallToolRequest, input ProcessOutputInput) (*mcp.CallToolResult, ProcessOutputOutput, error) {
|
||||
s.logger.Info("MCP tool execution", "tool", "process_output", "id", input.ID, "user", log.Username())
|
||||
|
||||
if input.ID == "" {
|
||||
return nil, ProcessOutputOutput{}, errIDEmpty
|
||||
}
|
||||
|
||||
output, err := s.processService.Output(input.ID)
|
||||
if err != nil {
|
||||
log.Error("mcp: process output failed", "id", input.ID, "err", err)
|
||||
return nil, ProcessOutputOutput{}, fmt.Errorf("failed to get process output: %w", err)
|
||||
}
|
||||
|
||||
return nil, ProcessOutputOutput{
|
||||
ID: input.ID,
|
||||
Output: output,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// processInput handles the process_input tool call.
|
||||
func (s *Service) processInput(ctx context.Context, req *mcp.CallToolRequest, input ProcessInputInput) (*mcp.CallToolResult, ProcessInputOutput, error) {
|
||||
s.logger.Security("MCP tool execution", "tool", "process_input", "id", input.ID, "user", log.Username())
|
||||
|
||||
if input.ID == "" {
|
||||
return nil, ProcessInputOutput{}, errIDEmpty
|
||||
}
|
||||
if input.Input == "" {
|
||||
return nil, ProcessInputOutput{}, errors.New("input cannot be empty")
|
||||
}
|
||||
|
||||
proc, err := s.processService.Get(input.ID)
|
||||
if err != nil {
|
||||
log.Error("mcp: process input get failed", "id", input.ID, "err", err)
|
||||
return nil, ProcessInputOutput{}, fmt.Errorf("process not found: %w", err)
|
||||
}
|
||||
|
||||
if err := proc.SendInput(input.Input); err != nil {
|
||||
log.Error("mcp: process input send failed", "id", input.ID, "err", err)
|
||||
return nil, ProcessInputOutput{}, fmt.Errorf("failed to send input: %w", err)
|
||||
}
|
||||
|
||||
return nil, ProcessInputOutput{
|
||||
ID: input.ID,
|
||||
Success: true,
|
||||
Message: "Input sent successfully",
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -1,515 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/go/pkg/core"
|
||||
"forge.lthn.ai/core/go-process"
|
||||
)
|
||||
|
||||
// newTestProcessService creates a real process.Service backed by a core.Core for CI tests.
|
||||
func newTestProcessService(t *testing.T) *process.Service {
|
||||
t.Helper()
|
||||
c, err := core.New(
|
||||
core.WithName("process", process.NewService(process.Options{})),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create framework core: %v", err)
|
||||
}
|
||||
svc, err := core.ServiceFor[*process.Service](c, "process")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get process service: %v", err)
|
||||
}
|
||||
// Start services (calls OnStartup)
|
||||
if err := c.ServiceStartup(context.Background(), nil); err != nil {
|
||||
t.Fatalf("Failed to start core: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = c.ServiceShutdown(context.Background())
|
||||
})
|
||||
return svc
|
||||
}
|
||||
|
||||
// newTestMCPWithProcess creates an MCP Service wired to a real process.Service.
|
||||
func newTestMCPWithProcess(t *testing.T) (*Service, *process.Service) {
|
||||
t.Helper()
|
||||
ps := newTestProcessService(t)
|
||||
s, err := New(WithProcessService(ps))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create MCP service: %v", err)
|
||||
}
|
||||
return s, ps
|
||||
}
|
||||
|
||||
// --- CI-safe handler tests ---
|
||||
|
||||
// TestProcessStart_Good_Echo starts "echo hello" and verifies the output.
|
||||
func TestProcessStart_Good_Echo(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, out, err := s.processStart(ctx, nil, ProcessStartInput{
|
||||
Command: "echo",
|
||||
Args: []string{"hello"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processStart failed: %v", err)
|
||||
}
|
||||
if out.ID == "" {
|
||||
t.Error("Expected non-empty process ID")
|
||||
}
|
||||
if out.Command != "echo" {
|
||||
t.Errorf("Expected command 'echo', got %q", out.Command)
|
||||
}
|
||||
if out.PID <= 0 {
|
||||
t.Errorf("Expected positive PID, got %d", out.PID)
|
||||
}
|
||||
if out.StartedAt.IsZero() {
|
||||
t.Error("Expected non-zero StartedAt")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessStart_Bad_EmptyCommand verifies empty command returns an error.
|
||||
func TestProcessStart_Bad_EmptyCommand(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err := s.processStart(ctx, nil, ProcessStartInput{})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty command")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "command cannot be empty") {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessStart_Bad_NonexistentCommand verifies an invalid command returns an error.
|
||||
func TestProcessStart_Bad_NonexistentCommand(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err := s.processStart(ctx, nil, ProcessStartInput{
|
||||
Command: "/nonexistent/binary/that/does/not/exist",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for nonexistent command")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessList_Good_Empty verifies list is empty initially.
|
||||
func TestProcessList_Good_Empty(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, out, err := s.processList(ctx, nil, ProcessListInput{})
|
||||
if err != nil {
|
||||
t.Fatalf("processList failed: %v", err)
|
||||
}
|
||||
if out.Total != 0 {
|
||||
t.Errorf("Expected 0 processes, got %d", out.Total)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessList_Good_AfterStart verifies a started process appears in list.
|
||||
func TestProcessList_Good_AfterStart(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Start a short-lived process
|
||||
_, startOut, err := s.processStart(ctx, nil, ProcessStartInput{
|
||||
Command: "echo",
|
||||
Args: []string{"listing"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processStart failed: %v", err)
|
||||
}
|
||||
|
||||
// Give it a moment to register
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// List all processes (including exited)
|
||||
_, listOut, err := s.processList(ctx, nil, ProcessListInput{})
|
||||
if err != nil {
|
||||
t.Fatalf("processList failed: %v", err)
|
||||
}
|
||||
if listOut.Total < 1 {
|
||||
t.Fatalf("Expected at least 1 process, got %d", listOut.Total)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, p := range listOut.Processes {
|
||||
if p.ID == startOut.ID {
|
||||
found = true
|
||||
if p.Command != "echo" {
|
||||
t.Errorf("Expected command 'echo', got %q", p.Command)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Process %s not found in list", startOut.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessList_Good_RunningOnly verifies filtering for running-only processes.
|
||||
func TestProcessList_Good_RunningOnly(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Start a process that exits quickly
|
||||
_, _, err := s.processStart(ctx, nil, ProcessStartInput{
|
||||
Command: "echo",
|
||||
Args: []string{"done"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processStart failed: %v", err)
|
||||
}
|
||||
|
||||
// Wait for it to exit
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Running-only should be empty now
|
||||
_, listOut, err := s.processList(ctx, nil, ProcessListInput{RunningOnly: true})
|
||||
if err != nil {
|
||||
t.Fatalf("processList failed: %v", err)
|
||||
}
|
||||
if listOut.Total != 0 {
|
||||
t.Errorf("Expected 0 running processes after echo exits, got %d", listOut.Total)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessOutput_Good_Echo verifies output capture from echo.
|
||||
func TestProcessOutput_Good_Echo(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, startOut, err := s.processStart(ctx, nil, ProcessStartInput{
|
||||
Command: "echo",
|
||||
Args: []string{"output_test"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processStart failed: %v", err)
|
||||
}
|
||||
|
||||
// Wait for process to complete and output to be captured
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
_, outputOut, err := s.processOutput(ctx, nil, ProcessOutputInput{ID: startOut.ID})
|
||||
if err != nil {
|
||||
t.Fatalf("processOutput failed: %v", err)
|
||||
}
|
||||
if !strings.Contains(outputOut.Output, "output_test") {
|
||||
t.Errorf("Expected output to contain 'output_test', got %q", outputOut.Output)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessOutput_Bad_EmptyID verifies empty ID returns error.
|
||||
func TestProcessOutput_Bad_EmptyID(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err := s.processOutput(ctx, nil, ProcessOutputInput{})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty ID")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "id cannot be empty") {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessOutput_Bad_NotFound verifies nonexistent ID returns error.
|
||||
func TestProcessOutput_Bad_NotFound(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err := s.processOutput(ctx, nil, ProcessOutputInput{ID: "nonexistent-id"})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for nonexistent ID")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessStop_Good_LongRunning starts a sleep, stops it, and verifies.
|
||||
func TestProcessStop_Good_LongRunning(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Start a process that sleeps for a while
|
||||
_, startOut, err := s.processStart(ctx, nil, ProcessStartInput{
|
||||
Command: "sleep",
|
||||
Args: []string{"10"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processStart failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify it's running
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_, listOut, _ := s.processList(ctx, nil, ProcessListInput{RunningOnly: true})
|
||||
if listOut.Total < 1 {
|
||||
t.Fatal("Expected at least 1 running process")
|
||||
}
|
||||
|
||||
// Stop it
|
||||
_, stopOut, err := s.processStop(ctx, nil, ProcessStopInput{ID: startOut.ID})
|
||||
if err != nil {
|
||||
t.Fatalf("processStop failed: %v", err)
|
||||
}
|
||||
if !stopOut.Success {
|
||||
t.Error("Expected stop to succeed")
|
||||
}
|
||||
if stopOut.ID != startOut.ID {
|
||||
t.Errorf("Expected ID %q, got %q", startOut.ID, stopOut.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessStop_Bad_EmptyID verifies empty ID returns error.
|
||||
func TestProcessStop_Bad_EmptyID(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err := s.processStop(ctx, nil, ProcessStopInput{})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty ID")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessStop_Bad_NotFound verifies nonexistent ID returns error.
|
||||
func TestProcessStop_Bad_NotFound(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err := s.processStop(ctx, nil, ProcessStopInput{ID: "nonexistent-id"})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for nonexistent ID")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessKill_Good_LongRunning starts a sleep, kills it, and verifies.
|
||||
func TestProcessKill_Good_LongRunning(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, startOut, err := s.processStart(ctx, nil, ProcessStartInput{
|
||||
Command: "sleep",
|
||||
Args: []string{"10"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processStart failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
_, killOut, err := s.processKill(ctx, nil, ProcessKillInput{ID: startOut.ID})
|
||||
if err != nil {
|
||||
t.Fatalf("processKill failed: %v", err)
|
||||
}
|
||||
if !killOut.Success {
|
||||
t.Error("Expected kill to succeed")
|
||||
}
|
||||
if killOut.Message != "Process killed" {
|
||||
t.Errorf("Expected message 'Process killed', got %q", killOut.Message)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessKill_Bad_EmptyID verifies empty ID returns error.
|
||||
func TestProcessKill_Bad_EmptyID(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err := s.processKill(ctx, nil, ProcessKillInput{})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty ID")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessKill_Bad_NotFound verifies nonexistent ID returns error.
|
||||
func TestProcessKill_Bad_NotFound(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err := s.processKill(ctx, nil, ProcessKillInput{ID: "nonexistent-id"})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for nonexistent ID")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessInput_Bad_EmptyID verifies empty ID returns error.
|
||||
func TestProcessInput_Bad_EmptyID(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err := s.processInput(ctx, nil, ProcessInputInput{})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty ID")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessInput_Bad_EmptyInput verifies empty input string returns error.
|
||||
func TestProcessInput_Bad_EmptyInput(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err := s.processInput(ctx, nil, ProcessInputInput{ID: "some-id"})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty input")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessInput_Bad_NotFound verifies nonexistent process ID returns error.
|
||||
func TestProcessInput_Bad_NotFound(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err := s.processInput(ctx, nil, ProcessInputInput{
|
||||
ID: "nonexistent-id",
|
||||
Input: "hello\n",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for nonexistent ID")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessInput_Good_Cat sends input to cat and reads it back.
|
||||
func TestProcessInput_Good_Cat(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Start cat which reads stdin and echoes to stdout
|
||||
_, startOut, err := s.processStart(ctx, nil, ProcessStartInput{
|
||||
Command: "cat",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processStart failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Send input
|
||||
_, inputOut, err := s.processInput(ctx, nil, ProcessInputInput{
|
||||
ID: startOut.ID,
|
||||
Input: "stdin_test\n",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processInput failed: %v", err)
|
||||
}
|
||||
if !inputOut.Success {
|
||||
t.Error("Expected input to succeed")
|
||||
}
|
||||
|
||||
// Wait for output capture
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Read output
|
||||
_, outputOut, err := s.processOutput(ctx, nil, ProcessOutputInput{ID: startOut.ID})
|
||||
if err != nil {
|
||||
t.Fatalf("processOutput failed: %v", err)
|
||||
}
|
||||
if !strings.Contains(outputOut.Output, "stdin_test") {
|
||||
t.Errorf("Expected output to contain 'stdin_test', got %q", outputOut.Output)
|
||||
}
|
||||
|
||||
// Kill the cat process (it's still running)
|
||||
_, _, _ = s.processKill(ctx, nil, ProcessKillInput{ID: startOut.ID})
|
||||
}
|
||||
|
||||
// TestProcessStart_Good_WithDir verifies working directory is respected.
|
||||
func TestProcessStart_Good_WithDir(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
ctx := context.Background()
|
||||
dir := t.TempDir()
|
||||
|
||||
_, startOut, err := s.processStart(ctx, nil, ProcessStartInput{
|
||||
Command: "pwd",
|
||||
Dir: dir,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processStart failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
_, outputOut, err := s.processOutput(ctx, nil, ProcessOutputInput{ID: startOut.ID})
|
||||
if err != nil {
|
||||
t.Fatalf("processOutput failed: %v", err)
|
||||
}
|
||||
if !strings.Contains(outputOut.Output, dir) {
|
||||
t.Errorf("Expected output to contain dir %q, got %q", dir, outputOut.Output)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessStart_Good_WithEnv verifies environment variables are passed.
|
||||
func TestProcessStart_Good_WithEnv(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, startOut, err := s.processStart(ctx, nil, ProcessStartInput{
|
||||
Command: "env",
|
||||
Env: []string{"TEST_MCP_VAR=hello_from_test"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processStart failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
_, outputOut, err := s.processOutput(ctx, nil, ProcessOutputInput{ID: startOut.ID})
|
||||
if err != nil {
|
||||
t.Fatalf("processOutput failed: %v", err)
|
||||
}
|
||||
if !strings.Contains(outputOut.Output, "TEST_MCP_VAR=hello_from_test") {
|
||||
t.Errorf("Expected output to contain env var, got %q", outputOut.Output)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessToolsRegistered_Good_WithService verifies tools are registered when service is provided.
|
||||
func TestProcessToolsRegistered_Good_WithService(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
if s.processService == nil {
|
||||
t.Error("Expected process service to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessFullLifecycle_Good tests the start → list → output → kill → list cycle.
|
||||
func TestProcessFullLifecycle_Good(t *testing.T) {
|
||||
s, _ := newTestMCPWithProcess(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 1. Start
|
||||
_, startOut, err := s.processStart(ctx, nil, ProcessStartInput{
|
||||
Command: "sleep",
|
||||
Args: []string{"10"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processStart failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// 2. List (should be running)
|
||||
_, listOut, _ := s.processList(ctx, nil, ProcessListInput{RunningOnly: true})
|
||||
if listOut.Total < 1 {
|
||||
t.Fatal("Expected at least 1 running process")
|
||||
}
|
||||
|
||||
// 3. Kill
|
||||
_, killOut, err := s.processKill(ctx, nil, ProcessKillInput{ID: startOut.ID})
|
||||
if err != nil {
|
||||
t.Fatalf("processKill failed: %v", err)
|
||||
}
|
||||
if !killOut.Success {
|
||||
t.Error("Expected kill to succeed")
|
||||
}
|
||||
|
||||
// 4. Wait for exit
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 5. Should not be running anymore
|
||||
_, listOut, _ = s.processList(ctx, nil, ProcessListInput{RunningOnly: true})
|
||||
for _, p := range listOut.Processes {
|
||||
if p.ID == startOut.ID {
|
||||
t.Errorf("Process %s should not be running after kill", startOut.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,290 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestProcessToolsRegistered_Good verifies that process tools are registered when process service is available.
|
||||
func TestProcessToolsRegistered_Good(t *testing.T) {
|
||||
// Create a new MCP service without process service - tools should not be registered
|
||||
s, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
if s.processService != nil {
|
||||
t.Error("Process service should be nil by default")
|
||||
}
|
||||
|
||||
if s.server == nil {
|
||||
t.Fatal("Server should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessStartInput_Good verifies the ProcessStartInput struct has expected fields.
|
||||
func TestProcessStartInput_Good(t *testing.T) {
|
||||
input := ProcessStartInput{
|
||||
Command: "echo",
|
||||
Args: []string{"hello", "world"},
|
||||
Dir: "/tmp",
|
||||
Env: []string{"FOO=bar"},
|
||||
}
|
||||
|
||||
if input.Command != "echo" {
|
||||
t.Errorf("Expected command 'echo', got %q", input.Command)
|
||||
}
|
||||
if len(input.Args) != 2 {
|
||||
t.Errorf("Expected 2 args, got %d", len(input.Args))
|
||||
}
|
||||
if input.Dir != "/tmp" {
|
||||
t.Errorf("Expected dir '/tmp', got %q", input.Dir)
|
||||
}
|
||||
if len(input.Env) != 1 {
|
||||
t.Errorf("Expected 1 env var, got %d", len(input.Env))
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessStartOutput_Good verifies the ProcessStartOutput struct has expected fields.
|
||||
func TestProcessStartOutput_Good(t *testing.T) {
|
||||
now := time.Now()
|
||||
output := ProcessStartOutput{
|
||||
ID: "proc-1",
|
||||
PID: 12345,
|
||||
Command: "echo",
|
||||
Args: []string{"hello"},
|
||||
StartedAt: now,
|
||||
}
|
||||
|
||||
if output.ID != "proc-1" {
|
||||
t.Errorf("Expected ID 'proc-1', got %q", output.ID)
|
||||
}
|
||||
if output.PID != 12345 {
|
||||
t.Errorf("Expected PID 12345, got %d", output.PID)
|
||||
}
|
||||
if output.Command != "echo" {
|
||||
t.Errorf("Expected command 'echo', got %q", output.Command)
|
||||
}
|
||||
if !output.StartedAt.Equal(now) {
|
||||
t.Errorf("Expected StartedAt %v, got %v", now, output.StartedAt)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessStopInput_Good verifies the ProcessStopInput struct has expected fields.
|
||||
func TestProcessStopInput_Good(t *testing.T) {
|
||||
input := ProcessStopInput{
|
||||
ID: "proc-1",
|
||||
}
|
||||
|
||||
if input.ID != "proc-1" {
|
||||
t.Errorf("Expected ID 'proc-1', got %q", input.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessStopOutput_Good verifies the ProcessStopOutput struct has expected fields.
|
||||
func TestProcessStopOutput_Good(t *testing.T) {
|
||||
output := ProcessStopOutput{
|
||||
ID: "proc-1",
|
||||
Success: true,
|
||||
Message: "Process stopped",
|
||||
}
|
||||
|
||||
if output.ID != "proc-1" {
|
||||
t.Errorf("Expected ID 'proc-1', got %q", output.ID)
|
||||
}
|
||||
if !output.Success {
|
||||
t.Error("Expected Success to be true")
|
||||
}
|
||||
if output.Message != "Process stopped" {
|
||||
t.Errorf("Expected message 'Process stopped', got %q", output.Message)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessKillInput_Good verifies the ProcessKillInput struct has expected fields.
|
||||
func TestProcessKillInput_Good(t *testing.T) {
|
||||
input := ProcessKillInput{
|
||||
ID: "proc-1",
|
||||
}
|
||||
|
||||
if input.ID != "proc-1" {
|
||||
t.Errorf("Expected ID 'proc-1', got %q", input.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessKillOutput_Good verifies the ProcessKillOutput struct has expected fields.
|
||||
func TestProcessKillOutput_Good(t *testing.T) {
|
||||
output := ProcessKillOutput{
|
||||
ID: "proc-1",
|
||||
Success: true,
|
||||
Message: "Process killed",
|
||||
}
|
||||
|
||||
if output.ID != "proc-1" {
|
||||
t.Errorf("Expected ID 'proc-1', got %q", output.ID)
|
||||
}
|
||||
if !output.Success {
|
||||
t.Error("Expected Success to be true")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessListInput_Good verifies the ProcessListInput struct has expected fields.
|
||||
func TestProcessListInput_Good(t *testing.T) {
|
||||
input := ProcessListInput{
|
||||
RunningOnly: true,
|
||||
}
|
||||
|
||||
if !input.RunningOnly {
|
||||
t.Error("Expected RunningOnly to be true")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessListInput_Defaults verifies default values.
|
||||
func TestProcessListInput_Defaults(t *testing.T) {
|
||||
input := ProcessListInput{}
|
||||
|
||||
if input.RunningOnly {
|
||||
t.Error("Expected RunningOnly to default to false")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessListOutput_Good verifies the ProcessListOutput struct has expected fields.
|
||||
func TestProcessListOutput_Good(t *testing.T) {
|
||||
now := time.Now()
|
||||
output := ProcessListOutput{
|
||||
Processes: []ProcessInfo{
|
||||
{
|
||||
ID: "proc-1",
|
||||
Command: "echo",
|
||||
Args: []string{"hello"},
|
||||
Dir: "/tmp",
|
||||
Status: "running",
|
||||
PID: 12345,
|
||||
ExitCode: 0,
|
||||
StartedAt: now,
|
||||
Duration: 5 * time.Second,
|
||||
},
|
||||
},
|
||||
Total: 1,
|
||||
}
|
||||
|
||||
if len(output.Processes) != 1 {
|
||||
t.Fatalf("Expected 1 process, got %d", len(output.Processes))
|
||||
}
|
||||
if output.Total != 1 {
|
||||
t.Errorf("Expected total 1, got %d", output.Total)
|
||||
}
|
||||
|
||||
proc := output.Processes[0]
|
||||
if proc.ID != "proc-1" {
|
||||
t.Errorf("Expected ID 'proc-1', got %q", proc.ID)
|
||||
}
|
||||
if proc.Status != "running" {
|
||||
t.Errorf("Expected status 'running', got %q", proc.Status)
|
||||
}
|
||||
if proc.PID != 12345 {
|
||||
t.Errorf("Expected PID 12345, got %d", proc.PID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessOutputInput_Good verifies the ProcessOutputInput struct has expected fields.
|
||||
func TestProcessOutputInput_Good(t *testing.T) {
|
||||
input := ProcessOutputInput{
|
||||
ID: "proc-1",
|
||||
}
|
||||
|
||||
if input.ID != "proc-1" {
|
||||
t.Errorf("Expected ID 'proc-1', got %q", input.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessOutputOutput_Good verifies the ProcessOutputOutput struct has expected fields.
|
||||
func TestProcessOutputOutput_Good(t *testing.T) {
|
||||
output := ProcessOutputOutput{
|
||||
ID: "proc-1",
|
||||
Output: "hello world\n",
|
||||
}
|
||||
|
||||
if output.ID != "proc-1" {
|
||||
t.Errorf("Expected ID 'proc-1', got %q", output.ID)
|
||||
}
|
||||
if output.Output != "hello world\n" {
|
||||
t.Errorf("Expected output 'hello world\\n', got %q", output.Output)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessInputInput_Good verifies the ProcessInputInput struct has expected fields.
|
||||
func TestProcessInputInput_Good(t *testing.T) {
|
||||
input := ProcessInputInput{
|
||||
ID: "proc-1",
|
||||
Input: "test input\n",
|
||||
}
|
||||
|
||||
if input.ID != "proc-1" {
|
||||
t.Errorf("Expected ID 'proc-1', got %q", input.ID)
|
||||
}
|
||||
if input.Input != "test input\n" {
|
||||
t.Errorf("Expected input 'test input\\n', got %q", input.Input)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessInputOutput_Good verifies the ProcessInputOutput struct has expected fields.
|
||||
func TestProcessInputOutput_Good(t *testing.T) {
|
||||
output := ProcessInputOutput{
|
||||
ID: "proc-1",
|
||||
Success: true,
|
||||
Message: "Input sent",
|
||||
}
|
||||
|
||||
if output.ID != "proc-1" {
|
||||
t.Errorf("Expected ID 'proc-1', got %q", output.ID)
|
||||
}
|
||||
if !output.Success {
|
||||
t.Error("Expected Success to be true")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessInfo_Good verifies the ProcessInfo struct has expected fields.
|
||||
func TestProcessInfo_Good(t *testing.T) {
|
||||
now := time.Now()
|
||||
info := ProcessInfo{
|
||||
ID: "proc-1",
|
||||
Command: "echo",
|
||||
Args: []string{"hello"},
|
||||
Dir: "/tmp",
|
||||
Status: "exited",
|
||||
PID: 12345,
|
||||
ExitCode: 0,
|
||||
StartedAt: now,
|
||||
Duration: 2 * time.Second,
|
||||
}
|
||||
|
||||
if info.ID != "proc-1" {
|
||||
t.Errorf("Expected ID 'proc-1', got %q", info.ID)
|
||||
}
|
||||
if info.Command != "echo" {
|
||||
t.Errorf("Expected command 'echo', got %q", info.Command)
|
||||
}
|
||||
if info.Status != "exited" {
|
||||
t.Errorf("Expected status 'exited', got %q", info.Status)
|
||||
}
|
||||
if info.ExitCode != 0 {
|
||||
t.Errorf("Expected exit code 0, got %d", info.ExitCode)
|
||||
}
|
||||
if info.Duration != 2*time.Second {
|
||||
t.Errorf("Expected duration 2s, got %v", info.Duration)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWithProcessService_Good verifies the WithProcessService option.
|
||||
func TestWithProcessService_Good(t *testing.T) {
|
||||
// Note: We can't easily create a real process.Service here without Core,
|
||||
// so we just verify the option doesn't panic with nil.
|
||||
s, err := New(WithProcessService(nil))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
if s.processService != nil {
|
||||
t.Error("Expected processService to be nil when passed nil")
|
||||
}
|
||||
}
|
||||
233
mcp/tools_rag.go
233
mcp/tools_rag.go
|
|
@ -1,233 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"forge.lthn.ai/core/go-rag"
|
||||
"forge.lthn.ai/core/go-log"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// Default values for RAG operations.
|
||||
const (
|
||||
DefaultRAGCollection = "hostuk-docs"
|
||||
DefaultRAGTopK = 5
|
||||
)
|
||||
|
||||
// RAGQueryInput contains parameters for querying the RAG vector database.
|
||||
type RAGQueryInput struct {
|
||||
Question string `json:"question"` // The question or search query
|
||||
Collection string `json:"collection,omitempty"` // Collection name (default: hostuk-docs)
|
||||
TopK int `json:"topK,omitempty"` // Number of results to return (default: 5)
|
||||
}
|
||||
|
||||
// RAGQueryResult represents a single query result.
|
||||
type RAGQueryResult struct {
|
||||
Content string `json:"content"`
|
||||
Source string `json:"source"`
|
||||
Section string `json:"section,omitempty"`
|
||||
Category string `json:"category,omitempty"`
|
||||
ChunkIndex int `json:"chunkIndex,omitempty"`
|
||||
Score float32 `json:"score"`
|
||||
}
|
||||
|
||||
// RAGQueryOutput contains the results of a RAG query.
|
||||
type RAGQueryOutput struct {
|
||||
Results []RAGQueryResult `json:"results"`
|
||||
Query string `json:"query"`
|
||||
Collection string `json:"collection"`
|
||||
Context string `json:"context"`
|
||||
}
|
||||
|
||||
// RAGIngestInput contains parameters for ingesting documents into the RAG database.
|
||||
type RAGIngestInput struct {
|
||||
Path string `json:"path"` // File or directory path to ingest
|
||||
Collection string `json:"collection,omitempty"` // Collection name (default: hostuk-docs)
|
||||
Recreate bool `json:"recreate,omitempty"` // Whether to recreate the collection
|
||||
}
|
||||
|
||||
// RAGIngestOutput contains the result of a RAG ingest operation.
|
||||
type RAGIngestOutput struct {
|
||||
Success bool `json:"success"`
|
||||
Path string `json:"path"`
|
||||
Collection string `json:"collection"`
|
||||
Chunks int `json:"chunks"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// RAGCollectionsInput contains parameters for listing collections.
|
||||
type RAGCollectionsInput struct {
|
||||
ShowStats bool `json:"show_stats,omitempty"` // Include collection stats (point count, status)
|
||||
}
|
||||
|
||||
// CollectionInfo contains information about a collection.
|
||||
type CollectionInfo struct {
|
||||
Name string `json:"name"`
|
||||
PointsCount uint64 `json:"points_count"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// RAGCollectionsOutput contains the list of available collections.
|
||||
type RAGCollectionsOutput struct {
|
||||
Collections []CollectionInfo `json:"collections"`
|
||||
}
|
||||
|
||||
// registerRAGTools adds RAG tools to the MCP server.
|
||||
func (s *Service) registerRAGTools(server *mcp.Server) {
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "rag_query",
|
||||
Description: "Query the RAG vector database for relevant documentation. Returns semantically similar content based on the query.",
|
||||
}, s.ragQuery)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "rag_ingest",
|
||||
Description: "Ingest documents into the RAG vector database. Supports both single files and directories.",
|
||||
}, s.ragIngest)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "rag_collections",
|
||||
Description: "List all available collections in the RAG vector database.",
|
||||
}, s.ragCollections)
|
||||
}
|
||||
|
||||
// ragQuery handles the rag_query tool call.
|
||||
func (s *Service) ragQuery(ctx context.Context, req *mcp.CallToolRequest, input RAGQueryInput) (*mcp.CallToolResult, RAGQueryOutput, error) {
|
||||
// Apply defaults
|
||||
collection := input.Collection
|
||||
if collection == "" {
|
||||
collection = DefaultRAGCollection
|
||||
}
|
||||
topK := input.TopK
|
||||
if topK <= 0 {
|
||||
topK = DefaultRAGTopK
|
||||
}
|
||||
|
||||
s.logger.Info("MCP tool execution", "tool", "rag_query", "question", input.Question, "collection", collection, "topK", topK, "user", log.Username())
|
||||
|
||||
// Validate input
|
||||
if input.Question == "" {
|
||||
return nil, RAGQueryOutput{}, errors.New("question cannot be empty")
|
||||
}
|
||||
|
||||
// Call the RAG query function
|
||||
results, err := rag.QueryDocs(ctx, input.Question, collection, topK)
|
||||
if err != nil {
|
||||
log.Error("mcp: rag query failed", "question", input.Question, "collection", collection, "err", err)
|
||||
return nil, RAGQueryOutput{}, fmt.Errorf("failed to query RAG: %w", err)
|
||||
}
|
||||
|
||||
// Convert results
|
||||
output := RAGQueryOutput{
|
||||
Results: make([]RAGQueryResult, len(results)),
|
||||
Query: input.Question,
|
||||
Collection: collection,
|
||||
Context: rag.FormatResultsContext(results),
|
||||
}
|
||||
for i, r := range results {
|
||||
output.Results[i] = RAGQueryResult{
|
||||
Content: r.Text,
|
||||
Source: r.Source,
|
||||
Section: r.Section,
|
||||
Category: r.Category,
|
||||
ChunkIndex: r.ChunkIndex,
|
||||
Score: r.Score,
|
||||
}
|
||||
}
|
||||
|
||||
return nil, output, nil
|
||||
}
|
||||
|
||||
// ragIngest handles the rag_ingest tool call.
|
||||
func (s *Service) ragIngest(ctx context.Context, req *mcp.CallToolRequest, input RAGIngestInput) (*mcp.CallToolResult, RAGIngestOutput, error) {
|
||||
// Apply defaults
|
||||
collection := input.Collection
|
||||
if collection == "" {
|
||||
collection = DefaultRAGCollection
|
||||
}
|
||||
|
||||
s.logger.Security("MCP tool execution", "tool", "rag_ingest", "path", input.Path, "collection", collection, "recreate", input.Recreate, "user", log.Username())
|
||||
|
||||
// Validate input
|
||||
if input.Path == "" {
|
||||
return nil, RAGIngestOutput{}, errors.New("path cannot be empty")
|
||||
}
|
||||
|
||||
// Check if path is a file or directory using the medium
|
||||
info, err := s.medium.Stat(input.Path)
|
||||
if err != nil {
|
||||
log.Error("mcp: rag ingest stat failed", "path", input.Path, "err", err)
|
||||
return nil, RAGIngestOutput{}, fmt.Errorf("failed to access path: %w", err)
|
||||
}
|
||||
|
||||
var message string
|
||||
var chunks int
|
||||
if info.IsDir() {
|
||||
// Ingest directory
|
||||
err = rag.IngestDirectory(ctx, input.Path, collection, input.Recreate)
|
||||
if err != nil {
|
||||
log.Error("mcp: rag ingest directory failed", "path", input.Path, "collection", collection, "err", err)
|
||||
return nil, RAGIngestOutput{}, fmt.Errorf("failed to ingest directory: %w", err)
|
||||
}
|
||||
message = fmt.Sprintf("Successfully ingested directory %s into collection %s", input.Path, collection)
|
||||
} else {
|
||||
// Ingest single file
|
||||
chunks, err = rag.IngestSingleFile(ctx, input.Path, collection)
|
||||
if err != nil {
|
||||
log.Error("mcp: rag ingest file failed", "path", input.Path, "collection", collection, "err", err)
|
||||
return nil, RAGIngestOutput{}, fmt.Errorf("failed to ingest file: %w", err)
|
||||
}
|
||||
message = fmt.Sprintf("Successfully ingested file %s (%d chunks) into collection %s", input.Path, chunks, collection)
|
||||
}
|
||||
|
||||
return nil, RAGIngestOutput{
|
||||
Success: true,
|
||||
Path: input.Path,
|
||||
Collection: collection,
|
||||
Chunks: chunks,
|
||||
Message: message,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ragCollections handles the rag_collections tool call.
|
||||
func (s *Service) ragCollections(ctx context.Context, req *mcp.CallToolRequest, input RAGCollectionsInput) (*mcp.CallToolResult, RAGCollectionsOutput, error) {
|
||||
s.logger.Info("MCP tool execution", "tool", "rag_collections", "show_stats", input.ShowStats, "user", log.Username())
|
||||
|
||||
// Create Qdrant client with default config
|
||||
qdrantClient, err := rag.NewQdrantClient(rag.DefaultQdrantConfig())
|
||||
if err != nil {
|
||||
log.Error("mcp: rag collections connect failed", "err", err)
|
||||
return nil, RAGCollectionsOutput{}, fmt.Errorf("failed to connect to Qdrant: %w", err)
|
||||
}
|
||||
defer func() { _ = qdrantClient.Close() }()
|
||||
|
||||
// List collections
|
||||
collectionNames, err := qdrantClient.ListCollections(ctx)
|
||||
if err != nil {
|
||||
log.Error("mcp: rag collections list failed", "err", err)
|
||||
return nil, RAGCollectionsOutput{}, fmt.Errorf("failed to list collections: %w", err)
|
||||
}
|
||||
|
||||
// Build collection info list
|
||||
collections := make([]CollectionInfo, len(collectionNames))
|
||||
for i, name := range collectionNames {
|
||||
collections[i] = CollectionInfo{Name: name}
|
||||
|
||||
// Fetch stats if requested
|
||||
if input.ShowStats {
|
||||
info, err := qdrantClient.CollectionInfo(ctx, name)
|
||||
if err != nil {
|
||||
log.Error("mcp: rag collection info failed", "collection", name, "err", err)
|
||||
// Continue with defaults on error
|
||||
continue
|
||||
}
|
||||
collections[i].PointsCount = info.PointCount
|
||||
collections[i].Status = info.Status
|
||||
}
|
||||
}
|
||||
|
||||
return nil, RAGCollectionsOutput{
|
||||
Collections: collections,
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -1,181 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// RAG tools use package-level functions (rag.QueryDocs, rag.IngestDirectory, etc.)
|
||||
// which require live Qdrant + Ollama services. Since those are not injectable,
|
||||
// we test handler input validation, default application, and struct behaviour
|
||||
// at the MCP handler level without requiring live services.
|
||||
|
||||
// --- ragQuery validation ---
|
||||
|
||||
// TestRagQuery_Bad_EmptyQuestion verifies empty question returns error.
|
||||
func TestRagQuery_Bad_EmptyQuestion(t *testing.T) {
|
||||
s, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err = s.ragQuery(ctx, nil, RAGQueryInput{})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty question")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "question cannot be empty") {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRagQuery_Good_DefaultsApplied verifies defaults are applied before validation.
|
||||
// Because the handler applies defaults then validates, a non-empty question with
|
||||
// zero Collection/TopK should have defaults applied. We cannot verify the actual
|
||||
// query (needs live Qdrant), but we can verify it gets past validation.
|
||||
func TestRagQuery_Good_DefaultsApplied(t *testing.T) {
|
||||
s, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
// This will fail when it tries to connect to Qdrant, but AFTER applying defaults.
|
||||
// The error should NOT be about empty question.
|
||||
_, _, err = s.ragQuery(ctx, nil, RAGQueryInput{Question: "test query"})
|
||||
if err == nil {
|
||||
t.Skip("RAG query succeeded — live Qdrant available, skip default test")
|
||||
}
|
||||
// The error should be about connection failure, not validation
|
||||
if strings.Contains(err.Error(), "question cannot be empty") {
|
||||
t.Error("Defaults should have been applied before validation check")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ragIngest validation ---
|
||||
|
||||
// TestRagIngest_Bad_EmptyPath verifies empty path returns error.
|
||||
func TestRagIngest_Bad_EmptyPath(t *testing.T) {
|
||||
s, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err = s.ragIngest(ctx, nil, RAGIngestInput{})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for empty path")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "path cannot be empty") {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRagIngest_Bad_NonexistentPath verifies nonexistent path returns error.
|
||||
func TestRagIngest_Bad_NonexistentPath(t *testing.T) {
|
||||
s, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err = s.ragIngest(ctx, nil, RAGIngestInput{
|
||||
Path: "/nonexistent/path/that/does/not/exist/at/all",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for nonexistent path")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRagIngest_Good_DefaultCollection verifies the default collection is applied.
|
||||
func TestRagIngest_Good_DefaultCollection(t *testing.T) {
|
||||
s, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
// Use a real but inaccessible path to trigger stat error (not validation error).
|
||||
// The collection default should be applied first.
|
||||
_, _, err = s.ragIngest(ctx, nil, RAGIngestInput{
|
||||
Path: "/nonexistent/path/for/default/test",
|
||||
})
|
||||
if err == nil {
|
||||
t.Skip("Ingest succeeded unexpectedly")
|
||||
}
|
||||
// The error should NOT be about empty path
|
||||
if strings.Contains(err.Error(), "path cannot be empty") {
|
||||
t.Error("Default collection should have been applied")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ragCollections validation ---
|
||||
|
||||
// TestRagCollections_Bad_NoQdrant verifies graceful error when Qdrant is not available.
|
||||
func TestRagCollections_Bad_NoQdrant(t *testing.T) {
|
||||
s, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
_, _, err = s.ragCollections(ctx, nil, RAGCollectionsInput{})
|
||||
if err == nil {
|
||||
t.Skip("Qdrant is available — skip connection error test")
|
||||
}
|
||||
// Should get a connection error, not a panic
|
||||
if !strings.Contains(err.Error(), "failed to connect") && !strings.Contains(err.Error(), "failed to list") {
|
||||
t.Logf("Got error (expected connection failure): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Struct round-trip tests ---
|
||||
|
||||
// TestRAGQueryResult_Good_AllFields verifies all fields can be set and read.
|
||||
func TestRAGQueryResult_Good_AllFields(t *testing.T) {
|
||||
r := RAGQueryResult{
|
||||
Content: "test content",
|
||||
Source: "source.md",
|
||||
Section: "Overview",
|
||||
Category: "docs",
|
||||
ChunkIndex: 3,
|
||||
Score: 0.88,
|
||||
}
|
||||
|
||||
if r.Content != "test content" {
|
||||
t.Errorf("Expected content 'test content', got %q", r.Content)
|
||||
}
|
||||
if r.ChunkIndex != 3 {
|
||||
t.Errorf("Expected chunkIndex 3, got %d", r.ChunkIndex)
|
||||
}
|
||||
if r.Score != 0.88 {
|
||||
t.Errorf("Expected score 0.88, got %f", r.Score)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCollectionInfo_Good_AllFields verifies CollectionInfo field access.
|
||||
func TestCollectionInfo_Good_AllFields(t *testing.T) {
|
||||
c := CollectionInfo{
|
||||
Name: "test-collection",
|
||||
PointsCount: 12345,
|
||||
Status: "green",
|
||||
}
|
||||
|
||||
if c.Name != "test-collection" {
|
||||
t.Errorf("Expected name 'test-collection', got %q", c.Name)
|
||||
}
|
||||
if c.PointsCount != 12345 {
|
||||
t.Errorf("Expected PointsCount 12345, got %d", c.PointsCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRAGDefaults_Good verifies default constants are sensible.
|
||||
func TestRAGDefaults_Good(t *testing.T) {
|
||||
if DefaultRAGCollection != "hostuk-docs" {
|
||||
t.Errorf("Expected default collection 'hostuk-docs', got %q", DefaultRAGCollection)
|
||||
}
|
||||
if DefaultRAGTopK != 5 {
|
||||
t.Errorf("Expected default topK 5, got %d", DefaultRAGTopK)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,173 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestRAGToolsRegistered_Good verifies that RAG tools are registered with the MCP server.
|
||||
func TestRAGToolsRegistered_Good(t *testing.T) {
|
||||
// Create a new MCP service - this should register all tools including RAG
|
||||
s, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
// The server should have registered the RAG tools
|
||||
// We verify by checking that the tool handlers exist on the service
|
||||
// (The actual MCP registration is tested by the SDK)
|
||||
|
||||
if s.server == nil {
|
||||
t.Fatal("Server should not be nil")
|
||||
}
|
||||
|
||||
// Verify the service was created with expected defaults
|
||||
if s.logger == nil {
|
||||
t.Error("Logger should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRAGQueryInput_Good verifies the RAGQueryInput struct has expected fields.
|
||||
func TestRAGQueryInput_Good(t *testing.T) {
|
||||
input := RAGQueryInput{
|
||||
Question: "test question",
|
||||
Collection: "test-collection",
|
||||
TopK: 10,
|
||||
}
|
||||
|
||||
if input.Question != "test question" {
|
||||
t.Errorf("Expected question 'test question', got %q", input.Question)
|
||||
}
|
||||
if input.Collection != "test-collection" {
|
||||
t.Errorf("Expected collection 'test-collection', got %q", input.Collection)
|
||||
}
|
||||
if input.TopK != 10 {
|
||||
t.Errorf("Expected topK 10, got %d", input.TopK)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRAGQueryInput_Defaults verifies default values are handled correctly.
|
||||
func TestRAGQueryInput_Defaults(t *testing.T) {
|
||||
// Empty input should use defaults when processed
|
||||
input := RAGQueryInput{
|
||||
Question: "test",
|
||||
}
|
||||
|
||||
// Defaults should be applied in the handler, not in the struct
|
||||
if input.Collection != "" {
|
||||
t.Errorf("Expected empty collection before defaults, got %q", input.Collection)
|
||||
}
|
||||
if input.TopK != 0 {
|
||||
t.Errorf("Expected zero topK before defaults, got %d", input.TopK)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRAGIngestInput_Good verifies the RAGIngestInput struct has expected fields.
|
||||
func TestRAGIngestInput_Good(t *testing.T) {
|
||||
input := RAGIngestInput{
|
||||
Path: "/path/to/docs",
|
||||
Collection: "my-collection",
|
||||
Recreate: true,
|
||||
}
|
||||
|
||||
if input.Path != "/path/to/docs" {
|
||||
t.Errorf("Expected path '/path/to/docs', got %q", input.Path)
|
||||
}
|
||||
if input.Collection != "my-collection" {
|
||||
t.Errorf("Expected collection 'my-collection', got %q", input.Collection)
|
||||
}
|
||||
if !input.Recreate {
|
||||
t.Error("Expected recreate to be true")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRAGCollectionsInput_Good verifies the RAGCollectionsInput struct exists.
|
||||
func TestRAGCollectionsInput_Good(t *testing.T) {
|
||||
// RAGCollectionsInput has optional ShowStats parameter
|
||||
input := RAGCollectionsInput{}
|
||||
if input.ShowStats {
|
||||
t.Error("Expected ShowStats to default to false")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRAGQueryOutput_Good verifies the RAGQueryOutput struct has expected fields.
|
||||
func TestRAGQueryOutput_Good(t *testing.T) {
|
||||
output := RAGQueryOutput{
|
||||
Results: []RAGQueryResult{
|
||||
{
|
||||
Content: "some content",
|
||||
Source: "doc.md",
|
||||
Section: "Introduction",
|
||||
Category: "docs",
|
||||
Score: 0.95,
|
||||
},
|
||||
},
|
||||
Query: "test query",
|
||||
Collection: "test-collection",
|
||||
Context: "<retrieved_context>...</retrieved_context>",
|
||||
}
|
||||
|
||||
if len(output.Results) != 1 {
|
||||
t.Fatalf("Expected 1 result, got %d", len(output.Results))
|
||||
}
|
||||
if output.Results[0].Content != "some content" {
|
||||
t.Errorf("Expected content 'some content', got %q", output.Results[0].Content)
|
||||
}
|
||||
if output.Results[0].Score != 0.95 {
|
||||
t.Errorf("Expected score 0.95, got %f", output.Results[0].Score)
|
||||
}
|
||||
if output.Context == "" {
|
||||
t.Error("Expected context to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRAGIngestOutput_Good verifies the RAGIngestOutput struct has expected fields.
|
||||
func TestRAGIngestOutput_Good(t *testing.T) {
|
||||
output := RAGIngestOutput{
|
||||
Success: true,
|
||||
Path: "/path/to/docs",
|
||||
Collection: "my-collection",
|
||||
Chunks: 10,
|
||||
Message: "Ingested successfully",
|
||||
}
|
||||
|
||||
if !output.Success {
|
||||
t.Error("Expected success to be true")
|
||||
}
|
||||
if output.Path != "/path/to/docs" {
|
||||
t.Errorf("Expected path '/path/to/docs', got %q", output.Path)
|
||||
}
|
||||
if output.Chunks != 10 {
|
||||
t.Errorf("Expected chunks 10, got %d", output.Chunks)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRAGCollectionsOutput_Good verifies the RAGCollectionsOutput struct has expected fields.
|
||||
func TestRAGCollectionsOutput_Good(t *testing.T) {
|
||||
output := RAGCollectionsOutput{
|
||||
Collections: []CollectionInfo{
|
||||
{Name: "collection1", PointsCount: 100, Status: "green"},
|
||||
{Name: "collection2", PointsCount: 200, Status: "green"},
|
||||
},
|
||||
}
|
||||
|
||||
if len(output.Collections) != 2 {
|
||||
t.Fatalf("Expected 2 collections, got %d", len(output.Collections))
|
||||
}
|
||||
if output.Collections[0].Name != "collection1" {
|
||||
t.Errorf("Expected 'collection1', got %q", output.Collections[0].Name)
|
||||
}
|
||||
if output.Collections[0].PointsCount != 100 {
|
||||
t.Errorf("Expected PointsCount 100, got %d", output.Collections[0].PointsCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRAGCollectionsInput_Good verifies the RAGCollectionsInput struct has expected fields.
|
||||
func TestRAGCollectionsInput_ShowStats(t *testing.T) {
|
||||
input := RAGCollectionsInput{
|
||||
ShowStats: true,
|
||||
}
|
||||
|
||||
if !input.ShowStats {
|
||||
t.Error("Expected ShowStats to be true")
|
||||
}
|
||||
}
|
||||
|
|
@ -1,497 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/go-log"
|
||||
"forge.lthn.ai/core/go-webview"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// webviewInstance holds the current webview connection.
|
||||
// This is managed by the MCP service.
|
||||
var webviewInstance *webview.Webview
|
||||
|
||||
// Sentinel errors for webview tools.
|
||||
var (
|
||||
errNotConnected = errors.New("not connected; use webview_connect first")
|
||||
errSelectorRequired = errors.New("selector is required")
|
||||
)
|
||||
|
||||
// WebviewConnectInput contains parameters for connecting to Chrome DevTools.
|
||||
type WebviewConnectInput struct {
|
||||
DebugURL string `json:"debug_url"` // Chrome DevTools URL (e.g., http://localhost:9222)
|
||||
Timeout int `json:"timeout,omitempty"` // Default timeout in seconds (default: 30)
|
||||
}
|
||||
|
||||
// WebviewConnectOutput contains the result of connecting to Chrome.
|
||||
type WebviewConnectOutput struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// WebviewNavigateInput contains parameters for navigating to a URL.
|
||||
type WebviewNavigateInput struct {
|
||||
URL string `json:"url"` // URL to navigate to
|
||||
}
|
||||
|
||||
// WebviewNavigateOutput contains the result of navigation.
|
||||
type WebviewNavigateOutput struct {
|
||||
Success bool `json:"success"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// WebviewClickInput contains parameters for clicking an element.
|
||||
type WebviewClickInput struct {
|
||||
Selector string `json:"selector"` // CSS selector
|
||||
}
|
||||
|
||||
// WebviewClickOutput contains the result of a click action.
|
||||
type WebviewClickOutput struct {
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
// WebviewTypeInput contains parameters for typing text.
|
||||
type WebviewTypeInput struct {
|
||||
Selector string `json:"selector"` // CSS selector
|
||||
Text string `json:"text"` // Text to type
|
||||
}
|
||||
|
||||
// WebviewTypeOutput contains the result of a type action.
|
||||
type WebviewTypeOutput struct {
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
// WebviewQueryInput contains parameters for querying an element.
|
||||
type WebviewQueryInput struct {
|
||||
Selector string `json:"selector"` // CSS selector
|
||||
All bool `json:"all,omitempty"` // If true, return all matching elements
|
||||
}
|
||||
|
||||
// WebviewQueryOutput contains the result of a query.
|
||||
type WebviewQueryOutput struct {
|
||||
Found bool `json:"found"`
|
||||
Count int `json:"count"`
|
||||
Elements []WebviewElementInfo `json:"elements,omitempty"`
|
||||
}
|
||||
|
||||
// WebviewElementInfo represents information about a DOM element.
|
||||
type WebviewElementInfo struct {
|
||||
NodeID int `json:"nodeId"`
|
||||
TagName string `json:"tagName"`
|
||||
Attributes map[string]string `json:"attributes,omitempty"`
|
||||
BoundingBox *webview.BoundingBox `json:"boundingBox,omitempty"`
|
||||
}
|
||||
|
||||
// WebviewConsoleInput contains parameters for getting console output.
|
||||
type WebviewConsoleInput struct {
|
||||
Clear bool `json:"clear,omitempty"` // If true, clear console after getting messages
|
||||
}
|
||||
|
||||
// WebviewConsoleOutput contains console messages.
|
||||
type WebviewConsoleOutput struct {
|
||||
Messages []WebviewConsoleMessage `json:"messages"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
// WebviewConsoleMessage represents a console message.
|
||||
type WebviewConsoleMessage struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Line int `json:"line,omitempty"`
|
||||
}
|
||||
|
||||
// WebviewEvalInput contains parameters for evaluating JavaScript.
|
||||
type WebviewEvalInput struct {
|
||||
Script string `json:"script"` // JavaScript to evaluate
|
||||
}
|
||||
|
||||
// WebviewEvalOutput contains the result of JavaScript evaluation.
|
||||
type WebviewEvalOutput struct {
|
||||
Success bool `json:"success"`
|
||||
Result any `json:"result,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// WebviewScreenshotInput contains parameters for taking a screenshot.
|
||||
type WebviewScreenshotInput struct {
|
||||
Format string `json:"format,omitempty"` // "png" or "jpeg" (default: png)
|
||||
}
|
||||
|
||||
// WebviewScreenshotOutput contains the screenshot data.
|
||||
type WebviewScreenshotOutput struct {
|
||||
Success bool `json:"success"`
|
||||
Data string `json:"data"` // Base64 encoded image
|
||||
Format string `json:"format"`
|
||||
}
|
||||
|
||||
// WebviewWaitInput contains parameters for waiting operations.
|
||||
type WebviewWaitInput struct {
|
||||
Selector string `json:"selector,omitempty"` // Wait for selector
|
||||
Timeout int `json:"timeout,omitempty"` // Timeout in seconds
|
||||
}
|
||||
|
||||
// WebviewWaitOutput contains the result of waiting.
|
||||
type WebviewWaitOutput struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// WebviewDisconnectInput contains parameters for disconnecting.
|
||||
type WebviewDisconnectInput struct{}
|
||||
|
||||
// WebviewDisconnectOutput contains the result of disconnecting.
|
||||
type WebviewDisconnectOutput struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// registerWebviewTools adds webview tools to the MCP server.
|
||||
func (s *Service) registerWebviewTools(server *mcp.Server) {
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "webview_connect",
|
||||
Description: "Connect to Chrome DevTools Protocol. Start Chrome with --remote-debugging-port=9222 first.",
|
||||
}, s.webviewConnect)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "webview_disconnect",
|
||||
Description: "Disconnect from Chrome DevTools.",
|
||||
}, s.webviewDisconnect)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "webview_navigate",
|
||||
Description: "Navigate the browser to a URL.",
|
||||
}, s.webviewNavigate)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "webview_click",
|
||||
Description: "Click on an element by CSS selector.",
|
||||
}, s.webviewClick)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "webview_type",
|
||||
Description: "Type text into an element by CSS selector.",
|
||||
}, s.webviewType)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "webview_query",
|
||||
Description: "Query DOM elements by CSS selector.",
|
||||
}, s.webviewQuery)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "webview_console",
|
||||
Description: "Get browser console output.",
|
||||
}, s.webviewConsole)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "webview_eval",
|
||||
Description: "Evaluate JavaScript in the browser context.",
|
||||
}, s.webviewEval)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "webview_screenshot",
|
||||
Description: "Capture a screenshot of the browser window.",
|
||||
}, s.webviewScreenshot)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "webview_wait",
|
||||
Description: "Wait for an element to appear by CSS selector.",
|
||||
}, s.webviewWait)
|
||||
}
|
||||
|
||||
// webviewConnect handles the webview_connect tool call.
|
||||
func (s *Service) webviewConnect(ctx context.Context, req *mcp.CallToolRequest, input WebviewConnectInput) (*mcp.CallToolResult, WebviewConnectOutput, error) {
|
||||
s.logger.Security("MCP tool execution", "tool", "webview_connect", "debug_url", input.DebugURL, "user", log.Username())
|
||||
|
||||
if input.DebugURL == "" {
|
||||
return nil, WebviewConnectOutput{}, errors.New("debug_url is required")
|
||||
}
|
||||
|
||||
// Close existing connection if any
|
||||
if webviewInstance != nil {
|
||||
_ = webviewInstance.Close()
|
||||
webviewInstance = nil
|
||||
}
|
||||
|
||||
// Set up options
|
||||
opts := []webview.Option{
|
||||
webview.WithDebugURL(input.DebugURL),
|
||||
}
|
||||
|
||||
if input.Timeout > 0 {
|
||||
opts = append(opts, webview.WithTimeout(time.Duration(input.Timeout)*time.Second))
|
||||
}
|
||||
|
||||
// Create new webview instance
|
||||
wv, err := webview.New(opts...)
|
||||
if err != nil {
|
||||
log.Error("mcp: webview connect failed", "debug_url", input.DebugURL, "err", err)
|
||||
return nil, WebviewConnectOutput{}, fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
webviewInstance = wv
|
||||
|
||||
return nil, WebviewConnectOutput{
|
||||
Success: true,
|
||||
Message: fmt.Sprintf("Connected to Chrome DevTools at %s", input.DebugURL),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// webviewDisconnect handles the webview_disconnect tool call.
|
||||
func (s *Service) webviewDisconnect(ctx context.Context, req *mcp.CallToolRequest, input WebviewDisconnectInput) (*mcp.CallToolResult, WebviewDisconnectOutput, error) {
|
||||
s.logger.Info("MCP tool execution", "tool", "webview_disconnect", "user", log.Username())
|
||||
|
||||
if webviewInstance == nil {
|
||||
return nil, WebviewDisconnectOutput{
|
||||
Success: true,
|
||||
Message: "No active connection",
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := webviewInstance.Close(); err != nil {
|
||||
log.Error("mcp: webview disconnect failed", "err", err)
|
||||
return nil, WebviewDisconnectOutput{}, fmt.Errorf("failed to disconnect: %w", err)
|
||||
}
|
||||
|
||||
webviewInstance = nil
|
||||
|
||||
return nil, WebviewDisconnectOutput{
|
||||
Success: true,
|
||||
Message: "Disconnected from Chrome DevTools",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// webviewNavigate handles the webview_navigate tool call.
|
||||
func (s *Service) webviewNavigate(ctx context.Context, req *mcp.CallToolRequest, input WebviewNavigateInput) (*mcp.CallToolResult, WebviewNavigateOutput, error) {
|
||||
s.logger.Info("MCP tool execution", "tool", "webview_navigate", "url", input.URL, "user", log.Username())
|
||||
|
||||
if webviewInstance == nil {
|
||||
return nil, WebviewNavigateOutput{}, errNotConnected
|
||||
}
|
||||
|
||||
if input.URL == "" {
|
||||
return nil, WebviewNavigateOutput{}, errors.New("url is required")
|
||||
}
|
||||
|
||||
if err := webviewInstance.Navigate(input.URL); err != nil {
|
||||
log.Error("mcp: webview navigate failed", "url", input.URL, "err", err)
|
||||
return nil, WebviewNavigateOutput{}, fmt.Errorf("failed to navigate: %w", err)
|
||||
}
|
||||
|
||||
return nil, WebviewNavigateOutput{
|
||||
Success: true,
|
||||
URL: input.URL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// webviewClick handles the webview_click tool call.
|
||||
func (s *Service) webviewClick(ctx context.Context, req *mcp.CallToolRequest, input WebviewClickInput) (*mcp.CallToolResult, WebviewClickOutput, error) {
|
||||
s.logger.Info("MCP tool execution", "tool", "webview_click", "selector", input.Selector, "user", log.Username())
|
||||
|
||||
if webviewInstance == nil {
|
||||
return nil, WebviewClickOutput{}, errNotConnected
|
||||
}
|
||||
|
||||
if input.Selector == "" {
|
||||
return nil, WebviewClickOutput{}, errSelectorRequired
|
||||
}
|
||||
|
||||
if err := webviewInstance.Click(input.Selector); err != nil {
|
||||
log.Error("mcp: webview click failed", "selector", input.Selector, "err", err)
|
||||
return nil, WebviewClickOutput{}, fmt.Errorf("failed to click: %w", err)
|
||||
}
|
||||
|
||||
return nil, WebviewClickOutput{Success: true}, nil
|
||||
}
|
||||
|
||||
// webviewType handles the webview_type tool call.
|
||||
func (s *Service) webviewType(ctx context.Context, req *mcp.CallToolRequest, input WebviewTypeInput) (*mcp.CallToolResult, WebviewTypeOutput, error) {
|
||||
s.logger.Info("MCP tool execution", "tool", "webview_type", "selector", input.Selector, "user", log.Username())
|
||||
|
||||
if webviewInstance == nil {
|
||||
return nil, WebviewTypeOutput{}, errNotConnected
|
||||
}
|
||||
|
||||
if input.Selector == "" {
|
||||
return nil, WebviewTypeOutput{}, errSelectorRequired
|
||||
}
|
||||
|
||||
if err := webviewInstance.Type(input.Selector, input.Text); err != nil {
|
||||
log.Error("mcp: webview type failed", "selector", input.Selector, "err", err)
|
||||
return nil, WebviewTypeOutput{}, fmt.Errorf("failed to type: %w", err)
|
||||
}
|
||||
|
||||
return nil, WebviewTypeOutput{Success: true}, nil
|
||||
}
|
||||
|
||||
// webviewQuery handles the webview_query tool call.
|
||||
func (s *Service) webviewQuery(ctx context.Context, req *mcp.CallToolRequest, input WebviewQueryInput) (*mcp.CallToolResult, WebviewQueryOutput, error) {
|
||||
s.logger.Info("MCP tool execution", "tool", "webview_query", "selector", input.Selector, "all", input.All, "user", log.Username())
|
||||
|
||||
if webviewInstance == nil {
|
||||
return nil, WebviewQueryOutput{}, errNotConnected
|
||||
}
|
||||
|
||||
if input.Selector == "" {
|
||||
return nil, WebviewQueryOutput{}, errSelectorRequired
|
||||
}
|
||||
|
||||
if input.All {
|
||||
elements, err := webviewInstance.QuerySelectorAll(input.Selector)
|
||||
if err != nil {
|
||||
log.Error("mcp: webview query all failed", "selector", input.Selector, "err", err)
|
||||
return nil, WebviewQueryOutput{}, fmt.Errorf("failed to query: %w", err)
|
||||
}
|
||||
|
||||
output := WebviewQueryOutput{
|
||||
Found: len(elements) > 0,
|
||||
Count: len(elements),
|
||||
Elements: make([]WebviewElementInfo, len(elements)),
|
||||
}
|
||||
|
||||
for i, elem := range elements {
|
||||
output.Elements[i] = WebviewElementInfo{
|
||||
NodeID: elem.NodeID,
|
||||
TagName: elem.TagName,
|
||||
Attributes: elem.Attributes,
|
||||
BoundingBox: elem.BoundingBox,
|
||||
}
|
||||
}
|
||||
|
||||
return nil, output, nil
|
||||
}
|
||||
|
||||
elem, err := webviewInstance.QuerySelector(input.Selector)
|
||||
if err != nil {
|
||||
// Element not found is not necessarily an error
|
||||
return nil, WebviewQueryOutput{
|
||||
Found: false,
|
||||
Count: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, WebviewQueryOutput{
|
||||
Found: true,
|
||||
Count: 1,
|
||||
Elements: []WebviewElementInfo{{
|
||||
NodeID: elem.NodeID,
|
||||
TagName: elem.TagName,
|
||||
Attributes: elem.Attributes,
|
||||
BoundingBox: elem.BoundingBox,
|
||||
}},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// webviewConsole handles the webview_console tool call.
|
||||
func (s *Service) webviewConsole(ctx context.Context, req *mcp.CallToolRequest, input WebviewConsoleInput) (*mcp.CallToolResult, WebviewConsoleOutput, error) {
|
||||
s.logger.Info("MCP tool execution", "tool", "webview_console", "clear", input.Clear, "user", log.Username())
|
||||
|
||||
if webviewInstance == nil {
|
||||
return nil, WebviewConsoleOutput{}, errNotConnected
|
||||
}
|
||||
|
||||
messages := webviewInstance.GetConsole()
|
||||
|
||||
output := WebviewConsoleOutput{
|
||||
Messages: make([]WebviewConsoleMessage, len(messages)),
|
||||
Count: len(messages),
|
||||
}
|
||||
|
||||
for i, msg := range messages {
|
||||
output.Messages[i] = WebviewConsoleMessage{
|
||||
Type: msg.Type,
|
||||
Text: msg.Text,
|
||||
Timestamp: msg.Timestamp.Format(time.RFC3339),
|
||||
URL: msg.URL,
|
||||
Line: msg.Line,
|
||||
}
|
||||
}
|
||||
|
||||
if input.Clear {
|
||||
webviewInstance.ClearConsole()
|
||||
}
|
||||
|
||||
return nil, output, nil
|
||||
}
|
||||
|
||||
// webviewEval handles the webview_eval tool call.
|
||||
func (s *Service) webviewEval(ctx context.Context, req *mcp.CallToolRequest, input WebviewEvalInput) (*mcp.CallToolResult, WebviewEvalOutput, error) {
|
||||
s.logger.Security("MCP tool execution", "tool", "webview_eval", "user", log.Username())
|
||||
|
||||
if webviewInstance == nil {
|
||||
return nil, WebviewEvalOutput{}, errNotConnected
|
||||
}
|
||||
|
||||
if input.Script == "" {
|
||||
return nil, WebviewEvalOutput{}, errors.New("script is required")
|
||||
}
|
||||
|
||||
result, err := webviewInstance.Evaluate(input.Script)
|
||||
if err != nil {
|
||||
log.Error("mcp: webview eval failed", "err", err)
|
||||
return nil, WebviewEvalOutput{
|
||||
Success: false,
|
||||
Error: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, WebviewEvalOutput{
|
||||
Success: true,
|
||||
Result: result,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// webviewScreenshot handles the webview_screenshot tool call.
|
||||
func (s *Service) webviewScreenshot(ctx context.Context, req *mcp.CallToolRequest, input WebviewScreenshotInput) (*mcp.CallToolResult, WebviewScreenshotOutput, error) {
|
||||
s.logger.Info("MCP tool execution", "tool", "webview_screenshot", "format", input.Format, "user", log.Username())
|
||||
|
||||
if webviewInstance == nil {
|
||||
return nil, WebviewScreenshotOutput{}, errNotConnected
|
||||
}
|
||||
|
||||
format := input.Format
|
||||
if format == "" {
|
||||
format = "png"
|
||||
}
|
||||
|
||||
data, err := webviewInstance.Screenshot()
|
||||
if err != nil {
|
||||
log.Error("mcp: webview screenshot failed", "err", err)
|
||||
return nil, WebviewScreenshotOutput{}, fmt.Errorf("failed to capture screenshot: %w", err)
|
||||
}
|
||||
|
||||
return nil, WebviewScreenshotOutput{
|
||||
Success: true,
|
||||
Data: base64.StdEncoding.EncodeToString(data),
|
||||
Format: format,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// webviewWait handles the webview_wait tool call.
|
||||
func (s *Service) webviewWait(ctx context.Context, req *mcp.CallToolRequest, input WebviewWaitInput) (*mcp.CallToolResult, WebviewWaitOutput, error) {
|
||||
s.logger.Info("MCP tool execution", "tool", "webview_wait", "selector", input.Selector, "timeout", input.Timeout, "user", log.Username())
|
||||
|
||||
if webviewInstance == nil {
|
||||
return nil, WebviewWaitOutput{}, errNotConnected
|
||||
}
|
||||
|
||||
if input.Selector == "" {
|
||||
return nil, WebviewWaitOutput{}, errSelectorRequired
|
||||
}
|
||||
|
||||
if err := webviewInstance.WaitForSelector(input.Selector); err != nil {
|
||||
log.Error("mcp: webview wait failed", "selector", input.Selector, "err", err)
|
||||
return nil, WebviewWaitOutput{}, fmt.Errorf("failed to wait for selector: %w", err)
|
||||
}
|
||||
|
||||
return nil, WebviewWaitOutput{
|
||||
Success: true,
|
||||
Message: fmt.Sprintf("Element found: %s", input.Selector),
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -1,452 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/go-webview"
|
||||
)
|
||||
|
||||
// skipIfShort skips webview tests in short mode (go test -short).
|
||||
// Webview tool handlers require a running Chrome instance with
|
||||
// --remote-debugging-port, which is not available in CI.
|
||||
// Struct-level tests below are safe without Chrome, but any future
|
||||
// tests that call webview tool handlers MUST use this guard.
|
||||
func skipIfShort(t *testing.T) {
|
||||
t.Helper()
|
||||
if testing.Short() {
|
||||
t.Skip("webview tests skipped in short mode (no Chrome available)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewToolsRegistered_Good verifies that webview tools are registered with the MCP server.
|
||||
func TestWebviewToolsRegistered_Good(t *testing.T) {
|
||||
// Create a new MCP service - this should register all tools including webview
|
||||
s, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
// The server should have registered the webview tools
|
||||
if s.server == nil {
|
||||
t.Fatal("Server should not be nil")
|
||||
}
|
||||
|
||||
// Verify the service was created with expected defaults
|
||||
if s.logger == nil {
|
||||
t.Error("Logger should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewToolHandlers_RequiresChrome demonstrates the CI guard
|
||||
// for tests that would require a running Chrome instance. Any future
|
||||
// test that calls webview tool handlers (webviewConnect, webviewNavigate,
|
||||
// etc.) should call skipIfShort(t) at the top.
|
||||
func TestWebviewToolHandlers_RequiresChrome(t *testing.T) {
|
||||
skipIfShort(t)
|
||||
|
||||
// This test verifies that webview tool handlers correctly reject
|
||||
// calls when not connected to Chrome.
|
||||
tmpDir := t.TempDir()
|
||||
s, err := New(WithWorkspaceRoot(tmpDir))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
// webview_navigate should fail without a connection
|
||||
_, _, err = s.webviewNavigate(ctx, nil, WebviewNavigateInput{URL: "https://example.com"})
|
||||
if err == nil {
|
||||
t.Error("Expected error when navigating without a webview connection")
|
||||
}
|
||||
|
||||
// webview_click should fail without a connection
|
||||
_, _, err = s.webviewClick(ctx, nil, WebviewClickInput{Selector: "#btn"})
|
||||
if err == nil {
|
||||
t.Error("Expected error when clicking without a webview connection")
|
||||
}
|
||||
|
||||
// webview_eval should fail without a connection
|
||||
_, _, err = s.webviewEval(ctx, nil, WebviewEvalInput{Script: "1+1"})
|
||||
if err == nil {
|
||||
t.Error("Expected error when evaluating without a webview connection")
|
||||
}
|
||||
|
||||
// webview_connect with invalid URL should fail
|
||||
_, _, err = s.webviewConnect(ctx, nil, WebviewConnectInput{DebugURL: ""})
|
||||
if err == nil {
|
||||
t.Error("Expected error when connecting with empty debug URL")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewConnectInput_Good verifies the WebviewConnectInput struct has expected fields.
|
||||
func TestWebviewConnectInput_Good(t *testing.T) {
|
||||
input := WebviewConnectInput{
|
||||
DebugURL: "http://localhost:9222",
|
||||
Timeout: 30,
|
||||
}
|
||||
|
||||
if input.DebugURL != "http://localhost:9222" {
|
||||
t.Errorf("Expected debug_url 'http://localhost:9222', got %q", input.DebugURL)
|
||||
}
|
||||
if input.Timeout != 30 {
|
||||
t.Errorf("Expected timeout 30, got %d", input.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewNavigateInput_Good verifies the WebviewNavigateInput struct has expected fields.
|
||||
func TestWebviewNavigateInput_Good(t *testing.T) {
|
||||
input := WebviewNavigateInput{
|
||||
URL: "https://example.com",
|
||||
}
|
||||
|
||||
if input.URL != "https://example.com" {
|
||||
t.Errorf("Expected URL 'https://example.com', got %q", input.URL)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewClickInput_Good verifies the WebviewClickInput struct has expected fields.
|
||||
func TestWebviewClickInput_Good(t *testing.T) {
|
||||
input := WebviewClickInput{
|
||||
Selector: "#submit-button",
|
||||
}
|
||||
|
||||
if input.Selector != "#submit-button" {
|
||||
t.Errorf("Expected selector '#submit-button', got %q", input.Selector)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewTypeInput_Good verifies the WebviewTypeInput struct has expected fields.
|
||||
func TestWebviewTypeInput_Good(t *testing.T) {
|
||||
input := WebviewTypeInput{
|
||||
Selector: "#email-input",
|
||||
Text: "test@example.com",
|
||||
}
|
||||
|
||||
if input.Selector != "#email-input" {
|
||||
t.Errorf("Expected selector '#email-input', got %q", input.Selector)
|
||||
}
|
||||
if input.Text != "test@example.com" {
|
||||
t.Errorf("Expected text 'test@example.com', got %q", input.Text)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewQueryInput_Good verifies the WebviewQueryInput struct has expected fields.
|
||||
func TestWebviewQueryInput_Good(t *testing.T) {
|
||||
input := WebviewQueryInput{
|
||||
Selector: "div.container",
|
||||
All: true,
|
||||
}
|
||||
|
||||
if input.Selector != "div.container" {
|
||||
t.Errorf("Expected selector 'div.container', got %q", input.Selector)
|
||||
}
|
||||
if !input.All {
|
||||
t.Error("Expected all to be true")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewQueryInput_Defaults verifies default values are handled correctly.
|
||||
func TestWebviewQueryInput_Defaults(t *testing.T) {
|
||||
input := WebviewQueryInput{
|
||||
Selector: ".test",
|
||||
}
|
||||
|
||||
if input.All {
|
||||
t.Error("Expected all to default to false")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewConsoleInput_Good verifies the WebviewConsoleInput struct has expected fields.
|
||||
func TestWebviewConsoleInput_Good(t *testing.T) {
|
||||
input := WebviewConsoleInput{
|
||||
Clear: true,
|
||||
}
|
||||
|
||||
if !input.Clear {
|
||||
t.Error("Expected clear to be true")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewEvalInput_Good verifies the WebviewEvalInput struct has expected fields.
|
||||
func TestWebviewEvalInput_Good(t *testing.T) {
|
||||
input := WebviewEvalInput{
|
||||
Script: "document.title",
|
||||
}
|
||||
|
||||
if input.Script != "document.title" {
|
||||
t.Errorf("Expected script 'document.title', got %q", input.Script)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewScreenshotInput_Good verifies the WebviewScreenshotInput struct has expected fields.
|
||||
func TestWebviewScreenshotInput_Good(t *testing.T) {
|
||||
input := WebviewScreenshotInput{
|
||||
Format: "png",
|
||||
}
|
||||
|
||||
if input.Format != "png" {
|
||||
t.Errorf("Expected format 'png', got %q", input.Format)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewScreenshotInput_Defaults verifies default values are handled correctly.
|
||||
func TestWebviewScreenshotInput_Defaults(t *testing.T) {
|
||||
input := WebviewScreenshotInput{}
|
||||
|
||||
if input.Format != "" {
|
||||
t.Errorf("Expected format to default to empty, got %q", input.Format)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewWaitInput_Good verifies the WebviewWaitInput struct has expected fields.
|
||||
func TestWebviewWaitInput_Good(t *testing.T) {
|
||||
input := WebviewWaitInput{
|
||||
Selector: "#loading",
|
||||
Timeout: 10,
|
||||
}
|
||||
|
||||
if input.Selector != "#loading" {
|
||||
t.Errorf("Expected selector '#loading', got %q", input.Selector)
|
||||
}
|
||||
if input.Timeout != 10 {
|
||||
t.Errorf("Expected timeout 10, got %d", input.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewConnectOutput_Good verifies the WebviewConnectOutput struct has expected fields.
|
||||
func TestWebviewConnectOutput_Good(t *testing.T) {
|
||||
output := WebviewConnectOutput{
|
||||
Success: true,
|
||||
Message: "Connected to Chrome DevTools",
|
||||
}
|
||||
|
||||
if !output.Success {
|
||||
t.Error("Expected success to be true")
|
||||
}
|
||||
if output.Message == "" {
|
||||
t.Error("Expected message to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewNavigateOutput_Good verifies the WebviewNavigateOutput struct has expected fields.
|
||||
func TestWebviewNavigateOutput_Good(t *testing.T) {
|
||||
output := WebviewNavigateOutput{
|
||||
Success: true,
|
||||
URL: "https://example.com",
|
||||
}
|
||||
|
||||
if !output.Success {
|
||||
t.Error("Expected success to be true")
|
||||
}
|
||||
if output.URL != "https://example.com" {
|
||||
t.Errorf("Expected URL 'https://example.com', got %q", output.URL)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewQueryOutput_Good verifies the WebviewQueryOutput struct has expected fields.
|
||||
func TestWebviewQueryOutput_Good(t *testing.T) {
|
||||
output := WebviewQueryOutput{
|
||||
Found: true,
|
||||
Count: 3,
|
||||
Elements: []WebviewElementInfo{
|
||||
{
|
||||
NodeID: 1,
|
||||
TagName: "DIV",
|
||||
Attributes: map[string]string{
|
||||
"class": "container",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if !output.Found {
|
||||
t.Error("Expected found to be true")
|
||||
}
|
||||
if output.Count != 3 {
|
||||
t.Errorf("Expected count 3, got %d", output.Count)
|
||||
}
|
||||
if len(output.Elements) != 1 {
|
||||
t.Fatalf("Expected 1 element, got %d", len(output.Elements))
|
||||
}
|
||||
if output.Elements[0].TagName != "DIV" {
|
||||
t.Errorf("Expected tagName 'DIV', got %q", output.Elements[0].TagName)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewConsoleOutput_Good verifies the WebviewConsoleOutput struct has expected fields.
|
||||
func TestWebviewConsoleOutput_Good(t *testing.T) {
|
||||
output := WebviewConsoleOutput{
|
||||
Messages: []WebviewConsoleMessage{
|
||||
{
|
||||
Type: "log",
|
||||
Text: "Hello, world!",
|
||||
Timestamp: "2024-01-01T00:00:00Z",
|
||||
},
|
||||
{
|
||||
Type: "error",
|
||||
Text: "An error occurred",
|
||||
Timestamp: "2024-01-01T00:00:01Z",
|
||||
URL: "https://example.com/script.js",
|
||||
Line: 42,
|
||||
},
|
||||
},
|
||||
Count: 2,
|
||||
}
|
||||
|
||||
if output.Count != 2 {
|
||||
t.Errorf("Expected count 2, got %d", output.Count)
|
||||
}
|
||||
if len(output.Messages) != 2 {
|
||||
t.Fatalf("Expected 2 messages, got %d", len(output.Messages))
|
||||
}
|
||||
if output.Messages[0].Type != "log" {
|
||||
t.Errorf("Expected type 'log', got %q", output.Messages[0].Type)
|
||||
}
|
||||
if output.Messages[1].Line != 42 {
|
||||
t.Errorf("Expected line 42, got %d", output.Messages[1].Line)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewEvalOutput_Good verifies the WebviewEvalOutput struct has expected fields.
|
||||
func TestWebviewEvalOutput_Good(t *testing.T) {
|
||||
output := WebviewEvalOutput{
|
||||
Success: true,
|
||||
Result: "Example Page",
|
||||
}
|
||||
|
||||
if !output.Success {
|
||||
t.Error("Expected success to be true")
|
||||
}
|
||||
if output.Result != "Example Page" {
|
||||
t.Errorf("Expected result 'Example Page', got %v", output.Result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewEvalOutput_Error verifies the WebviewEvalOutput struct handles errors.
|
||||
func TestWebviewEvalOutput_Error(t *testing.T) {
|
||||
output := WebviewEvalOutput{
|
||||
Success: false,
|
||||
Error: "ReferenceError: foo is not defined",
|
||||
}
|
||||
|
||||
if output.Success {
|
||||
t.Error("Expected success to be false")
|
||||
}
|
||||
if output.Error == "" {
|
||||
t.Error("Expected error message to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewScreenshotOutput_Good verifies the WebviewScreenshotOutput struct has expected fields.
|
||||
func TestWebviewScreenshotOutput_Good(t *testing.T) {
|
||||
output := WebviewScreenshotOutput{
|
||||
Success: true,
|
||||
Data: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
|
||||
Format: "png",
|
||||
}
|
||||
|
||||
if !output.Success {
|
||||
t.Error("Expected success to be true")
|
||||
}
|
||||
if output.Data == "" {
|
||||
t.Error("Expected data to be set")
|
||||
}
|
||||
if output.Format != "png" {
|
||||
t.Errorf("Expected format 'png', got %q", output.Format)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewElementInfo_Good verifies the WebviewElementInfo struct has expected fields.
|
||||
func TestWebviewElementInfo_Good(t *testing.T) {
|
||||
elem := WebviewElementInfo{
|
||||
NodeID: 123,
|
||||
TagName: "INPUT",
|
||||
Attributes: map[string]string{
|
||||
"type": "text",
|
||||
"name": "email",
|
||||
"class": "form-control",
|
||||
},
|
||||
BoundingBox: &webview.BoundingBox{
|
||||
X: 100,
|
||||
Y: 200,
|
||||
Width: 300,
|
||||
Height: 50,
|
||||
},
|
||||
}
|
||||
|
||||
if elem.NodeID != 123 {
|
||||
t.Errorf("Expected nodeId 123, got %d", elem.NodeID)
|
||||
}
|
||||
if elem.TagName != "INPUT" {
|
||||
t.Errorf("Expected tagName 'INPUT', got %q", elem.TagName)
|
||||
}
|
||||
if elem.Attributes["type"] != "text" {
|
||||
t.Errorf("Expected type attribute 'text', got %q", elem.Attributes["type"])
|
||||
}
|
||||
if elem.BoundingBox == nil {
|
||||
t.Fatal("Expected bounding box to be set")
|
||||
}
|
||||
if elem.BoundingBox.Width != 300 {
|
||||
t.Errorf("Expected width 300, got %f", elem.BoundingBox.Width)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewConsoleMessage_Good verifies the WebviewConsoleMessage struct has expected fields.
|
||||
func TestWebviewConsoleMessage_Good(t *testing.T) {
|
||||
msg := WebviewConsoleMessage{
|
||||
Type: "error",
|
||||
Text: "Failed to load resource",
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
URL: "https://example.com/api/data",
|
||||
Line: 1,
|
||||
}
|
||||
|
||||
if msg.Type != "error" {
|
||||
t.Errorf("Expected type 'error', got %q", msg.Type)
|
||||
}
|
||||
if msg.Text == "" {
|
||||
t.Error("Expected text to be set")
|
||||
}
|
||||
if msg.URL == "" {
|
||||
t.Error("Expected URL to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewDisconnectInput_Good verifies the WebviewDisconnectInput struct exists.
|
||||
func TestWebviewDisconnectInput_Good(t *testing.T) {
|
||||
// WebviewDisconnectInput has no fields
|
||||
input := WebviewDisconnectInput{}
|
||||
_ = input // Just verify the struct exists
|
||||
}
|
||||
|
||||
// TestWebviewDisconnectOutput_Good verifies the WebviewDisconnectOutput struct has expected fields.
|
||||
func TestWebviewDisconnectOutput_Good(t *testing.T) {
|
||||
output := WebviewDisconnectOutput{
|
||||
Success: true,
|
||||
Message: "Disconnected from Chrome DevTools",
|
||||
}
|
||||
|
||||
if !output.Success {
|
||||
t.Error("Expected success to be true")
|
||||
}
|
||||
if output.Message == "" {
|
||||
t.Error("Expected message to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebviewWaitOutput_Good verifies the WebviewWaitOutput struct has expected fields.
|
||||
func TestWebviewWaitOutput_Good(t *testing.T) {
|
||||
output := WebviewWaitOutput{
|
||||
Success: true,
|
||||
Message: "Element found: #login-form",
|
||||
}
|
||||
|
||||
if !output.Success {
|
||||
t.Error("Expected success to be true")
|
||||
}
|
||||
if output.Message == "" {
|
||||
t.Error("Expected message to be set")
|
||||
}
|
||||
}
|
||||
142
mcp/tools_ws.go
142
mcp/tools_ws.go
|
|
@ -1,142 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"forge.lthn.ai/core/go-log"
|
||||
"forge.lthn.ai/core/go-ws"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// WSStartInput contains parameters for starting the WebSocket server.
|
||||
type WSStartInput struct {
|
||||
Addr string `json:"addr,omitempty"` // Address to listen on (default: ":8080")
|
||||
}
|
||||
|
||||
// WSStartOutput contains the result of starting the WebSocket server.
|
||||
type WSStartOutput struct {
|
||||
Success bool `json:"success"`
|
||||
Addr string `json:"addr"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// WSInfoInput contains parameters for getting WebSocket hub info.
|
||||
type WSInfoInput struct{}
|
||||
|
||||
// WSInfoOutput contains WebSocket hub statistics.
|
||||
type WSInfoOutput struct {
|
||||
Clients int `json:"clients"`
|
||||
Channels int `json:"channels"`
|
||||
}
|
||||
|
||||
// registerWSTools adds WebSocket tools to the MCP server.
|
||||
// Returns false if WebSocket hub is not available.
|
||||
func (s *Service) registerWSTools(server *mcp.Server) bool {
|
||||
if s.wsHub == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "ws_start",
|
||||
Description: "Start the WebSocket server for real-time process output streaming.",
|
||||
}, s.wsStart)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "ws_info",
|
||||
Description: "Get WebSocket hub statistics (connected clients and active channels).",
|
||||
}, s.wsInfo)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// wsStart handles the ws_start tool call.
|
||||
func (s *Service) wsStart(ctx context.Context, req *mcp.CallToolRequest, input WSStartInput) (*mcp.CallToolResult, WSStartOutput, error) {
|
||||
addr := input.Addr
|
||||
if addr == "" {
|
||||
addr = ":8080"
|
||||
}
|
||||
|
||||
s.logger.Security("MCP tool execution", "tool", "ws_start", "addr", addr, "user", log.Username())
|
||||
|
||||
// Check if server is already running
|
||||
if s.wsServer != nil {
|
||||
return nil, WSStartOutput{
|
||||
Success: true,
|
||||
Addr: s.wsAddr,
|
||||
Message: "WebSocket server already running",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Create HTTP server with WebSocket handler
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/ws", s.wsHub.Handler())
|
||||
|
||||
server := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// Start listener to get actual address
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
log.Error("mcp: ws start listen failed", "addr", addr, "err", err)
|
||||
return nil, WSStartOutput{}, fmt.Errorf("failed to listen on %s: %w", addr, err)
|
||||
}
|
||||
|
||||
actualAddr := ln.Addr().String()
|
||||
s.wsServer = server
|
||||
s.wsAddr = actualAddr
|
||||
|
||||
// Start server in background
|
||||
go func() {
|
||||
if err := server.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
log.Error("mcp: ws server error", "err", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil, WSStartOutput{
|
||||
Success: true,
|
||||
Addr: actualAddr,
|
||||
Message: fmt.Sprintf("WebSocket server started at ws://%s/ws", actualAddr),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// wsInfo handles the ws_info tool call.
|
||||
func (s *Service) wsInfo(ctx context.Context, req *mcp.CallToolRequest, input WSInfoInput) (*mcp.CallToolResult, WSInfoOutput, error) {
|
||||
s.logger.Info("MCP tool execution", "tool", "ws_info", "user", log.Username())
|
||||
|
||||
stats := s.wsHub.Stats()
|
||||
|
||||
return nil, WSInfoOutput{
|
||||
Clients: stats.Clients,
|
||||
Channels: stats.Channels,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ProcessEventCallback is a callback function for process events.
|
||||
// It can be registered with the process service to forward events to WebSocket.
|
||||
type ProcessEventCallback struct {
|
||||
hub *ws.Hub
|
||||
}
|
||||
|
||||
// NewProcessEventCallback creates a callback that forwards process events to WebSocket.
|
||||
func NewProcessEventCallback(hub *ws.Hub) *ProcessEventCallback {
|
||||
return &ProcessEventCallback{hub: hub}
|
||||
}
|
||||
|
||||
// OnProcessOutput forwards process output to WebSocket subscribers.
|
||||
func (c *ProcessEventCallback) OnProcessOutput(processID string, line string) {
|
||||
if c.hub != nil {
|
||||
_ = c.hub.SendProcessOutput(processID, line)
|
||||
}
|
||||
}
|
||||
|
||||
// OnProcessStatus forwards process status changes to WebSocket subscribers.
|
||||
func (c *ProcessEventCallback) OnProcessStatus(processID string, status string, exitCode int) {
|
||||
if c.hub != nil {
|
||||
_ = c.hub.SendProcessStatus(processID, status, exitCode)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,174 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"forge.lthn.ai/core/go-ws"
|
||||
)
|
||||
|
||||
// TestWSToolsRegistered_Good verifies that WebSocket tools are registered when hub is available.
|
||||
func TestWSToolsRegistered_Good(t *testing.T) {
|
||||
// Create a new MCP service without ws hub - tools should not be registered
|
||||
s, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
if s.wsHub != nil {
|
||||
t.Error("WS hub should be nil by default")
|
||||
}
|
||||
|
||||
if s.server == nil {
|
||||
t.Fatal("Server should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWSStartInput_Good verifies the WSStartInput struct has expected fields.
|
||||
func TestWSStartInput_Good(t *testing.T) {
|
||||
input := WSStartInput{
|
||||
Addr: ":9090",
|
||||
}
|
||||
|
||||
if input.Addr != ":9090" {
|
||||
t.Errorf("Expected addr ':9090', got %q", input.Addr)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWSStartInput_Defaults verifies default values.
|
||||
func TestWSStartInput_Defaults(t *testing.T) {
|
||||
input := WSStartInput{}
|
||||
|
||||
if input.Addr != "" {
|
||||
t.Errorf("Expected addr to default to empty, got %q", input.Addr)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWSStartOutput_Good verifies the WSStartOutput struct has expected fields.
|
||||
func TestWSStartOutput_Good(t *testing.T) {
|
||||
output := WSStartOutput{
|
||||
Success: true,
|
||||
Addr: "127.0.0.1:8080",
|
||||
Message: "WebSocket server started",
|
||||
}
|
||||
|
||||
if !output.Success {
|
||||
t.Error("Expected Success to be true")
|
||||
}
|
||||
if output.Addr != "127.0.0.1:8080" {
|
||||
t.Errorf("Expected addr '127.0.0.1:8080', got %q", output.Addr)
|
||||
}
|
||||
if output.Message != "WebSocket server started" {
|
||||
t.Errorf("Expected message 'WebSocket server started', got %q", output.Message)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWSInfoInput_Good verifies the WSInfoInput struct exists (it's empty).
|
||||
func TestWSInfoInput_Good(t *testing.T) {
|
||||
input := WSInfoInput{}
|
||||
_ = input // Just verify it compiles
|
||||
}
|
||||
|
||||
// TestWSInfoOutput_Good verifies the WSInfoOutput struct has expected fields.
|
||||
func TestWSInfoOutput_Good(t *testing.T) {
|
||||
output := WSInfoOutput{
|
||||
Clients: 5,
|
||||
Channels: 3,
|
||||
}
|
||||
|
||||
if output.Clients != 5 {
|
||||
t.Errorf("Expected clients 5, got %d", output.Clients)
|
||||
}
|
||||
if output.Channels != 3 {
|
||||
t.Errorf("Expected channels 3, got %d", output.Channels)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWithWSHub_Good verifies the WithWSHub option.
|
||||
func TestWithWSHub_Good(t *testing.T) {
|
||||
hub := ws.NewHub()
|
||||
|
||||
s, err := New(WithWSHub(hub))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
if s.wsHub != hub {
|
||||
t.Error("Expected wsHub to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWithWSHub_Nil verifies the WithWSHub option with nil.
|
||||
func TestWithWSHub_Nil(t *testing.T) {
|
||||
s, err := New(WithWSHub(nil))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
if s.wsHub != nil {
|
||||
t.Error("Expected wsHub to be nil when passed nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessEventCallback_Good verifies the ProcessEventCallback struct.
|
||||
func TestProcessEventCallback_Good(t *testing.T) {
|
||||
hub := ws.NewHub()
|
||||
callback := NewProcessEventCallback(hub)
|
||||
|
||||
if callback.hub != hub {
|
||||
t.Error("Expected callback hub to be set")
|
||||
}
|
||||
|
||||
// Test that methods don't panic
|
||||
callback.OnProcessOutput("proc-1", "test output")
|
||||
callback.OnProcessStatus("proc-1", "exited", 0)
|
||||
}
|
||||
|
||||
// TestProcessEventCallback_NilHub verifies the ProcessEventCallback with nil hub doesn't panic.
|
||||
func TestProcessEventCallback_NilHub(t *testing.T) {
|
||||
callback := NewProcessEventCallback(nil)
|
||||
|
||||
if callback.hub != nil {
|
||||
t.Error("Expected callback hub to be nil")
|
||||
}
|
||||
|
||||
// Test that methods don't panic with nil hub
|
||||
callback.OnProcessOutput("proc-1", "test output")
|
||||
callback.OnProcessStatus("proc-1", "exited", 0)
|
||||
}
|
||||
|
||||
// TestServiceWSHub_Good verifies the WSHub getter method.
|
||||
func TestServiceWSHub_Good(t *testing.T) {
|
||||
hub := ws.NewHub()
|
||||
s, err := New(WithWSHub(hub))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
if s.WSHub() != hub {
|
||||
t.Error("Expected WSHub() to return the hub")
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceWSHub_Nil verifies the WSHub getter returns nil when not configured.
|
||||
func TestServiceWSHub_Nil(t *testing.T) {
|
||||
s, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
if s.WSHub() != nil {
|
||||
t.Error("Expected WSHub() to return nil when not configured")
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceProcessService_Nil verifies the ProcessService getter returns nil when not configured.
|
||||
func TestServiceProcessService_Nil(t *testing.T) {
|
||||
s, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
if s.ProcessService() != nil {
|
||||
t.Error("Expected ProcessService() to return nil when not configured")
|
||||
}
|
||||
}
|
||||
|
|
@ -1,742 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"context"
|
||||
)
|
||||
|
||||
// jsonRPCRequest builds a raw JSON-RPC 2.0 request string with newline delimiter.
|
||||
func jsonRPCRequest(id int, method string, params any) string {
|
||||
msg := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"method": method,
|
||||
}
|
||||
if params != nil {
|
||||
msg["params"] = params
|
||||
}
|
||||
data, _ := json.Marshal(msg)
|
||||
return string(data) + "\n"
|
||||
}
|
||||
|
||||
// jsonRPCNotification builds a raw JSON-RPC 2.0 notification (no id).
|
||||
func jsonRPCNotification(method string) string {
|
||||
msg := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
}
|
||||
data, _ := json.Marshal(msg)
|
||||
return string(data) + "\n"
|
||||
}
|
||||
|
||||
// readJSONRPCResponse reads a single line-delimited JSON-RPC response and
|
||||
// returns the decoded map. It handles the case where the server sends a
|
||||
// ping request interleaved with responses (responds to it and keeps reading).
|
||||
func readJSONRPCResponse(t *testing.T, scanner *bufio.Scanner, conn net.Conn) map[string]any {
|
||||
t.Helper()
|
||||
for {
|
||||
if !scanner.Scan() {
|
||||
if err := scanner.Err(); err != nil {
|
||||
t.Fatalf("scanner error: %v", err)
|
||||
}
|
||||
t.Fatal("unexpected EOF reading JSON-RPC response")
|
||||
}
|
||||
line := scanner.Text()
|
||||
var msg map[string]any
|
||||
if err := json.Unmarshal([]byte(line), &msg); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v\nraw: %s", err, line)
|
||||
}
|
||||
|
||||
// If this is a server-initiated request (e.g. ping), respond and keep reading.
|
||||
if method, ok := msg["method"]; ok {
|
||||
if id, hasID := msg["id"]; hasID {
|
||||
resp := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": map[string]any{},
|
||||
}
|
||||
data, _ := json.Marshal(resp)
|
||||
_, _ = conn.Write(append(data, '\n'))
|
||||
_ = method // consume
|
||||
continue
|
||||
}
|
||||
// Notification from server — ignore and keep reading
|
||||
continue
|
||||
}
|
||||
|
||||
return msg
|
||||
}
|
||||
}
|
||||
|
||||
// --- TCP E2E Tests ---
|
||||
|
||||
func TestTCPTransport_E2E_FullRoundTrip(t *testing.T) {
|
||||
// Create a temp workspace with a known file
|
||||
tmpDir := t.TempDir()
|
||||
testContent := "hello from tcp e2e test"
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "test.txt"), []byte(testContent), 0644); err != nil {
|
||||
t.Fatalf("Failed to create test file: %v", err)
|
||||
}
|
||||
|
||||
s, err := New(WithWorkspaceRoot(tmpDir))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Start TCP server on a random port
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- s.ServeTCP(ctx, "127.0.0.1:0")
|
||||
}()
|
||||
|
||||
// Wait for the server to start and get the actual address.
|
||||
// ServeTCP creates its own listener internally, so we need to probe.
|
||||
// We'll retry connecting for up to 2 seconds.
|
||||
var conn net.Conn
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
// Since ServeTCP binds :0, we can't predict the port. Instead, create
|
||||
// our own listener to find a free port, close it, then pass that port
|
||||
// to ServeTCP. This is a known race, but fine for tests.
|
||||
cancel()
|
||||
<-errCh
|
||||
|
||||
// Restart with a known port: find a free port first
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to find free port: %v", err)
|
||||
}
|
||||
addr := ln.Addr().String()
|
||||
ln.Close()
|
||||
|
||||
ctx2, cancel2 := context.WithCancel(context.Background())
|
||||
defer cancel2()
|
||||
|
||||
errCh2 := make(chan error, 1)
|
||||
go func() {
|
||||
errCh2 <- s.ServeTCP(ctx2, addr)
|
||||
}()
|
||||
|
||||
// Wait for server to accept connections
|
||||
deadline = time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
conn, err = net.DialTimeout("tcp", addr, 200*time.Millisecond)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to TCP server at %s: %v", addr, err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Set a read deadline to avoid hanging
|
||||
conn.SetDeadline(time.Now().Add(10 * time.Second))
|
||||
|
||||
scanner := bufio.NewScanner(conn)
|
||||
scanner.Buffer(make([]byte, 64*1024), 10*1024*1024)
|
||||
|
||||
// Step 1: Send initialize request
|
||||
initReq := jsonRPCRequest(1, "initialize", map[string]any{
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": map[string]any{},
|
||||
"clientInfo": map[string]any{"name": "TestClient", "version": "1.0.0"},
|
||||
})
|
||||
if _, err := conn.Write([]byte(initReq)); err != nil {
|
||||
t.Fatalf("Failed to send initialize: %v", err)
|
||||
}
|
||||
|
||||
// Read initialize response
|
||||
initResp := readJSONRPCResponse(t, scanner, conn)
|
||||
if initResp["error"] != nil {
|
||||
t.Fatalf("Initialize returned error: %v", initResp["error"])
|
||||
}
|
||||
result, ok := initResp["result"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected result object, got %T", initResp["result"])
|
||||
}
|
||||
serverInfo, _ := result["serverInfo"].(map[string]any)
|
||||
if serverInfo["name"] != "core-cli" {
|
||||
t.Errorf("Expected server name 'core-cli', got %v", serverInfo["name"])
|
||||
}
|
||||
|
||||
// Step 2: Send notifications/initialized
|
||||
if _, err := conn.Write([]byte(jsonRPCNotification("notifications/initialized"))); err != nil {
|
||||
t.Fatalf("Failed to send initialized notification: %v", err)
|
||||
}
|
||||
|
||||
// Step 3: Send tools/list
|
||||
if _, err := conn.Write([]byte(jsonRPCRequest(2, "tools/list", nil))); err != nil {
|
||||
t.Fatalf("Failed to send tools/list: %v", err)
|
||||
}
|
||||
|
||||
toolsResp := readJSONRPCResponse(t, scanner, conn)
|
||||
if toolsResp["error"] != nil {
|
||||
t.Fatalf("tools/list returned error: %v", toolsResp["error"])
|
||||
}
|
||||
|
||||
toolsResult, ok := toolsResp["result"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected result object for tools/list, got %T", toolsResp["result"])
|
||||
}
|
||||
tools, ok := toolsResult["tools"].([]any)
|
||||
if !ok || len(tools) == 0 {
|
||||
t.Fatal("Expected non-empty tools list")
|
||||
}
|
||||
|
||||
// Verify file_read is among the tools
|
||||
foundFileRead := false
|
||||
for _, tool := range tools {
|
||||
toolMap, _ := tool.(map[string]any)
|
||||
if toolMap["name"] == "file_read" {
|
||||
foundFileRead = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundFileRead {
|
||||
t.Error("Expected file_read tool in tools/list response")
|
||||
}
|
||||
|
||||
// Step 4: Call file_read
|
||||
callReq := jsonRPCRequest(3, "tools/call", map[string]any{
|
||||
"name": "file_read",
|
||||
"arguments": map[string]any{"path": "test.txt"},
|
||||
})
|
||||
if _, err := conn.Write([]byte(callReq)); err != nil {
|
||||
t.Fatalf("Failed to send tools/call: %v", err)
|
||||
}
|
||||
|
||||
callResp := readJSONRPCResponse(t, scanner, conn)
|
||||
if callResp["error"] != nil {
|
||||
t.Fatalf("tools/call file_read returned error: %v", callResp["error"])
|
||||
}
|
||||
|
||||
callResult, ok := callResp["result"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected result object for tools/call, got %T", callResp["result"])
|
||||
}
|
||||
|
||||
// The MCP SDK wraps tool results in content array
|
||||
content, ok := callResult["content"].([]any)
|
||||
if !ok || len(content) == 0 {
|
||||
t.Fatal("Expected non-empty content in tools/call response")
|
||||
}
|
||||
|
||||
firstContent, _ := content[0].(map[string]any)
|
||||
text, _ := firstContent["text"].(string)
|
||||
if !strings.Contains(text, testContent) {
|
||||
t.Errorf("Expected file content to contain %q, got %q", testContent, text)
|
||||
}
|
||||
|
||||
// Graceful shutdown
|
||||
cancel2()
|
||||
err = <-errCh2
|
||||
if err != nil {
|
||||
t.Errorf("ServeTCP returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCPTransport_E2E_FileWrite(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
s, err := New(WithWorkspaceRoot(tmpDir))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Find free port
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to find free port: %v", err)
|
||||
}
|
||||
addr := ln.Addr().String()
|
||||
ln.Close()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- s.ServeTCP(ctx, addr)
|
||||
}()
|
||||
|
||||
// Connect
|
||||
var conn net.Conn
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
conn, err = net.DialTimeout("tcp", addr, 200*time.Millisecond)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetDeadline(time.Now().Add(10 * time.Second))
|
||||
scanner := bufio.NewScanner(conn)
|
||||
scanner.Buffer(make([]byte, 64*1024), 10*1024*1024)
|
||||
|
||||
// Initialize handshake
|
||||
conn.Write([]byte(jsonRPCRequest(1, "initialize", map[string]any{
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": map[string]any{},
|
||||
"clientInfo": map[string]any{"name": "TestClient", "version": "1.0.0"},
|
||||
})))
|
||||
readJSONRPCResponse(t, scanner, conn)
|
||||
conn.Write([]byte(jsonRPCNotification("notifications/initialized")))
|
||||
|
||||
// Write a file
|
||||
writeContent := "written via tcp transport"
|
||||
conn.Write([]byte(jsonRPCRequest(2, "tools/call", map[string]any{
|
||||
"name": "file_write",
|
||||
"arguments": map[string]any{"path": "tcp-written.txt", "content": writeContent},
|
||||
})))
|
||||
writeResp := readJSONRPCResponse(t, scanner, conn)
|
||||
if writeResp["error"] != nil {
|
||||
t.Fatalf("file_write returned error: %v", writeResp["error"])
|
||||
}
|
||||
|
||||
// Verify file on disk
|
||||
diskContent, err := os.ReadFile(filepath.Join(tmpDir, "tcp-written.txt"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read written file: %v", err)
|
||||
}
|
||||
if string(diskContent) != writeContent {
|
||||
t.Errorf("Expected %q on disk, got %q", writeContent, string(diskContent))
|
||||
}
|
||||
|
||||
cancel()
|
||||
<-errCh
|
||||
}
|
||||
|
||||
// --- Unix Socket E2E Tests ---
|
||||
|
||||
// shortSocketPath returns a Unix socket path under /tmp that fits within
|
||||
// the macOS 104-byte sun_path limit. t.TempDir() paths on macOS are
|
||||
// often too long (>104 bytes) for Unix sockets.
|
||||
func shortSocketPath(t *testing.T, suffix string) string {
|
||||
t.Helper()
|
||||
path := fmt.Sprintf("/tmp/mcp-test-%s-%d.sock", suffix, os.Getpid())
|
||||
t.Cleanup(func() { os.Remove(path) })
|
||||
return path
|
||||
}
|
||||
|
||||
func TestUnixTransport_E2E_FullRoundTrip(t *testing.T) {
|
||||
// Create a temp workspace with a known file
|
||||
tmpDir := t.TempDir()
|
||||
testContent := "hello from unix e2e test"
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "test.txt"), []byte(testContent), 0644); err != nil {
|
||||
t.Fatalf("Failed to create test file: %v", err)
|
||||
}
|
||||
|
||||
s, err := New(WithWorkspaceRoot(tmpDir))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Use a short socket path to avoid macOS 104-byte sun_path limit
|
||||
socketPath := shortSocketPath(t, "full")
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- s.ServeUnix(ctx, socketPath)
|
||||
}()
|
||||
|
||||
// Wait for socket to appear
|
||||
var conn net.Conn
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
conn, err = net.DialTimeout("unix", socketPath, 200*time.Millisecond)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to Unix socket at %s: %v", socketPath, err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetDeadline(time.Now().Add(10 * time.Second))
|
||||
scanner := bufio.NewScanner(conn)
|
||||
scanner.Buffer(make([]byte, 64*1024), 10*1024*1024)
|
||||
|
||||
// Step 1: Initialize
|
||||
conn.Write([]byte(jsonRPCRequest(1, "initialize", map[string]any{
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": map[string]any{},
|
||||
"clientInfo": map[string]any{"name": "TestClient", "version": "1.0.0"},
|
||||
})))
|
||||
initResp := readJSONRPCResponse(t, scanner, conn)
|
||||
if initResp["error"] != nil {
|
||||
t.Fatalf("Initialize returned error: %v", initResp["error"])
|
||||
}
|
||||
|
||||
// Step 2: Send initialized notification
|
||||
conn.Write([]byte(jsonRPCNotification("notifications/initialized")))
|
||||
|
||||
// Step 3: tools/list
|
||||
conn.Write([]byte(jsonRPCRequest(2, "tools/list", nil)))
|
||||
toolsResp := readJSONRPCResponse(t, scanner, conn)
|
||||
if toolsResp["error"] != nil {
|
||||
t.Fatalf("tools/list returned error: %v", toolsResp["error"])
|
||||
}
|
||||
|
||||
toolsResult, ok := toolsResp["result"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected result object, got %T", toolsResp["result"])
|
||||
}
|
||||
tools, ok := toolsResult["tools"].([]any)
|
||||
if !ok || len(tools) == 0 {
|
||||
t.Fatal("Expected non-empty tools list")
|
||||
}
|
||||
|
||||
// Step 4: Call file_read
|
||||
conn.Write([]byte(jsonRPCRequest(3, "tools/call", map[string]any{
|
||||
"name": "file_read",
|
||||
"arguments": map[string]any{"path": "test.txt"},
|
||||
})))
|
||||
callResp := readJSONRPCResponse(t, scanner, conn)
|
||||
if callResp["error"] != nil {
|
||||
t.Fatalf("tools/call file_read returned error: %v", callResp["error"])
|
||||
}
|
||||
|
||||
callResult, ok := callResp["result"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected result object, got %T", callResp["result"])
|
||||
}
|
||||
content, ok := callResult["content"].([]any)
|
||||
if !ok || len(content) == 0 {
|
||||
t.Fatal("Expected non-empty content")
|
||||
}
|
||||
|
||||
firstContent, _ := content[0].(map[string]any)
|
||||
text, _ := firstContent["text"].(string)
|
||||
if !strings.Contains(text, testContent) {
|
||||
t.Errorf("Expected content to contain %q, got %q", testContent, text)
|
||||
}
|
||||
|
||||
// Graceful shutdown
|
||||
cancel()
|
||||
err = <-errCh
|
||||
if err != nil {
|
||||
t.Errorf("ServeUnix returned error: %v", err)
|
||||
}
|
||||
|
||||
// Verify socket file is cleaned up
|
||||
if _, err := os.Stat(socketPath); !os.IsNotExist(err) {
|
||||
t.Error("Expected socket file to be cleaned up after shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnixTransport_E2E_DirList(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create some files and dirs
|
||||
os.MkdirAll(filepath.Join(tmpDir, "subdir"), 0755)
|
||||
os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("one"), 0644)
|
||||
os.WriteFile(filepath.Join(tmpDir, "subdir", "file2.txt"), []byte("two"), 0644)
|
||||
|
||||
s, err := New(WithWorkspaceRoot(tmpDir))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
socketPath := shortSocketPath(t, "dir")
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- s.ServeUnix(ctx, socketPath)
|
||||
}()
|
||||
|
||||
var conn net.Conn
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
conn, err = net.DialTimeout("unix", socketPath, 200*time.Millisecond)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetDeadline(time.Now().Add(10 * time.Second))
|
||||
scanner := bufio.NewScanner(conn)
|
||||
scanner.Buffer(make([]byte, 64*1024), 10*1024*1024)
|
||||
|
||||
// Initialize
|
||||
conn.Write([]byte(jsonRPCRequest(1, "initialize", map[string]any{
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": map[string]any{},
|
||||
"clientInfo": map[string]any{"name": "TestClient", "version": "1.0.0"},
|
||||
})))
|
||||
readJSONRPCResponse(t, scanner, conn)
|
||||
conn.Write([]byte(jsonRPCNotification("notifications/initialized")))
|
||||
|
||||
// Call dir_list on root
|
||||
conn.Write([]byte(jsonRPCRequest(2, "tools/call", map[string]any{
|
||||
"name": "dir_list",
|
||||
"arguments": map[string]any{"path": "."},
|
||||
})))
|
||||
dirResp := readJSONRPCResponse(t, scanner, conn)
|
||||
if dirResp["error"] != nil {
|
||||
t.Fatalf("dir_list returned error: %v", dirResp["error"])
|
||||
}
|
||||
|
||||
dirResult, ok := dirResp["result"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected result object, got %T", dirResp["result"])
|
||||
}
|
||||
dirContent, ok := dirResult["content"].([]any)
|
||||
if !ok || len(dirContent) == 0 {
|
||||
t.Fatal("Expected non-empty content in dir_list response")
|
||||
}
|
||||
|
||||
// The response content should mention our files
|
||||
firstItem, _ := dirContent[0].(map[string]any)
|
||||
text, _ := firstItem["text"].(string)
|
||||
if !strings.Contains(text, "file1.txt") && !strings.Contains(text, "subdir") {
|
||||
t.Errorf("Expected dir_list to contain file1.txt or subdir, got: %s", text)
|
||||
}
|
||||
|
||||
cancel()
|
||||
<-errCh
|
||||
}
|
||||
|
||||
// --- Stdio Transport Tests ---
|
||||
|
||||
func TestStdioTransport_Documented_Skip(t *testing.T) {
|
||||
// The MCP SDK's StdioTransport binds directly to os.Stdin and os.Stdout,
|
||||
// with no way to inject custom io.Reader/io.Writer. Testing stdio transport
|
||||
// would require spawning the binary as a subprocess and piping JSON-RPC
|
||||
// messages through its stdin/stdout.
|
||||
//
|
||||
// Since the core MCP protocol handling is identical across all transports
|
||||
// (the transport layer only handles framing), and we thoroughly test the
|
||||
// protocol via TCP and Unix socket e2e tests, the stdio transport is
|
||||
// effectively covered. The only untested code path is the StdioTransport
|
||||
// adapter itself, which is a thin wrapper in the upstream SDK.
|
||||
//
|
||||
// If process-level testing is needed in the future, the approach would be:
|
||||
// 1. Build the binary: `go build -o /tmp/mcp-test ./cmd/...`
|
||||
// 2. Spawn it: exec.Command("/tmp/mcp-test", "mcp", "serve")
|
||||
// 3. Pipe JSON-RPC to stdin, read from stdout
|
||||
// 4. Verify responses match expected protocol
|
||||
t.Skip("stdio transport requires process spawning; protocol is covered by TCP and Unix e2e tests")
|
||||
}
|
||||
|
||||
// --- Helper: verify a specific tool exists in tools/list response ---
|
||||
|
||||
func assertToolExists(t *testing.T, tools []any, name string) {
|
||||
t.Helper()
|
||||
for _, tool := range tools {
|
||||
toolMap, _ := tool.(map[string]any)
|
||||
if toolMap["name"] == name {
|
||||
return
|
||||
}
|
||||
}
|
||||
toolNames := make([]string, 0, len(tools))
|
||||
for _, tool := range tools {
|
||||
toolMap, _ := tool.(map[string]any)
|
||||
if n, ok := toolMap["name"].(string); ok {
|
||||
toolNames = append(toolNames, n)
|
||||
}
|
||||
}
|
||||
t.Errorf("Expected tool %q in list, got: %v", name, toolNames)
|
||||
}
|
||||
|
||||
func TestTCPTransport_E2E_ToolsDiscovery(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
s, err := New(WithWorkspaceRoot(tmpDir))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to find free port: %v", err)
|
||||
}
|
||||
addr := ln.Addr().String()
|
||||
ln.Close()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- s.ServeTCP(ctx, addr)
|
||||
}()
|
||||
|
||||
var conn net.Conn
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
conn, err = net.DialTimeout("tcp", addr, 200*time.Millisecond)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetDeadline(time.Now().Add(10 * time.Second))
|
||||
scanner := bufio.NewScanner(conn)
|
||||
scanner.Buffer(make([]byte, 64*1024), 10*1024*1024)
|
||||
|
||||
// Initialize
|
||||
conn.Write([]byte(jsonRPCRequest(1, "initialize", map[string]any{
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": map[string]any{},
|
||||
"clientInfo": map[string]any{"name": "TestClient", "version": "1.0.0"},
|
||||
})))
|
||||
readJSONRPCResponse(t, scanner, conn)
|
||||
conn.Write([]byte(jsonRPCNotification("notifications/initialized")))
|
||||
|
||||
// Get tools list
|
||||
conn.Write([]byte(jsonRPCRequest(2, "tools/list", nil)))
|
||||
toolsResp := readJSONRPCResponse(t, scanner, conn)
|
||||
if toolsResp["error"] != nil {
|
||||
t.Fatalf("tools/list error: %v", toolsResp["error"])
|
||||
}
|
||||
toolsResult, _ := toolsResp["result"].(map[string]any)
|
||||
tools, _ := toolsResult["tools"].([]any)
|
||||
|
||||
// Verify all core tools are registered
|
||||
expectedTools := []string{
|
||||
"file_read", "file_write", "file_delete", "file_rename",
|
||||
"file_exists", "file_edit", "dir_list", "dir_create",
|
||||
"lang_detect", "lang_list",
|
||||
}
|
||||
for _, name := range expectedTools {
|
||||
assertToolExists(t, tools, name)
|
||||
}
|
||||
|
||||
// Log total tool count for visibility
|
||||
t.Logf("Server registered %d tools", len(tools))
|
||||
|
||||
cancel()
|
||||
<-errCh
|
||||
}
|
||||
|
||||
func TestTCPTransport_E2E_ErrorHandling(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
s, err := New(WithWorkspaceRoot(tmpDir))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to find free port: %v", err)
|
||||
}
|
||||
addr := ln.Addr().String()
|
||||
ln.Close()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- s.ServeTCP(ctx, addr)
|
||||
}()
|
||||
|
||||
var conn net.Conn
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
conn, err = net.DialTimeout("tcp", addr, 200*time.Millisecond)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetDeadline(time.Now().Add(10 * time.Second))
|
||||
scanner := bufio.NewScanner(conn)
|
||||
scanner.Buffer(make([]byte, 64*1024), 10*1024*1024)
|
||||
|
||||
// Initialize
|
||||
conn.Write([]byte(jsonRPCRequest(1, "initialize", map[string]any{
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": map[string]any{},
|
||||
"clientInfo": map[string]any{"name": "TestClient", "version": "1.0.0"},
|
||||
})))
|
||||
readJSONRPCResponse(t, scanner, conn)
|
||||
conn.Write([]byte(jsonRPCNotification("notifications/initialized")))
|
||||
|
||||
// Try to read a nonexistent file
|
||||
conn.Write([]byte(jsonRPCRequest(2, "tools/call", map[string]any{
|
||||
"name": "file_read",
|
||||
"arguments": map[string]any{"path": "nonexistent.txt"},
|
||||
})))
|
||||
errResp := readJSONRPCResponse(t, scanner, conn)
|
||||
|
||||
// The MCP SDK wraps tool errors as isError content, not JSON-RPC errors.
|
||||
// Check both possibilities.
|
||||
if errResp["error"] != nil {
|
||||
// JSON-RPC level error — this is acceptable
|
||||
t.Logf("Got JSON-RPC error for nonexistent file: %v", errResp["error"])
|
||||
} else {
|
||||
errResult, _ := errResp["result"].(map[string]any)
|
||||
isError, _ := errResult["isError"].(bool)
|
||||
if !isError {
|
||||
// Check content for error indicator
|
||||
content, _ := errResult["content"].([]any)
|
||||
if len(content) > 0 {
|
||||
firstContent, _ := content[0].(map[string]any)
|
||||
text, _ := firstContent["text"].(string)
|
||||
t.Logf("Tool response for nonexistent file: %s", text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify tools/call without params returns an error
|
||||
conn.Write([]byte(jsonRPCRequest(3, "tools/call", nil)))
|
||||
noParamsResp := readJSONRPCResponse(t, scanner, conn)
|
||||
if noParamsResp["error"] == nil {
|
||||
t.Log("tools/call without params did not return JSON-RPC error (SDK may handle differently)")
|
||||
} else {
|
||||
errObj, _ := noParamsResp["error"].(map[string]any)
|
||||
code, _ := errObj["code"].(float64)
|
||||
if code != -32600 {
|
||||
t.Logf("tools/call without params returned error code: %v", code)
|
||||
}
|
||||
}
|
||||
|
||||
cancel()
|
||||
<-errCh
|
||||
}
|
||||
|
||||
// Suppress "unused import" for fmt — used in helpers
|
||||
var _ = fmt.Sprintf
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"forge.lthn.ai/core/go-log"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// ServeStdio starts the MCP server over stdin/stdout.
|
||||
// This is the default transport for CLI integrations.
|
||||
func (s *Service) ServeStdio(ctx context.Context) error {
|
||||
s.logger.Info("MCP Stdio server starting", "user", log.Username())
|
||||
return s.server.Run(ctx, &mcp.StdioTransport{})
|
||||
}
|
||||
|
|
@ -1,177 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
)
|
||||
|
||||
// DefaultTCPAddr is the default address for the MCP TCP server.
|
||||
const DefaultTCPAddr = "127.0.0.1:9100"
|
||||
|
||||
// diagMu protects diagWriter from concurrent access across tests and goroutines.
|
||||
var diagMu sync.Mutex
|
||||
|
||||
// diagWriter is the destination for warning and diagnostic messages.
|
||||
// Use diagPrintf to write to it safely.
|
||||
var diagWriter io.Writer = os.Stderr
|
||||
|
||||
// diagPrintf writes a formatted message to diagWriter under the mutex.
|
||||
func diagPrintf(format string, args ...any) {
|
||||
diagMu.Lock()
|
||||
defer diagMu.Unlock()
|
||||
fmt.Fprintf(diagWriter, format, args...)
|
||||
}
|
||||
|
||||
// setDiagWriter swaps the diagnostic writer and returns the previous one.
|
||||
// Used by tests to capture output without racing.
|
||||
func setDiagWriter(w io.Writer) io.Writer {
|
||||
diagMu.Lock()
|
||||
defer diagMu.Unlock()
|
||||
old := diagWriter
|
||||
diagWriter = w
|
||||
return old
|
||||
}
|
||||
|
||||
// maxMCPMessageSize is the maximum size for MCP JSON-RPC messages (10 MB).
|
||||
const maxMCPMessageSize = 10 * 1024 * 1024
|
||||
|
||||
// TCPTransport manages a TCP listener for MCP.
|
||||
type TCPTransport struct {
|
||||
addr string
|
||||
listener net.Listener
|
||||
}
|
||||
|
||||
// NewTCPTransport creates a new TCP transport listener.
|
||||
// It listens on the provided address (e.g. "localhost:9100").
|
||||
// Defaults to 127.0.0.1 when the host component is empty (e.g. ":9100").
|
||||
// Emits a security warning when explicitly binding to 0.0.0.0 (all interfaces).
|
||||
func NewTCPTransport(addr string) (*TCPTransport, error) {
|
||||
host, port, _ := net.SplitHostPort(addr)
|
||||
if host == "" {
|
||||
addr = net.JoinHostPort("127.0.0.1", port)
|
||||
} else if host == "0.0.0.0" {
|
||||
diagPrintf("WARNING: MCP TCP server binding to all interfaces (%s). Use 127.0.0.1 for local-only access.\n", addr)
|
||||
}
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TCPTransport{addr: addr, listener: listener}, nil
|
||||
}
|
||||
|
||||
// ServeTCP starts a TCP server for the MCP service.
|
||||
// It accepts connections and spawns a new MCP server session for each connection.
|
||||
func (s *Service) ServeTCP(ctx context.Context, addr string) error {
|
||||
t, err := NewTCPTransport(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = t.listener.Close() }()
|
||||
|
||||
// Close listener when context is cancelled to unblock Accept
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = t.listener.Close()
|
||||
}()
|
||||
|
||||
if addr == "" {
|
||||
addr = t.listener.Addr().String()
|
||||
}
|
||||
diagPrintf("MCP TCP server listening on %s\n", addr)
|
||||
|
||||
for {
|
||||
conn, err := t.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
diagPrintf("Accept error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
go s.handleConnection(ctx, conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
// Note: We don't defer conn.Close() here because it's closed by the Server/Transport
|
||||
|
||||
// Create new server instance for this connection
|
||||
impl := &mcp.Implementation{
|
||||
Name: "core-cli",
|
||||
Version: "0.1.0",
|
||||
}
|
||||
server := mcp.NewServer(impl, nil)
|
||||
s.registerTools(server)
|
||||
|
||||
// Create transport for this connection
|
||||
transport := &connTransport{conn: conn}
|
||||
|
||||
// Run server (blocks until connection closed)
|
||||
// Server.Run calls Connect, then Read loop.
|
||||
if err := server.Run(ctx, transport); err != nil {
|
||||
diagPrintf("Connection error: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// connTransport adapts net.Conn to mcp.Transport
|
||||
type connTransport struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (t *connTransport) Connect(ctx context.Context) (mcp.Connection, error) {
|
||||
scanner := bufio.NewScanner(t.conn)
|
||||
scanner.Buffer(make([]byte, 64*1024), maxMCPMessageSize)
|
||||
return &connConnection{
|
||||
conn: t.conn,
|
||||
scanner: scanner,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// connConnection implements mcp.Connection
|
||||
type connConnection struct {
|
||||
conn net.Conn
|
||||
scanner *bufio.Scanner
|
||||
}
|
||||
|
||||
func (c *connConnection) Read(ctx context.Context) (jsonrpc.Message, error) {
|
||||
// Blocks until line is read
|
||||
if !c.scanner.Scan() {
|
||||
if err := c.scanner.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// EOF - connection closed cleanly
|
||||
return nil, io.EOF
|
||||
}
|
||||
line := c.scanner.Bytes()
|
||||
return jsonrpc.DecodeMessage(line)
|
||||
}
|
||||
|
||||
func (c *connConnection) Write(ctx context.Context, msg jsonrpc.Message) error {
|
||||
data, err := jsonrpc.EncodeMessage(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Append newline for line-delimited JSON
|
||||
data = append(data, '\n')
|
||||
_, err = c.conn.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *connConnection) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *connConnection) SessionID() string {
|
||||
return "tcp-session" // Unique ID might be better, but optional
|
||||
}
|
||||
|
|
@ -1,184 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewTCPTransport_Defaults(t *testing.T) {
|
||||
// Test that empty string gets replaced with default address constant
|
||||
// Note: We can't actually bind to 9100 as it may be in use,
|
||||
// so we verify the address is set correctly before Listen is called
|
||||
if DefaultTCPAddr != "127.0.0.1:9100" {
|
||||
t.Errorf("Expected default constant 127.0.0.1:9100, got %s", DefaultTCPAddr)
|
||||
}
|
||||
|
||||
// Test with a dynamic port to verify transport creation works
|
||||
tr, err := NewTCPTransport("127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create transport with dynamic port: %v", err)
|
||||
}
|
||||
defer tr.listener.Close()
|
||||
|
||||
// Verify we got a valid address
|
||||
if tr.addr != "127.0.0.1:0" {
|
||||
t.Errorf("Expected address to be set, got %s", tr.addr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTCPTransport_Warning(t *testing.T) {
|
||||
// Capture warning output via setDiagWriter (mutex-protected, no race).
|
||||
var buf bytes.Buffer
|
||||
old := setDiagWriter(&buf)
|
||||
defer setDiagWriter(old)
|
||||
|
||||
// Trigger warning
|
||||
tr, err := NewTCPTransport("0.0.0.0:9101")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create transport: %v", err)
|
||||
}
|
||||
defer tr.listener.Close()
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "WARNING") {
|
||||
t.Error("Expected warning for binding to 0.0.0.0, but didn't find it in stderr")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeTCP_Connection(t *testing.T) {
|
||||
s, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Use a random port for testing to avoid collisions
|
||||
addr := "127.0.0.1:0"
|
||||
|
||||
// Create transport first to get the actual address if we use :0
|
||||
tr, err := NewTCPTransport(addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create transport: %v", err)
|
||||
}
|
||||
actualAddr := tr.listener.Addr().String()
|
||||
tr.listener.Close() // Close it so ServeTCP can re-open it or use the same address
|
||||
|
||||
// Start server in background
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- s.ServeTCP(ctx, actualAddr)
|
||||
}()
|
||||
|
||||
// Give it a moment to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Connect to the server
|
||||
conn, err := net.Dial("tcp", actualAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to server: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Verify we can write to it
|
||||
_, err = conn.Write([]byte("{}\n"))
|
||||
if err != nil {
|
||||
t.Errorf("Failed to write to connection: %v", err)
|
||||
}
|
||||
|
||||
// Shutdown server
|
||||
cancel()
|
||||
err = <-errCh
|
||||
if err != nil {
|
||||
t.Errorf("ServeTCP returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_TCPTrigger(t *testing.T) {
|
||||
s, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Set MCP_ADDR to empty to trigger default TCP
|
||||
os.Setenv("MCP_ADDR", "")
|
||||
defer os.Unsetenv("MCP_ADDR")
|
||||
|
||||
// We use a random port for testing, but Run will try to use 127.0.0.1:9100 by default if we don't override.
|
||||
// Since 9100 might be in use, we'll set MCP_ADDR to use :0 (random port)
|
||||
os.Setenv("MCP_ADDR", "127.0.0.1:0")
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- s.Run(ctx)
|
||||
}()
|
||||
|
||||
// Give it a moment to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Since we can't easily get the actual port used by Run (it's internal),
|
||||
// we just verify it didn't immediately fail.
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatalf("Run failed immediately: %v", err)
|
||||
default:
|
||||
// still running, which is good
|
||||
}
|
||||
|
||||
cancel()
|
||||
_ = <-errCh
|
||||
}
|
||||
|
||||
func TestServeTCP_MultipleConnections(t *testing.T) {
|
||||
s, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
addr := "127.0.0.1:0"
|
||||
tr, err := NewTCPTransport(addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create transport: %v", err)
|
||||
}
|
||||
actualAddr := tr.listener.Addr().String()
|
||||
tr.listener.Close()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- s.ServeTCP(ctx, actualAddr)
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Connect multiple clients
|
||||
const numClients = 3
|
||||
for i := range numClients {
|
||||
conn, err := net.Dial("tcp", actualAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("Client %d failed to connect: %v", i, err)
|
||||
}
|
||||
defer conn.Close()
|
||||
_, err = conn.Write([]byte("{}\n"))
|
||||
if err != nil {
|
||||
t.Errorf("Client %d failed to write: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
cancel()
|
||||
err = <-errCh
|
||||
if err != nil {
|
||||
t.Errorf("ServeTCP returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,52 +0,0 @@
|
|||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"forge.lthn.ai/core/go-log"
|
||||
)
|
||||
|
||||
// ServeUnix starts a Unix domain socket server for the MCP service.
|
||||
// The socket file is created at the given path and removed on shutdown.
|
||||
// It accepts connections and spawns a new MCP server session for each connection.
|
||||
func (s *Service) ServeUnix(ctx context.Context, socketPath string) error {
|
||||
// Clean up any stale socket file
|
||||
if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
|
||||
s.logger.Warn("Failed to remove stale socket", "path", socketPath, "err", err)
|
||||
}
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = listener.Close()
|
||||
_ = os.Remove(socketPath)
|
||||
}()
|
||||
|
||||
// Close listener when context is cancelled to unblock Accept
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = listener.Close()
|
||||
}()
|
||||
|
||||
s.logger.Security("MCP Unix server listening", "path", socketPath, "user", log.Username())
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
s.logger.Error("MCP Unix accept error", "err", err, "user", log.Username())
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Security("MCP Unix connection accepted", "user", log.Username())
|
||||
go s.handleConnection(ctx, conn)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue