Compare commits
42 commits
agent/read
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c2a1c4e007 | ||
|
|
70e0c51dc5 | ||
|
|
7d32ed8661 | ||
|
|
cdb4bdbc45 | ||
|
|
e73deea84b | ||
|
|
cb62378a2b | ||
|
|
1e3a5996fa | ||
|
|
2acf186925 | ||
|
|
e2bc724bb4 | ||
|
|
8b48b33622 | ||
|
|
44206708f9 | ||
|
|
2910a0d588 | ||
|
|
494e05ada4 | ||
|
|
b29b8b5685 | ||
|
|
53bd7478a7 | ||
|
|
9fd2185a86 | ||
|
|
95f8ad387c | ||
|
|
b2ed228b3f | ||
|
|
903aba4695 | ||
|
|
96fd169239 | ||
|
|
633f295244 | ||
|
|
09f786fb80 | ||
|
|
f26ae14222 | ||
|
|
62c1949458 | ||
|
|
e5caa8d32e | ||
|
|
982a3b4b00 | ||
|
|
91803e32df | ||
|
|
65b686283f | ||
|
|
cbab302661 | ||
|
|
91297d733d | ||
|
|
ca07b6cd62 | ||
|
|
92727025e7 | ||
|
|
d8144fde09 | ||
|
|
f9c5362151 | ||
|
|
8f3afaa42a | ||
|
|
7fde0c1c21 | ||
|
|
f8f137b465 | ||
|
|
429f1c2b6c | ||
|
|
9f7dd84d4a | ||
|
|
9bd3084da4 | ||
|
|
20e4a381cf | ||
|
|
cd305904e5 |
80 changed files with 6666 additions and 831 deletions
|
|
@ -1,40 +1,38 @@
|
||||||
// SPDX-License-Identifier: EUPL-1.2
|
// SPDX-License-Identifier: EUPL-1.2
|
||||||
|
|
||||||
// brain-seed imports Claude Code MEMORY.md files into the OpenBrain knowledge
|
// brain-seed imports Claude Code MEMORY.md files into the OpenBrain knowledge
|
||||||
// store via the MCP HTTP API (brain_remember tool). The Laravel app handles
|
// store via the shared OpenBrain HTTP client. The Laravel app handles
|
||||||
// embedding, Qdrant storage, and MariaDB dual-write internally.
|
// embedding, Qdrant storage, and MariaDB dual-write internally.
|
||||||
//
|
//
|
||||||
// Usage:
|
// Usage:
|
||||||
//
|
//
|
||||||
// go run ./cmd/brain-seed -api-key YOUR_KEY
|
// go run ./cmd/brain-seed -api-key YOUR_KEY
|
||||||
// go run ./cmd/brain-seed -api-key YOUR_KEY -api https://lthn.sh/api/v1/mcp
|
// go run ./cmd/brain-seed -api-key YOUR_KEY -api https://api.lthn.sh
|
||||||
// go run ./cmd/brain-seed -api-key YOUR_KEY -dry-run
|
// go run ./cmd/brain-seed -api-key YOUR_KEY -dry-run
|
||||||
// go run ./cmd/brain-seed -api-key YOUR_KEY -plans
|
// go run ./cmd/brain-seed -api-key YOUR_KEY -plans
|
||||||
// go run ./cmd/brain-seed -api-key YOUR_KEY -claude-md # Also import CLAUDE.md files
|
// go run ./cmd/brain-seed -api-key YOUR_KEY -claude-md # Also import CLAUDE.md files
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"encoding/json"
|
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
|
||||||
goio "io"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
coreio "forge.lthn.ai/core/go-io"
|
core "dappco.re/go/core"
|
||||||
coreerr "forge.lthn.ai/core/go-log"
|
coreio "dappco.re/go/io"
|
||||||
|
coreerr "dappco.re/go/log"
|
||||||
|
brainclient "dappco.re/go/mcp/pkg/mcp/brain/client"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const seedDivider = "======================================================="
|
||||||
|
|
||||||
var (
|
var (
|
||||||
apiURL = flag.String("api", "https://lthn.sh/api/v1/mcp", "MCP API base URL")
|
apiURL = flag.String("api", brainclient.DefaultURL, "OpenBrain API base URL")
|
||||||
apiKey = flag.String("api-key", "", "MCP API key (Bearer token)")
|
apiKey = flag.String("api-key", core.Env("CORE_BRAIN_KEY"), "OpenBrain API key (Bearer token)")
|
||||||
server = flag.String("server", "hosthub-agent", "MCP server ID")
|
server = flag.String("server", "hosthub-agent", "Legacy MCP server ID flag; accepted for compatibility")
|
||||||
|
org = flag.String("org", core.Env("CORE_BRAIN_ORG"), "OpenBrain org for seeded memories")
|
||||||
agent = flag.String("agent", "charon", "Agent ID for attribution")
|
agent = flag.String("agent", "charon", "Agent ID for attribution")
|
||||||
dryRun = flag.Bool("dry-run", false, "Preview without storing")
|
dryRun = flag.Bool("dry-run", false, "Preview without storing")
|
||||||
plans = flag.Bool("plans", false, "Also import plan documents")
|
plans = flag.Bool("plans", false, "Also import plan documents")
|
||||||
|
|
@ -45,33 +43,33 @@ var (
|
||||||
maxChars = flag.Int("max-chars", 3800, "Max chars per section (embeddinggemma limit ~4000)")
|
maxChars = flag.Int("max-chars", 3800, "Max chars per section (embeddinggemma limit ~4000)")
|
||||||
)
|
)
|
||||||
|
|
||||||
// httpClient with TLS skip for non-public TLDs (.lthn.sh has real certs, but
|
var openbrain *brainclient.Client
|
||||||
// allow .lan/.local if someone has legacy config).
|
|
||||||
var httpClient = &http.Client{
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
Transport: &http.Transport{
|
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: false},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
fmt.Println("OpenBrain Seed — MCP API Client")
|
core.Println("OpenBrain Seed — API Client")
|
||||||
fmt.Println(strings.Repeat("=", 55))
|
core.Println(seedDivider)
|
||||||
|
|
||||||
if *apiKey == "" && !*dryRun {
|
if *apiKey == "" && !*dryRun {
|
||||||
fmt.Println("ERROR: -api-key is required (or use -dry-run)")
|
core.Println("ERROR: -api-key is required (or use -dry-run)")
|
||||||
fmt.Println(" Generate one at: https://lthn.sh/admin/mcp/api-keys")
|
core.Println(" Generate one at: https://lthn.sh/admin/mcp/api-keys")
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
if *dryRun {
|
if *dryRun {
|
||||||
fmt.Println("[DRY RUN] — no data will be stored")
|
core.Println("[DRY RUN] — no data will be stored")
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("API: %s\n", *apiURL)
|
core.Print(nil, "API: %s", *apiURL)
|
||||||
fmt.Printf("Server: %s | Agent: %s\n", *server, *agent)
|
core.Print(nil, "Org: %s | Agent: %s", *org, *agent)
|
||||||
|
|
||||||
|
openbrain = brainclient.New(brainclient.Options{
|
||||||
|
URL: *apiURL,
|
||||||
|
Key: *apiKey,
|
||||||
|
Org: *org,
|
||||||
|
AgentID: *agent,
|
||||||
|
})
|
||||||
|
|
||||||
// Discover memory files
|
// Discover memory files
|
||||||
memPath := *memoryPath
|
memPath := *memoryPath
|
||||||
|
|
@ -80,7 +78,7 @@ func main() {
|
||||||
memPath = filepath.Join(home, ".claude", "projects", "*", "memory")
|
memPath = filepath.Join(home, ".claude", "projects", "*", "memory")
|
||||||
}
|
}
|
||||||
memFiles, _ := filepath.Glob(filepath.Join(memPath, "*.md"))
|
memFiles, _ := filepath.Glob(filepath.Join(memPath, "*.md"))
|
||||||
fmt.Printf("\nFound %d memory files\n", len(memFiles))
|
core.Print(nil, "\nFound %d memory files", len(memFiles))
|
||||||
|
|
||||||
// Discover plan files
|
// Discover plan files
|
||||||
var planFiles []string
|
var planFiles []string
|
||||||
|
|
@ -103,7 +101,7 @@ func main() {
|
||||||
hostUkNested, _ := filepath.Glob(filepath.Join(hostUkPath, "*", "*.md"))
|
hostUkNested, _ := filepath.Glob(filepath.Join(hostUkPath, "*", "*.md"))
|
||||||
planFiles = append(planFiles, hostUkNested...)
|
planFiles = append(planFiles, hostUkNested...)
|
||||||
|
|
||||||
fmt.Printf("Found %d plan files\n", len(planFiles))
|
core.Print(nil, "Found %d plan files", len(planFiles))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Discover CLAUDE.md files
|
// Discover CLAUDE.md files
|
||||||
|
|
@ -115,7 +113,7 @@ func main() {
|
||||||
cPath = filepath.Join(home, "Code")
|
cPath = filepath.Join(home, "Code")
|
||||||
}
|
}
|
||||||
claudeFiles = discoverClaudeMdFiles(cPath)
|
claudeFiles = discoverClaudeMdFiles(cPath)
|
||||||
fmt.Printf("Found %d CLAUDE.md files\n", len(claudeFiles))
|
core.Print(nil, "Found %d CLAUDE.md files", len(claudeFiles))
|
||||||
}
|
}
|
||||||
|
|
||||||
imported := 0
|
imported := 0
|
||||||
|
|
@ -123,11 +121,11 @@ func main() {
|
||||||
errors := 0
|
errors := 0
|
||||||
|
|
||||||
// Process memory files
|
// Process memory files
|
||||||
fmt.Println("\n--- Memory Files ---")
|
core.Println("\n--- Memory Files ---")
|
||||||
for _, f := range memFiles {
|
for _, f := range memFiles {
|
||||||
project := extractProject(f)
|
project := extractProject(f)
|
||||||
sections := parseMarkdownSections(f)
|
sections := parseMarkdownSections(f)
|
||||||
filename := strings.TrimSuffix(filepath.Base(f), ".md")
|
filename := core.TrimSuffix(filepath.Base(f), ".md")
|
||||||
|
|
||||||
if len(sections) == 0 {
|
if len(sections) == 0 {
|
||||||
coreerr.Warn("brain-seed: skip file (no sections)", "project", project, "file", filename)
|
coreerr.Warn("brain-seed: skip file (no sections)", "project", project, "file", filename)
|
||||||
|
|
@ -137,7 +135,7 @@ func main() {
|
||||||
|
|
||||||
for _, sec := range sections {
|
for _, sec := range sections {
|
||||||
content := sec.heading + "\n\n" + sec.content
|
content := sec.heading + "\n\n" + sec.content
|
||||||
if strings.TrimSpace(sec.content) == "" {
|
if core.Trim(sec.content) == "" {
|
||||||
skipped++
|
skipped++
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
@ -150,7 +148,7 @@ func main() {
|
||||||
content = truncate(content, *maxChars)
|
content = truncate(content, *maxChars)
|
||||||
|
|
||||||
if *dryRun {
|
if *dryRun {
|
||||||
fmt.Printf(" [DRY] %s/%s :: %s (%s) — %d chars\n",
|
core.Print(nil, " [DRY] %s/%s :: %s (%s) — %d chars",
|
||||||
project, filename, sec.heading, memType, len(content))
|
project, filename, sec.heading, memType, len(content))
|
||||||
imported++
|
imported++
|
||||||
continue
|
continue
|
||||||
|
|
@ -161,18 +159,18 @@ func main() {
|
||||||
errors++
|
errors++
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
fmt.Printf(" ok %s/%s :: %s (%s)\n", project, filename, sec.heading, memType)
|
core.Print(nil, " ok %s/%s :: %s (%s)", project, filename, sec.heading, memType)
|
||||||
imported++
|
imported++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process plan files
|
// Process plan files
|
||||||
if *plans && len(planFiles) > 0 {
|
if *plans && len(planFiles) > 0 {
|
||||||
fmt.Println("\n--- Plan Documents ---")
|
core.Println("\n--- Plan Documents ---")
|
||||||
for _, f := range planFiles {
|
for _, f := range planFiles {
|
||||||
project := extractProjectFromPlan(f)
|
project := extractProjectFromPlan(f)
|
||||||
sections := parseMarkdownSections(f)
|
sections := parseMarkdownSections(f)
|
||||||
filename := strings.TrimSuffix(filepath.Base(f), ".md")
|
filename := core.TrimSuffix(filepath.Base(f), ".md")
|
||||||
|
|
||||||
if len(sections) == 0 {
|
if len(sections) == 0 {
|
||||||
skipped++
|
skipped++
|
||||||
|
|
@ -181,7 +179,7 @@ func main() {
|
||||||
|
|
||||||
for _, sec := range sections {
|
for _, sec := range sections {
|
||||||
content := sec.heading + "\n\n" + sec.content
|
content := sec.heading + "\n\n" + sec.content
|
||||||
if strings.TrimSpace(sec.content) == "" {
|
if core.Trim(sec.content) == "" {
|
||||||
skipped++
|
skipped++
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
@ -190,7 +188,7 @@ func main() {
|
||||||
content = truncate(content, *maxChars)
|
content = truncate(content, *maxChars)
|
||||||
|
|
||||||
if *dryRun {
|
if *dryRun {
|
||||||
fmt.Printf(" [DRY] %s :: %s / %s (plan) — %d chars\n",
|
core.Print(nil, " [DRY] %s :: %s / %s (plan) — %d chars",
|
||||||
project, filename, sec.heading, len(content))
|
project, filename, sec.heading, len(content))
|
||||||
imported++
|
imported++
|
||||||
continue
|
continue
|
||||||
|
|
@ -201,7 +199,7 @@ func main() {
|
||||||
errors++
|
errors++
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
fmt.Printf(" ok %s :: %s / %s (plan)\n", project, filename, sec.heading)
|
core.Print(nil, " ok %s :: %s / %s (plan)", project, filename, sec.heading)
|
||||||
imported++
|
imported++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -209,7 +207,7 @@ func main() {
|
||||||
|
|
||||||
// Process CLAUDE.md files
|
// Process CLAUDE.md files
|
||||||
if *claudeMd && len(claudeFiles) > 0 {
|
if *claudeMd && len(claudeFiles) > 0 {
|
||||||
fmt.Println("\n--- CLAUDE.md Files ---")
|
core.Println("\n--- CLAUDE.md Files ---")
|
||||||
for _, f := range claudeFiles {
|
for _, f := range claudeFiles {
|
||||||
project := extractProjectFromClaudeMd(f)
|
project := extractProjectFromClaudeMd(f)
|
||||||
sections := parseMarkdownSections(f)
|
sections := parseMarkdownSections(f)
|
||||||
|
|
@ -221,7 +219,7 @@ func main() {
|
||||||
|
|
||||||
for _, sec := range sections {
|
for _, sec := range sections {
|
||||||
content := sec.heading + "\n\n" + sec.content
|
content := sec.heading + "\n\n" + sec.content
|
||||||
if strings.TrimSpace(sec.content) == "" {
|
if core.Trim(sec.content) == "" {
|
||||||
skipped++
|
skipped++
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
@ -230,7 +228,7 @@ func main() {
|
||||||
content = truncate(content, *maxChars)
|
content = truncate(content, *maxChars)
|
||||||
|
|
||||||
if *dryRun {
|
if *dryRun {
|
||||||
fmt.Printf(" [DRY] %s :: CLAUDE.md / %s (convention) — %d chars\n",
|
core.Print(nil, " [DRY] %s :: CLAUDE.md / %s (convention) — %d chars",
|
||||||
project, sec.heading, len(content))
|
project, sec.heading, len(content))
|
||||||
imported++
|
imported++
|
||||||
continue
|
continue
|
||||||
|
|
@ -241,74 +239,44 @@ func main() {
|
||||||
errors++
|
errors++
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
fmt.Printf(" ok %s :: CLAUDE.md / %s (convention)\n", project, sec.heading)
|
core.Print(nil, " ok %s :: CLAUDE.md / %s (convention)", project, sec.heading)
|
||||||
imported++
|
imported++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("\n%s\n", strings.Repeat("=", 55))
|
core.Print(nil, "\n%s", seedDivider)
|
||||||
prefix := ""
|
prefix := ""
|
||||||
if *dryRun {
|
if *dryRun {
|
||||||
prefix = "[DRY RUN] "
|
prefix = "[DRY RUN] "
|
||||||
}
|
}
|
||||||
fmt.Printf("%sImported: %d | Skipped: %d | Errors: %d\n", prefix, imported, skipped, errors)
|
core.Print(nil, "%sImported: %d | Skipped: %d | Errors: %d", prefix, imported, skipped, errors)
|
||||||
}
|
}
|
||||||
|
|
||||||
// callBrainRemember sends a memory to the MCP API via brain_remember tool.
|
// callBrainRemember sends a memory to OpenBrain via /v1/brain/remember.
|
||||||
func callBrainRemember(content, memType string, tags []string, project string, confidence float64) error {
|
func callBrainRemember(content, memType string, tags []string, project string, confidence float64) error {
|
||||||
args := map[string]any{
|
if openbrain == nil {
|
||||||
"content": content,
|
openbrain = brainclient.New(brainclient.Options{
|
||||||
"type": memType,
|
URL: *apiURL,
|
||||||
"tags": tags,
|
Key: *apiKey,
|
||||||
"confidence": confidence,
|
Org: *org,
|
||||||
|
AgentID: *agent,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
input := brainclient.RememberInput{
|
||||||
|
Content: content,
|
||||||
|
Type: memType,
|
||||||
|
Tags: tags,
|
||||||
|
Org: *org,
|
||||||
|
AgentID: *agent,
|
||||||
|
Confidence: confidence,
|
||||||
}
|
}
|
||||||
if project != "" && project != "unknown" {
|
if project != "" && project != "unknown" {
|
||||||
args["project"] = project
|
input.Project = project
|
||||||
}
|
}
|
||||||
|
_, err := openbrain.Remember(context.Background(), input)
|
||||||
payload := map[string]any{
|
return coreerr.Wrap(err, "callBrainRemember", "remember")
|
||||||
"server": *server,
|
|
||||||
"tool": "brain_remember",
|
|
||||||
"arguments": args,
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := json.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
return coreerr.E("callBrainRemember", "marshal", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", *apiURL+"/tools/call", bytes.NewReader(body))
|
|
||||||
if err != nil {
|
|
||||||
return coreerr.E("callBrainRemember", "request", err)
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Authorization", "Bearer "+*apiKey)
|
|
||||||
|
|
||||||
resp, err := httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return coreerr.E("callBrainRemember", "http", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
respBody, _ := goio.ReadAll(resp.Body)
|
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
return coreerr.E("callBrainRemember", "HTTP "+string(respBody), nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
var result struct {
|
|
||||||
Success bool `json:"success"`
|
|
||||||
Error string `json:"error"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
|
||||||
return coreerr.E("callBrainRemember", "decode", err)
|
|
||||||
}
|
|
||||||
if !result.Success {
|
|
||||||
return coreerr.E("callBrainRemember", "API: "+result.Error, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// truncate caps content to maxLen chars, appending an ellipsis if truncated.
|
// truncate caps content to maxLen chars, appending an ellipsis if truncated.
|
||||||
|
|
@ -318,12 +286,21 @@ func truncate(s string, maxLen int) string {
|
||||||
}
|
}
|
||||||
// Find last space before limit to avoid splitting mid-word
|
// Find last space before limit to avoid splitting mid-word
|
||||||
cut := maxLen
|
cut := maxLen
|
||||||
if idx := strings.LastIndex(s[:maxLen], " "); idx > maxLen-200 {
|
if idx := lastByteIndex(s[:maxLen], ' '); idx > maxLen-200 {
|
||||||
cut = idx
|
cut = idx
|
||||||
}
|
}
|
||||||
return s[:cut] + "…"
|
return s[:cut] + "…"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func lastByteIndex(s string, target byte) int {
|
||||||
|
for i := len(s) - 1; i >= 0; i-- {
|
||||||
|
if s[i] == target {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
// discoverClaudeMdFiles finds CLAUDE.md files across a code directory.
|
// discoverClaudeMdFiles finds CLAUDE.md files across a code directory.
|
||||||
func discoverClaudeMdFiles(codePath string) []string {
|
func discoverClaudeMdFiles(codePath string) []string {
|
||||||
var files []string
|
var files []string
|
||||||
|
|
@ -340,7 +317,7 @@ func discoverClaudeMdFiles(codePath string) []string {
|
||||||
}
|
}
|
||||||
// Limit depth
|
// Limit depth
|
||||||
rel, _ := filepath.Rel(codePath, path)
|
rel, _ := filepath.Rel(codePath, path)
|
||||||
if strings.Count(rel, string(os.PathSeparator)) > 3 {
|
if len(core.Split(rel, string(os.PathSeparator))) > 4 {
|
||||||
return filepath.SkipDir
|
return filepath.SkipDir
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -370,19 +347,19 @@ func parseMarkdownSections(path string) []section {
|
||||||
}
|
}
|
||||||
|
|
||||||
var sections []section
|
var sections []section
|
||||||
lines := strings.Split(data, "\n")
|
lines := core.Split(data, "\n")
|
||||||
var curHeading string
|
var curHeading string
|
||||||
var curContent []string
|
var curContent []string
|
||||||
|
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
if m := headingRe.FindStringSubmatch(line); m != nil {
|
if m := headingRe.FindStringSubmatch(line); m != nil {
|
||||||
if curHeading != "" && len(curContent) > 0 {
|
if curHeading != "" && len(curContent) > 0 {
|
||||||
text := strings.TrimSpace(strings.Join(curContent, "\n"))
|
text := core.Trim(core.Join("\n", curContent...))
|
||||||
if text != "" {
|
if text != "" {
|
||||||
sections = append(sections, section{curHeading, text})
|
sections = append(sections, section{curHeading, text})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
curHeading = strings.TrimSpace(m[1])
|
curHeading = core.Trim(m[1])
|
||||||
curContent = nil
|
curContent = nil
|
||||||
} else {
|
} else {
|
||||||
curContent = append(curContent, line)
|
curContent = append(curContent, line)
|
||||||
|
|
@ -391,17 +368,17 @@ func parseMarkdownSections(path string) []section {
|
||||||
|
|
||||||
// Flush last section
|
// Flush last section
|
||||||
if curHeading != "" && len(curContent) > 0 {
|
if curHeading != "" && len(curContent) > 0 {
|
||||||
text := strings.TrimSpace(strings.Join(curContent, "\n"))
|
text := core.Trim(core.Join("\n", curContent...))
|
||||||
if text != "" {
|
if text != "" {
|
||||||
sections = append(sections, section{curHeading, text})
|
sections = append(sections, section{curHeading, text})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no headings found, treat entire file as one section
|
// If no headings found, treat entire file as one section
|
||||||
if len(sections) == 0 && strings.TrimSpace(data) != "" {
|
if len(sections) == 0 && core.Trim(data) != "" {
|
||||||
sections = append(sections, section{
|
sections = append(sections, section{
|
||||||
heading: strings.TrimSuffix(filepath.Base(path), ".md"),
|
heading: core.TrimSuffix(filepath.Base(path), ".md"),
|
||||||
content: strings.TrimSpace(data),
|
content: core.Trim(data),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -459,7 +436,7 @@ func inferType(heading, content, source string) string {
|
||||||
return "convention"
|
return "convention"
|
||||||
}
|
}
|
||||||
|
|
||||||
lower := strings.ToLower(heading + " " + content)
|
lower := core.Lower(heading + " " + content)
|
||||||
patterns := map[string][]string{
|
patterns := map[string][]string{
|
||||||
"architecture": {"architecture", "stack", "infrastructure", "layer", "service mesh"},
|
"architecture": {"architecture", "stack", "infrastructure", "layer", "service mesh"},
|
||||||
"convention": {"convention", "standard", "naming", "pattern", "rule", "coding"},
|
"convention": {"convention", "standard", "naming", "pattern", "rule", "coding"},
|
||||||
|
|
@ -470,7 +447,7 @@ func inferType(heading, content, source string) string {
|
||||||
}
|
}
|
||||||
for t, keywords := range patterns {
|
for t, keywords := range patterns {
|
||||||
for _, kw := range keywords {
|
for _, kw := range keywords {
|
||||||
if strings.Contains(lower, kw) {
|
if core.Contains(lower, kw) {
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -485,7 +462,7 @@ func buildTags(filename, source, project string) []string {
|
||||||
tags = append(tags, "project:"+project)
|
tags = append(tags, "project:"+project)
|
||||||
}
|
}
|
||||||
if filename != "MEMORY" && filename != "CLAUDE" {
|
if filename != "MEMORY" && filename != "CLAUDE" {
|
||||||
tags = append(tags, strings.ReplaceAll(strings.ReplaceAll(filename, "-", " "), "_", " "))
|
tags = append(tags, core.Replace(core.Replace(filename, "-", " "), "_", " "))
|
||||||
}
|
}
|
||||||
return tags
|
return tags
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"forge.lthn.ai/core/cli/pkg/cli"
|
"dappco.re/go/cli/pkg/cli"
|
||||||
mcpcmd "dappco.re/go/mcp/cmd/mcpcmd"
|
mcpcmd "dappco.re/go/mcp/cmd/mcpcmd"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,14 @@
|
||||||
// Package mcpcmd provides the MCP server command.
|
// SPDX-License-Identifier: EUPL-1.2
|
||||||
|
|
||||||
|
// Package mcpcmd registers the `mcp` and `mcp serve` CLI commands.
|
||||||
|
//
|
||||||
|
// Wiring example:
|
||||||
|
//
|
||||||
|
// cli.Main(cli.WithCommands("mcp", mcpcmd.AddMCPCommands))
|
||||||
//
|
//
|
||||||
// Commands:
|
// Commands:
|
||||||
// - mcp serve: Start the MCP server for AI tool integration
|
// - mcp Start the MCP server on stdio (default transport).
|
||||||
|
// - mcp serve Start the MCP server with auto-selected transport.
|
||||||
package mcpcmd
|
package mcpcmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
@ -10,75 +17,89 @@ import (
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
"dappco.re/go/mcp/pkg/mcp"
|
"dappco.re/go/mcp/pkg/mcp"
|
||||||
"dappco.re/go/mcp/pkg/mcp/agentic"
|
"dappco.re/go/mcp/pkg/mcp/agentic"
|
||||||
"dappco.re/go/mcp/pkg/mcp/brain"
|
"dappco.re/go/mcp/pkg/mcp/brain"
|
||||||
"forge.lthn.ai/core/cli/pkg/cli"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var workspaceFlag string
|
// newMCPService is the service constructor, indirected for tests.
|
||||||
var unrestrictedFlag bool
|
|
||||||
|
|
||||||
var newMCPService = mcp.New
|
var newMCPService = mcp.New
|
||||||
|
|
||||||
|
// runMCPService starts the MCP server, indirected for tests.
|
||||||
var runMCPService = func(svc *mcp.Service, ctx context.Context) error {
|
var runMCPService = func(svc *mcp.Service, ctx context.Context) error {
|
||||||
return svc.Run(ctx)
|
return svc.Run(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// shutdownMCPService performs graceful shutdown, indirected for tests.
|
||||||
var shutdownMCPService = func(svc *mcp.Service, ctx context.Context) error {
|
var shutdownMCPService = func(svc *mcp.Service, ctx context.Context) error {
|
||||||
return svc.Shutdown(ctx)
|
return svc.Shutdown(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
var mcpCmd = &cli.Command{
|
// workspaceFlag mirrors the --workspace CLI flag value.
|
||||||
Use: "mcp",
|
var workspaceFlag string
|
||||||
Short: "MCP server for AI tool integration",
|
|
||||||
Long: "Model Context Protocol (MCP) server providing file operations, RAG, and metrics tools.",
|
// unrestrictedFlag mirrors the --unrestricted CLI flag value.
|
||||||
|
var unrestrictedFlag bool
|
||||||
|
|
||||||
|
// AddMCPCommands registers the `mcp` command tree on the Core instance.
|
||||||
|
//
|
||||||
|
// cli.Main(cli.WithCommands("mcp", mcpcmd.AddMCPCommands))
|
||||||
|
func AddMCPCommands(c *core.Core) {
|
||||||
|
c.Command("mcp", core.Command{
|
||||||
|
Description: "Model Context Protocol server (stdio, TCP, Unix socket, HTTP).",
|
||||||
|
Action: runServeAction,
|
||||||
|
Flags: core.NewOptions(
|
||||||
|
core.Option{Key: "workspace", Value: ""},
|
||||||
|
core.Option{Key: "w", Value: ""},
|
||||||
|
core.Option{Key: "unrestricted", Value: false},
|
||||||
|
),
|
||||||
|
})
|
||||||
|
|
||||||
|
c.Command("mcp/serve", core.Command{
|
||||||
|
Description: "Start the MCP server with auto-selected transport (stdio, TCP, Unix, or HTTP).",
|
||||||
|
Action: runServeAction,
|
||||||
|
Flags: core.NewOptions(
|
||||||
|
core.Option{Key: "workspace", Value: ""},
|
||||||
|
core.Option{Key: "w", Value: ""},
|
||||||
|
core.Option{Key: "unrestricted", Value: false},
|
||||||
|
),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
var serveCmd = &cli.Command{
|
// runServeAction is the CLI entrypoint for `mcp` and `mcp serve`.
|
||||||
Use: "serve",
|
//
|
||||||
Short: "Start the MCP server",
|
// opts := core.NewOptions(core.Option{Key: "workspace", Value: "."})
|
||||||
Long: `Start the MCP server on stdio (default), TCP, Unix socket, or HTTP.
|
// result := runServeAction(opts)
|
||||||
|
func runServeAction(opts core.Options) core.Result {
|
||||||
|
workspaceFlag = core.Trim(firstNonEmpty(opts.String("workspace"), opts.String("w")))
|
||||||
|
unrestrictedFlag = opts.Bool("unrestricted")
|
||||||
|
|
||||||
The server provides file operations plus the brain and agentic subsystems
|
if err := runServe(); err != nil {
|
||||||
registered by this command.
|
return core.Result{Value: err, OK: false}
|
||||||
|
}
|
||||||
Environment variables:
|
return core.Result{OK: true}
|
||||||
MCP_ADDR TCP address to listen on (e.g., "localhost:9999")
|
|
||||||
MCP_UNIX_SOCKET
|
|
||||||
Unix socket path to listen on (e.g., "/tmp/core-mcp.sock")
|
|
||||||
Selected after MCP_ADDR and before stdio.
|
|
||||||
MCP_HTTP_ADDR
|
|
||||||
HTTP address to listen on (e.g., "127.0.0.1:9101")
|
|
||||||
Selected before MCP_ADDR and stdio.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
# Start with stdio transport (for Claude Code integration)
|
|
||||||
core mcp serve
|
|
||||||
|
|
||||||
# Start with workspace restriction
|
|
||||||
core mcp serve --workspace /path/to/project
|
|
||||||
|
|
||||||
# Start unrestricted (explicit opt-in)
|
|
||||||
core mcp serve --unrestricted
|
|
||||||
|
|
||||||
# Start TCP server
|
|
||||||
MCP_ADDR=localhost:9999 core mcp serve`,
|
|
||||||
RunE: func(cmd *cli.Command, args []string) error {
|
|
||||||
return runServe()
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func initFlags() {
|
// firstNonEmpty returns the first non-empty string argument.
|
||||||
cli.StringFlag(serveCmd, &workspaceFlag, "workspace", "w", "", "Restrict file operations to this directory")
|
//
|
||||||
cli.BoolFlag(serveCmd, &unrestrictedFlag, "unrestricted", "", false, "Disable filesystem sandboxing entirely")
|
// firstNonEmpty("", "foo") == "foo"
|
||||||
}
|
// firstNonEmpty("bar", "baz") == "bar"
|
||||||
|
func firstNonEmpty(values ...string) string {
|
||||||
// AddMCPCommands registers the 'mcp' command and all subcommands.
|
for _, v := range values {
|
||||||
func AddMCPCommands(root *cli.Command) {
|
if v != "" {
|
||||||
initFlags()
|
return v
|
||||||
mcpCmd.AddCommand(serveCmd)
|
}
|
||||||
root.AddCommand(mcpCmd)
|
}
|
||||||
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// runServe wires the MCP service together and blocks until the context is
|
||||||
|
// cancelled by SIGINT/SIGTERM or a transport error.
|
||||||
|
//
|
||||||
|
// if err := runServe(); err != nil {
|
||||||
|
// core.Error("mcp serve failed", "err", err)
|
||||||
|
// }
|
||||||
func runServe() error {
|
func runServe() error {
|
||||||
opts := mcp.Options{}
|
opts := mcp.Options{}
|
||||||
|
|
||||||
|
|
@ -88,22 +109,20 @@ func runServe() error {
|
||||||
opts.WorkspaceRoot = workspaceFlag
|
opts.WorkspaceRoot = workspaceFlag
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register OpenBrain and agentic subsystems
|
// Register OpenBrain and agentic subsystems.
|
||||||
opts.Subsystems = []mcp.Subsystem{
|
opts.Subsystems = []mcp.Subsystem{
|
||||||
brain.NewDirect(),
|
brain.NewDirect(),
|
||||||
agentic.NewPrep(),
|
agentic.NewPrep(),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the MCP service
|
|
||||||
svc, err := newMCPService(opts)
|
svc, err := newMCPService(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return cli.Wrap(err, "create MCP service")
|
return core.E("mcpcmd.runServe", "create MCP service", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = shutdownMCPService(svc, context.Background())
|
_ = shutdownMCPService(svc, context.Background())
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Set up signal handling for clean shutdown
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
|
@ -111,10 +130,12 @@ func runServe() error {
|
||||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
<-sigCh
|
select {
|
||||||
|
case <-sigCh:
|
||||||
cancel()
|
cancel()
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Run the server (blocks until context cancelled or error)
|
|
||||||
return runMCPService(svc, ctx)
|
return runMCPService(svc, ctx)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,26 +1,18 @@
|
||||||
|
// SPDX-License-Identifier: EUPL-1.2
|
||||||
|
|
||||||
package mcpcmd
|
package mcpcmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
"dappco.re/go/mcp/pkg/mcp"
|
"dappco.re/go/mcp/pkg/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRunServe_Good_ShutsDownService(t *testing.T) {
|
func TestCmdMCP_RunServe_Good_ShutsDownService(t *testing.T) {
|
||||||
oldNew := newMCPService
|
restore := stubMCPService(t)
|
||||||
oldRun := runMCPService
|
defer restore()
|
||||||
oldShutdown := shutdownMCPService
|
|
||||||
oldWorkspace := workspaceFlag
|
|
||||||
oldUnrestricted := unrestrictedFlag
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
newMCPService = oldNew
|
|
||||||
runMCPService = oldRun
|
|
||||||
shutdownMCPService = oldShutdown
|
|
||||||
workspaceFlag = oldWorkspace
|
|
||||||
unrestrictedFlag = oldUnrestricted
|
|
||||||
})
|
|
||||||
|
|
||||||
workspaceFlag = ""
|
workspaceFlag = ""
|
||||||
unrestrictedFlag = false
|
unrestrictedFlag = false
|
||||||
|
|
@ -50,3 +42,186 @@ func TestRunServe_Good_ShutsDownService(t *testing.T) {
|
||||||
t.Fatal("expected shutdownMCPService to be called")
|
t.Fatal("expected shutdownMCPService to be called")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCmdMCP_RunServeAction_Good_PropagatesFlags(t *testing.T) {
|
||||||
|
restore := stubMCPService(t)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
workspaceFlag = ""
|
||||||
|
unrestrictedFlag = false
|
||||||
|
|
||||||
|
var gotOpts mcp.Options
|
||||||
|
newMCPService = func(opts mcp.Options) (*mcp.Service, error) {
|
||||||
|
gotOpts = opts
|
||||||
|
return mcp.New(mcp.Options{WorkspaceRoot: t.TempDir()})
|
||||||
|
}
|
||||||
|
runMCPService = func(svc *mcp.Service, ctx context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
shutdownMCPService = func(svc *mcp.Service, ctx context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tmp := t.TempDir()
|
||||||
|
opts := core.NewOptions(core.Option{Key: "workspace", Value: tmp})
|
||||||
|
|
||||||
|
result := runServeAction(opts)
|
||||||
|
if !result.OK {
|
||||||
|
t.Fatalf("expected OK, got %+v", result)
|
||||||
|
}
|
||||||
|
if gotOpts.WorkspaceRoot != tmp {
|
||||||
|
t.Fatalf("expected workspace root %q, got %q", tmp, gotOpts.WorkspaceRoot)
|
||||||
|
}
|
||||||
|
if gotOpts.Unrestricted {
|
||||||
|
t.Fatal("expected Unrestricted=false when --workspace is set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdMCP_RunServeAction_Good_UnrestrictedFlag(t *testing.T) {
|
||||||
|
restore := stubMCPService(t)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
workspaceFlag = ""
|
||||||
|
unrestrictedFlag = false
|
||||||
|
|
||||||
|
var gotOpts mcp.Options
|
||||||
|
newMCPService = func(opts mcp.Options) (*mcp.Service, error) {
|
||||||
|
gotOpts = opts
|
||||||
|
return mcp.New(mcp.Options{Unrestricted: true})
|
||||||
|
}
|
||||||
|
runMCPService = func(svc *mcp.Service, ctx context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
shutdownMCPService = func(svc *mcp.Service, ctx context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := core.NewOptions(core.Option{Key: "unrestricted", Value: true})
|
||||||
|
|
||||||
|
result := runServeAction(opts)
|
||||||
|
if !result.OK {
|
||||||
|
t.Fatalf("expected OK, got %+v", result)
|
||||||
|
}
|
||||||
|
if !gotOpts.Unrestricted {
|
||||||
|
t.Fatal("expected Unrestricted=true when --unrestricted is set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdMCP_RunServe_Bad_CreateServiceFails(t *testing.T) {
|
||||||
|
restore := stubMCPService(t)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
workspaceFlag = ""
|
||||||
|
unrestrictedFlag = false
|
||||||
|
|
||||||
|
sentinel := core.E("mcpcmd.test", "boom", nil)
|
||||||
|
newMCPService = func(opts mcp.Options) (*mcp.Service, error) {
|
||||||
|
return nil, sentinel
|
||||||
|
}
|
||||||
|
runMCPService = func(svc *mcp.Service, ctx context.Context) error {
|
||||||
|
t.Fatal("runMCPService should not be called when New fails")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
shutdownMCPService = func(svc *mcp.Service, ctx context.Context) error {
|
||||||
|
t.Fatal("shutdownMCPService should not be called when New fails")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := runServe()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when newMCPService fails")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdMCP_RunServeAction_Bad_PropagatesFailure(t *testing.T) {
|
||||||
|
restore := stubMCPService(t)
|
||||||
|
defer restore()
|
||||||
|
|
||||||
|
workspaceFlag = ""
|
||||||
|
unrestrictedFlag = false
|
||||||
|
|
||||||
|
newMCPService = func(opts mcp.Options) (*mcp.Service, error) {
|
||||||
|
return nil, core.E("mcpcmd.test", "construction failed", nil)
|
||||||
|
}
|
||||||
|
runMCPService = func(svc *mcp.Service, ctx context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
shutdownMCPService = func(svc *mcp.Service, ctx context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := runServeAction(core.NewOptions())
|
||||||
|
if result.OK {
|
||||||
|
t.Fatal("expected runServeAction to fail when service creation fails")
|
||||||
|
}
|
||||||
|
if result.Value == nil {
|
||||||
|
t.Fatal("expected error value on failure")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdMCP_FirstNonEmpty_Ugly_HandlesAllVariants(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
values []string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"no args", nil, ""},
|
||||||
|
{"empty string", []string{""}, ""},
|
||||||
|
{"all empty", []string{"", "", ""}, ""},
|
||||||
|
{"first non-empty", []string{"foo", "bar"}, "foo"},
|
||||||
|
{"skip empty", []string{"", "baz"}, "baz"},
|
||||||
|
{"mixed", []string{"", "", "last"}, "last"},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := firstNonEmpty(tc.values...)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Fatalf("firstNonEmpty(%v) = %q, want %q", tc.values, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdMCP_AddMCPCommands_Good_RegistersMcpTree(t *testing.T) {
|
||||||
|
c := core.New()
|
||||||
|
AddMCPCommands(c)
|
||||||
|
|
||||||
|
commands := c.Commands()
|
||||||
|
if len(commands) == 0 {
|
||||||
|
t.Fatal("expected at least one registered command")
|
||||||
|
}
|
||||||
|
|
||||||
|
mustHave := map[string]bool{
|
||||||
|
"mcp": false,
|
||||||
|
"mcp/serve": false,
|
||||||
|
}
|
||||||
|
for _, path := range commands {
|
||||||
|
if _, ok := mustHave[path]; ok {
|
||||||
|
mustHave[path] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for path, present := range mustHave {
|
||||||
|
if !present {
|
||||||
|
t.Fatalf("expected command %q to be registered", path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// stubMCPService captures the package-level function pointers and returns a
|
||||||
|
// restore hook so each test can mutate them without leaking into siblings.
|
||||||
|
func stubMCPService(t *testing.T) func() {
|
||||||
|
t.Helper()
|
||||||
|
oldNew := newMCPService
|
||||||
|
oldRun := runMCPService
|
||||||
|
oldShutdown := shutdownMCPService
|
||||||
|
oldWorkspace := workspaceFlag
|
||||||
|
oldUnrestricted := unrestrictedFlag
|
||||||
|
|
||||||
|
return func() {
|
||||||
|
newMCPService = oldNew
|
||||||
|
runMCPService = oldRun
|
||||||
|
shutdownMCPService = oldShutdown
|
||||||
|
workspaceFlag = oldWorkspace
|
||||||
|
unrestrictedFlag = oldUnrestricted
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
29
cmd/openbrain-mcp/README.md
Normal file
29
cmd/openbrain-mcp/README.md
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
# openbrain-mcp
|
||||||
|
|
||||||
|
`openbrain-mcp` is a thin stdio MCP wrapper for the OpenBrain tools registered in `pkg/mcp/brain`.
|
||||||
|
|
||||||
|
Install:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
go install dappco.re/go/mcp/cmd/openbrain-mcp@latest
|
||||||
|
```
|
||||||
|
|
||||||
|
Add it to Claude Code:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
claude mcp add openbrain -- openbrain-mcp --brain-url=http://127.0.0.1:8000/v1/brain --api-key=$OPENBRAIN_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
The wrapper exposes:
|
||||||
|
|
||||||
|
- `brain_remember`
|
||||||
|
- `brain_recall`
|
||||||
|
- `brain_forget`
|
||||||
|
- `brain_list`
|
||||||
|
|
||||||
|
Flags:
|
||||||
|
|
||||||
|
- `--brain-url`: OpenBrain BrainService URL. Defaults to `http://127.0.0.1:8000/v1/brain`.
|
||||||
|
- `--api-key`: OpenBrain API key. Defaults to `OPENBRAIN_API_KEY`.
|
||||||
|
|
||||||
|
The process logs to stderr only. Stdout is reserved for MCP framing.
|
||||||
95
cmd/openbrain-mcp/main.go
Normal file
95
cmd/openbrain-mcp/main.go
Normal file
|
|
@ -0,0 +1,95 @@
|
||||||
|
// SPDX-License-Identifier: EUPL-1.2
|
||||||
|
|
||||||
|
// openbrain-mcp exposes the OpenBrain MCP tools over stdio for Claude Code.
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
|
coreerr "dappco.re/go/log"
|
||||||
|
"dappco.re/go/mcp/pkg/mcp"
|
||||||
|
"dappco.re/go/mcp/pkg/mcp/brain"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultBrainURL = "http://127.0.0.1:8000/v1/brain"
|
||||||
|
|
||||||
|
var (
|
||||||
|
brainURLFlag = flag.String("brain-url", defaultBrainURL, "OpenBrain BrainService URL")
|
||||||
|
apiKeyFlag = flag.String("api-key", "", "OpenBrain API key (defaults to OPENBRAIN_API_KEY)")
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
if err := run(); err != nil {
|
||||||
|
coreerr.Error("openbrain-mcp failed", "err", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func run() error {
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if err := configureBrainEnv(*brainURLFlag, *apiKeyFlag); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
svc, err := mcp.New(mcp.Options{
|
||||||
|
Subsystems: []mcp.Subsystem{
|
||||||
|
brain.NewDirect(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return core.E("openbrain-mcp.run", "create MCP service", err)
|
||||||
|
}
|
||||||
|
defer shutdownService(svc)
|
||||||
|
|
||||||
|
if err := svc.ServeStdio(ctx); err != nil && !core.Is(err, context.Canceled) {
|
||||||
|
return core.E("openbrain-mcp.run", "serve stdio", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func configureBrainEnv(brainURL, apiKey string) error {
|
||||||
|
baseURL := directBrainBaseURL(brainURL)
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = directBrainBaseURL(defaultBrainURL)
|
||||||
|
}
|
||||||
|
if err := os.Setenv("CORE_BRAIN_URL", baseURL); err != nil {
|
||||||
|
return core.E("openbrain-mcp.configure", "set CORE_BRAIN_URL", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
key := core.Trim(apiKey)
|
||||||
|
if key == "" {
|
||||||
|
key = core.Trim(core.Env("OPENBRAIN_API_KEY"))
|
||||||
|
}
|
||||||
|
if key == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := os.Setenv("CORE_BRAIN_KEY", key); err != nil {
|
||||||
|
return core.E("openbrain-mcp.configure", "set CORE_BRAIN_KEY", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func directBrainBaseURL(brainURL string) string {
|
||||||
|
baseURL := core.Trim(brainURL)
|
||||||
|
baseURL = core.TrimSuffix(baseURL, "/")
|
||||||
|
baseURL = core.TrimSuffix(baseURL, "/v1/brain")
|
||||||
|
return core.TrimSuffix(baseURL, "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
func shutdownService(svc *mcp.Service) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := svc.Shutdown(ctx); err != nil {
|
||||||
|
coreerr.Error("openbrain-mcp shutdown failed", "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
141
docs/security-vulnerabilities.md
Normal file
141
docs/security-vulnerabilities.md
Normal file
|
|
@ -0,0 +1,141 @@
|
||||||
|
# Security Vulnerabilities — Accepted Findings + Operator Mitigations
|
||||||
|
|
||||||
|
This document records security findings (govulncheck, etc.) that have been
|
||||||
|
manually reviewed and **accepted with documented rationale** rather than
|
||||||
|
patched. Each entry names the CVE, what makes it not-applicable to our use
|
||||||
|
case, and any operator-side mitigations required to keep that not-applicable
|
||||||
|
status valid.
|
||||||
|
|
||||||
|
Audit history:
|
||||||
|
- Mantis #323 — 9 ollama CVEs reviewed and documented (2026-04-25)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## github.com/ollama/ollama (indirect via go-rag)
|
||||||
|
|
||||||
|
**Status as of 2026-04-25:** all 9 CVEs filed in Mantis #323 are **UNFIXED
|
||||||
|
upstream** per [pkg.go.dev/vuln](https://pkg.go.dev/vuln/). Pin-bumping does
|
||||||
|
not resolve any of them. We are on `v0.18.1` indirect; ollama upstream is at
|
||||||
|
`v0.21.2` (2026-04-23).
|
||||||
|
|
||||||
|
**Our usage scope:** the entire workspace imports `github.com/ollama/ollama/api`
|
||||||
|
from exactly ONE file (`go-rag/ollama.go`). The surface in use is **3 symbols
|
||||||
|
only**:
|
||||||
|
- `api.NewClient(baseURL, *http.Client)` — constructor
|
||||||
|
- `api.Client` — struct value (held as a field by `OllamaClient`)
|
||||||
|
- `api.EmbedRequest` — embedding-request DTO
|
||||||
|
|
||||||
|
**We are a CLIENT** of someone else's Ollama server. We do NOT host an Ollama
|
||||||
|
server. Most CVEs in the list are server-side code paths that govulncheck's
|
||||||
|
reachability graph flags because the package is imported, but our actual call
|
||||||
|
sites do not traverse those paths.
|
||||||
|
|
||||||
|
### CVE-by-CVE reachability assessment
|
||||||
|
|
||||||
|
| CVE | Description | Reachable from our call graph? | Action |
|
||||||
|
|---|---|---|---|
|
||||||
|
| GO-2025-3548 (CVE-2024-12886) | DoS via crafted GZIP | NO — server-side parser | Accept |
|
||||||
|
| GO-2025-3557 (CVE-2025-0315) | Resource alloc without limits | NO — server-side dispatcher | Accept |
|
||||||
|
| GO-2025-3558 | Out-of-bounds read | NO — server-side inference | Accept |
|
||||||
|
| GO-2025-3559 | Divide by zero | NO — server-side inference | Accept |
|
||||||
|
| GO-2025-3582 | Null pointer deref DoS | NO — server-side handler | Accept |
|
||||||
|
| GO-2025-3689 | Divide by zero | NO — server-side inference | Accept |
|
||||||
|
| GO-2025-3695 | Server DoS | NO — server-side handler | Accept |
|
||||||
|
| GO-2025-3824 (CVE-2025-51471) | Cross-domain token exposure | **CONDITIONAL** — see below | Watch |
|
||||||
|
| GO-2025-4251 (CVE-2025-63389) | Missing auth on model-mgmt | **OPERATOR-SIDE** — see below | Runbook |
|
||||||
|
|
||||||
|
### GO-2025-3824 — token-exposure watch flag
|
||||||
|
|
||||||
|
This CVE concerns auth tokens leaking across domain boundaries when Ollama
|
||||||
|
clients pass authentication. Currently `NewOllamaClient(cfg)` constructs over
|
||||||
|
plain HTTP/HTTPS without auth headers — the embedding client connects to a
|
||||||
|
trusted local Ollama instance per the deployment runbook below.
|
||||||
|
|
||||||
|
**If we ever add auth-token plumbing to the Ollama client** (e.g. for hosted
|
||||||
|
Ollama services), re-evaluate this CVE. The reachability flips from NO to YES
|
||||||
|
the moment we set an Authorization header on `api.NewClient`.
|
||||||
|
|
||||||
|
### GO-2025-4251 — operator-side mitigation required
|
||||||
|
|
||||||
|
This CVE is a missing authentication / authorization gap on Ollama's
|
||||||
|
model-management endpoints. The vulnerability is in the **Ollama server**,
|
||||||
|
not our client code. Our client doesn't expose model-management calls;
|
||||||
|
operators do via running an Ollama server.
|
||||||
|
|
||||||
|
**Operator mitigation (REQUIRED):** see "Ollama deployment" section below.
|
||||||
|
Operators MUST front their Ollama instance with network-level access controls
|
||||||
|
or an authentication proxy. This is also Ollama upstream's own recommendation
|
||||||
|
in the advisory.
|
||||||
|
|
||||||
|
### Watch flag
|
||||||
|
|
||||||
|
If any of the 9 CVEs gets a fixed version released, re-evaluate:
|
||||||
|
- Bump `go-rag/go.mod` require for `github.com/ollama/ollama` to the fixed version
|
||||||
|
- Re-run govulncheck and prune entries from this document accordingly
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Ollama deployment — operator runbook
|
||||||
|
|
||||||
|
The Ollama instance the agent connects to runs OUTSIDE of our application
|
||||||
|
boundary. Operators are responsible for these mitigations:
|
||||||
|
|
||||||
|
### 1. Network-level isolation (mandatory)
|
||||||
|
|
||||||
|
Bind the Ollama server to a private interface or front it with a reverse proxy:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# OPTION A — localhost-only binding (single-host deployments)
|
||||||
|
OLLAMA_HOST=127.0.0.1:11434 ollama serve
|
||||||
|
|
||||||
|
# OPTION B — private network only (multi-host fleet)
|
||||||
|
# Bind to the wireguard / tailscale / private-VLAN interface, not 0.0.0.0
|
||||||
|
OLLAMA_HOST=10.42.0.5:11434 ollama serve
|
||||||
|
```
|
||||||
|
|
||||||
|
**Never** expose Ollama directly to the public internet. GO-2025-4251 makes
|
||||||
|
model-management operations possible without auth.
|
||||||
|
|
||||||
|
### 2. Reverse proxy with auth (recommended for shared deployments)
|
||||||
|
|
||||||
|
If multiple agents share an Ollama server, front it with nginx/caddy/traefik
|
||||||
|
adding HTTP Basic Auth or an authentication proxy (oauth2-proxy, authentik):
|
||||||
|
|
||||||
|
```nginx
|
||||||
|
location /ollama/ {
|
||||||
|
auth_basic "Ollama API";
|
||||||
|
auth_basic_user_file /etc/nginx/ollama.htpasswd;
|
||||||
|
proxy_pass http://10.42.0.5:11434/;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Configure the agent's `OllamaConfig.Endpoint` to point at the reverse proxy
|
||||||
|
URL, and add an `Authorization` header to the http.Client passed to
|
||||||
|
`api.NewClient`. (When that change lands, re-evaluate GO-2025-3824 per
|
||||||
|
the watch-flag note above.)
|
||||||
|
|
||||||
|
### 3. CI-side govulncheck filter
|
||||||
|
|
||||||
|
Until upstream Ollama ships fixes for any of the 9 CVEs, CI should suppress
|
||||||
|
just these specific findings (not blanket-suppress all govulncheck output):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
govulncheck ./... 2>&1 | grep -vE 'GO-2025-(3548|3557|3558|3559|3582|3689|3695|3824|4251)\b'
|
||||||
|
```
|
||||||
|
|
||||||
|
When a CVE gets a fix and we bump past it, drop that CVE ID from the grep
|
||||||
|
filter so future regressions surface cleanly.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## How to add to this document
|
||||||
|
|
||||||
|
When a new accepted finding lands:
|
||||||
|
|
||||||
|
1. Open a new H2 section named for the dependency
|
||||||
|
2. Document the reachability + rationale per CVE in a table
|
||||||
|
3. Add operator-side mitigations if any
|
||||||
|
4. Update the audit-history bullet at the top with a Mantis ticket reference
|
||||||
|
|
||||||
|
**Do NOT add findings here without a Mantis ticket.** Every accepted finding
|
||||||
|
must have a tracker entry so the rationale is auditable + reviewable.
|
||||||
27
go.mod
27
go.mod
|
|
@ -4,26 +4,25 @@ go 1.26.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
dappco.re/go/core v0.8.0-alpha.1
|
dappco.re/go/core v0.8.0-alpha.1
|
||||||
forge.lthn.ai/core/api v0.1.5
|
dappco.re/go/ai v0.8.0-alpha.1
|
||||||
forge.lthn.ai/core/cli v0.3.7
|
dappco.re/go/api v0.8.0-alpha.1
|
||||||
forge.lthn.ai/core/go-ai v0.1.12
|
dappco.re/go/cli v0.8.0-alpha.1
|
||||||
forge.lthn.ai/core/go-io v0.1.7
|
dappco.re/go/io v0.8.0-alpha.1
|
||||||
forge.lthn.ai/core/go-log v0.0.4
|
dappco.re/go/log v0.8.0-alpha.1
|
||||||
forge.lthn.ai/core/go-process v0.2.9
|
dappco.re/go/process v0.8.0-alpha.1
|
||||||
forge.lthn.ai/core/go-rag v0.1.11
|
dappco.re/go/rag v0.8.0-alpha.1
|
||||||
forge.lthn.ai/core/go-webview v0.1.6
|
dappco.re/go/webview v0.8.0-alpha.1
|
||||||
forge.lthn.ai/core/go-ws v0.2.5
|
dappco.re/go/ws v0.8.0-alpha.1
|
||||||
github.com/gin-gonic/gin v1.12.0
|
github.com/gin-gonic/gin v1.12.0
|
||||||
github.com/gorilla/websocket v1.5.3
|
github.com/gorilla/websocket v1.5.3
|
||||||
github.com/modelcontextprotocol/go-sdk v1.4.1
|
github.com/modelcontextprotocol/go-sdk v1.5.0
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
forge.lthn.ai/core/go v0.3.3 // indirect
|
dappco.re/go/i18n v0.8.0-alpha.1 // indirect
|
||||||
forge.lthn.ai/core/go-i18n v0.1.7 // indirect
|
dappco.re/go/inference v0.8.0-alpha.1 // indirect
|
||||||
forge.lthn.ai/core/go-inference v0.1.6 // indirect
|
|
||||||
github.com/99designs/gqlgen v0.17.88 // indirect
|
github.com/99designs/gqlgen v0.17.88 // indirect
|
||||||
github.com/KyleBanks/depth v1.2.1 // indirect
|
github.com/KyleBanks/depth v1.2.1 // indirect
|
||||||
github.com/agnivade/levenshtein v1.2.1 // indirect
|
github.com/agnivade/levenshtein v1.2.1 // indirect
|
||||||
|
|
@ -149,3 +148,5 @@ require (
|
||||||
google.golang.org/grpc v1.79.2 // indirect
|
google.golang.org/grpc v1.79.2 // indirect
|
||||||
google.golang.org/protobuf v1.36.11 // indirect
|
google.golang.org/protobuf v1.36.11 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
|
replace dappco.re/go/core/process => ../go-process
|
||||||
|
|
|
||||||
48
go.sum
48
go.sum
|
|
@ -1,29 +1,25 @@
|
||||||
dappco.re/go/core v0.8.0-alpha.1 h1:gj7+Scv+L63Z7wMxbJYHhaRFkHJo2u4MMPuUSv/Dhtk=
|
dappco.re/go/core v0.8.0-alpha.1 h1:gj7+Scv+L63Z7wMxbJYHhaRFkHJo2u4MMPuUSv/Dhtk=
|
||||||
dappco.re/go/core v0.8.0-alpha.1/go.mod h1:f2/tBZ3+3IqDrg2F5F598llv0nmb/4gJVCFzM5geE4A=
|
dappco.re/go/core v0.8.0-alpha.1/go.mod h1:f2/tBZ3+3IqDrg2F5F598llv0nmb/4gJVCFzM5geE4A=
|
||||||
forge.lthn.ai/core/api v0.1.5 h1:NwZrcOyBjaiz5/cn0n0tnlMUodi8Or6FHMx59C7Kv2o=
|
dappco.re/go/core/ai v0.2.2 h1:fkSKm3ezAljYbghlax5qHDm11uq7LUyIedIQO1PtdcY=
|
||||||
forge.lthn.ai/core/api v0.1.5/go.mod h1:PBnaWyOVXSOGy+0x2XAPUFMYJxQ2CNhppia/D06ZPII=
|
dappco.re/go/core/ai v0.2.2/go.mod h1:+MZN/EArn/W2ag91McL034WxdMSO4IPqFcQER5/POGU=
|
||||||
forge.lthn.ai/core/cli v0.3.7 h1:1GrbaGg0wDGHr6+klSbbGyN/9sSbHvFbdySJznymhwg=
|
dappco.re/go/core/api v0.3.0 h1:uWYgDQ+B4e5pXPX3S5lMsqSJamfpui3LWD5hcdwvWew=
|
||||||
forge.lthn.ai/core/cli v0.3.7/go.mod h1:DBUppJkA9P45ZFGgI2B8VXw1rAZxamHoI/KG7fRvTNs=
|
dappco.re/go/core/api v0.3.0/go.mod h1:1ZDNwPHV6YjkUsjtC3nfLk6U4eqWlQ6qj6yT/MB8r6k=
|
||||||
forge.lthn.ai/core/go v0.3.3 h1:kYYZ2nRYy0/Be3cyuLJspRjLqTMxpckVyhb/7Sw2gd0=
|
dappco.re/go/core/cli v0.5.2 h1:mo+PERo3lUytE+r3ArHr8o2nTftXjgPPsU/rn3ETXDM=
|
||||||
forge.lthn.ai/core/go v0.3.3/go.mod h1:Cp4ac25pghvO2iqOu59t1GyngTKVOzKB5/VPdhRi9CQ=
|
dappco.re/go/core/cli v0.5.2/go.mod h1:D4zfn3ec/hb72AWX/JWDvkW+h2WDKQcxGUrzoss7q2s=
|
||||||
forge.lthn.ai/core/go-ai v0.1.12 h1:OHt0bUABlyhvgxZxyMwueRoh8rS3YKWGFY6++zCAwC8=
|
dappco.re/go/core/i18n v0.2.3 h1:GqFaTR1I0SfSEc4WtsAkgao+jp8X5qcMPqrX0eMAOrY=
|
||||||
forge.lthn.ai/core/go-ai v0.1.12/go.mod h1:5Pc9lszxgkO7Aj2Z3dtq4L9Xk9l/VNN+Baj1t///OCM=
|
dappco.re/go/core/i18n v0.2.3/go.mod h1:LoyX/4fIEJO/wiHY3Q682+4P0Ob7zPemcATfwp0JBUg=
|
||||||
forge.lthn.ai/core/go-i18n v0.1.7 h1:aHkAoc3W8fw3RPNvw/UszQbjyFWXHszzbZgty3SwyAA=
|
dappco.re/go/core/inference v0.3.0 h1:ANFnlVO1LEYDipeDeBgqmb8CHvOTUFhMPyfyHGqO0IY=
|
||||||
forge.lthn.ai/core/go-i18n v0.1.7/go.mod h1:0VDjwtY99NSj2iqwrI09h5GUsJeM9s48MLkr+/Dn4G8=
|
dappco.re/go/core/inference v0.3.0/go.mod h1:wbRY0v6iwOoJCpTvcBFarAM08bMgpPcrF6yv3vccYoA=
|
||||||
forge.lthn.ai/core/go-inference v0.1.6 h1:ce42zC0zO8PuISUyAukAN1NACEdWp5wF1mRgnh5+58E=
|
dappco.re/go/core/io v0.4.1 h1:15dm7ldhFIAuZOrBiQG6XVZDpSvCxtZsUXApwTAB3wQ=
|
||||||
forge.lthn.ai/core/go-inference v0.1.6/go.mod h1:jfWz+IJX55wAH98+ic6FEqqGB6/P31CHlg7VY7pxREw=
|
dappco.re/go/core/io v0.4.1/go.mod h1:w71dukyunczLb8frT9JOd5B78PjwWQD3YAXiCt3AcPA=
|
||||||
forge.lthn.ai/core/go-io v0.1.7 h1:Tdb6sqh+zz1lsGJaNX9RFWM6MJ/RhSAyxfulLXrJsbk=
|
dappco.re/go/core/log v0.1.2 h1:pQSZxKD8VycdvjNJmatXbPSq2OxcP2xHbF20zgFIiZI=
|
||||||
forge.lthn.ai/core/go-io v0.1.7/go.mod h1:8lRLFk4Dnp5cR/Cyzh9WclD5566TbpdRgwcH7UZLWn4=
|
dappco.re/go/core/log v0.1.2/go.mod h1:Nkqb8gsXhZAO8VLpx7B8i1iAmohhzqA20b9Zr8VUcJs=
|
||||||
forge.lthn.ai/core/go-log v0.0.4 h1:KTuCEPgFmuM8KJfnyQ8vPOU1Jg654W74h8IJvfQMfv0=
|
dappco.re/go/core/rag v0.1.13 h1:R2Q+Xw5YenT4uFemXLBu+xQYtyUIYGSmMln5/Z+nol4=
|
||||||
forge.lthn.ai/core/go-log v0.0.4/go.mod h1:r14MXKOD3LF/sI8XUJQhRk/SZHBE7jAFVuCfgkXoZPw=
|
dappco.re/go/core/rag v0.1.13/go.mod h1:wthXtCqYEChjlGIHcJXetlgk49lPDmzG6jFWd1PEIZc=
|
||||||
forge.lthn.ai/core/go-process v0.2.9 h1:Wql+5TUF+lfU2oJ9I+S764MkTqJhBsuyMM0v1zsfZC4=
|
dappco.re/go/core/webview v0.2.1 h1:rdy2sV+MS6RZsav8BiARJxtWhfx7eOAJp3b1Ynp1sYs=
|
||||||
forge.lthn.ai/core/go-process v0.2.9/go.mod h1:NIzZOF5IVYYCjHkcNIGcg1mZH+bzGoie4SlZUDYOKIM=
|
dappco.re/go/core/webview v0.2.1/go.mod h1:Qdo1V/sJJwOnL0hYd3+vzVUJxWYC8eGyILZROya6KoM=
|
||||||
forge.lthn.ai/core/go-rag v0.1.11 h1:KXTOtnOdrx8YKmvnj0EOi2EI/+cKjE8w2PpJCQIrSd8=
|
dappco.re/go/core/ws v0.4.0 h1:yEDV9whXyo+GWzBSjuB3NiLiH2bmBPBWD6rydwHyBn8=
|
||||||
forge.lthn.ai/core/go-rag v0.1.11/go.mod h1:vIlOKVD1SdqqjkJ2XQyXPuKPtiajz/STPLCaDpqOzk8=
|
dappco.re/go/core/ws v0.4.0/go.mod h1:L1rrgW6zU+DztcVBJW2yO5Lm3rGXpyUMOA8OL9zsAok=
|
||||||
forge.lthn.ai/core/go-webview v0.1.6 h1:szXQxRJf2bOZJKh3v1P01B1Vf9mgXaBCXzh0EZu9aoc=
|
|
||||||
forge.lthn.ai/core/go-webview v0.1.6/go.mod h1:5n1tECD1wBV/uFZRY9ZjfPFO5TYZrlaR3mQFwvO2nek=
|
|
||||||
forge.lthn.ai/core/go-ws v0.2.5 h1:ZIV7Yrv01R/xpJUogA5vrfP9yB9li1w7EV3eZFMt8h0=
|
|
||||||
forge.lthn.ai/core/go-ws v0.2.5/go.mod h1:C3riJyLLcV6QhLvYlq3P/XkGTsN598qQeGBoLdoHBU4=
|
|
||||||
github.com/99designs/gqlgen v0.17.88 h1:neMQDgehMwT1vYIOx/w5ZYPUU/iMNAJzRO44I5Intoc=
|
github.com/99designs/gqlgen v0.17.88 h1:neMQDgehMwT1vYIOx/w5ZYPUU/iMNAJzRO44I5Intoc=
|
||||||
github.com/99designs/gqlgen v0.17.88/go.mod h1:qeqYFEgOeSKqWedOjogPizimp2iu4E23bdPvl4jTYic=
|
github.com/99designs/gqlgen v0.17.88/go.mod h1:qeqYFEgOeSKqWedOjogPizimp2iu4E23bdPvl4jTYic=
|
||||||
github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc=
|
github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc=
|
||||||
|
|
@ -222,8 +218,8 @@ github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2J
|
||||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||||
github.com/mattn/go-runewidth v0.0.21 h1:jJKAZiQH+2mIinzCJIaIG9Be1+0NR+5sz/lYEEjdM8w=
|
github.com/mattn/go-runewidth v0.0.21 h1:jJKAZiQH+2mIinzCJIaIG9Be1+0NR+5sz/lYEEjdM8w=
|
||||||
github.com/mattn/go-runewidth v0.0.21/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
github.com/mattn/go-runewidth v0.0.21/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||||
github.com/modelcontextprotocol/go-sdk v1.4.1 h1:M4x9GyIPj+HoIlHNGpK2hq5o3BFhC+78PkEaldQRphc=
|
github.com/modelcontextprotocol/go-sdk v1.5.0 h1:CHU0FIX9kpueNkxuYtfYQn1Z0slhFzBZuq+x6IiblIU=
|
||||||
github.com/modelcontextprotocol/go-sdk v1.4.1/go.mod h1:Bo/mS87hPQqHSRkMv4dQq1XCu6zv4INdXnFZabkNU6s=
|
github.com/modelcontextprotocol/go-sdk v1.5.0/go.mod h1:gggDIhoemhWs3BGkGwd1umzEXCEMMvAnhTrnbXJKKKA=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
|
|
|
||||||
|
|
@ -4,17 +4,15 @@ package agentic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
|
coreio "dappco.re/go/io"
|
||||||
|
coreerr "dappco.re/go/log"
|
||||||
coremcp "dappco.re/go/mcp/pkg/mcp"
|
coremcp "dappco.re/go/mcp/pkg/mcp"
|
||||||
coreio "forge.lthn.ai/core/go-io"
|
|
||||||
coreerr "forge.lthn.ai/core/go-log"
|
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -54,7 +52,7 @@ func (s *PrepSubsystem) registerDispatchTool(svc *coremcp.Service) {
|
||||||
// agentCommand returns the command and args for a given agent type.
|
// agentCommand returns the command and args for a given agent type.
|
||||||
// Supports model variants: "gemini", "gemini:flash", "gemini:pro", "claude", "claude:haiku".
|
// Supports model variants: "gemini", "gemini:flash", "gemini:pro", "claude", "claude:haiku".
|
||||||
func agentCommand(agent, prompt string) (string, []string, error) {
|
func agentCommand(agent, prompt string) (string, []string, error) {
|
||||||
parts := strings.SplitN(agent, ":", 2)
|
parts := core.SplitN(agent, ":", 2)
|
||||||
base := parts[0]
|
base := parts[0]
|
||||||
model := ""
|
model := ""
|
||||||
if len(parts) > 1 {
|
if len(parts) > 1 {
|
||||||
|
|
@ -78,7 +76,7 @@ func agentCommand(agent, prompt string) (string, []string, error) {
|
||||||
return "claude", args, nil
|
return "claude", args, nil
|
||||||
case "local":
|
case "local":
|
||||||
home, _ := os.UserHomeDir()
|
home, _ := os.UserHomeDir()
|
||||||
script := filepath.Join(home, "Code", "core", "agent", "scripts", "local-agent.sh")
|
script := core.Path(home, "Code", "core", "agent", "scripts", "local-agent.sh")
|
||||||
return "bash", []string{script, prompt}, nil
|
return "bash", []string{script, prompt}, nil
|
||||||
default:
|
default:
|
||||||
return "", nil, coreerr.E("agentCommand", "unknown agent: "+agent, nil)
|
return "", nil, coreerr.E("agentCommand", "unknown agent: "+agent, nil)
|
||||||
|
|
@ -86,6 +84,9 @@ func agentCommand(agent, prompt string) (string, []string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PrepSubsystem) dispatch(ctx context.Context, req *mcp.CallToolRequest, input DispatchInput) (*mcp.CallToolResult, DispatchOutput, error) {
|
func (s *PrepSubsystem) dispatch(ctx context.Context, req *mcp.CallToolRequest, input DispatchInput) (*mcp.CallToolResult, DispatchOutput, error) {
|
||||||
|
progress := coremcp.NewProgressNotifier(ctx, req)
|
||||||
|
const dispatchProgressTotal = 4
|
||||||
|
|
||||||
if input.Repo == "" {
|
if input.Repo == "" {
|
||||||
return nil, DispatchOutput{}, coreerr.E("dispatch", "repo is required", nil)
|
return nil, DispatchOutput{}, coreerr.E("dispatch", "repo is required", nil)
|
||||||
}
|
}
|
||||||
|
|
@ -102,7 +103,10 @@ func (s *PrepSubsystem) dispatch(ctx context.Context, req *mcp.CallToolRequest,
|
||||||
input.Template = "coding"
|
input.Template = "coding"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_ = progress.Send(1, dispatchProgressTotal, "validated dispatch request")
|
||||||
|
|
||||||
// Step 1: Prep the sandboxed workspace
|
// Step 1: Prep the sandboxed workspace
|
||||||
|
_ = progress.Send(2, dispatchProgressTotal, "preparing workspace")
|
||||||
prepInput := PrepInput{
|
prepInput := PrepInput{
|
||||||
Repo: input.Repo,
|
Repo: input.Repo,
|
||||||
Org: input.Org,
|
Org: input.Org,
|
||||||
|
|
@ -117,16 +121,18 @@ func (s *PrepSubsystem) dispatch(ctx context.Context, req *mcp.CallToolRequest,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, DispatchOutput{}, coreerr.E("dispatch", "prep workspace failed", err)
|
return nil, DispatchOutput{}, coreerr.E("dispatch", "prep workspace failed", err)
|
||||||
}
|
}
|
||||||
|
_ = progress.Send(3, dispatchProgressTotal, core.Sprintf("workspace prepared for %s", prepOut.Branch))
|
||||||
|
|
||||||
wsDir := prepOut.WorkspaceDir
|
wsDir := prepOut.WorkspaceDir
|
||||||
srcDir := filepath.Join(wsDir, "src")
|
srcDir := core.Path(wsDir, "src")
|
||||||
|
|
||||||
// The prompt is just: read PROMPT.md and do the work
|
// The prompt is just: read PROMPT.md and do the work
|
||||||
prompt := "Read PROMPT.md for instructions. All context files (CLAUDE.md, TODO.md, CONTEXT.md, CONSUMERS.md, RECENT.md) are in the parent directory. Work in this directory."
|
prompt := "Read PROMPT.md for instructions. All context files (CLAUDE.md, TODO.md, CONTEXT.md, CONSUMERS.md, RECENT.md) are in the parent directory. Work in this directory."
|
||||||
|
|
||||||
if input.DryRun {
|
if input.DryRun {
|
||||||
// Read PROMPT.md for the dry run output
|
// Read PROMPT.md for the dry run output
|
||||||
promptRaw, _ := coreio.Local.Read(filepath.Join(wsDir, "PROMPT.md"))
|
promptRaw, _ := coreio.Local.Read(core.Path(wsDir, "PROMPT.md"))
|
||||||
|
_ = progress.Send(dispatchProgressTotal, dispatchProgressTotal, "dry run complete")
|
||||||
return nil, DispatchOutput{
|
return nil, DispatchOutput{
|
||||||
Success: true,
|
Success: true,
|
||||||
Agent: input.Agent,
|
Agent: input.Agent,
|
||||||
|
|
@ -150,6 +156,7 @@ func (s *PrepSubsystem) dispatch(ctx context.Context, req *mcp.CallToolRequest,
|
||||||
StartedAt: time.Now(),
|
StartedAt: time.Now(),
|
||||||
Runs: 0,
|
Runs: 0,
|
||||||
})
|
})
|
||||||
|
_ = progress.Send(dispatchProgressTotal, dispatchProgressTotal, "queued until an agent slot is available")
|
||||||
return nil, DispatchOutput{
|
return nil, DispatchOutput{
|
||||||
Success: true,
|
Success: true,
|
||||||
Agent: input.Agent,
|
Agent: input.Agent,
|
||||||
|
|
@ -172,8 +179,10 @@ func (s *PrepSubsystem) dispatch(ctx context.Context, req *mcp.CallToolRequest,
|
||||||
StartedAt: time.Now(),
|
StartedAt: time.Now(),
|
||||||
Runs: 1,
|
Runs: 1,
|
||||||
})
|
})
|
||||||
|
_ = progress.Send(3.5, dispatchProgressTotal, "dispatch slot acquired")
|
||||||
|
|
||||||
// Step 4: Spawn agent as a detached process
|
// Step 4: Spawn agent as a detached process
|
||||||
|
_ = progress.Send(4, dispatchProgressTotal, core.Sprintf("spawning agent %s", input.Agent))
|
||||||
// Uses Setpgid so the agent survives parent (MCP server) death.
|
// Uses Setpgid so the agent survives parent (MCP server) death.
|
||||||
// Output goes directly to log file (not buffered in memory).
|
// Output goes directly to log file (not buffered in memory).
|
||||||
command, args, err := agentCommand(input.Agent, prompt)
|
command, args, err := agentCommand(input.Agent, prompt)
|
||||||
|
|
@ -181,7 +190,7 @@ func (s *PrepSubsystem) dispatch(ctx context.Context, req *mcp.CallToolRequest,
|
||||||
return nil, DispatchOutput{}, err
|
return nil, DispatchOutput{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
outputFile := filepath.Join(wsDir, fmt.Sprintf("agent-%s.log", input.Agent))
|
outputFile := core.Path(wsDir, core.Sprintf("agent-%s.log", input.Agent))
|
||||||
outFile, err := os.Create(outputFile)
|
outFile, err := os.Create(outputFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, DispatchOutput{}, coreerr.E("dispatch", "failed to create log file", err)
|
return nil, DispatchOutput{}, coreerr.E("dispatch", "failed to create log file", err)
|
||||||
|
|
@ -222,6 +231,7 @@ func (s *PrepSubsystem) dispatch(ctx context.Context, req *mcp.CallToolRequest,
|
||||||
}
|
}
|
||||||
|
|
||||||
pid := cmd.Process.Pid
|
pid := cmd.Process.Pid
|
||||||
|
_ = progress.Send(dispatchProgressTotal, dispatchProgressTotal, "agent process started")
|
||||||
|
|
||||||
// Update status with PID now that agent is running
|
// Update status with PID now that agent is running
|
||||||
s.saveStatus(wsDir, &WorkspaceStatus{
|
s.saveStatus(wsDir, &WorkspaceStatus{
|
||||||
|
|
@ -247,7 +257,7 @@ func (s *PrepSubsystem) dispatch(ctx context.Context, req *mcp.CallToolRequest,
|
||||||
status := "completed"
|
status := "completed"
|
||||||
channel := coremcp.ChannelAgentComplete
|
channel := coremcp.ChannelAgentComplete
|
||||||
payload := map[string]any{
|
payload := map[string]any{
|
||||||
"workspace": filepath.Base(wsDir),
|
"workspace": core.PathBase(wsDir),
|
||||||
"repo": input.Repo,
|
"repo": input.Repo,
|
||||||
"org": input.Org,
|
"org": input.Org,
|
||||||
"agent": input.Agent,
|
"agent": input.Agent,
|
||||||
|
|
@ -257,11 +267,11 @@ func (s *PrepSubsystem) dispatch(ctx context.Context, req *mcp.CallToolRequest,
|
||||||
// Update status to completed or blocked.
|
// Update status to completed or blocked.
|
||||||
if st, err := readStatus(wsDir); err == nil {
|
if st, err := readStatus(wsDir); err == nil {
|
||||||
st.PID = 0
|
st.PID = 0
|
||||||
if data, err := coreio.Local.Read(filepath.Join(wsDir, "src", "BLOCKED.md")); err == nil {
|
if data, err := coreio.Local.Read(core.Path(wsDir, "src", "BLOCKED.md")); err == nil {
|
||||||
status = "blocked"
|
status = "blocked"
|
||||||
channel = coremcp.ChannelAgentBlocked
|
channel = coremcp.ChannelAgentBlocked
|
||||||
st.Status = status
|
st.Status = status
|
||||||
st.Question = strings.TrimSpace(data)
|
st.Question = core.Trim(data)
|
||||||
if st.Question != "" {
|
if st.Question != "" {
|
||||||
payload["question"] = st.Question
|
payload["question"] = st.Question
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,12 +6,11 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
|
coreerr "dappco.re/go/log"
|
||||||
coremcp "dappco.re/go/mcp/pkg/mcp"
|
coremcp "dappco.re/go/mcp/pkg/mcp"
|
||||||
coreerr "forge.lthn.ai/core/go-log"
|
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -101,14 +100,14 @@ func (s *PrepSubsystem) createEpic(ctx context.Context, req *mcp.CallToolRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 2: Build epic body with checklist
|
// Step 2: Build epic body with checklist
|
||||||
var body strings.Builder
|
body := core.NewBuilder()
|
||||||
if input.Body != "" {
|
if input.Body != "" {
|
||||||
body.WriteString(input.Body)
|
body.WriteString(input.Body)
|
||||||
body.WriteString("\n\n")
|
body.WriteString("\n\n")
|
||||||
}
|
}
|
||||||
body.WriteString("## Tasks\n\n")
|
body.WriteString("## Tasks\n\n")
|
||||||
for _, child := range children {
|
for _, child := range children {
|
||||||
body.WriteString(fmt.Sprintf("- [ ] #%d %s\n", child.Number, child.Title))
|
body.WriteString(core.Sprintf("- [ ] #%d %s\n", child.Number, child.Title))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 3: Create epic issue
|
// Step 3: Create epic issue
|
||||||
|
|
@ -157,8 +156,12 @@ func (s *PrepSubsystem) createIssue(ctx context.Context, org, repo, title, body
|
||||||
payload["labels"] = labelIDs
|
payload["labels"] = labelIDs
|
||||||
}
|
}
|
||||||
|
|
||||||
data, _ := json.Marshal(payload)
|
r := core.JSONMarshal(payload)
|
||||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/issues", s.forgeURL, org, repo)
|
if !r.OK {
|
||||||
|
return ChildRef{}, coreerr.E("createIssue", "failed to encode issue payload", nil)
|
||||||
|
}
|
||||||
|
data := r.Value.([]byte)
|
||||||
|
url := core.Sprintf("%s/api/v1/repos/%s/%s/issues", s.forgeURL, org, repo)
|
||||||
req, _ := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(data))
|
req, _ := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(data))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("Authorization", "token "+s.forgeToken)
|
req.Header.Set("Authorization", "token "+s.forgeToken)
|
||||||
|
|
@ -170,7 +173,7 @@ func (s *PrepSubsystem) createIssue(ctx context.Context, org, repo, title, body
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != 201 {
|
if resp.StatusCode != 201 {
|
||||||
return ChildRef{}, coreerr.E("createIssue", fmt.Sprintf("returned %d", resp.StatusCode), nil)
|
return ChildRef{}, coreerr.E("createIssue", core.Sprintf("returned %d", resp.StatusCode), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
var result struct {
|
var result struct {
|
||||||
|
|
@ -193,7 +196,7 @@ func (s *PrepSubsystem) resolveLabelIDs(ctx context.Context, org, repo string, n
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch existing labels
|
// Fetch existing labels
|
||||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/labels?limit=50", s.forgeURL, org, repo)
|
url := core.Sprintf("%s/api/v1/repos/%s/%s/labels?limit=50", s.forgeURL, org, repo)
|
||||||
req, _ := http.NewRequestWithContext(ctx, "GET", url, nil)
|
req, _ := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||||
req.Header.Set("Authorization", "token "+s.forgeToken)
|
req.Header.Set("Authorization", "token "+s.forgeToken)
|
||||||
|
|
||||||
|
|
@ -246,12 +249,16 @@ func (s *PrepSubsystem) createLabel(ctx context.Context, org, repo, name string)
|
||||||
colour = "#6b7280"
|
colour = "#6b7280"
|
||||||
}
|
}
|
||||||
|
|
||||||
payload, _ := json.Marshal(map[string]string{
|
r := core.JSONMarshal(map[string]string{
|
||||||
"name": name,
|
"name": name,
|
||||||
"color": colour,
|
"color": colour,
|
||||||
})
|
})
|
||||||
|
if !r.OK {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
payload := r.Value.([]byte)
|
||||||
|
|
||||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/labels", s.forgeURL, org, repo)
|
url := core.Sprintf("%s/api/v1/repos/%s/%s/labels", s.forgeURL, org, repo)
|
||||||
req, _ := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payload))
|
req, _ := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payload))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("Authorization", "token "+s.forgeToken)
|
req.Header.Set("Authorization", "token "+s.forgeToken)
|
||||||
|
|
|
||||||
|
|
@ -3,17 +3,12 @@
|
||||||
package agentic
|
package agentic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
|
coreio "dappco.re/go/io"
|
||||||
coremcp "dappco.re/go/mcp/pkg/mcp"
|
coremcp "dappco.re/go/mcp/pkg/mcp"
|
||||||
coreio "forge.lthn.ai/core/go-io"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ingestFindings reads the agent output log and creates issues via the API
|
// ingestFindings reads the agent output log and creates issues via the API
|
||||||
|
|
@ -25,10 +20,7 @@ func (s *PrepSubsystem) ingestFindings(wsDir string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read the log file
|
// Read the log file
|
||||||
logFiles, err := filepath.Glob(filepath.Join(wsDir, "agent-*.log"))
|
logFiles := core.PathGlob(core.Path(wsDir, "agent-*.log"))
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(logFiles) == 0 {
|
if len(logFiles) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -41,7 +33,7 @@ func (s *PrepSubsystem) ingestFindings(wsDir string) {
|
||||||
body := contentStr
|
body := contentStr
|
||||||
|
|
||||||
// Skip quota errors
|
// Skip quota errors
|
||||||
if strings.Contains(body, "QUOTA_EXHAUSTED") || strings.Contains(body, "QuotaError") {
|
if core.Contains(body, "QUOTA_EXHAUSTED") || core.Contains(body, "QuotaError") {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -56,13 +48,13 @@ func (s *PrepSubsystem) ingestFindings(wsDir string) {
|
||||||
// Determine issue type from the template used
|
// Determine issue type from the template used
|
||||||
issueType := "task"
|
issueType := "task"
|
||||||
priority := "normal"
|
priority := "normal"
|
||||||
if strings.Contains(body, "security") || strings.Contains(body, "Security") {
|
if core.Contains(body, "security") || core.Contains(body, "Security") {
|
||||||
issueType = "bug"
|
issueType = "bug"
|
||||||
priority = "high"
|
priority = "high"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a single issue per repo with all findings in the body
|
// Create a single issue per repo with all findings in the body
|
||||||
title := fmt.Sprintf("Scan findings for %s (%d items)", st.Repo, findings)
|
title := core.Sprintf("Scan findings for %s (%d items)", st.Repo, findings)
|
||||||
|
|
||||||
// Truncate body to reasonable size for issue description
|
// Truncate body to reasonable size for issue description
|
||||||
description := body
|
description := body
|
||||||
|
|
@ -86,7 +78,7 @@ func countFileRefs(body string) int {
|
||||||
}
|
}
|
||||||
if j < len(body) && body[j] == '`' {
|
if j < len(body) && body[j] == '`' {
|
||||||
ref := body[i+1 : j]
|
ref := body[i+1 : j]
|
||||||
if strings.Contains(ref, ".go:") || strings.Contains(ref, ".php:") {
|
if core.Contains(ref, ".go:") || core.Contains(ref, ".php:") {
|
||||||
count++
|
count++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -102,25 +94,22 @@ func (s *PrepSubsystem) createIssueViaAPI(repo, title, description, issueType, p
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read the agent API key from file
|
// Read the agent API key from file
|
||||||
home, _ := os.UserHomeDir()
|
home := core.Env("HOME")
|
||||||
apiKeyData, err := coreio.Local.Read(filepath.Join(home, ".claude", "agent-api.key"))
|
apiKeyData, err := coreio.Local.Read(core.Path(home, ".claude", "agent-api.key"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
apiKey := strings.TrimSpace(apiKeyData)
|
apiKey := core.Trim(apiKeyData)
|
||||||
|
|
||||||
payload, err := json.Marshal(map[string]string{
|
payloadStr := core.JSONMarshalString(map[string]string{
|
||||||
"title": title,
|
"title": title,
|
||||||
"description": description,
|
"description": description,
|
||||||
"type": issueType,
|
"type": issueType,
|
||||||
"priority": priority,
|
"priority": priority,
|
||||||
"reporter": "cladius",
|
"reporter": "cladius",
|
||||||
})
|
})
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", s.brainURL+"/v1/issues", bytes.NewReader(payload))
|
req, err := http.NewRequest("POST", s.brainURL+"/v1/issues", core.NewReader(payloadStr))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,11 +6,11 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
|
coreerr "dappco.re/go/log"
|
||||||
coremcp "dappco.re/go/mcp/pkg/mcp"
|
coremcp "dappco.re/go/mcp/pkg/mcp"
|
||||||
coreerr "forge.lthn.ai/core/go-log"
|
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -49,6 +49,12 @@ func (s *PrepSubsystem) registerIssueTools(svc *coremcp.Service) {
|
||||||
Description: "Dispatch an agent to work on a Forge issue. Assigns the issue as a lock, prepends the issue body to TODO.md, creates an issue-specific branch, and spawns the agent.",
|
Description: "Dispatch an agent to work on a Forge issue. Assigns the issue as a lock, prepends the issue body to TODO.md, creates an issue-specific branch, and spawns the agent.",
|
||||||
}, s.dispatchIssue)
|
}, s.dispatchIssue)
|
||||||
|
|
||||||
|
// agentic_issue_dispatch is the spec-aligned name for the same action.
|
||||||
|
coremcp.AddToolRecorded(svc, server, "agentic", &mcp.Tool{
|
||||||
|
Name: "agentic_issue_dispatch",
|
||||||
|
Description: "Dispatch an agent to work on a Forge issue. Spec-aligned alias for agentic_dispatch_issue.",
|
||||||
|
}, s.dispatchIssue)
|
||||||
|
|
||||||
coremcp.AddToolRecorded(svc, server, "agentic", &mcp.Tool{
|
coremcp.AddToolRecorded(svc, server, "agentic", &mcp.Tool{
|
||||||
Name: "agentic_pr",
|
Name: "agentic_pr",
|
||||||
Description: "Create a pull request from an agent workspace. Pushes the branch and creates a Forge PR linked to the tracked issue, if any.",
|
Description: "Create a pull request from an agent workspace. Pushes the branch and creates a Forge PR linked to the tracked issue, if any.",
|
||||||
|
|
@ -77,10 +83,10 @@ func (s *PrepSubsystem) dispatchIssue(ctx context.Context, req *mcp.CallToolRequ
|
||||||
return nil, DispatchOutput{}, err
|
return nil, DispatchOutput{}, err
|
||||||
}
|
}
|
||||||
if issue.State != "open" {
|
if issue.State != "open" {
|
||||||
return nil, DispatchOutput{}, coreerr.E("dispatchIssue", fmt.Sprintf("issue %d is %s, not open", input.Issue, issue.State), nil)
|
return nil, DispatchOutput{}, coreerr.E("dispatchIssue", core.Sprintf("issue %d is %s, not open", input.Issue, issue.State), nil)
|
||||||
}
|
}
|
||||||
if issue.Assignee != nil && issue.Assignee.Login != "" {
|
if issue.Assignee != nil && issue.Assignee.Login != "" {
|
||||||
return nil, DispatchOutput{}, coreerr.E("dispatchIssue", fmt.Sprintf("issue %d is already assigned to %s", input.Issue, issue.Assignee.Login), nil)
|
return nil, DispatchOutput{}, coreerr.E("dispatchIssue", core.Sprintf("issue %d is already assigned to %s", input.Issue, issue.Assignee.Login), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !input.DryRun {
|
if !input.DryRun {
|
||||||
|
|
@ -124,7 +130,7 @@ func (s *PrepSubsystem) dispatchIssue(ctx context.Context, req *mcp.CallToolRequ
|
||||||
func (s *PrepSubsystem) unlockIssue(ctx context.Context, org, repo string, issue int, labels []struct {
|
func (s *PrepSubsystem) unlockIssue(ctx context.Context, org, repo string, issue int, labels []struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
}) error {
|
}) error {
|
||||||
updateURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/issues/%d", s.forgeURL, org, repo, issue)
|
updateURL := core.Sprintf("%s/api/v1/repos/%s/%s/issues/%d", s.forgeURL, org, repo, issue)
|
||||||
issueLabels := make([]string, 0, len(labels))
|
issueLabels := make([]string, 0, len(labels))
|
||||||
for _, label := range labels {
|
for _, label := range labels {
|
||||||
if label.Name == "in-progress" {
|
if label.Name == "in-progress" {
|
||||||
|
|
@ -135,13 +141,14 @@ func (s *PrepSubsystem) unlockIssue(ctx context.Context, org, repo string, issue
|
||||||
if issueLabels == nil {
|
if issueLabels == nil {
|
||||||
issueLabels = []string{}
|
issueLabels = []string{}
|
||||||
}
|
}
|
||||||
payload, err := json.Marshal(map[string]any{
|
r := core.JSONMarshal(map[string]any{
|
||||||
"assignees": []string{},
|
"assignees": []string{},
|
||||||
"labels": issueLabels,
|
"labels": issueLabels,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if !r.OK {
|
||||||
return coreerr.E("unlockIssue", "failed to encode issue unlock", err)
|
return coreerr.E("unlockIssue", "failed to encode issue unlock", nil)
|
||||||
}
|
}
|
||||||
|
payload := r.Value.([]byte)
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, updateURL, bytes.NewReader(payload))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, updateURL, bytes.NewReader(payload))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -156,14 +163,14 @@ func (s *PrepSubsystem) unlockIssue(ctx context.Context, org, repo string, issue
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
if resp.StatusCode >= http.StatusBadRequest {
|
if resp.StatusCode >= http.StatusBadRequest {
|
||||||
return coreerr.E("unlockIssue", fmt.Sprintf("issue unlock returned %d", resp.StatusCode), nil)
|
return coreerr.E("unlockIssue", core.Sprintf("issue unlock returned %d", resp.StatusCode), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PrepSubsystem) fetchIssue(ctx context.Context, org, repo string, issue int) (*forgeIssue, error) {
|
func (s *PrepSubsystem) fetchIssue(ctx context.Context, org, repo string, issue int) (*forgeIssue, error) {
|
||||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/issues/%d", s.forgeURL, org, repo, issue)
|
url := core.Sprintf("%s/api/v1/repos/%s/%s/issues/%d", s.forgeURL, org, repo, issue)
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, coreerr.E("fetchIssue", "failed to build request", err)
|
return nil, coreerr.E("fetchIssue", "failed to build request", err)
|
||||||
|
|
@ -176,7 +183,7 @@ func (s *PrepSubsystem) fetchIssue(ctx context.Context, org, repo string, issue
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return nil, coreerr.E("fetchIssue", fmt.Sprintf("issue %d not found in %s/%s", issue, org, repo), nil)
|
return nil, coreerr.E("fetchIssue", core.Sprintf("issue %d not found in %s/%s", issue, org, repo), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
var out forgeIssue
|
var out forgeIssue
|
||||||
|
|
@ -187,14 +194,15 @@ func (s *PrepSubsystem) fetchIssue(ctx context.Context, org, repo string, issue
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PrepSubsystem) lockIssue(ctx context.Context, org, repo string, issue int, assignee string) error {
|
func (s *PrepSubsystem) lockIssue(ctx context.Context, org, repo string, issue int, assignee string) error {
|
||||||
updateURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/issues/%d", s.forgeURL, org, repo, issue)
|
updateURL := core.Sprintf("%s/api/v1/repos/%s/%s/issues/%d", s.forgeURL, org, repo, issue)
|
||||||
payload, err := json.Marshal(map[string]any{
|
r := core.JSONMarshal(map[string]any{
|
||||||
"assignees": []string{assignee},
|
"assignees": []string{assignee},
|
||||||
"labels": []string{"in-progress"},
|
"labels": []string{"in-progress"},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if !r.OK {
|
||||||
return coreerr.E("lockIssue", "failed to encode issue update", err)
|
return coreerr.E("lockIssue", "failed to encode issue update", nil)
|
||||||
}
|
}
|
||||||
|
payload := r.Value.([]byte)
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, updateURL, bytes.NewReader(payload))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, updateURL, bytes.NewReader(payload))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -209,7 +217,7 @@ func (s *PrepSubsystem) lockIssue(ctx context.Context, org, repo string, issue i
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
if resp.StatusCode >= http.StatusBadRequest {
|
if resp.StatusCode >= http.StatusBadRequest {
|
||||||
return coreerr.E("lockIssue", fmt.Sprintf("issue update returned %d", resp.StatusCode), nil)
|
return coreerr.E("lockIssue", core.Sprintf("issue update returned %d", resp.StatusCode), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,11 @@ package agentic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
|
coreerr "dappco.re/go/log"
|
||||||
coremcp "dappco.re/go/mcp/pkg/mcp"
|
coremcp "dappco.re/go/mcp/pkg/mcp"
|
||||||
coreerr "forge.lthn.ai/core/go-log"
|
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -64,7 +63,7 @@ func (s *PrepSubsystem) mirror(ctx context.Context, _ *mcp.CallToolRequest, inpu
|
||||||
skipped := make([]string, 0)
|
skipped := make([]string, 0)
|
||||||
|
|
||||||
for _, repo := range repos {
|
for _, repo := range repos {
|
||||||
repoDir := filepath.Join(basePath, repo)
|
repoDir := core.Path(basePath, repo)
|
||||||
if !hasRemote(repoDir, "github") {
|
if !hasRemote(repoDir, "github") {
|
||||||
skipped = append(skipped, repo+": no github remote")
|
skipped = append(skipped, repo+": no github remote")
|
||||||
continue
|
continue
|
||||||
|
|
@ -88,7 +87,7 @@ func (s *PrepSubsystem) mirror(ctx context.Context, _ *mcp.CallToolRequest, inpu
|
||||||
}
|
}
|
||||||
|
|
||||||
if files > maxFiles {
|
if files > maxFiles {
|
||||||
sync.Skipped = fmt.Sprintf("%d files exceeds limit of %d", files, maxFiles)
|
sync.Skipped = core.Sprintf("%d files exceeds limit of %d", files, maxFiles)
|
||||||
synced = append(synced, sync)
|
synced = append(synced, sync)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,13 +7,12 @@ import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
|
coreio "dappco.re/go/io"
|
||||||
|
coreerr "dappco.re/go/log"
|
||||||
coremcp "dappco.re/go/mcp/pkg/mcp"
|
coremcp "dappco.re/go/mcp/pkg/mcp"
|
||||||
coreio "forge.lthn.ai/core/go-io"
|
|
||||||
coreerr "forge.lthn.ai/core/go-log"
|
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -349,11 +348,11 @@ func (s *PrepSubsystem) planList(_ context.Context, _ *mcp.CallToolRequest, inpu
|
||||||
|
|
||||||
var plans []Plan
|
var plans []Plan
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".json") {
|
if entry.IsDir() || !core.HasSuffix(entry.Name(), ".json") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
id := strings.TrimSuffix(entry.Name(), ".json")
|
id := core.TrimSuffix(entry.Name(), ".json")
|
||||||
plan, err := readPlan(dir, id)
|
plan, err := readPlan(dir, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
|
|
@ -422,41 +421,41 @@ func (s *PrepSubsystem) planCheckpoint(_ context.Context, _ *mcp.CallToolRequest
|
||||||
// --- Helpers ---
|
// --- Helpers ---
|
||||||
|
|
||||||
func (s *PrepSubsystem) plansDir() string {
|
func (s *PrepSubsystem) plansDir() string {
|
||||||
return filepath.Join(s.codePath, ".core", "plans")
|
return core.Path(s.codePath, ".core", "plans")
|
||||||
}
|
}
|
||||||
|
|
||||||
func planPath(dir, id string) string {
|
func planPath(dir, id string) string {
|
||||||
return filepath.Join(dir, id+".json")
|
return core.Path(dir, id+".json")
|
||||||
}
|
}
|
||||||
|
|
||||||
func generatePlanID(title string) string {
|
func generatePlanID(title string) string {
|
||||||
slug := strings.Map(func(r rune) rune {
|
b := core.NewBuilder()
|
||||||
if r >= 'a' && r <= 'z' || r >= '0' && r <= '9' || r == '-' {
|
b.Grow(len(title))
|
||||||
return r
|
for _, r := range title {
|
||||||
|
switch {
|
||||||
|
case r >= 'a' && r <= 'z', r >= '0' && r <= '9', r == '-':
|
||||||
|
b.WriteRune(r)
|
||||||
|
case r >= 'A' && r <= 'Z':
|
||||||
|
b.WriteRune(r + 32)
|
||||||
|
case r == ' ':
|
||||||
|
b.WriteByte('-')
|
||||||
}
|
}
|
||||||
if r >= 'A' && r <= 'Z' {
|
|
||||||
return r + 32
|
|
||||||
}
|
}
|
||||||
if r == ' ' {
|
slug := b.String()
|
||||||
return '-'
|
|
||||||
}
|
|
||||||
return -1
|
|
||||||
}, title)
|
|
||||||
|
|
||||||
// Trim consecutive dashes and cap length
|
// Collapse consecutive dashes and cap length
|
||||||
for strings.Contains(slug, "--") {
|
for core.Contains(slug, "--") {
|
||||||
slug = strings.ReplaceAll(slug, "--", "-")
|
slug = core.Replace(slug, "--", "-")
|
||||||
}
|
}
|
||||||
slug = strings.Trim(slug, "-")
|
slug = trimDashes(slug)
|
||||||
if len(slug) > 30 {
|
if len(slug) > 30 {
|
||||||
slug = slug[:30]
|
slug = trimDashes(slug[:30])
|
||||||
}
|
}
|
||||||
slug = strings.TrimRight(slug, "-")
|
|
||||||
|
|
||||||
// Append short random suffix for uniqueness
|
// Append short random suffix for uniqueness
|
||||||
b := make([]byte, 3)
|
rnd := make([]byte, 3)
|
||||||
rand.Read(b)
|
rand.Read(rnd)
|
||||||
return slug + "-" + hex.EncodeToString(b)
|
return slug + "-" + hex.EncodeToString(rnd)
|
||||||
}
|
}
|
||||||
|
|
||||||
func readPlan(dir, id string) (*Plan, error) {
|
func readPlan(dir, id string) (*Plan, error) {
|
||||||
|
|
@ -466,8 +465,8 @@ func readPlan(dir, id string) (*Plan, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var plan Plan
|
var plan Plan
|
||||||
if err := json.Unmarshal([]byte(data), &plan); err != nil {
|
if r := core.JSONUnmarshal([]byte(data), &plan); !r.OK {
|
||||||
return nil, coreerr.E("readPlan", "failed to parse plan "+id, err)
|
return nil, coreerr.E("readPlan", "failed to parse plan "+id, nil)
|
||||||
}
|
}
|
||||||
return &plan, nil
|
return &plan, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,15 +6,13 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
|
coreio "dappco.re/go/io"
|
||||||
|
coreerr "dappco.re/go/log"
|
||||||
coremcp "dappco.re/go/mcp/pkg/mcp"
|
coremcp "dappco.re/go/mcp/pkg/mcp"
|
||||||
coreio "forge.lthn.ai/core/go-io"
|
|
||||||
coreerr "forge.lthn.ai/core/go-log"
|
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -66,8 +64,8 @@ func (s *PrepSubsystem) createPR(ctx context.Context, _ *mcp.CallToolRequest, in
|
||||||
return nil, CreatePROutput{}, coreerr.E("createPR", "no Forge token configured", nil)
|
return nil, CreatePROutput{}, coreerr.E("createPR", "no Forge token configured", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
wsDir := filepath.Join(s.workspaceRoot(), input.Workspace)
|
wsDir := core.Path(s.workspaceRoot(), input.Workspace)
|
||||||
srcDir := filepath.Join(wsDir, "src")
|
srcDir := core.Path(wsDir, "src")
|
||||||
|
|
||||||
if _, err := coreio.Local.List(srcDir); err != nil {
|
if _, err := coreio.Local.List(srcDir); err != nil {
|
||||||
return nil, CreatePROutput{}, coreerr.E("createPR", "workspace not found: "+input.Workspace, nil)
|
return nil, CreatePROutput{}, coreerr.E("createPR", "workspace not found: "+input.Workspace, nil)
|
||||||
|
|
@ -87,7 +85,7 @@ func (s *PrepSubsystem) createPR(ctx context.Context, _ *mcp.CallToolRequest, in
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, CreatePROutput{}, coreerr.E("createPR", "failed to detect branch", err)
|
return nil, CreatePROutput{}, coreerr.E("createPR", "failed to detect branch", err)
|
||||||
}
|
}
|
||||||
st.Branch = strings.TrimSpace(string(out))
|
st.Branch = core.Trim(string(out))
|
||||||
}
|
}
|
||||||
|
|
||||||
org := st.Org
|
org := st.Org
|
||||||
|
|
@ -105,7 +103,7 @@ func (s *PrepSubsystem) createPR(ctx context.Context, _ *mcp.CallToolRequest, in
|
||||||
title = st.Task
|
title = st.Task
|
||||||
}
|
}
|
||||||
if title == "" {
|
if title == "" {
|
||||||
title = fmt.Sprintf("Agent work on %s", st.Branch)
|
title = core.Sprintf("Agent work on %s", st.Branch)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build PR body
|
// Build PR body
|
||||||
|
|
@ -143,7 +141,7 @@ func (s *PrepSubsystem) createPR(ctx context.Context, _ *mcp.CallToolRequest, in
|
||||||
|
|
||||||
// Comment on issue if tracked
|
// Comment on issue if tracked
|
||||||
if st.Issue > 0 {
|
if st.Issue > 0 {
|
||||||
comment := fmt.Sprintf("Pull request created: %s", prURL)
|
comment := core.Sprintf("Pull request created: %s", prURL)
|
||||||
s.commentOnIssue(ctx, org, st.Repo, st.Issue, comment)
|
s.commentOnIssue(ctx, org, st.Repo, st.Issue, comment)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -159,17 +157,17 @@ func (s *PrepSubsystem) createPR(ctx context.Context, _ *mcp.CallToolRequest, in
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PrepSubsystem) buildPRBody(st *WorkspaceStatus) string {
|
func (s *PrepSubsystem) buildPRBody(st *WorkspaceStatus) string {
|
||||||
var b strings.Builder
|
b := core.NewBuilder()
|
||||||
b.WriteString("## Summary\n\n")
|
b.WriteString("## Summary\n\n")
|
||||||
if st.Task != "" {
|
if st.Task != "" {
|
||||||
b.WriteString(st.Task)
|
b.WriteString(st.Task)
|
||||||
b.WriteString("\n\n")
|
b.WriteString("\n\n")
|
||||||
}
|
}
|
||||||
if st.Issue > 0 {
|
if st.Issue > 0 {
|
||||||
b.WriteString(fmt.Sprintf("Closes #%d\n\n", st.Issue))
|
b.WriteString(core.Sprintf("Closes #%d\n\n", st.Issue))
|
||||||
}
|
}
|
||||||
b.WriteString(fmt.Sprintf("**Agent:** %s\n", st.Agent))
|
b.WriteString(core.Sprintf("**Agent:** %s\n", st.Agent))
|
||||||
b.WriteString(fmt.Sprintf("**Runs:** %d\n", st.Runs))
|
b.WriteString(core.Sprintf("**Runs:** %d\n", st.Runs))
|
||||||
b.WriteString("\n---\n*Created by agentic dispatch*\n")
|
b.WriteString("\n---\n*Created by agentic dispatch*\n")
|
||||||
return b.String()
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
@ -185,7 +183,7 @@ func (s *PrepSubsystem) forgeCreatePR(ctx context.Context, org, repo, head, base
|
||||||
return "", 0, coreerr.E("forgeCreatePR", "failed to marshal PR payload", err)
|
return "", 0, coreerr.E("forgeCreatePR", "failed to marshal PR payload", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls", s.forgeURL, org, repo)
|
url := core.Sprintf("%s/api/v1/repos/%s/%s/pulls", s.forgeURL, org, repo)
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payload))
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payload))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", 0, coreerr.E("forgeCreatePR", "failed to build PR request", err)
|
return "", 0, coreerr.E("forgeCreatePR", "failed to build PR request", err)
|
||||||
|
|
@ -202,10 +200,10 @@ func (s *PrepSubsystem) forgeCreatePR(ctx context.Context, org, repo, head, base
|
||||||
if resp.StatusCode != 201 {
|
if resp.StatusCode != 201 {
|
||||||
var errBody map[string]any
|
var errBody map[string]any
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&errBody); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&errBody); err != nil {
|
||||||
return "", 0, coreerr.E("forgeCreatePR", fmt.Sprintf("HTTP %d with unreadable error body", resp.StatusCode), err)
|
return "", 0, coreerr.E("forgeCreatePR", core.Sprintf("HTTP %d with unreadable error body", resp.StatusCode), err)
|
||||||
}
|
}
|
||||||
msg, _ := errBody["message"].(string)
|
msg, _ := errBody["message"].(string)
|
||||||
return "", 0, coreerr.E("forgeCreatePR", fmt.Sprintf("HTTP %d: %s", resp.StatusCode, msg), nil)
|
return "", 0, coreerr.E("forgeCreatePR", core.Sprintf("HTTP %d: %s", resp.StatusCode, msg), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
var pr struct {
|
var pr struct {
|
||||||
|
|
@ -225,7 +223,7 @@ func (s *PrepSubsystem) commentOnIssue(ctx context.Context, org, repo string, is
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/issues/%d/comments", s.forgeURL, org, repo, issue)
|
url := core.Sprintf("%s/api/v1/repos/%s/%s/issues/%d/comments", s.forgeURL, org, repo, issue)
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payload))
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(payload))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
@ -337,7 +335,7 @@ func (s *PrepSubsystem) listPRs(ctx context.Context, _ *mcp.CallToolRequest, inp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PrepSubsystem) listRepoPRs(ctx context.Context, org, repo, state string) ([]PRInfo, error) {
|
func (s *PrepSubsystem) listRepoPRs(ctx context.Context, org, repo, state string) ([]PRInfo, error) {
|
||||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/pulls?state=%s&limit=10",
|
url := core.Sprintf("%s/api/v1/repos/%s/%s/pulls?state=%s&limit=10",
|
||||||
s.forgeURL, org, repo, state)
|
s.forgeURL, org, repo, state)
|
||||||
req, _ := http.NewRequestWithContext(ctx, "GET", url, nil)
|
req, _ := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||||
req.Header.Set("Authorization", "token "+s.forgeToken)
|
req.Header.Set("Authorization", "token "+s.forgeToken)
|
||||||
|
|
@ -348,7 +346,7 @@ func (s *PrepSubsystem) listRepoPRs(ctx context.Context, org, repo, state string
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
return nil, coreerr.E("listRepoPRs", fmt.Sprintf("HTTP %d for "+repo, resp.StatusCode), nil)
|
return nil, coreerr.E("listRepoPRs", core.Sprintf("HTTP %d for "+repo, resp.StatusCode), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
var prs []struct {
|
var prs []struct {
|
||||||
|
|
|
||||||
|
|
@ -8,18 +8,14 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
goio "io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
|
coreio "dappco.re/go/io"
|
||||||
|
coreerr "dappco.re/go/log"
|
||||||
coremcp "dappco.re/go/mcp/pkg/mcp"
|
coremcp "dappco.re/go/mcp/pkg/mcp"
|
||||||
coreio "forge.lthn.ai/core/go-io"
|
|
||||||
coreerr "forge.lthn.ai/core/go-log"
|
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
@ -46,17 +42,17 @@ var (
|
||||||
//
|
//
|
||||||
// prep := NewPrep()
|
// prep := NewPrep()
|
||||||
func NewPrep() *PrepSubsystem {
|
func NewPrep() *PrepSubsystem {
|
||||||
home, _ := os.UserHomeDir()
|
home := core.Env("HOME")
|
||||||
|
|
||||||
forgeToken := os.Getenv("FORGE_TOKEN")
|
forgeToken := core.Env("FORGE_TOKEN")
|
||||||
if forgeToken == "" {
|
if forgeToken == "" {
|
||||||
forgeToken = os.Getenv("GITEA_TOKEN")
|
forgeToken = core.Env("GITEA_TOKEN")
|
||||||
}
|
}
|
||||||
|
|
||||||
brainKey := os.Getenv("CORE_BRAIN_KEY")
|
brainKey := core.Env("CORE_BRAIN_KEY")
|
||||||
if brainKey == "" {
|
if brainKey == "" {
|
||||||
if data, err := coreio.Local.Read(filepath.Join(home, ".claude", "brain.key")); err == nil {
|
if data, err := coreio.Local.Read(core.Path(home, ".claude", "brain.key")); err == nil {
|
||||||
brainKey = strings.TrimSpace(data)
|
brainKey = core.Trim(data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -65,8 +61,8 @@ func NewPrep() *PrepSubsystem {
|
||||||
forgeToken: forgeToken,
|
forgeToken: forgeToken,
|
||||||
brainURL: envOr("CORE_BRAIN_URL", "https://api.lthn.sh"),
|
brainURL: envOr("CORE_BRAIN_URL", "https://api.lthn.sh"),
|
||||||
brainKey: brainKey,
|
brainKey: brainKey,
|
||||||
specsPath: envOr("SPECS_PATH", filepath.Join(home, "Code", "host-uk", "specs")),
|
specsPath: envOr("SPECS_PATH", core.Path(home, "Code", "host-uk", "specs")),
|
||||||
codePath: envOr("CODE_PATH", filepath.Join(home, "Code")),
|
codePath: envOr("CODE_PATH", core.Path(home, "Code")),
|
||||||
client: &http.Client{Timeout: 30 * time.Second},
|
client: &http.Client{Timeout: 30 * time.Second},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -84,24 +80,24 @@ func (s *PrepSubsystem) emitChannel(ctx context.Context, channel string, data an
|
||||||
}
|
}
|
||||||
|
|
||||||
func envOr(key, fallback string) string {
|
func envOr(key, fallback string) string {
|
||||||
if v := os.Getenv(key); v != "" {
|
if v := core.Env(key); v != "" {
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
return fallback
|
return fallback
|
||||||
}
|
}
|
||||||
|
|
||||||
func sanitizeRepoPathSegment(value, field string, allowSubdirs bool) (string, error) {
|
func sanitizeRepoPathSegment(value, field string, allowSubdirs bool) (string, error) {
|
||||||
if strings.TrimSpace(value) != value {
|
if core.Trim(value) != value {
|
||||||
return "", coreerr.E("prepWorkspace", field+" contains whitespace", nil)
|
return "", coreerr.E("prepWorkspace", field+" contains whitespace", nil)
|
||||||
}
|
}
|
||||||
if value == "" {
|
if value == "" {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
if strings.Contains(value, "\\") {
|
if core.Contains(value, "\\") {
|
||||||
return "", coreerr.E("prepWorkspace", field+" contains invalid path separator", nil)
|
return "", coreerr.E("prepWorkspace", field+" contains invalid path separator", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
parts := strings.Split(value, "/")
|
parts := core.Split(value, "/")
|
||||||
if !allowSubdirs && len(parts) != 1 {
|
if !allowSubdirs && len(parts) != 1 {
|
||||||
return "", coreerr.E("prepWorkspace", field+" may not contain subdirectories", nil)
|
return "", coreerr.E("prepWorkspace", field+" may not contain subdirectories", nil)
|
||||||
}
|
}
|
||||||
|
|
@ -161,7 +157,7 @@ func (s *PrepSubsystem) Shutdown(_ context.Context) error { return nil }
|
||||||
|
|
||||||
// workspaceRoot returns the base directory for agent workspaces.
|
// workspaceRoot returns the base directory for agent workspaces.
|
||||||
func (s *PrepSubsystem) workspaceRoot() string {
|
func (s *PrepSubsystem) workspaceRoot() string {
|
||||||
return filepath.Join(s.codePath, ".core", "workspace")
|
return core.Path(s.codePath, ".core", "workspace")
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Input/Output types ---
|
// --- Input/Output types ---
|
||||||
|
|
@ -227,8 +223,8 @@ func (s *PrepSubsystem) prepWorkspace(ctx context.Context, _ *mcp.CallToolReques
|
||||||
// Workspace root: .core/workspace/{repo}-{timestamp}/
|
// Workspace root: .core/workspace/{repo}-{timestamp}/
|
||||||
wsRoot := s.workspaceRoot()
|
wsRoot := s.workspaceRoot()
|
||||||
coreio.Local.EnsureDir(wsRoot)
|
coreio.Local.EnsureDir(wsRoot)
|
||||||
wsName := fmt.Sprintf("%s-%d", input.Repo, time.Now().Unix())
|
wsName := core.Sprintf("%s-%d", input.Repo, time.Now().Unix())
|
||||||
wsDir := filepath.Join(wsRoot, wsName)
|
wsDir := core.Path(wsRoot, wsName)
|
||||||
|
|
||||||
// Create workspace structure
|
// Create workspace structure
|
||||||
// kb/ and specs/ will be created inside src/ after clone
|
// kb/ and specs/ will be created inside src/ after clone
|
||||||
|
|
@ -236,10 +232,10 @@ func (s *PrepSubsystem) prepWorkspace(ctx context.Context, _ *mcp.CallToolReques
|
||||||
out := PrepOutput{WorkspaceDir: wsDir}
|
out := PrepOutput{WorkspaceDir: wsDir}
|
||||||
|
|
||||||
// Source repo path
|
// Source repo path
|
||||||
repoPath := filepath.Join(s.codePath, "core", input.Repo)
|
repoPath := core.Path(s.codePath, "core", input.Repo)
|
||||||
|
|
||||||
// 1. Clone repo into src/ and create feature branch
|
// 1. Clone repo into src/ and create feature branch
|
||||||
srcDir := filepath.Join(wsDir, "src")
|
srcDir := core.Path(wsDir, "src")
|
||||||
cloneCmd := exec.CommandContext(ctx, "git", "clone", repoPath, srcDir)
|
cloneCmd := exec.CommandContext(ctx, "git", "clone", repoPath, srcDir)
|
||||||
if err := cloneCmd.Run(); err != nil {
|
if err := cloneCmd.Run(); err != nil {
|
||||||
return nil, PrepOutput{}, coreerr.E("prepWorkspace", "failed to clone repository", err)
|
return nil, PrepOutput{}, coreerr.E("prepWorkspace", "failed to clone repository", err)
|
||||||
|
|
@ -251,12 +247,12 @@ func (s *PrepSubsystem) prepWorkspace(ctx context.Context, _ *mcp.CallToolReques
|
||||||
taskSlug := branchSlug(input.Task)
|
taskSlug := branchSlug(input.Task)
|
||||||
if input.Issue > 0 {
|
if input.Issue > 0 {
|
||||||
issueSlug := branchSlug(input.Task)
|
issueSlug := branchSlug(input.Task)
|
||||||
branchName = fmt.Sprintf("agent/issue-%d", input.Issue)
|
branchName = core.Sprintf("agent/issue-%d", input.Issue)
|
||||||
if issueSlug != "" {
|
if issueSlug != "" {
|
||||||
branchName += "-" + issueSlug
|
branchName += "-" + issueSlug
|
||||||
}
|
}
|
||||||
} else if taskSlug != "" {
|
} else if taskSlug != "" {
|
||||||
branchName = fmt.Sprintf("agent/%s", taskSlug)
|
branchName = core.Sprintf("agent/%s", taskSlug)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if branchName != "" {
|
if branchName != "" {
|
||||||
|
|
@ -269,29 +265,29 @@ func (s *PrepSubsystem) prepWorkspace(ctx context.Context, _ *mcp.CallToolReques
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create context dirs inside src/
|
// Create context dirs inside src/
|
||||||
coreio.Local.EnsureDir(filepath.Join(srcDir, "kb"))
|
coreio.Local.EnsureDir(core.Path(srcDir, "kb"))
|
||||||
coreio.Local.EnsureDir(filepath.Join(srcDir, "specs"))
|
coreio.Local.EnsureDir(core.Path(srcDir, "specs"))
|
||||||
|
|
||||||
// Remote stays as local clone origin — agent cannot push to forge.
|
// Remote stays as local clone origin — agent cannot push to forge.
|
||||||
// Reviewer pulls changes from workspace and pushes after verification.
|
// Reviewer pulls changes from workspace and pushes after verification.
|
||||||
|
|
||||||
// 2. Copy CLAUDE.md and GEMINI.md to workspace
|
// 2. Copy CLAUDE.md and GEMINI.md to workspace
|
||||||
claudeMdPath := filepath.Join(repoPath, "CLAUDE.md")
|
claudeMdPath := core.Path(repoPath, "CLAUDE.md")
|
||||||
if data, err := coreio.Local.Read(claudeMdPath); err == nil {
|
if data, err := coreio.Local.Read(claudeMdPath); err == nil {
|
||||||
_ = writeAtomic(filepath.Join(wsDir, "src", "CLAUDE.md"), data)
|
_ = writeAtomic(core.Path(wsDir, "src", "CLAUDE.md"), data)
|
||||||
out.ClaudeMd = true
|
out.ClaudeMd = true
|
||||||
}
|
}
|
||||||
// Copy GEMINI.md from core/agent (ethics framework for all agents)
|
// Copy GEMINI.md from core/agent (ethics framework for all agents)
|
||||||
agentGeminiMd := filepath.Join(s.codePath, "core", "agent", "GEMINI.md")
|
agentGeminiMd := core.Path(s.codePath, "core", "agent", "GEMINI.md")
|
||||||
if data, err := coreio.Local.Read(agentGeminiMd); err == nil {
|
if data, err := coreio.Local.Read(agentGeminiMd); err == nil {
|
||||||
_ = writeAtomic(filepath.Join(wsDir, "src", "GEMINI.md"), data)
|
_ = writeAtomic(core.Path(wsDir, "src", "GEMINI.md"), data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy persona if specified
|
// Copy persona if specified
|
||||||
if persona != "" {
|
if persona != "" {
|
||||||
personaPath := filepath.Join(s.codePath, "core", "agent", "prompts", "personas", persona+".md")
|
personaPath := core.Path(s.codePath, "core", "agent", "prompts", "personas", persona+".md")
|
||||||
if data, err := coreio.Local.Read(personaPath); err == nil {
|
if data, err := coreio.Local.Read(personaPath); err == nil {
|
||||||
_ = writeAtomic(filepath.Join(wsDir, "src", "PERSONA.md"), data)
|
_ = writeAtomic(core.Path(wsDir, "src", "PERSONA.md"), data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -299,9 +295,9 @@ func (s *PrepSubsystem) prepWorkspace(ctx context.Context, _ *mcp.CallToolReques
|
||||||
if input.Issue > 0 {
|
if input.Issue > 0 {
|
||||||
s.generateTodo(ctx, input.Org, input.Repo, input.Issue, wsDir)
|
s.generateTodo(ctx, input.Org, input.Repo, input.Issue, wsDir)
|
||||||
} else if input.Task != "" {
|
} else if input.Task != "" {
|
||||||
todo := fmt.Sprintf("# TASK: %s\n\n**Repo:** %s/%s\n**Status:** ready\n\n## Objective\n\n%s\n",
|
todo := core.Sprintf("# TASK: %s\n\n**Repo:** %s/%s\n**Status:** ready\n\n## Objective\n\n%s\n",
|
||||||
input.Task, input.Org, input.Repo, input.Task)
|
input.Task, input.Org, input.Repo, input.Task)
|
||||||
_ = writeAtomic(filepath.Join(wsDir, "src", "TODO.md"), todo)
|
_ = writeAtomic(core.Path(wsDir, "src", "TODO.md"), todo)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Generate CONTEXT.md from OpenBrain
|
// 4. Generate CONTEXT.md from OpenBrain
|
||||||
|
|
@ -333,12 +329,12 @@ func (s *PrepSubsystem) prepWorkspace(ctx context.Context, _ *mcp.CallToolReques
|
||||||
|
|
||||||
// branchSlug converts a free-form string into a git-friendly branch suffix.
|
// branchSlug converts a free-form string into a git-friendly branch suffix.
|
||||||
func branchSlug(value string) string {
|
func branchSlug(value string) string {
|
||||||
value = strings.ToLower(strings.TrimSpace(value))
|
value = core.Lower(core.Trim(value))
|
||||||
if value == "" {
|
if value == "" {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
var b strings.Builder
|
b := core.NewBuilder()
|
||||||
b.Grow(len(value))
|
b.Grow(len(value))
|
||||||
lastDash := false
|
lastDash := false
|
||||||
for _, r := range value {
|
for _, r := range value {
|
||||||
|
|
@ -359,14 +355,42 @@ func branchSlug(value string) string {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slug := strings.Trim(b.String(), "-")
|
slug := trimDashes(b.String())
|
||||||
if len(slug) > 40 {
|
if len(slug) > 40 {
|
||||||
slug = slug[:40]
|
slug = trimDashes(slug[:40])
|
||||||
slug = strings.Trim(slug, "-")
|
|
||||||
}
|
}
|
||||||
return slug
|
return slug
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sanitizeFilename replaces non-alphanumeric characters (except - _ .) with dashes.
|
||||||
|
func sanitizeFilename(title string) string {
|
||||||
|
b := core.NewBuilder()
|
||||||
|
b.Grow(len(title))
|
||||||
|
for _, r := range title {
|
||||||
|
switch {
|
||||||
|
case r >= 'a' && r <= 'z', r >= 'A' && r <= 'Z', r >= '0' && r <= '9',
|
||||||
|
r == '-', r == '_', r == '.':
|
||||||
|
b.WriteRune(r)
|
||||||
|
default:
|
||||||
|
b.WriteByte('-')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// trimDashes strips leading and trailing dash characters from a string.
|
||||||
|
func trimDashes(s string) string {
|
||||||
|
start := 0
|
||||||
|
for start < len(s) && s[start] == '-' {
|
||||||
|
start++
|
||||||
|
}
|
||||||
|
end := len(s)
|
||||||
|
for end > start && s[end-1] == '-' {
|
||||||
|
end--
|
||||||
|
}
|
||||||
|
return s[start:end]
|
||||||
|
}
|
||||||
|
|
||||||
// --- Prompt templates ---
|
// --- Prompt templates ---
|
||||||
|
|
||||||
func (s *PrepSubsystem) writePromptTemplate(template, wsDir string) {
|
func (s *PrepSubsystem) writePromptTemplate(template, wsDir string) {
|
||||||
|
|
@ -434,7 +458,7 @@ Do NOT push. Commit only — a reviewer will verify and push.
|
||||||
prompt = "Read TODO.md and complete the task. Work in src/.\n"
|
prompt = "Read TODO.md and complete the task. Work in src/.\n"
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = writeAtomic(filepath.Join(wsDir, "src", "PROMPT.md"), prompt)
|
_ = writeAtomic(core.Path(wsDir, "src", "PROMPT.md"), prompt)
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Plan template rendering ---
|
// --- Plan template rendering ---
|
||||||
|
|
@ -443,11 +467,11 @@ Do NOT push. Commit only — a reviewer will verify and push.
|
||||||
// and writes PLAN.md into the workspace src/ directory.
|
// and writes PLAN.md into the workspace src/ directory.
|
||||||
func (s *PrepSubsystem) writePlanFromTemplate(templateSlug string, variables map[string]string, task string, wsDir string) {
|
func (s *PrepSubsystem) writePlanFromTemplate(templateSlug string, variables map[string]string, task string, wsDir string) {
|
||||||
// Look for template in core/agent/prompts/templates/
|
// Look for template in core/agent/prompts/templates/
|
||||||
templatePath := filepath.Join(s.codePath, "core", "agent", "prompts", "templates", templateSlug+".yaml")
|
templatePath := core.Path(s.codePath, "core", "agent", "prompts", "templates", templateSlug+".yaml")
|
||||||
content, err := coreio.Local.Read(templatePath)
|
content, err := coreio.Local.Read(templatePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Try .yml extension
|
// Try .yml extension
|
||||||
templatePath = filepath.Join(s.codePath, "core", "agent", "prompts", "templates", templateSlug+".yml")
|
templatePath = core.Path(s.codePath, "core", "agent", "prompts", "templates", templateSlug+".yml")
|
||||||
content, err = coreio.Local.Read(templatePath)
|
content, err = coreio.Local.Read(templatePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return // Template not found, skip silently
|
return // Template not found, skip silently
|
||||||
|
|
@ -456,8 +480,8 @@ func (s *PrepSubsystem) writePlanFromTemplate(templateSlug string, variables map
|
||||||
|
|
||||||
// Substitute variables ({{variable_name}} → value)
|
// Substitute variables ({{variable_name}} → value)
|
||||||
for key, value := range variables {
|
for key, value := range variables {
|
||||||
content = strings.ReplaceAll(content, "{{"+key+"}}", value)
|
content = core.Replace(content, "{{"+key+"}}", value)
|
||||||
content = strings.ReplaceAll(content, "{{ "+key+" }}", value)
|
content = core.Replace(content, "{{ "+key+" }}", value)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the YAML to render as markdown
|
// Parse the YAML to render as markdown
|
||||||
|
|
@ -477,7 +501,7 @@ func (s *PrepSubsystem) writePlanFromTemplate(templateSlug string, variables map
|
||||||
}
|
}
|
||||||
|
|
||||||
// Render as PLAN.md
|
// Render as PLAN.md
|
||||||
var plan strings.Builder
|
plan := core.NewBuilder()
|
||||||
plan.WriteString("# Plan: " + tmpl.Name + "\n\n")
|
plan.WriteString("# Plan: " + tmpl.Name + "\n\n")
|
||||||
if task != "" {
|
if task != "" {
|
||||||
plan.WriteString("**Task:** " + task + "\n\n")
|
plan.WriteString("**Task:** " + task + "\n\n")
|
||||||
|
|
@ -495,7 +519,7 @@ func (s *PrepSubsystem) writePlanFromTemplate(templateSlug string, variables map
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, phase := range tmpl.Phases {
|
for i, phase := range tmpl.Phases {
|
||||||
plan.WriteString(fmt.Sprintf("## Phase %d: %s\n\n", i+1, phase.Name))
|
plan.WriteString(core.Sprintf("## Phase %d: %s\n\n", i+1, phase.Name))
|
||||||
if phase.Description != "" {
|
if phase.Description != "" {
|
||||||
plan.WriteString(phase.Description + "\n\n")
|
plan.WriteString(phase.Description + "\n\n")
|
||||||
}
|
}
|
||||||
|
|
@ -512,7 +536,7 @@ func (s *PrepSubsystem) writePlanFromTemplate(templateSlug string, variables map
|
||||||
plan.WriteString("\n**Commit after completing this phase.**\n\n---\n\n")
|
plan.WriteString("\n**Commit after completing this phase.**\n\n---\n\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = writeAtomic(filepath.Join(wsDir, "src", "PLAN.md"), plan.String())
|
_ = writeAtomic(core.Path(wsDir, "src", "PLAN.md"), plan.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Helpers (unchanged) ---
|
// --- Helpers (unchanged) ---
|
||||||
|
|
@ -522,7 +546,7 @@ func (s *PrepSubsystem) pullWiki(ctx context.Context, org, repo, wsDir string) i
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/wiki/pages", s.forgeURL, org, repo)
|
url := core.Sprintf("%s/api/v1/repos/%s/%s/wiki/pages", s.forgeURL, org, repo)
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0
|
return 0
|
||||||
|
|
@ -553,7 +577,7 @@ func (s *PrepSubsystem) pullWiki(ctx context.Context, org, repo, wsDir string) i
|
||||||
subURL = page.Title
|
subURL = page.Title
|
||||||
}
|
}
|
||||||
|
|
||||||
pageURL := fmt.Sprintf("%s/api/v1/repos/%s/%s/wiki/page/%s", s.forgeURL, org, repo, subURL)
|
pageURL := core.Sprintf("%s/api/v1/repos/%s/%s/wiki/page/%s", s.forgeURL, org, repo, subURL)
|
||||||
pageReq, err := http.NewRequestWithContext(ctx, "GET", pageURL, nil)
|
pageReq, err := http.NewRequestWithContext(ctx, "GET", pageURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
|
|
@ -585,14 +609,9 @@ func (s *PrepSubsystem) pullWiki(ctx context.Context, org, repo, wsDir string) i
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
filename := strings.Map(func(r rune) rune {
|
filename := sanitizeFilename(page.Title) + ".md"
|
||||||
if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '-' || r == '_' || r == '.' {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
return '-'
|
|
||||||
}, page.Title) + ".md"
|
|
||||||
|
|
||||||
_ = writeAtomic(filepath.Join(wsDir, "src", "kb", filename), string(content))
|
_ = writeAtomic(core.Path(wsDir, "src", "kb", filename), string(content))
|
||||||
count++
|
count++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -604,9 +623,9 @@ func (s *PrepSubsystem) copySpecs(wsDir string) int {
|
||||||
count := 0
|
count := 0
|
||||||
|
|
||||||
for _, file := range specFiles {
|
for _, file := range specFiles {
|
||||||
src := filepath.Join(s.specsPath, file)
|
src := core.Path(s.specsPath, file)
|
||||||
if data, err := coreio.Local.Read(src); err == nil {
|
if data, err := coreio.Local.Read(src); err == nil {
|
||||||
_ = writeAtomic(filepath.Join(wsDir, "src", "specs", file), data)
|
_ = writeAtomic(core.Path(wsDir, "src", "specs", file), data)
|
||||||
count++
|
count++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -629,7 +648,7 @@ func (s *PrepSubsystem) generateContext(ctx context.Context, repo, wsDir string)
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", s.brainURL+"/v1/brain/recall", strings.NewReader(string(body)))
|
req, err := http.NewRequestWithContext(ctx, "POST", s.brainURL+"/v1/brain/recall", core.NewReader(string(body)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
@ -646,18 +665,18 @@ func (s *PrepSubsystem) generateContext(ctx context.Context, repo, wsDir string)
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
respData, err := goio.ReadAll(resp.Body)
|
readResult := core.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if !readResult.OK {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
var result struct {
|
var result struct {
|
||||||
Memories []map[string]any `json:"memories"`
|
Memories []map[string]any `json:"memories"`
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(respData, &result); err != nil {
|
if ur := core.JSONUnmarshal([]byte(readResult.Value.(string)), &result); !ur.OK {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
var content strings.Builder
|
content := core.NewBuilder()
|
||||||
content.WriteString("# Context — " + repo + "\n\n")
|
content.WriteString("# Context — " + repo + "\n\n")
|
||||||
content.WriteString("> Relevant knowledge from OpenBrain.\n\n")
|
content.WriteString("> Relevant knowledge from OpenBrain.\n\n")
|
||||||
|
|
||||||
|
|
@ -666,15 +685,15 @@ func (s *PrepSubsystem) generateContext(ctx context.Context, repo, wsDir string)
|
||||||
memContent, _ := mem["content"].(string)
|
memContent, _ := mem["content"].(string)
|
||||||
memProject, _ := mem["project"].(string)
|
memProject, _ := mem["project"].(string)
|
||||||
score, _ := mem["score"].(float64)
|
score, _ := mem["score"].(float64)
|
||||||
content.WriteString(fmt.Sprintf("### %d. %s [%s] (score: %.3f)\n\n%s\n\n", i+1, memProject, memType, score, memContent))
|
content.WriteString(core.Sprintf("### %d. %s [%s] (score: %.3f)\n\n%s\n\n", i+1, memProject, memType, score, memContent))
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = writeAtomic(filepath.Join(wsDir, "src", "CONTEXT.md"), content.String())
|
_ = writeAtomic(core.Path(wsDir, "src", "CONTEXT.md"), content.String())
|
||||||
return len(result.Memories)
|
return len(result.Memories)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PrepSubsystem) findConsumers(repo, wsDir string) int {
|
func (s *PrepSubsystem) findConsumers(repo, wsDir string) int {
|
||||||
goWorkPath := filepath.Join(s.codePath, "go.work")
|
goWorkPath := core.Path(s.codePath, "go.work")
|
||||||
modulePath := "forge.lthn.ai/core/" + repo
|
modulePath := "forge.lthn.ai/core/" + repo
|
||||||
|
|
||||||
workData, err := coreio.Local.Read(goWorkPath)
|
workData, err := coreio.Local.Read(goWorkPath)
|
||||||
|
|
@ -683,19 +702,19 @@ func (s *PrepSubsystem) findConsumers(repo, wsDir string) int {
|
||||||
}
|
}
|
||||||
|
|
||||||
var consumers []string
|
var consumers []string
|
||||||
for _, line := range strings.Split(workData, "\n") {
|
for _, line := range core.Split(workData, "\n") {
|
||||||
line = strings.TrimSpace(line)
|
line = core.Trim(line)
|
||||||
if !strings.HasPrefix(line, "./") {
|
if !core.HasPrefix(line, "./") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
dir := filepath.Join(s.codePath, strings.TrimPrefix(line, "./"))
|
dir := core.Path(s.codePath, core.TrimPrefix(line, "./"))
|
||||||
goMod := filepath.Join(dir, "go.mod")
|
goMod := core.Path(dir, "go.mod")
|
||||||
modData, err := coreio.Local.Read(goMod)
|
modData, err := coreio.Local.Read(goMod)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if strings.Contains(modData, modulePath) && !strings.HasPrefix(modData, "module "+modulePath) {
|
if core.Contains(modData, modulePath) && !core.HasPrefix(modData, "module "+modulePath) {
|
||||||
consumers = append(consumers, filepath.Base(dir))
|
consumers = append(consumers, core.PathBase(dir))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -705,8 +724,8 @@ func (s *PrepSubsystem) findConsumers(repo, wsDir string) int {
|
||||||
for _, c := range consumers {
|
for _, c := range consumers {
|
||||||
content += "- " + c + "\n"
|
content += "- " + c + "\n"
|
||||||
}
|
}
|
||||||
content += fmt.Sprintf("\n**Breaking change risk: %d consumers.**\n", len(consumers))
|
content += core.Sprintf("\n**Breaking change risk: %d consumers.**\n", len(consumers))
|
||||||
_ = writeAtomic(filepath.Join(wsDir, "src", "CONSUMERS.md"), content)
|
_ = writeAtomic(core.Path(wsDir, "src", "CONSUMERS.md"), content)
|
||||||
}
|
}
|
||||||
|
|
||||||
return len(consumers)
|
return len(consumers)
|
||||||
|
|
@ -720,10 +739,10 @@ func (s *PrepSubsystem) gitLog(repoPath, wsDir string) int {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
lines := strings.Split(strings.TrimSpace(string(output)), "\n")
|
lines := core.Split(core.Trim(string(output)), "\n")
|
||||||
if len(lines) > 0 && lines[0] != "" {
|
if len(lines) > 0 && lines[0] != "" {
|
||||||
content := "# Recent Changes\n\n```\n" + string(output) + "```\n"
|
content := "# Recent Changes\n\n```\n" + string(output) + "```\n"
|
||||||
_ = writeAtomic(filepath.Join(wsDir, "src", "RECENT.md"), content)
|
_ = writeAtomic(core.Path(wsDir, "src", "RECENT.md"), content)
|
||||||
}
|
}
|
||||||
|
|
||||||
return len(lines)
|
return len(lines)
|
||||||
|
|
@ -734,7 +753,7 @@ func (s *PrepSubsystem) generateTodo(ctx context.Context, org, repo string, issu
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/issues/%d", s.forgeURL, org, repo, issue)
|
url := core.Sprintf("%s/api/v1/repos/%s/%s/issues/%d", s.forgeURL, org, repo, issue)
|
||||||
req, _ := http.NewRequestWithContext(ctx, "GET", url, nil)
|
req, _ := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||||
req.Header.Set("Authorization", "token "+s.forgeToken)
|
req.Header.Set("Authorization", "token "+s.forgeToken)
|
||||||
|
|
||||||
|
|
@ -753,11 +772,11 @@ func (s *PrepSubsystem) generateTodo(ctx context.Context, org, repo string, issu
|
||||||
}
|
}
|
||||||
json.NewDecoder(resp.Body).Decode(&issueData)
|
json.NewDecoder(resp.Body).Decode(&issueData)
|
||||||
|
|
||||||
content := fmt.Sprintf("# TASK: %s\n\n", issueData.Title)
|
content := core.Sprintf("# TASK: %s\n\n", issueData.Title)
|
||||||
content += fmt.Sprintf("**Status:** ready\n")
|
content += core.Sprintf("**Status:** ready\n")
|
||||||
content += fmt.Sprintf("**Source:** %s/%s/%s/issues/%d\n", s.forgeURL, org, repo, issue)
|
content += core.Sprintf("**Source:** %s/%s/%s/issues/%d\n", s.forgeURL, org, repo, issue)
|
||||||
content += fmt.Sprintf("**Repo:** %s/%s\n\n---\n\n", org, repo)
|
content += core.Sprintf("**Repo:** %s/%s\n\n---\n\n", org, repo)
|
||||||
content += "## Objective\n\n" + issueData.Body + "\n"
|
content += "## Objective\n\n" + issueData.Body + "\n"
|
||||||
|
|
||||||
_ = writeAtomic(filepath.Join(wsDir, "src", "TODO.md"), content)
|
_ = writeAtomic(core.Path(wsDir, "src", "TODO.md"), content)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,18 +3,19 @@
|
||||||
package agentic
|
package agentic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
coreio "forge.lthn.ai/core/go-io"
|
core "dappco.re/go/core"
|
||||||
|
coreio "dappco.re/go/io"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// os.Create, os.Open, os.DevNull, os.Environ, os.FindProcess are used for
|
||||||
|
// process spawning and management — no core equivalents for these OS primitives.
|
||||||
|
|
||||||
// DispatchConfig controls agent dispatch behaviour.
|
// DispatchConfig controls agent dispatch behaviour.
|
||||||
type DispatchConfig struct {
|
type DispatchConfig struct {
|
||||||
DefaultAgent string `yaml:"default_agent"`
|
DefaultAgent string `yaml:"default_agent"`
|
||||||
|
|
@ -43,7 +44,7 @@ type AgentsConfig struct {
|
||||||
// loadAgentsConfig reads config/agents.yaml from the code path.
|
// loadAgentsConfig reads config/agents.yaml from the code path.
|
||||||
func (s *PrepSubsystem) loadAgentsConfig() *AgentsConfig {
|
func (s *PrepSubsystem) loadAgentsConfig() *AgentsConfig {
|
||||||
paths := []string{
|
paths := []string{
|
||||||
filepath.Join(s.codePath, ".core", "agents.yaml"),
|
core.Path(s.codePath, ".core", "agents.yaml"),
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, path := range paths {
|
for _, path := range paths {
|
||||||
|
|
@ -79,9 +80,16 @@ func (s *PrepSubsystem) delayForAgent(agent string) time.Duration {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse reset time
|
// Parse reset time (format: "HH:MM")
|
||||||
resetHour, resetMin := 6, 0
|
resetHour, resetMin := 6, 0
|
||||||
fmt.Sscanf(rate.ResetUTC, "%d:%d", &resetHour, &resetMin)
|
if parts := core.Split(rate.ResetUTC, ":"); len(parts) == 2 {
|
||||||
|
if h, ok := parseSimpleInt(parts[0]); ok {
|
||||||
|
resetHour = h
|
||||||
|
}
|
||||||
|
if m, ok := parseSimpleInt(parts[1]); ok {
|
||||||
|
resetMin = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
resetToday := time.Date(now.Year(), now.Month(), now.Day(), resetHour, resetMin, 0, 0, time.UTC)
|
resetToday := time.Date(now.Year(), now.Month(), now.Day(), resetHour, resetMin, 0, 0, time.UTC)
|
||||||
|
|
@ -115,9 +123,9 @@ func (s *PrepSubsystem) listWorkspaceDirs() []string {
|
||||||
if !entry.IsDir() {
|
if !entry.IsDir() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
path := filepath.Join(wsRoot, entry.Name())
|
path := core.Path(wsRoot, entry.Name())
|
||||||
// Check if this dir has a status.json (it's a workspace)
|
// Check if this dir has a status.json (it's a workspace)
|
||||||
if coreio.Local.IsFile(filepath.Join(path, "status.json")) {
|
if coreio.Local.IsFile(core.Path(path, "status.json")) {
|
||||||
dirs = append(dirs, path)
|
dirs = append(dirs, path)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
@ -128,8 +136,8 @@ func (s *PrepSubsystem) listWorkspaceDirs() []string {
|
||||||
}
|
}
|
||||||
for _, sub := range subEntries {
|
for _, sub := range subEntries {
|
||||||
if sub.IsDir() {
|
if sub.IsDir() {
|
||||||
subPath := filepath.Join(path, sub.Name())
|
subPath := core.Path(path, sub.Name())
|
||||||
if coreio.Local.IsFile(filepath.Join(subPath, "status.json")) {
|
if coreio.Local.IsFile(core.Path(subPath, "status.json")) {
|
||||||
dirs = append(dirs, subPath)
|
dirs = append(dirs, subPath)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -146,7 +154,7 @@ func (s *PrepSubsystem) countRunningByAgent(agent string) int {
|
||||||
if err != nil || st.Status != "running" {
|
if err != nil || st.Status != "running" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
stBase := strings.SplitN(st.Agent, ":", 2)[0]
|
stBase := core.SplitN(st.Agent, ":", 2)[0]
|
||||||
if stBase != agent {
|
if stBase != agent {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
@ -162,7 +170,7 @@ func (s *PrepSubsystem) countRunningByAgent(agent string) int {
|
||||||
|
|
||||||
// baseAgent strips the model variant (gemini:flash → gemini).
|
// baseAgent strips the model variant (gemini:flash → gemini).
|
||||||
func baseAgent(agent string) string {
|
func baseAgent(agent string) string {
|
||||||
return strings.SplitN(agent, ":", 2)[0]
|
return core.SplitN(agent, ":", 2)[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
// canDispatchAgent checks if we're under the concurrency limit for a specific agent type.
|
// canDispatchAgent checks if we're under the concurrency limit for a specific agent type.
|
||||||
|
|
@ -176,6 +184,23 @@ func (s *PrepSubsystem) canDispatchAgent(agent string) bool {
|
||||||
return s.countRunningByAgent(base) < limit
|
return s.countRunningByAgent(base) < limit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseSimpleInt parses a small non-negative integer from a string.
|
||||||
|
// Returns (value, true) on success, (0, false) on failure.
|
||||||
|
func parseSimpleInt(s string) (int, bool) {
|
||||||
|
s = core.Trim(s)
|
||||||
|
if s == "" {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
n := 0
|
||||||
|
for _, r := range s {
|
||||||
|
if r < '0' || r > '9' {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
n = n*10 + int(r-'0')
|
||||||
|
}
|
||||||
|
return n, true
|
||||||
|
}
|
||||||
|
|
||||||
// canDispatch is kept for backwards compat.
|
// canDispatch is kept for backwards compat.
|
||||||
func (s *PrepSubsystem) canDispatch() bool {
|
func (s *PrepSubsystem) canDispatch() bool {
|
||||||
return true
|
return true
|
||||||
|
|
@ -205,7 +230,7 @@ func (s *PrepSubsystem) drainQueue() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
srcDir := filepath.Join(wsDir, "src")
|
srcDir := core.Path(wsDir, "src")
|
||||||
prompt := "Read PROMPT.md for instructions. All context files (CLAUDE.md, TODO.md, CONTEXT.md, CONSUMERS.md, RECENT.md) are in the parent directory. Work in this directory."
|
prompt := "Read PROMPT.md for instructions. All context files (CLAUDE.md, TODO.md, CONTEXT.md, CONSUMERS.md, RECENT.md) are in the parent directory. Work in this directory."
|
||||||
|
|
||||||
command, args, err := agentCommand(st.Agent, prompt)
|
command, args, err := agentCommand(st.Agent, prompt)
|
||||||
|
|
@ -213,7 +238,7 @@ func (s *PrepSubsystem) drainQueue() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
outputFile := filepath.Join(wsDir, fmt.Sprintf("agent-%s.log", st.Agent))
|
outputFile := core.Path(wsDir, core.Sprintf("agent-%s.log", st.Agent))
|
||||||
outFile, err := os.Create(outputFile)
|
outFile, err := os.Create(outputFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
|
|
|
||||||
|
|
@ -5,19 +5,18 @@ package agentic
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"os"
|
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
coreerr "forge.lthn.ai/core/go-log"
|
core "dappco.re/go/core"
|
||||||
|
coreio "dappco.re/go/io"
|
||||||
|
coreerr "dappco.re/go/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func listLocalRepos(basePath string) []string {
|
func listLocalRepos(basePath string) []string {
|
||||||
entries, err := os.ReadDir(basePath)
|
entries, err := coreio.Local.List(basePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -35,7 +34,7 @@ func hasRemote(repoDir, remote string) bool {
|
||||||
cmd := exec.Command("git", "remote", "get-url", remote)
|
cmd := exec.Command("git", "remote", "get-url", remote)
|
||||||
cmd.Dir = repoDir
|
cmd.Dir = repoDir
|
||||||
if out, err := cmd.Output(); err == nil {
|
if out, err := cmd.Output(); err == nil {
|
||||||
return strings.TrimSpace(string(out)) != ""
|
return core.Trim(string(out)) != ""
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
@ -48,7 +47,7 @@ func commitsAhead(repoDir, baseRef, headRef string) int {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
count, err := parsePositiveInt(strings.TrimSpace(string(out)))
|
count, err := parsePositiveInt(core.Trim(string(out)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
@ -64,8 +63,8 @@ func filesChanged(repoDir, baseRef, headRef string) int {
|
||||||
}
|
}
|
||||||
|
|
||||||
count := 0
|
count := 0
|
||||||
for _, line := range strings.Split(strings.TrimSpace(string(out)), "\n") {
|
for _, line := range core.Split(core.Trim(string(out)), "\n") {
|
||||||
if strings.TrimSpace(line) != "" {
|
if core.Trim(line) != "" {
|
||||||
count++
|
count++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -79,11 +78,11 @@ func gitOutput(repoDir string, args ...string) (string, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", coreerr.E("gitOutput", string(out), err)
|
return "", coreerr.E("gitOutput", string(out), err)
|
||||||
}
|
}
|
||||||
return strings.TrimSpace(string(out)), nil
|
return core.Trim(string(out)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parsePositiveInt(value string) (int, error) {
|
func parsePositiveInt(value string) (int, error) {
|
||||||
value = strings.TrimSpace(value)
|
value = core.Trim(value)
|
||||||
if value == "" {
|
if value == "" {
|
||||||
return 0, coreerr.E("parsePositiveInt", "empty value", nil)
|
return 0, coreerr.E("parsePositiveInt", "empty value", nil)
|
||||||
}
|
}
|
||||||
|
|
@ -148,11 +147,11 @@ func createGitHubPR(ctx context.Context, repoDir, repo string, commits, files in
|
||||||
return "", coreerr.E("createGitHubPR", string(out), err)
|
return "", coreerr.E("createGitHubPR", string(out), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
lines := strings.Split(strings.TrimSpace(string(out)), "\n")
|
lines := core.Split(core.Trim(string(out)), "\n")
|
||||||
if len(lines) == 0 {
|
if len(lines) == 0 {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
return strings.TrimSpace(lines[len(lines)-1]), nil
|
return core.Trim(lines[len(lines)-1]), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ensureDevBranch(repoDir string) error {
|
func ensureDevBranch(repoDir string) error {
|
||||||
|
|
@ -194,7 +193,7 @@ func parseRetryAfter(detail string) time.Duration {
|
||||||
return 5 * time.Minute
|
return 5 * time.Minute
|
||||||
}
|
}
|
||||||
|
|
||||||
switch strings.ToLower(match[2]) {
|
switch core.Lower(match[2]) {
|
||||||
case "hour", "hours":
|
case "hour", "hours":
|
||||||
return time.Duration(n) * time.Hour
|
return time.Duration(n) * time.Hour
|
||||||
case "second", "seconds":
|
case "second", "seconds":
|
||||||
|
|
@ -205,5 +204,5 @@ func parseRetryAfter(detail string) time.Duration {
|
||||||
}
|
}
|
||||||
|
|
||||||
func repoRootFromCodePath(codePath string) string {
|
func repoRootFromCodePath(codePath string) string {
|
||||||
return filepath.Join(codePath, "core")
|
return core.Path(codePath, "core")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,16 +4,14 @@ package agentic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
|
coreio "dappco.re/go/io"
|
||||||
|
coreerr "dappco.re/go/log"
|
||||||
coremcp "dappco.re/go/mcp/pkg/mcp"
|
coremcp "dappco.re/go/mcp/pkg/mcp"
|
||||||
coreio "forge.lthn.ai/core/go-io"
|
|
||||||
coreerr "forge.lthn.ai/core/go-log"
|
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -52,8 +50,8 @@ func (s *PrepSubsystem) resume(ctx context.Context, _ *mcp.CallToolRequest, inpu
|
||||||
return nil, ResumeOutput{}, coreerr.E("resume", "workspace is required", nil)
|
return nil, ResumeOutput{}, coreerr.E("resume", "workspace is required", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
wsDir := filepath.Join(s.workspaceRoot(), input.Workspace)
|
wsDir := core.Path(s.workspaceRoot(), input.Workspace)
|
||||||
srcDir := filepath.Join(wsDir, "src")
|
srcDir := core.Path(wsDir, "src")
|
||||||
|
|
||||||
// Verify workspace exists
|
// Verify workspace exists
|
||||||
if _, err := coreio.Local.List(srcDir); err != nil {
|
if _, err := coreio.Local.List(srcDir); err != nil {
|
||||||
|
|
@ -78,8 +76,8 @@ func (s *PrepSubsystem) resume(ctx context.Context, _ *mcp.CallToolRequest, inpu
|
||||||
|
|
||||||
// Write ANSWER.md if answer provided
|
// Write ANSWER.md if answer provided
|
||||||
if input.Answer != "" {
|
if input.Answer != "" {
|
||||||
answerPath := filepath.Join(srcDir, "ANSWER.md")
|
answerPath := core.Path(srcDir, "ANSWER.md")
|
||||||
content := fmt.Sprintf("# Answer\n\n%s\n", input.Answer)
|
content := core.Sprintf("# Answer\n\n%s\n", input.Answer)
|
||||||
if err := writeAtomic(answerPath, content); err != nil {
|
if err := writeAtomic(answerPath, content); err != nil {
|
||||||
return nil, ResumeOutput{}, coreerr.E("resume", "failed to write ANSWER.md", err)
|
return nil, ResumeOutput{}, coreerr.E("resume", "failed to write ANSWER.md", err)
|
||||||
}
|
}
|
||||||
|
|
@ -102,7 +100,7 @@ func (s *PrepSubsystem) resume(ctx context.Context, _ *mcp.CallToolRequest, inpu
|
||||||
}
|
}
|
||||||
|
|
||||||
// Spawn agent as detached process (survives parent death)
|
// Spawn agent as detached process (survives parent death)
|
||||||
outputFile := filepath.Join(wsDir, fmt.Sprintf("agent-%s-run%d.log", agent, st.Runs+1))
|
outputFile := core.Path(wsDir, core.Sprintf("agent-%s-run%d.log", agent, st.Runs+1))
|
||||||
|
|
||||||
command, args, err := agentCommand(agent, prompt)
|
command, args, err := agentCommand(agent, prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -154,10 +152,10 @@ func (s *PrepSubsystem) resume(ctx context.Context, _ *mcp.CallToolRequest, inpu
|
||||||
"branch": st.Branch,
|
"branch": st.Branch,
|
||||||
}
|
}
|
||||||
|
|
||||||
if data, err := coreio.Local.Read(filepath.Join(srcDir, "BLOCKED.md")); err == nil {
|
if data, err := coreio.Local.Read(core.Path(srcDir, "BLOCKED.md")); err == nil {
|
||||||
status = "blocked"
|
status = "blocked"
|
||||||
channel = coremcp.ChannelAgentBlocked
|
channel = coremcp.ChannelAgentBlocked
|
||||||
st.Question = strings.TrimSpace(data)
|
st.Question = core.Trim(data)
|
||||||
if st.Question != "" {
|
if st.Question != "" {
|
||||||
payload["question"] = st.Question
|
payload["question"] = st.Question
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,16 +5,14 @@ package agentic
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
|
coreio "dappco.re/go/io"
|
||||||
coremcp "dappco.re/go/mcp/pkg/mcp"
|
coremcp "dappco.re/go/mcp/pkg/mcp"
|
||||||
coreio "forge.lthn.ai/core/go-io"
|
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -93,7 +91,7 @@ func (s *PrepSubsystem) reviewQueue(ctx context.Context, _ *mcp.CallToolRequest,
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
repoDir := filepath.Join(basePath, repo)
|
repoDir := core.Path(basePath, repo)
|
||||||
reviewer := input.Reviewer
|
reviewer := input.Reviewer
|
||||||
if reviewer == "" {
|
if reviewer == "" {
|
||||||
reviewer = "coderabbit"
|
reviewer = "coderabbit"
|
||||||
|
|
@ -137,7 +135,7 @@ func (s *PrepSubsystem) findReviewCandidates(basePath string) []string {
|
||||||
if !entry.IsDir() {
|
if !entry.IsDir() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
repoDir := filepath.Join(basePath, entry.Name())
|
repoDir := core.Path(basePath, entry.Name())
|
||||||
if !hasRemote(repoDir, "github") {
|
if !hasRemote(repoDir, "github") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
@ -154,22 +152,22 @@ func (s *PrepSubsystem) reviewRepo(ctx context.Context, repoDir, repo, reviewer
|
||||||
|
|
||||||
if rl := s.loadRateLimitState(); rl != nil && rl.Limited && time.Now().Before(rl.RetryAt) {
|
if rl := s.loadRateLimitState(); rl != nil && rl.Limited && time.Now().Before(rl.RetryAt) {
|
||||||
result.Verdict = "rate_limited"
|
result.Verdict = "rate_limited"
|
||||||
result.Detail = fmt.Sprintf("retry after %s", rl.RetryAt.Format(time.RFC3339))
|
result.Detail = core.Sprintf("retry after %s", rl.RetryAt.Format(time.RFC3339))
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := reviewerCommand(ctx, repoDir, reviewer)
|
cmd := reviewerCommand(ctx, repoDir, reviewer)
|
||||||
cmd.Dir = repoDir
|
cmd.Dir = repoDir
|
||||||
out, err := cmd.CombinedOutput()
|
out, err := cmd.CombinedOutput()
|
||||||
output := strings.TrimSpace(string(out))
|
output := core.Trim(string(out))
|
||||||
|
|
||||||
if strings.Contains(strings.ToLower(output), "rate limit") {
|
if core.Contains(core.Lower(output), "rate limit") {
|
||||||
result.Verdict = "rate_limited"
|
result.Verdict = "rate_limited"
|
||||||
result.Detail = output
|
result.Detail = output
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil && !strings.Contains(output, "No findings") && !strings.Contains(output, "no issues") {
|
if err != nil && !core.Contains(output, "No findings") && !core.Contains(output, "no issues") {
|
||||||
result.Verdict = "error"
|
result.Verdict = "error"
|
||||||
if output != "" {
|
if output != "" {
|
||||||
result.Detail = output
|
result.Detail = output
|
||||||
|
|
@ -182,7 +180,7 @@ func (s *PrepSubsystem) reviewRepo(ctx context.Context, repoDir, repo, reviewer
|
||||||
s.storeReviewOutput(repoDir, repo, reviewer, output)
|
s.storeReviewOutput(repoDir, repo, reviewer, output)
|
||||||
result.Findings = countFindingHints(output)
|
result.Findings = countFindingHints(output)
|
||||||
|
|
||||||
if strings.Contains(output, "No findings") || strings.Contains(output, "no issues") || strings.Contains(output, "LGTM") {
|
if core.Contains(output, "No findings") || core.Contains(output, "no issues") || core.Contains(output, "LGTM") {
|
||||||
result.Verdict = "clean"
|
result.Verdict = "clean"
|
||||||
if dryRun {
|
if dryRun {
|
||||||
result.Action = "skipped (dry run)"
|
result.Action = "skipped (dry run)"
|
||||||
|
|
@ -198,7 +196,7 @@ func (s *PrepSubsystem) reviewRepo(ctx context.Context, repoDir, repo, reviewer
|
||||||
mergeCmd.Dir = repoDir
|
mergeCmd.Dir = repoDir
|
||||||
if mergeOut, err := mergeCmd.CombinedOutput(); err == nil {
|
if mergeOut, err := mergeCmd.CombinedOutput(); err == nil {
|
||||||
result.Action = "merged"
|
result.Action = "merged"
|
||||||
result.Detail = strings.TrimSpace(string(mergeOut))
|
result.Detail = core.Trim(string(mergeOut))
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -219,7 +217,7 @@ func (s *PrepSubsystem) reviewRepo(ctx context.Context, repoDir, repo, reviewer
|
||||||
|
|
||||||
func (s *PrepSubsystem) storeReviewOutput(repoDir, repo, reviewer, output string) {
|
func (s *PrepSubsystem) storeReviewOutput(repoDir, repo, reviewer, output string) {
|
||||||
home := reviewQueueHomeDir()
|
home := reviewQueueHomeDir()
|
||||||
dataDir := filepath.Join(home, ".core", "training", "reviews")
|
dataDir := core.Path(home, ".core", "training", "reviews")
|
||||||
if err := coreio.Local.EnsureDir(dataDir); err != nil {
|
if err := coreio.Local.EnsureDir(dataDir); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -235,13 +233,13 @@ func (s *PrepSubsystem) storeReviewOutput(repoDir, repo, reviewer, output string
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
name := fmt.Sprintf("%s-%s-%d.json", repo, reviewer, time.Now().Unix())
|
name := core.Sprintf("%s-%s-%d.json", repo, reviewer, time.Now().Unix())
|
||||||
_ = writeAtomic(filepath.Join(dataDir, name), string(data))
|
_ = writeAtomic(core.Path(dataDir, name), string(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PrepSubsystem) saveRateLimitState(info *RateLimitInfo) {
|
func (s *PrepSubsystem) saveRateLimitState(info *RateLimitInfo) {
|
||||||
home := reviewQueueHomeDir()
|
home := reviewQueueHomeDir()
|
||||||
path := filepath.Join(home, ".core", "coderabbit-ratelimit.json")
|
path := core.Path(home, ".core", "coderabbit-ratelimit.json")
|
||||||
data, err := json.Marshal(info)
|
data, err := json.Marshal(info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|
@ -251,7 +249,7 @@ func (s *PrepSubsystem) saveRateLimitState(info *RateLimitInfo) {
|
||||||
|
|
||||||
func (s *PrepSubsystem) loadRateLimitState() *RateLimitInfo {
|
func (s *PrepSubsystem) loadRateLimitState() *RateLimitInfo {
|
||||||
home := reviewQueueHomeDir()
|
home := reviewQueueHomeDir()
|
||||||
path := filepath.Join(home, ".core", "coderabbit-ratelimit.json")
|
path := core.Path(home, ".core", "coderabbit-ratelimit.json")
|
||||||
data, err := coreio.Local.Read(path)
|
data, err := coreio.Local.Read(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,10 @@ package agentic
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
|
|
||||||
coreerr "forge.lthn.ai/core/go-log"
|
core "dappco.re/go/core"
|
||||||
|
coreerr "dappco.re/go/log"
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -81,7 +80,7 @@ func (s *PrepSubsystem) scan(ctx context.Context, _ *mcp.CallToolRequest, input
|
||||||
seen := make(map[string]bool)
|
seen := make(map[string]bool)
|
||||||
var unique []ScanIssue
|
var unique []ScanIssue
|
||||||
for _, issue := range allIssues {
|
for _, issue := range allIssues {
|
||||||
key := fmt.Sprintf("%s#%d", issue.Repo, issue.Number)
|
key := core.Sprintf("%s#%d", issue.Repo, issue.Number)
|
||||||
if !seen[key] {
|
if !seen[key] {
|
||||||
seen[key] = true
|
seen[key] = true
|
||||||
unique = append(unique, issue)
|
unique = append(unique, issue)
|
||||||
|
|
@ -100,7 +99,7 @@ func (s *PrepSubsystem) scan(ctx context.Context, _ *mcp.CallToolRequest, input
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PrepSubsystem) listOrgRepos(ctx context.Context, org string) ([]string, error) {
|
func (s *PrepSubsystem) listOrgRepos(ctx context.Context, org string) ([]string, error) {
|
||||||
url := fmt.Sprintf("%s/api/v1/orgs/%s/repos?limit=50", s.forgeURL, org)
|
url := core.Sprintf("%s/api/v1/orgs/%s/repos?limit=50", s.forgeURL, org)
|
||||||
req, _ := http.NewRequestWithContext(ctx, "GET", url, nil)
|
req, _ := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||||
req.Header.Set("Authorization", "token "+s.forgeToken)
|
req.Header.Set("Authorization", "token "+s.forgeToken)
|
||||||
|
|
||||||
|
|
@ -110,7 +109,7 @@ func (s *PrepSubsystem) listOrgRepos(ctx context.Context, org string) ([]string,
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
return nil, coreerr.E("listOrgRepos", fmt.Sprintf("HTTP %d listing repos", resp.StatusCode), nil)
|
return nil, coreerr.E("listOrgRepos", core.Sprintf("HTTP %d listing repos", resp.StatusCode), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
var repos []struct {
|
var repos []struct {
|
||||||
|
|
@ -126,7 +125,7 @@ func (s *PrepSubsystem) listOrgRepos(ctx context.Context, org string) ([]string,
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PrepSubsystem) listRepoIssues(ctx context.Context, org, repo, label string) ([]ScanIssue, error) {
|
func (s *PrepSubsystem) listRepoIssues(ctx context.Context, org, repo, label string) ([]ScanIssue, error) {
|
||||||
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/issues?state=open&labels=%s&limit=10&type=issues",
|
url := core.Sprintf("%s/api/v1/repos/%s/%s/issues?state=open&labels=%s&limit=10&type=issues",
|
||||||
s.forgeURL, org, repo, label)
|
s.forgeURL, org, repo, label)
|
||||||
req, _ := http.NewRequestWithContext(ctx, "GET", url, nil)
|
req, _ := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||||
req.Header.Set("Authorization", "token "+s.forgeToken)
|
req.Header.Set("Authorization", "token "+s.forgeToken)
|
||||||
|
|
@ -137,7 +136,7 @@ func (s *PrepSubsystem) listRepoIssues(ctx context.Context, org, repo, label str
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
return nil, coreerr.E("listRepoIssues", fmt.Sprintf("HTTP %d for "+repo, resp.StatusCode), nil)
|
return nil, coreerr.E("listRepoIssues", core.Sprintf("HTTP %d for "+repo, resp.StatusCode), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
var issues []struct {
|
var issues []struct {
|
||||||
|
|
@ -170,7 +169,7 @@ func (s *PrepSubsystem) listRepoIssues(ctx context.Context, org, repo, label str
|
||||||
Title: issue.Title,
|
Title: issue.Title,
|
||||||
Labels: labels,
|
Labels: labels,
|
||||||
Assignee: assignee,
|
Assignee: assignee,
|
||||||
URL: strings.Replace(issue.HTMLURL, "https://forge.lthn.ai", s.forgeURL, 1),
|
URL: core.Replace(issue.HTMLURL, "https://forge.lthn.ai", s.forgeURL),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,16 +6,18 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
|
coreio "dappco.re/go/io"
|
||||||
|
coreerr "dappco.re/go/log"
|
||||||
coremcp "dappco.re/go/mcp/pkg/mcp"
|
coremcp "dappco.re/go/mcp/pkg/mcp"
|
||||||
coreio "forge.lthn.ai/core/go-io"
|
|
||||||
coreerr "forge.lthn.ai/core/go-log"
|
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// os.Stat and os.FindProcess are used for workspace age detection and PID
|
||||||
|
// liveness checks — these are OS-level queries with no core equivalent.
|
||||||
|
|
||||||
// Workspace status file convention:
|
// Workspace status file convention:
|
||||||
//
|
//
|
||||||
// {workspace}/status.json — current state of the workspace
|
// {workspace}/status.json — current state of the workspace
|
||||||
|
|
@ -57,23 +59,23 @@ func writeStatus(wsDir string, status *WorkspaceStatus) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return writeAtomic(filepath.Join(wsDir, "status.json"), string(data))
|
return writeAtomic(core.JoinPath(wsDir, "status.json"), string(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PrepSubsystem) saveStatus(wsDir string, status *WorkspaceStatus) {
|
func (s *PrepSubsystem) saveStatus(wsDir string, status *WorkspaceStatus) {
|
||||||
if err := writeStatus(wsDir, status); err != nil {
|
if err := writeStatus(wsDir, status); err != nil {
|
||||||
coreerr.Warn("failed to write workspace status", "workspace", filepath.Base(wsDir), "err", err)
|
coreerr.Warn("failed to write workspace status", "workspace", core.PathBase(wsDir), "err", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func readStatus(wsDir string) (*WorkspaceStatus, error) {
|
func readStatus(wsDir string) (*WorkspaceStatus, error) {
|
||||||
data, err := coreio.Local.Read(filepath.Join(wsDir, "status.json"))
|
data, err := coreio.Local.Read(core.JoinPath(wsDir, "status.json"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var s WorkspaceStatus
|
var s WorkspaceStatus
|
||||||
if err := json.Unmarshal([]byte(data), &s); err != nil {
|
if r := core.JSONUnmarshal([]byte(data), &s); !r.OK {
|
||||||
return nil, err
|
return nil, coreerr.E("readStatus", "failed to parse status.json", nil)
|
||||||
}
|
}
|
||||||
return &s, nil
|
return &s, nil
|
||||||
}
|
}
|
||||||
|
|
@ -126,7 +128,7 @@ func (s *PrepSubsystem) status(ctx context.Context, _ *mcp.CallToolRequest, inpu
|
||||||
var workspaces []WorkspaceInfo
|
var workspaces []WorkspaceInfo
|
||||||
|
|
||||||
for _, wsDir := range wsDirs {
|
for _, wsDir := range wsDirs {
|
||||||
name := filepath.Base(wsDir)
|
name := core.PathBase(wsDir)
|
||||||
|
|
||||||
// Filter by specific workspace if requested
|
// Filter by specific workspace if requested
|
||||||
if input.Workspace != "" && name != input.Workspace {
|
if input.Workspace != "" && name != input.Workspace {
|
||||||
|
|
@ -139,7 +141,7 @@ func (s *PrepSubsystem) status(ctx context.Context, _ *mcp.CallToolRequest, inpu
|
||||||
st, err := readStatus(wsDir)
|
st, err := readStatus(wsDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Legacy workspace (no status.json) — check for log file
|
// Legacy workspace (no status.json) — check for log file
|
||||||
logFiles, _ := filepath.Glob(filepath.Join(wsDir, "agent-*.log"))
|
logFiles := core.PathGlob(core.Path(wsDir, "agent-*.log"))
|
||||||
if len(logFiles) > 0 {
|
if len(logFiles) > 0 {
|
||||||
info.Status = "completed"
|
info.Status = "completed"
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -177,10 +179,10 @@ func (s *PrepSubsystem) status(ctx context.Context, _ *mcp.CallToolRequest, inpu
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process died — check for BLOCKED.md
|
// Process died — check for BLOCKED.md
|
||||||
blockedPath := filepath.Join(wsDir, "src", "BLOCKED.md")
|
blockedPath := core.Path(wsDir, "src", "BLOCKED.md")
|
||||||
if data, err := coreio.Local.Read(blockedPath); err == nil {
|
if data, err := coreio.Local.Read(blockedPath); err == nil {
|
||||||
info.Status = "blocked"
|
info.Status = "blocked"
|
||||||
info.Question = strings.TrimSpace(data)
|
info.Question = core.Trim(data)
|
||||||
st.Status = "blocked"
|
st.Status = "blocked"
|
||||||
st.Question = info.Question
|
st.Question = info.Question
|
||||||
status = "blocked"
|
status = "blocked"
|
||||||
|
|
|
||||||
|
|
@ -4,14 +4,20 @@ package agentic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"path/filepath"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
|
coreerr "dappco.re/go/log"
|
||||||
coremcp "dappco.re/go/mcp/pkg/mcp"
|
coremcp "dappco.re/go/mcp/pkg/mcp"
|
||||||
coreerr "forge.lthn.ai/core/go-log"
|
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultWatchPollInterval = 5 * time.Second
|
||||||
|
defaultWatchTimeout = 60 * time.Second
|
||||||
|
maxWatchTimeout = 30 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
// WatchInput is the input for agentic_watch.
|
// WatchInput is the input for agentic_watch.
|
||||||
type WatchInput struct {
|
type WatchInput struct {
|
||||||
Workspaces []string `json:"workspaces,omitempty"`
|
Workspaces []string `json:"workspaces,omitempty"`
|
||||||
|
|
@ -49,13 +55,10 @@ func (s *PrepSubsystem) registerWatchTool(svc *coremcp.Service) {
|
||||||
func (s *PrepSubsystem) watch(ctx context.Context, req *mcp.CallToolRequest, input WatchInput) (*mcp.CallToolResult, WatchOutput, error) {
|
func (s *PrepSubsystem) watch(ctx context.Context, req *mcp.CallToolRequest, input WatchInput) (*mcp.CallToolResult, WatchOutput, error) {
|
||||||
pollInterval := time.Duration(input.PollInterval) * time.Second
|
pollInterval := time.Duration(input.PollInterval) * time.Second
|
||||||
if pollInterval <= 0 {
|
if pollInterval <= 0 {
|
||||||
pollInterval = 5 * time.Second
|
pollInterval = defaultWatchPollInterval
|
||||||
}
|
}
|
||||||
|
|
||||||
timeout := time.Duration(input.Timeout) * time.Second
|
timeout := resolveWatchTimeout(input)
|
||||||
if timeout <= 0 {
|
|
||||||
timeout = 10 * time.Minute
|
|
||||||
}
|
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
deadline := start.Add(timeout)
|
deadline := start.Add(timeout)
|
||||||
|
|
@ -69,6 +72,14 @@ func (s *PrepSubsystem) watch(ctx context.Context, req *mcp.CallToolRequest, inp
|
||||||
return nil, WatchOutput{Success: true, Duration: "0s"}, nil
|
return nil, WatchOutput{Success: true, Duration: "0s"}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
notifier := coremcp.NewProgressNotifier(ctx, req)
|
||||||
|
progress := float64(0)
|
||||||
|
total := float64(len(targets))
|
||||||
|
|
||||||
|
sendProgress := func(current float64, status WorkspaceStatus) {
|
||||||
|
_ = notifier.Send(current, total, core.Sprintf("%s %s (%s)", status.Repo, status.Status, status.Agent))
|
||||||
|
}
|
||||||
|
|
||||||
remaining := make(map[string]struct{}, len(targets))
|
remaining := make(map[string]struct{}, len(targets))
|
||||||
for _, workspace := range targets {
|
for _, workspace := range targets {
|
||||||
remaining[workspace] = struct{}{}
|
remaining[workspace] = struct{}{}
|
||||||
|
|
@ -106,6 +117,11 @@ func (s *PrepSubsystem) watch(ctx context.Context, req *mcp.CallToolRequest, inp
|
||||||
|
|
||||||
switch info.Status {
|
switch info.Status {
|
||||||
case "completed", "merged", "ready-for-review":
|
case "completed", "merged", "ready-for-review":
|
||||||
|
status := WorkspaceStatus{
|
||||||
|
Repo: info.Repo,
|
||||||
|
Agent: info.Agent,
|
||||||
|
Status: info.Status,
|
||||||
|
}
|
||||||
completed = append(completed, WatchResult{
|
completed = append(completed, WatchResult{
|
||||||
Workspace: info.Name,
|
Workspace: info.Name,
|
||||||
Agent: info.Agent,
|
Agent: info.Agent,
|
||||||
|
|
@ -116,7 +132,14 @@ func (s *PrepSubsystem) watch(ctx context.Context, req *mcp.CallToolRequest, inp
|
||||||
PRURL: info.PRURL,
|
PRURL: info.PRURL,
|
||||||
})
|
})
|
||||||
delete(remaining, info.Name)
|
delete(remaining, info.Name)
|
||||||
|
progress++
|
||||||
|
sendProgress(progress, status)
|
||||||
case "failed", "blocked":
|
case "failed", "blocked":
|
||||||
|
status := WorkspaceStatus{
|
||||||
|
Repo: info.Repo,
|
||||||
|
Agent: info.Agent,
|
||||||
|
Status: info.Status,
|
||||||
|
}
|
||||||
failed = append(failed, WatchResult{
|
failed = append(failed, WatchResult{
|
||||||
Workspace: info.Name,
|
Workspace: info.Name,
|
||||||
Agent: info.Agent,
|
Agent: info.Agent,
|
||||||
|
|
@ -127,6 +150,8 @@ func (s *PrepSubsystem) watch(ctx context.Context, req *mcp.CallToolRequest, inp
|
||||||
PRURL: info.PRURL,
|
PRURL: info.PRURL,
|
||||||
})
|
})
|
||||||
delete(remaining, info.Name)
|
delete(remaining, info.Name)
|
||||||
|
progress++
|
||||||
|
sendProgress(progress, status)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -139,6 +164,19 @@ func (s *PrepSubsystem) watch(ctx context.Context, req *mcp.CallToolRequest, inp
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func resolveWatchTimeout(input WatchInput) time.Duration {
|
||||||
|
if input.Timeout <= 0 {
|
||||||
|
return defaultWatchTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
maxSeconds := int(maxWatchTimeout / time.Second)
|
||||||
|
if input.Timeout > maxSeconds {
|
||||||
|
return maxWatchTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
return time.Duration(input.Timeout) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
func (s *PrepSubsystem) findActiveWorkspaces() []string {
|
func (s *PrepSubsystem) findActiveWorkspaces() []string {
|
||||||
wsDirs := s.listWorkspaceDirs()
|
wsDirs := s.listWorkspaceDirs()
|
||||||
if len(wsDirs) == 0 {
|
if len(wsDirs) == 0 {
|
||||||
|
|
@ -153,15 +191,15 @@ func (s *PrepSubsystem) findActiveWorkspaces() []string {
|
||||||
}
|
}
|
||||||
switch st.Status {
|
switch st.Status {
|
||||||
case "running", "queued":
|
case "running", "queued":
|
||||||
active = append(active, filepath.Base(wsDir))
|
active = append(active, core.PathBase(wsDir))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return active
|
return active
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PrepSubsystem) resolveWorkspaceDir(name string) string {
|
func (s *PrepSubsystem) resolveWorkspaceDir(name string) string {
|
||||||
if filepath.IsAbs(name) {
|
if core.PathIsAbs(name) {
|
||||||
return name
|
return name
|
||||||
}
|
}
|
||||||
return filepath.Join(s.workspaceRoot(), name)
|
return core.JoinPath(s.workspaceRoot(), name)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
41
pkg/mcp/agentic/watch_test.go
Normal file
41
pkg/mcp/agentic/watch_test.go
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
// SPDX-License-Identifier: EUPL-1.2
|
||||||
|
|
||||||
|
package agentic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWatchDefaults_Good_RFCOneMinuteTimeout(t *testing.T) {
|
||||||
|
if defaultWatchTimeout != 60*time.Second {
|
||||||
|
t.Fatalf("expected default watch timeout to be 60s, got %s", defaultWatchTimeout)
|
||||||
|
}
|
||||||
|
if defaultWatchPollInterval != 5*time.Second {
|
||||||
|
t.Fatalf("expected default poll interval to be 5s, got %s", defaultWatchPollInterval)
|
||||||
|
}
|
||||||
|
if maxWatchTimeout != 30*time.Minute {
|
||||||
|
t.Fatalf("expected max watch timeout to be 30m, got %s", maxWatchTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveWatchTimeout_Good_HonorsInputTimeout(t *testing.T) {
|
||||||
|
got := resolveWatchTimeout(WatchInput{Timeout: 10})
|
||||||
|
if got != 10*time.Second {
|
||||||
|
t.Fatalf("expected input timeout to be honored as 10s, got %s", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveWatchTimeout_Good_ClampsInputTimeout(t *testing.T) {
|
||||||
|
got := resolveWatchTimeout(WatchInput{Timeout: int((10 * time.Hour) / time.Second)})
|
||||||
|
if got != 30*time.Minute {
|
||||||
|
t.Fatalf("expected input timeout to clamp to 30m, got %s", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveWatchTimeout_Good_ZeroUsesDefault(t *testing.T) {
|
||||||
|
got := resolveWatchTimeout(WatchInput{Timeout: 0})
|
||||||
|
if got != defaultWatchTimeout {
|
||||||
|
t.Fatalf("expected zero timeout to use default %s, got %s", defaultWatchTimeout, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -4,23 +4,26 @@ package agentic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
coreio "forge.lthn.ai/core/go-io"
|
core "dappco.re/go/core"
|
||||||
|
coreio "dappco.re/go/io"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// os.CreateTemp, os.Remove, os.Rename are framework-boundary calls for
|
||||||
|
// atomic file writes — no core equivalent exists for temp file creation.
|
||||||
|
|
||||||
// writeAtomic writes content to path by staging it in a temporary file and
|
// writeAtomic writes content to path by staging it in a temporary file and
|
||||||
// renaming it into place.
|
// renaming it into place.
|
||||||
//
|
//
|
||||||
// This avoids exposing partially written workspace files to agents that may
|
// This avoids exposing partially written workspace files to agents that may
|
||||||
// read status, prompt, or plan documents while they are being updated.
|
// read status, prompt, or plan documents while they are being updated.
|
||||||
func writeAtomic(path, content string) error {
|
func writeAtomic(path, content string) error {
|
||||||
dir := filepath.Dir(path)
|
dir := core.PathDir(path)
|
||||||
if err := coreio.Local.EnsureDir(dir); err != nil {
|
if err := coreio.Local.EnsureDir(dir); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
tmp, err := os.CreateTemp(dir, "."+filepath.Base(path)+".*.tmp")
|
tmp, err := os.CreateTemp(dir, "."+core.PathBase(path)+".*.tmp")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
397
pkg/mcp/authz.go
Normal file
397
pkg/mcp/authz.go
Normal file
|
|
@ -0,0 +1,397 @@
|
||||||
|
// SPDX-License-Identifier: EUPL-1.2
|
||||||
|
|
||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
|
coreerr "dappco.re/go/log"
|
||||||
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// authTokenPrefix is the prefix used by HTTP Authorization headers.
|
||||||
|
authTokenPrefix = "Bearer "
|
||||||
|
// authDefaultJWTTTL is the default validity duration for minted JWTs.
|
||||||
|
authDefaultJWTTTL = time.Hour
|
||||||
|
// authJWTSecretEnv is the HMAC secret used for JWT signing and verification.
|
||||||
|
authJWTSecretEnv = "MCP_JWT_SECRET"
|
||||||
|
// authJWTTTLSecondsEnv allows overriding token lifetime.
|
||||||
|
authJWTTTLSecondsEnv = "MCP_JWT_TTL_SECONDS"
|
||||||
|
)
|
||||||
|
|
||||||
|
// authClaims is the compact claim payload stored inside our internal JWTs.
|
||||||
|
type authClaims struct {
|
||||||
|
Workspace string `json:"workspace,omitempty"`
|
||||||
|
Entitlements []string `json:"entitlements,omitempty"`
|
||||||
|
Subject string `json:"sub,omitempty"`
|
||||||
|
Issuer string `json:"iss,omitempty"`
|
||||||
|
IssuedAt int64 `json:"iat,omitempty"`
|
||||||
|
ExpiresAt int64 `json:"exp,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type authContextKey struct{}
|
||||||
|
|
||||||
|
func withAuthClaims(ctx context.Context, claims *authClaims) context.Context {
|
||||||
|
if ctx == nil {
|
||||||
|
return context.Background()
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, authContextKey{}, claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
func claimsFromContext(ctx context.Context) *authClaims {
|
||||||
|
if ctx == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if c := ctx.Value(authContextKey{}); c != nil {
|
||||||
|
if cl, ok := c.(*authClaims); ok {
|
||||||
|
return cl
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// authConfig holds token verification options derived from environment.
|
||||||
|
type authConfig struct {
|
||||||
|
apiToken string
|
||||||
|
secret []byte
|
||||||
|
ttl time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func currentAuthConfig(apiToken string) authConfig {
|
||||||
|
cfg := authConfig{
|
||||||
|
apiToken: apiToken,
|
||||||
|
secret: []byte(core.Env(authJWTSecretEnv)),
|
||||||
|
ttl: authDefaultJWTTTL,
|
||||||
|
}
|
||||||
|
if len(cfg.secret) == 0 {
|
||||||
|
cfg.secret = []byte(apiToken)
|
||||||
|
}
|
||||||
|
if ttlRaw := core.Trim(core.Env(authJWTTTLSecondsEnv)); ttlRaw != "" {
|
||||||
|
if ttlVal, err := strconv.Atoi(ttlRaw); err == nil && ttlVal > 0 {
|
||||||
|
cfg.ttl = time.Duration(ttlVal) * time.Second
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractBearerToken(raw string) string {
|
||||||
|
raw = core.Trim(raw)
|
||||||
|
if core.HasPrefix(raw, authTokenPrefix) {
|
||||||
|
return core.Trim(core.TrimPrefix(raw, authTokenPrefix))
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAuthClaims(authToken, apiToken string) (*authClaims, error) {
|
||||||
|
cfg := currentAuthConfig(apiToken)
|
||||||
|
if cfg.apiToken == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
tkn := extractBearerToken(authToken)
|
||||||
|
if tkn == "" {
|
||||||
|
return nil, coreerr.E("", "missing bearer token", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
if subtle.ConstantTimeCompare([]byte(tkn), []byte(cfg.apiToken)) == 1 {
|
||||||
|
return &authClaims{
|
||||||
|
Subject: "api-key",
|
||||||
|
IssuedAt: time.Now().Unix(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cfg.secret) == 0 {
|
||||||
|
return nil, coreerr.E("", "jwt secret is not configured", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := core.Split(tkn, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return nil, coreerr.E("", "invalid token format", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
headerJSON, err := decodeJWTSection(parts[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var header map[string]any
|
||||||
|
if err := json.Unmarshal(headerJSON, &header); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if alg, _ := header["alg"].(string); alg != "" && alg != "HS256" {
|
||||||
|
return nil, coreerr.E("", core.Sprintf("unsupported jwt algorithm: %s", alg), nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
signatureBase := parts[0] + "." + parts[1]
|
||||||
|
mac := hmac.New(sha256.New, cfg.secret)
|
||||||
|
mac.Write([]byte(signatureBase))
|
||||||
|
expectedSig := mac.Sum(nil)
|
||||||
|
actualSig, err := decodeJWTSection(parts[2])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !hmac.Equal(expectedSig, actualSig) {
|
||||||
|
return nil, coreerr.E("", "invalid token signature", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
payloadJSON, err := decodeJWTSection(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var claims authClaims
|
||||||
|
if err := json.Unmarshal(payloadJSON, &claims); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().Unix()
|
||||||
|
if claims.ExpiresAt > 0 && claims.ExpiresAt < now {
|
||||||
|
return nil, coreerr.E("", "token has expired", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeJWTSection(value string) ([]byte, error) {
|
||||||
|
raw, err := base64.RawURLEncoding.DecodeString(value)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return raw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeJWTSection(value []byte) string {
|
||||||
|
return base64.RawURLEncoding.EncodeToString(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func mintJWTToken(rawClaims authClaims, cfg authConfig) (string, error) {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
if rawClaims.IssuedAt == 0 {
|
||||||
|
rawClaims.IssuedAt = now
|
||||||
|
}
|
||||||
|
if rawClaims.ExpiresAt == 0 {
|
||||||
|
rawClaims.ExpiresAt = now + int64(cfg.ttl.Seconds())
|
||||||
|
}
|
||||||
|
header := map[string]string{
|
||||||
|
"alg": "HS256",
|
||||||
|
"typ": "JWT",
|
||||||
|
}
|
||||||
|
headerJSON, err := json.Marshal(header)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
payloadJSON, err := json.Marshal(rawClaims)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
signingInput := encodeJWTSection(headerJSON) + "." + encodeJWTSection(payloadJSON)
|
||||||
|
mac := hmac.New(sha256.New, cfg.secret)
|
||||||
|
mac.Write([]byte(signingInput))
|
||||||
|
signature := mac.Sum(nil)
|
||||||
|
|
||||||
|
return signingInput + "." + encodeJWTSection(signature), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func authClaimsFromToolRequest(ctx context.Context, req *mcp.CallToolRequest, apiToken string) (claims *authClaims, inTransport bool, err error) {
|
||||||
|
cfg := currentAuthConfig(apiToken)
|
||||||
|
if cfg.apiToken == "" {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
if req != nil {
|
||||||
|
extra := req.GetExtra()
|
||||||
|
if extra == nil || extra.Header == nil {
|
||||||
|
return nil, true, coreerr.E("", "missing request auth metadata", nil)
|
||||||
|
}
|
||||||
|
raw := extra.Header.Get("Authorization")
|
||||||
|
parsed, err := parseAuthClaims(raw, apiToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, true, err
|
||||||
|
}
|
||||||
|
return parsed, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims = claimsFromContext(ctx); claims != nil {
|
||||||
|
return claims, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) authorizeToolAccess(ctx context.Context, req *mcp.CallToolRequest, tool string, input any) error {
|
||||||
|
apiToken := core.Env("MCP_AUTH_TOKEN")
|
||||||
|
cfg := currentAuthConfig(apiToken)
|
||||||
|
if cfg.apiToken == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, inTransport, err := authClaimsFromToolRequest(ctx, req, apiToken)
|
||||||
|
if err != nil {
|
||||||
|
return coreerr.E("auth", "unauthorized", err)
|
||||||
|
}
|
||||||
|
if !inTransport {
|
||||||
|
// Allow direct service method calls in-process, while still enforcing
|
||||||
|
// transport requests where auth metadata is present.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if claims == nil {
|
||||||
|
return coreerr.E("auth", "unauthorized", coreerr.E("", "missing auth claims", nil))
|
||||||
|
}
|
||||||
|
if !claims.canRunTool(tool) {
|
||||||
|
return coreerr.E("auth", "forbidden", coreerr.E("", "tool not allowed for token", nil))
|
||||||
|
}
|
||||||
|
if !claims.canAccessWorkspaceFromInput(input) {
|
||||||
|
return coreerr.E("auth", "forbidden", coreerr.E("", "workspace scope mismatch", nil))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *authClaims) canRunTool(tool string) bool {
|
||||||
|
if c == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(c.Entitlements) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
toolAllow := "tool:" + tool
|
||||||
|
for _, e := range c.Entitlements {
|
||||||
|
switch e {
|
||||||
|
case "*", "tool:*", "tools:*":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
if e == tool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if e == toolAllow || e == "tools:"+tool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *authClaims) canAccessWorkspaceFromInput(input any) bool {
|
||||||
|
if c == nil || c.Workspace == "" || c.Workspace == "*" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
target := inputWorkspaceFromValue(input)
|
||||||
|
if target == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return workspaceMatch(c.Workspace, target)
|
||||||
|
}
|
||||||
|
|
||||||
|
func workspaceMatch(claimed, target string) bool {
|
||||||
|
if core.Trim(claimed) == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if core.Trim(target) == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if claimed == target {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if core.HasSuffix(claimed, "*") {
|
||||||
|
prefix := core.TrimSuffix(claimed, "*")
|
||||||
|
return core.HasPrefix(target, prefix)
|
||||||
|
}
|
||||||
|
return core.HasPrefix(target, claimed+"/")
|
||||||
|
}
|
||||||
|
|
||||||
|
func inputWorkspaceFromValue(input any) string {
|
||||||
|
if input == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
v := reflect.ValueOf(input)
|
||||||
|
for v.Kind() == reflect.Pointer && !v.IsNil() {
|
||||||
|
v = v.Elem()
|
||||||
|
}
|
||||||
|
if !v.IsValid() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v.Kind() {
|
||||||
|
case reflect.Map:
|
||||||
|
return workspaceFromMap(v)
|
||||||
|
case reflect.Struct:
|
||||||
|
return workspaceFromStruct(v)
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func workspaceFromMap(v reflect.Value) string {
|
||||||
|
if v.IsNil() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
keyType := v.Type().Key()
|
||||||
|
if keyType.Kind() != reflect.String {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for _, key := range []string{
|
||||||
|
"workspace",
|
||||||
|
"repo",
|
||||||
|
"repository",
|
||||||
|
"project",
|
||||||
|
"workspace_id",
|
||||||
|
} {
|
||||||
|
mapKey := reflect.ValueOf(key)
|
||||||
|
if mapKey.Type() != keyType {
|
||||||
|
if mapKey.Type().ConvertibleTo(keyType) {
|
||||||
|
mapKey = mapKey.Convert(keyType)
|
||||||
|
} else {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if mapKey.IsValid() {
|
||||||
|
raw := v.MapIndex(mapKey)
|
||||||
|
if raw.IsValid() && raw.Kind() == reflect.String {
|
||||||
|
return core.Trim(raw.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func workspaceFromStruct(v reflect.Value) string {
|
||||||
|
t := v.Type()
|
||||||
|
for i := 0; i < v.NumField(); i++ {
|
||||||
|
f := v.Field(i)
|
||||||
|
ft := t.Field(i)
|
||||||
|
if !f.CanInterface() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := []string{core.Lower(ft.Name)}
|
||||||
|
if tag := ft.Tag.Get("json"); tag != "" {
|
||||||
|
keys = append(keys, core.Lower(core.Split(tag, ",")[0]))
|
||||||
|
}
|
||||||
|
for _, candidate := range keys {
|
||||||
|
if candidate != "workspace" && candidate != "repo" && candidate != "repository" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch f.Kind() {
|
||||||
|
case reflect.String:
|
||||||
|
if s := core.Trim(f.String()); s != "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
case reflect.Pointer:
|
||||||
|
if f.IsNil() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if f.Elem().Kind() == reflect.String {
|
||||||
|
if s := core.Trim(f.Elem().String()); s != "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
@ -7,9 +7,9 @@ package brain
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
coreerr "dappco.re/go/log"
|
||||||
coremcp "dappco.re/go/mcp/pkg/mcp"
|
coremcp "dappco.re/go/mcp/pkg/mcp"
|
||||||
"dappco.re/go/mcp/pkg/mcp/ide"
|
"dappco.re/go/mcp/pkg/mcp/ide"
|
||||||
coreerr "forge.lthn.ai/core/go-log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// errBridgeNotAvailable is returned when a tool requires the Laravel bridge
|
// errBridgeNotAvailable is returned when a tool requires the Laravel bridge
|
||||||
|
|
@ -60,15 +60,15 @@ func (s *Subsystem) RegisterTools(svc *coremcp.Service) {
|
||||||
func (s *Subsystem) handleBridgeMessage(msg ide.BridgeMessage) {
|
func (s *Subsystem) handleBridgeMessage(msg ide.BridgeMessage) {
|
||||||
switch msg.Type {
|
switch msg.Type {
|
||||||
case "brain_remember":
|
case "brain_remember":
|
||||||
emitBridgeChannel(context.Background(), s.notifier, coremcp.ChannelBrainRememberDone, bridgePayload(msg.Data, "type", "project"))
|
emitBridgeChannel(context.Background(), s.notifier, coremcp.ChannelBrainRememberDone, bridgePayload(msg.Data, "org", "type", "project"))
|
||||||
case "brain_recall":
|
case "brain_recall":
|
||||||
payload := bridgePayload(msg.Data, "query", "project", "type", "agent_id")
|
payload := bridgePayload(msg.Data, "query", "org", "project", "type", "agent_id")
|
||||||
payload["count"] = bridgeCount(msg.Data)
|
payload["count"] = bridgeCount(msg.Data)
|
||||||
emitBridgeChannel(context.Background(), s.notifier, coremcp.ChannelBrainRecallDone, payload)
|
emitBridgeChannel(context.Background(), s.notifier, coremcp.ChannelBrainRecallDone, payload)
|
||||||
case "brain_forget":
|
case "brain_forget":
|
||||||
emitBridgeChannel(context.Background(), s.notifier, coremcp.ChannelBrainForgetDone, bridgePayload(msg.Data, "id", "reason"))
|
emitBridgeChannel(context.Background(), s.notifier, coremcp.ChannelBrainForgetDone, bridgePayload(msg.Data, "id", "reason"))
|
||||||
case "brain_list":
|
case "brain_list":
|
||||||
emitBridgeChannel(context.Background(), s.notifier, coremcp.ChannelBrainListDone, bridgePayload(msg.Data, "project", "type", "agent_id", "limit"))
|
emitBridgeChannel(context.Background(), s.notifier, coremcp.ChannelBrainListDone, bridgePayload(msg.Data, "org", "project", "type", "agent_id", "limit"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -89,6 +89,8 @@ func TestSubsystem_Good_BridgeRecallNotification(t *testing.T) {
|
||||||
Type: "brain_recall",
|
Type: "brain_recall",
|
||||||
Data: map[string]any{
|
Data: map[string]any{
|
||||||
"query": "how does scoring work?",
|
"query": "how does scoring work?",
|
||||||
|
"org": "core",
|
||||||
|
"project": "eaas",
|
||||||
"memories": []any{
|
"memories": []any{
|
||||||
map[string]any{"id": "m1"},
|
map[string]any{"id": "m1"},
|
||||||
map[string]any{"id": "m2"},
|
map[string]any{"id": "m2"},
|
||||||
|
|
@ -110,6 +112,9 @@ func TestSubsystem_Good_BridgeRecallNotification(t *testing.T) {
|
||||||
if payload["query"] != "how does scoring work?" {
|
if payload["query"] != "how does scoring work?" {
|
||||||
t.Fatalf("expected query to be forwarded, got %v", payload["query"])
|
t.Fatalf("expected query to be forwarded, got %v", payload["query"])
|
||||||
}
|
}
|
||||||
|
if payload["org"] != "core" {
|
||||||
|
t.Fatalf("expected org to be forwarded, got %v", payload["org"])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Struct round-trip tests ---
|
// --- Struct round-trip tests ---
|
||||||
|
|
@ -119,6 +124,7 @@ func TestRememberInput_Good_RoundTrip(t *testing.T) {
|
||||||
Content: "LEM scoring was blind to negative emotions",
|
Content: "LEM scoring was blind to negative emotions",
|
||||||
Type: "bug",
|
Type: "bug",
|
||||||
Tags: []string{"scoring", "lem"},
|
Tags: []string{"scoring", "lem"},
|
||||||
|
Org: "core",
|
||||||
Project: "eaas",
|
Project: "eaas",
|
||||||
Confidence: 0.95,
|
Confidence: 0.95,
|
||||||
Supersedes: "550e8400-e29b-41d4-a716-446655440000",
|
Supersedes: "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
|
@ -138,6 +144,9 @@ func TestRememberInput_Good_RoundTrip(t *testing.T) {
|
||||||
if len(out.Tags) != 2 || out.Tags[0] != "scoring" {
|
if len(out.Tags) != 2 || out.Tags[0] != "scoring" {
|
||||||
t.Errorf("round-trip mismatch: tags")
|
t.Errorf("round-trip mismatch: tags")
|
||||||
}
|
}
|
||||||
|
if out.Org != "core" {
|
||||||
|
t.Errorf("round-trip mismatch: org %q != core", out.Org)
|
||||||
|
}
|
||||||
if out.Confidence != 0.95 {
|
if out.Confidence != 0.95 {
|
||||||
t.Errorf("round-trip mismatch: confidence %f != 0.95", out.Confidence)
|
t.Errorf("round-trip mismatch: confidence %f != 0.95", out.Confidence)
|
||||||
}
|
}
|
||||||
|
|
@ -167,6 +176,7 @@ func TestRecallInput_Good_RoundTrip(t *testing.T) {
|
||||||
Query: "how does verdict classification work?",
|
Query: "how does verdict classification work?",
|
||||||
TopK: 5,
|
TopK: 5,
|
||||||
Filter: RecallFilter{
|
Filter: RecallFilter{
|
||||||
|
Org: "core",
|
||||||
Project: "eaas",
|
Project: "eaas",
|
||||||
MinConfidence: 0.5,
|
MinConfidence: 0.5,
|
||||||
},
|
},
|
||||||
|
|
@ -182,7 +192,7 @@ func TestRecallInput_Good_RoundTrip(t *testing.T) {
|
||||||
if out.Query != in.Query || out.TopK != 5 {
|
if out.Query != in.Query || out.TopK != 5 {
|
||||||
t.Errorf("round-trip mismatch: query or topK")
|
t.Errorf("round-trip mismatch: query or topK")
|
||||||
}
|
}
|
||||||
if out.Filter.Project != "eaas" || out.Filter.MinConfidence != 0.5 {
|
if out.Filter.Org != "core" || out.Filter.Project != "eaas" || out.Filter.MinConfidence != 0.5 {
|
||||||
t.Errorf("round-trip mismatch: filter")
|
t.Errorf("round-trip mismatch: filter")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -194,6 +204,7 @@ func TestMemory_Good_RoundTrip(t *testing.T) {
|
||||||
Type: "decision",
|
Type: "decision",
|
||||||
Content: "Use Qdrant for vector search",
|
Content: "Use Qdrant for vector search",
|
||||||
Tags: []string{"architecture", "openbrain"},
|
Tags: []string{"architecture", "openbrain"},
|
||||||
|
Org: "core",
|
||||||
Project: "php-agentic",
|
Project: "php-agentic",
|
||||||
Confidence: 0.9,
|
Confidence: 0.9,
|
||||||
CreatedAt: "2026-03-03T12:00:00+00:00",
|
CreatedAt: "2026-03-03T12:00:00+00:00",
|
||||||
|
|
@ -207,7 +218,7 @@ func TestMemory_Good_RoundTrip(t *testing.T) {
|
||||||
if err := json.Unmarshal(data, &out); err != nil {
|
if err := json.Unmarshal(data, &out); err != nil {
|
||||||
t.Fatalf("unmarshal failed: %v", err)
|
t.Fatalf("unmarshal failed: %v", err)
|
||||||
}
|
}
|
||||||
if out.ID != in.ID || out.AgentID != "virgil" || out.Type != "decision" {
|
if out.ID != in.ID || out.AgentID != "virgil" || out.Type != "decision" || out.Org != "core" {
|
||||||
t.Errorf("round-trip mismatch: %+v", out)
|
t.Errorf("round-trip mismatch: %+v", out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -232,6 +243,7 @@ func TestForgetInput_Good_RoundTrip(t *testing.T) {
|
||||||
|
|
||||||
func TestListInput_Good_RoundTrip(t *testing.T) {
|
func TestListInput_Good_RoundTrip(t *testing.T) {
|
||||||
in := ListInput{
|
in := ListInput{
|
||||||
|
Org: "core",
|
||||||
Project: "eaas",
|
Project: "eaas",
|
||||||
Type: "decision",
|
Type: "decision",
|
||||||
AgentID: "charon",
|
AgentID: "charon",
|
||||||
|
|
@ -245,7 +257,7 @@ func TestListInput_Good_RoundTrip(t *testing.T) {
|
||||||
if err := json.Unmarshal(data, &out); err != nil {
|
if err := json.Unmarshal(data, &out); err != nil {
|
||||||
t.Fatalf("unmarshal failed: %v", err)
|
t.Fatalf("unmarshal failed: %v", err)
|
||||||
}
|
}
|
||||||
if out.Project != "eaas" || out.Type != "decision" || out.AgentID != "charon" || out.Limit != 20 {
|
if out.Org != "core" || out.Project != "eaas" || out.Type != "decision" || out.AgentID != "charon" || out.Limit != 20 {
|
||||||
t.Errorf("round-trip mismatch: %+v", out)
|
t.Errorf("round-trip mismatch: %+v", out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
673
pkg/mcp/brain/client/client.go
Normal file
673
pkg/mcp/brain/client/client.go
Normal file
|
|
@ -0,0 +1,673 @@
|
||||||
|
// SPDX-License-Identifier: EUPL-1.2
|
||||||
|
|
||||||
|
// Package client provides the shared OpenBrain HTTP client.
|
||||||
|
//
|
||||||
|
// c := client.New(client.Options{URL: core.Env("CORE_BRAIN_URL"), Key: core.Env("CORE_BRAIN_KEY")})
|
||||||
|
// _, err := c.Remember(ctx, client.RememberInput{
|
||||||
|
// Org: "core",
|
||||||
|
// Project: "mcp",
|
||||||
|
// Content: "Use one OpenBrain client for retry and circuit-breaker policy.",
|
||||||
|
// Type: "decision",
|
||||||
|
// })
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
cryptorand "crypto/rand"
|
||||||
|
"io"
|
||||||
|
"io/fs"
|
||||||
|
"math/big"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
|
coreio "dappco.re/go/core/io"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultURL = "https://api.lthn.sh"
|
||||||
|
insecureBrainEnv = "CORE_BRAIN_INSECURE"
|
||||||
|
brainKeyFileMode = fs.FileMode(0o600)
|
||||||
|
defaultAgentID = "cladius"
|
||||||
|
defaultTimeout = 30 * time.Second
|
||||||
|
defaultMaxAttempts = 3
|
||||||
|
defaultBaseDelay = 100 * time.Millisecond
|
||||||
|
defaultFailureThreshold = 3
|
||||||
|
defaultSuccessThreshold = 1
|
||||||
|
defaultCircuitCooldown = 30 * time.Second
|
||||||
|
defaultMaxResponseBytes = int64(1 << 20)
|
||||||
|
maxBackoffDelay = 30 * time.Second
|
||||||
|
maxRetryAfterDelay = 60 * time.Second
|
||||||
|
defaultRecallTopK = 10
|
||||||
|
defaultListLimit = 50
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrCircuitOpen is returned when repeated upstream failures have opened the circuit.
|
||||||
|
var ErrCircuitOpen = core.NewError("brain client circuit open")
|
||||||
|
|
||||||
|
// Options configures the shared OpenBrain client.
|
||||||
|
type Options struct {
|
||||||
|
URL string
|
||||||
|
Key string
|
||||||
|
Org string
|
||||||
|
AgentID string
|
||||||
|
HTTPClient *http.Client
|
||||||
|
MaxAttempts int
|
||||||
|
BaseDelay time.Duration
|
||||||
|
MaxResponseBytes int64
|
||||||
|
CircuitBreaker *CircuitBreaker
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client calls the Laravel /v1/brain/* API with shared retry and circuit policy.
|
||||||
|
type Client struct {
|
||||||
|
apiURL string
|
||||||
|
apiKey string
|
||||||
|
org string
|
||||||
|
agentID string
|
||||||
|
httpClient *http.Client
|
||||||
|
maxAttempts int
|
||||||
|
baseDelay time.Duration
|
||||||
|
maxResponseBytes int64
|
||||||
|
circuitBreaker *CircuitBreaker
|
||||||
|
configErr error
|
||||||
|
sleepFunc func(context.Context, time.Duration) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// RememberInput is the request body for POST /v1/brain/remember.
|
||||||
|
type RememberInput struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Tags []string `json:"tags,omitempty"`
|
||||||
|
Org string `json:"org,omitempty"`
|
||||||
|
Project string `json:"project,omitempty"`
|
||||||
|
AgentID string `json:"agent_id,omitempty"`
|
||||||
|
Confidence float64 `json:"confidence,omitempty"`
|
||||||
|
Supersedes string `json:"supersedes,omitempty"`
|
||||||
|
ExpiresIn int `json:"expires_in,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecallInput is the request body for POST /v1/brain/recall.
|
||||||
|
type RecallInput struct {
|
||||||
|
Query string `json:"query"`
|
||||||
|
TopK int `json:"top_k,omitempty"`
|
||||||
|
Org string `json:"org,omitempty"`
|
||||||
|
Project string `json:"project,omitempty"`
|
||||||
|
Type any `json:"type,omitempty"`
|
||||||
|
AgentID string `json:"agent_id,omitempty"`
|
||||||
|
MinConfidence float64 `json:"min_confidence,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForgetInput selects the memory removed by DELETE /v1/brain/forget/{id}.
|
||||||
|
type ForgetInput struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Reason string `json:"reason,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListInput provides URL parameters for GET /v1/brain/list.
|
||||||
|
type ListInput struct {
|
||||||
|
Org string `json:"org,omitempty"`
|
||||||
|
Project string `json:"project,omitempty"`
|
||||||
|
Type string `json:"type,omitempty"`
|
||||||
|
AgentID string `json:"agent_id,omitempty"`
|
||||||
|
Limit int `json:"limit,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CircuitState is the current breaker state.
|
||||||
|
type CircuitState string
|
||||||
|
|
||||||
|
const (
|
||||||
|
CircuitClosed CircuitState = "closed"
|
||||||
|
CircuitOpen CircuitState = "open"
|
||||||
|
CircuitHalfOpen CircuitState = "half_open"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CircuitBreakerOptions controls when the circuit opens and recovers.
|
||||||
|
type CircuitBreakerOptions struct {
|
||||||
|
FailureThreshold int
|
||||||
|
SuccessThreshold int
|
||||||
|
Cooldown time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// CircuitBreaker protects OpenBrain from repeated failed calls.
|
||||||
|
type CircuitBreaker struct {
|
||||||
|
lock sync.Mutex
|
||||||
|
state CircuitState
|
||||||
|
failureThreshold int
|
||||||
|
successThreshold int
|
||||||
|
cooldown time.Duration
|
||||||
|
consecutiveFails int
|
||||||
|
consecutiveWins int
|
||||||
|
openedAt time.Time
|
||||||
|
halfOpenInFlight bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a shared OpenBrain client.
|
||||||
|
func New(options Options) *Client {
|
||||||
|
apiURL := core.Trim(options.URL)
|
||||||
|
if apiURL == "" {
|
||||||
|
apiURL = DefaultURL
|
||||||
|
}
|
||||||
|
configErr := validateAPIURL(apiURL)
|
||||||
|
agentID := core.Trim(options.AgentID)
|
||||||
|
if agentID == "" {
|
||||||
|
agentID = defaultAgentID
|
||||||
|
}
|
||||||
|
httpClient := options.HTTPClient
|
||||||
|
if httpClient == nil {
|
||||||
|
httpClient = &http.Client{Timeout: defaultTimeout}
|
||||||
|
}
|
||||||
|
maxAttempts := options.MaxAttempts
|
||||||
|
if maxAttempts <= 0 {
|
||||||
|
maxAttempts = defaultMaxAttempts
|
||||||
|
}
|
||||||
|
baseDelay := options.BaseDelay
|
||||||
|
if baseDelay <= 0 {
|
||||||
|
baseDelay = defaultBaseDelay
|
||||||
|
}
|
||||||
|
maxResponseBytes := options.MaxResponseBytes
|
||||||
|
if maxResponseBytes <= 0 {
|
||||||
|
maxResponseBytes = defaultMaxResponseBytes
|
||||||
|
}
|
||||||
|
breaker := options.CircuitBreaker
|
||||||
|
if breaker == nil {
|
||||||
|
breaker = NewCircuitBreaker(CircuitBreakerOptions{})
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Client{
|
||||||
|
apiURL: core.TrimSuffix(apiURL, "/"),
|
||||||
|
apiKey: core.Trim(options.Key),
|
||||||
|
org: core.Trim(options.Org),
|
||||||
|
agentID: agentID,
|
||||||
|
httpClient: httpClient,
|
||||||
|
maxAttempts: maxAttempts,
|
||||||
|
baseDelay: baseDelay,
|
||||||
|
maxResponseBytes: maxResponseBytes,
|
||||||
|
circuitBreaker: breaker,
|
||||||
|
configErr: configErr,
|
||||||
|
sleepFunc: sleepDuration,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFromEnvironment reads CORE_BRAIN_* settings and ~/.claude/brain.key.
|
||||||
|
func NewFromEnvironment() *Client {
|
||||||
|
apiKey, configErr := apiKeyFromEnvironment()
|
||||||
|
client := New(Options{
|
||||||
|
URL: envOr("CORE_BRAIN_URL", DefaultURL),
|
||||||
|
Key: apiKey,
|
||||||
|
Org: core.Env("CORE_BRAIN_ORG"),
|
||||||
|
AgentID: core.Env("CORE_BRAIN_AGENT_ID"),
|
||||||
|
})
|
||||||
|
if configErr != nil {
|
||||||
|
client.configErr = configErr
|
||||||
|
}
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateAPIURL(apiURL string) error {
|
||||||
|
parsed, err := url.Parse(apiURL)
|
||||||
|
if err != nil || parsed.Scheme == "" || parsed.Host == "" {
|
||||||
|
return core.E("brain.client", "invalid API URL", err)
|
||||||
|
}
|
||||||
|
if parsed.Scheme == "https" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if parsed.Scheme == "http" && core.Trim(core.Env(insecureBrainEnv)) == "true" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return core.E("brain.client", "API URL must use https unless CORE_BRAIN_INSECURE=true", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteBrainKey stores the OpenBrain API key at ~/.claude/brain.key with owner-only permissions.
|
||||||
|
func WriteBrainKey(apiKey string) error {
|
||||||
|
home := core.Env("HOME")
|
||||||
|
if home == "" {
|
||||||
|
return core.E("brain.client", "HOME not set", nil)
|
||||||
|
}
|
||||||
|
return writeBrainKeyFile(brainKeyPath(home), apiKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCircuitBreaker creates a circuit breaker with OpenBrain defaults.
|
||||||
|
func NewCircuitBreaker(options CircuitBreakerOptions) *CircuitBreaker {
|
||||||
|
failureThreshold := options.FailureThreshold
|
||||||
|
if failureThreshold <= 0 {
|
||||||
|
failureThreshold = defaultFailureThreshold
|
||||||
|
}
|
||||||
|
successThreshold := options.SuccessThreshold
|
||||||
|
if successThreshold <= 0 {
|
||||||
|
successThreshold = defaultSuccessThreshold
|
||||||
|
}
|
||||||
|
cooldown := options.Cooldown
|
||||||
|
if cooldown <= 0 {
|
||||||
|
cooldown = defaultCircuitCooldown
|
||||||
|
}
|
||||||
|
return &CircuitBreaker{
|
||||||
|
state: CircuitClosed,
|
||||||
|
failureThreshold: failureThreshold,
|
||||||
|
successThreshold: successThreshold,
|
||||||
|
cooldown: cooldown,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// State returns the current breaker state.
|
||||||
|
func (breaker *CircuitBreaker) State() CircuitState {
|
||||||
|
if breaker == nil {
|
||||||
|
return CircuitClosed
|
||||||
|
}
|
||||||
|
breaker.lock.Lock()
|
||||||
|
defer breaker.lock.Unlock()
|
||||||
|
return breaker.stateNow(time.Now())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remember stores a memory in OpenBrain.
|
||||||
|
func (c *Client) Remember(ctx context.Context, input RememberInput) (map[string]any, error) {
|
||||||
|
input.Org = c.orgFor(input.Org)
|
||||||
|
input.AgentID = c.agentFor(input.AgentID)
|
||||||
|
return c.Call(ctx, http.MethodPost, "/v1/brain/remember", input)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recall searches memories in OpenBrain.
|
||||||
|
func (c *Client) Recall(ctx context.Context, input RecallInput) (map[string]any, error) {
|
||||||
|
input.Org = c.orgFor(input.Org)
|
||||||
|
input.AgentID = c.agentFor(input.AgentID)
|
||||||
|
if input.TopK == 0 {
|
||||||
|
input.TopK = defaultRecallTopK
|
||||||
|
}
|
||||||
|
return c.Call(ctx, http.MethodPost, "/v1/brain/recall", input)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forget removes one memory from OpenBrain.
|
||||||
|
func (c *Client) Forget(ctx context.Context, input ForgetInput) (map[string]any, error) {
|
||||||
|
return c.Call(ctx, http.MethodDelete, core.Concat("/v1/brain/forget/", url.PathEscape(input.ID)), nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// List returns memories from OpenBrain using URL query filters.
|
||||||
|
func (c *Client) List(ctx context.Context, input ListInput) (map[string]any, error) {
|
||||||
|
input.Org = c.orgFor(input.Org)
|
||||||
|
if input.Limit == 0 {
|
||||||
|
input.Limit = defaultListLimit
|
||||||
|
}
|
||||||
|
values := url.Values{}
|
||||||
|
if input.Org != "" {
|
||||||
|
values.Set("org", input.Org)
|
||||||
|
}
|
||||||
|
if input.Project != "" {
|
||||||
|
values.Set("project", input.Project)
|
||||||
|
}
|
||||||
|
if input.Type != "" {
|
||||||
|
values.Set("type", input.Type)
|
||||||
|
}
|
||||||
|
if input.AgentID != "" {
|
||||||
|
values.Set("agent_id", input.AgentID)
|
||||||
|
}
|
||||||
|
values.Set("limit", core.Sprintf("%d", input.Limit))
|
||||||
|
return c.Call(ctx, http.MethodGet, core.Concat("/v1/brain/list?", values.Encode()), nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call performs one OpenBrain API request through retry and circuit-breaker policy.
|
||||||
|
func (c *Client) Call(ctx context.Context, method, path string, body any) (map[string]any, error) {
|
||||||
|
if c.configErr != nil {
|
||||||
|
return nil, c.configErr
|
||||||
|
}
|
||||||
|
if c.apiKey == "" {
|
||||||
|
return nil, core.E("brain.client", "no API key (set CORE_BRAIN_KEY or create ~/.claude/brain.key)", nil)
|
||||||
|
}
|
||||||
|
if err := c.circuitBreaker.beforeRequest(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyString := ""
|
||||||
|
if body != nil {
|
||||||
|
bodyString = core.JSONMarshalString(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
for attempt := 1; attempt <= c.maxAttempts; attempt++ {
|
||||||
|
payload, retryable, retryAfter, hasRetryAfter, err := c.doOnce(ctx, method, path, bodyString, body != nil)
|
||||||
|
if err == nil {
|
||||||
|
c.circuitBreaker.recordSuccess()
|
||||||
|
return payload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
lastErr = err
|
||||||
|
if !retryable {
|
||||||
|
c.circuitBreaker.recordIgnored()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
c.circuitBreaker.recordFailure()
|
||||||
|
if c.circuitBreaker.State() == CircuitOpen || attempt == c.maxAttempts {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
var sleepErr error
|
||||||
|
if hasRetryAfter {
|
||||||
|
sleepErr = c.sleepFor(ctx, retryAfter)
|
||||||
|
} else {
|
||||||
|
sleepErr = c.sleep(ctx, attempt)
|
||||||
|
}
|
||||||
|
if sleepErr != nil {
|
||||||
|
lastErr = sleepErr
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) doOnce(ctx context.Context, method, path, bodyString string, hasBody bool) (map[string]any, bool, time.Duration, bool, error) {
|
||||||
|
var reader io.Reader
|
||||||
|
if hasBody {
|
||||||
|
reader = core.NewReader(bodyString)
|
||||||
|
}
|
||||||
|
requestURL, err := c.requestURL(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, 0, false, err
|
||||||
|
}
|
||||||
|
request, err := http.NewRequestWithContext(ctx, method, requestURL, reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, 0, false, core.E("brain.client", "create request", err)
|
||||||
|
}
|
||||||
|
request.Header.Set("Accept", "application/json")
|
||||||
|
request.Header.Set("Authorization", core.Concat("Bearer ", c.apiKey))
|
||||||
|
if hasBody {
|
||||||
|
request.Header.Set("Content-Type", "application/json")
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := c.httpClient.Do(request)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return nil, false, 0, false, core.E("brain.client", "request cancelled", ctx.Err())
|
||||||
|
}
|
||||||
|
return nil, true, 0, false, core.E("brain.client", "request failed", err)
|
||||||
|
}
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
readResult := core.ReadAll(io.LimitReader(response.Body, c.maxResponseBytes+1))
|
||||||
|
if !readResult.OK {
|
||||||
|
if readErr, ok := readResult.Value.(error); ok {
|
||||||
|
return nil, false, 0, false, core.E("brain.client", "read response", readErr)
|
||||||
|
}
|
||||||
|
return nil, false, 0, false, core.E("brain.client", "read response", nil)
|
||||||
|
}
|
||||||
|
raw := readResult.Value.(string)
|
||||||
|
if int64(len(raw)) > c.maxResponseBytes {
|
||||||
|
return nil, false, 0, false, core.E("brain.client", "response too large", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.StatusCode >= http.StatusBadRequest {
|
||||||
|
retryAfter, hasRetryAfter := parseRetryAfter(response.Header.Get("Retry-After"), time.Now())
|
||||||
|
return nil, retryableStatus(response.StatusCode), retryAfter, hasRetryAfter, core.E("brain.client", core.Concat("upstream returned ", response.Status, ": ", core.Trim(raw)), nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := map[string]any{}
|
||||||
|
if parseResult := core.JSONUnmarshalString(raw, &result); !parseResult.OK {
|
||||||
|
if parseErr, ok := parseResult.Value.(error); ok {
|
||||||
|
return nil, false, 0, false, core.E("brain.client", "parse response", parseErr)
|
||||||
|
}
|
||||||
|
return nil, false, 0, false, core.E("brain.client", "parse response", nil)
|
||||||
|
}
|
||||||
|
return result, false, 0, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) requestURL(path string) (string, error) {
|
||||||
|
parsed, err := url.Parse(path)
|
||||||
|
if err == nil && (parsed.IsAbs() || parsed.Host != "") {
|
||||||
|
return "", core.E("brain.client", "absolute request URL rejected", nil)
|
||||||
|
}
|
||||||
|
if !core.HasPrefix(path, "/") {
|
||||||
|
path = core.Concat("/", path)
|
||||||
|
}
|
||||||
|
return core.Concat(c.apiURL, path), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) sleep(ctx context.Context, attempt int) error {
|
||||||
|
retryAttempt := attempt - 1
|
||||||
|
delay := jitteredBackoffDelay(c.baseDelay, retryAttempt)
|
||||||
|
return c.sleepFor(ctx, delay)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) sleepFor(ctx context.Context, delay time.Duration) error {
|
||||||
|
if c.sleepFunc != nil {
|
||||||
|
return c.sleepFunc(ctx, delay)
|
||||||
|
}
|
||||||
|
return sleepDuration(ctx, delay)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sleepDuration(ctx context.Context, delay time.Duration) error {
|
||||||
|
if delay <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
timer := time.NewTimer(delay)
|
||||||
|
defer timer.Stop()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return core.E("brain.client", "request cancelled", ctx.Err())
|
||||||
|
case <-timer.C:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func jitteredBackoffDelay(baseDelay time.Duration, attempt int) time.Duration {
|
||||||
|
limit := backoffDelayLimit(baseDelay, attempt)
|
||||||
|
if limit <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
jitter, err := cryptorand.Int(cryptorand.Reader, big.NewInt(int64(limit)))
|
||||||
|
if err != nil {
|
||||||
|
return limit
|
||||||
|
}
|
||||||
|
return time.Duration(jitter.Int64())
|
||||||
|
}
|
||||||
|
|
||||||
|
func backoffDelayLimit(baseDelay time.Duration, attempt int) time.Duration {
|
||||||
|
if baseDelay <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if baseDelay >= maxBackoffDelay {
|
||||||
|
return maxBackoffDelay
|
||||||
|
}
|
||||||
|
if attempt <= 0 {
|
||||||
|
return baseDelay
|
||||||
|
}
|
||||||
|
|
||||||
|
delay := baseDelay
|
||||||
|
for i := 0; i < attempt; i++ {
|
||||||
|
if delay >= maxBackoffDelay/2 {
|
||||||
|
return maxBackoffDelay
|
||||||
|
}
|
||||||
|
delay *= 2
|
||||||
|
}
|
||||||
|
if delay > maxBackoffDelay {
|
||||||
|
return maxBackoffDelay
|
||||||
|
}
|
||||||
|
return delay
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) orgFor(org string) string {
|
||||||
|
org = core.Trim(org)
|
||||||
|
if org != "" {
|
||||||
|
return org
|
||||||
|
}
|
||||||
|
return c.org
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) agentFor(agentID string) string {
|
||||||
|
agentID = core.Trim(agentID)
|
||||||
|
if agentID != "" {
|
||||||
|
return agentID
|
||||||
|
}
|
||||||
|
return c.agentID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (breaker *CircuitBreaker) beforeRequest() error {
|
||||||
|
if breaker == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
breaker.lock.Lock()
|
||||||
|
defer breaker.lock.Unlock()
|
||||||
|
|
||||||
|
state := breaker.stateNow(time.Now())
|
||||||
|
if state == CircuitOpen {
|
||||||
|
return ErrCircuitOpen
|
||||||
|
}
|
||||||
|
if state == CircuitHalfOpen {
|
||||||
|
if breaker.halfOpenInFlight {
|
||||||
|
return ErrCircuitOpen
|
||||||
|
}
|
||||||
|
breaker.halfOpenInFlight = true
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (breaker *CircuitBreaker) recordSuccess() {
|
||||||
|
if breaker == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
breaker.lock.Lock()
|
||||||
|
defer breaker.lock.Unlock()
|
||||||
|
|
||||||
|
breaker.halfOpenInFlight = false
|
||||||
|
breaker.consecutiveFails = 0
|
||||||
|
breaker.consecutiveWins++
|
||||||
|
if breaker.state == CircuitHalfOpen && breaker.consecutiveWins >= breaker.successThreshold {
|
||||||
|
breaker.state = CircuitClosed
|
||||||
|
breaker.consecutiveWins = 0
|
||||||
|
}
|
||||||
|
if breaker.state == CircuitClosed {
|
||||||
|
breaker.consecutiveWins = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (breaker *CircuitBreaker) recordFailure() {
|
||||||
|
if breaker == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
breaker.lock.Lock()
|
||||||
|
defer breaker.lock.Unlock()
|
||||||
|
|
||||||
|
breaker.halfOpenInFlight = false
|
||||||
|
breaker.consecutiveWins = 0
|
||||||
|
breaker.consecutiveFails++
|
||||||
|
if breaker.state == CircuitHalfOpen || breaker.consecutiveFails >= breaker.failureThreshold {
|
||||||
|
breaker.state = CircuitOpen
|
||||||
|
breaker.openedAt = time.Now()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (breaker *CircuitBreaker) recordIgnored() {
|
||||||
|
if breaker == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
breaker.lock.Lock()
|
||||||
|
defer breaker.lock.Unlock()
|
||||||
|
breaker.halfOpenInFlight = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (breaker *CircuitBreaker) stateNow(now time.Time) CircuitState {
|
||||||
|
if breaker.state == "" {
|
||||||
|
breaker.state = CircuitClosed
|
||||||
|
}
|
||||||
|
if breaker.state == CircuitOpen && now.Sub(breaker.openedAt) >= breaker.cooldown {
|
||||||
|
breaker.state = CircuitHalfOpen
|
||||||
|
breaker.consecutiveFails = 0
|
||||||
|
breaker.consecutiveWins = 0
|
||||||
|
breaker.halfOpenInFlight = false
|
||||||
|
}
|
||||||
|
return breaker.state
|
||||||
|
}
|
||||||
|
|
||||||
|
func retryableStatus(statusCode int) bool {
|
||||||
|
return statusCode == http.StatusRequestTimeout || statusCode == http.StatusTooManyRequests || statusCode >= http.StatusInternalServerError
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseRetryAfter(value string, now time.Time) (time.Duration, bool) {
|
||||||
|
value = core.Trim(value)
|
||||||
|
if value == "" {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if seconds, err := strconv.ParseInt(value, 10, 64); err == nil {
|
||||||
|
if seconds <= 0 {
|
||||||
|
return 0, true
|
||||||
|
}
|
||||||
|
maxSeconds := int64(maxRetryAfterDelay / time.Second)
|
||||||
|
if seconds > maxSeconds {
|
||||||
|
return maxRetryAfterDelay, true
|
||||||
|
}
|
||||||
|
return time.Duration(seconds) * time.Second, true
|
||||||
|
}
|
||||||
|
|
||||||
|
retryAt, err := http.ParseTime(value)
|
||||||
|
if err != nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
delay := retryAt.Sub(now)
|
||||||
|
if delay <= 0 {
|
||||||
|
return 0, true
|
||||||
|
}
|
||||||
|
if delay > maxRetryAfterDelay {
|
||||||
|
return maxRetryAfterDelay, true
|
||||||
|
}
|
||||||
|
return delay, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func envOr(key, fallback string) string {
|
||||||
|
value := core.Env(key)
|
||||||
|
if value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
func apiKeyFromEnvironment() (string, error) {
|
||||||
|
if apiKey := core.Trim(core.Env("CORE_BRAIN_KEY")); apiKey != "" {
|
||||||
|
return apiKey, nil
|
||||||
|
}
|
||||||
|
home := core.Env("HOME")
|
||||||
|
if home == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
apiKey, err := readBrainKeyFile(brainKeyPath(home))
|
||||||
|
if err != nil {
|
||||||
|
if core.Is(err, fs.ErrNotExist) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return apiKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func brainKeyPath(home string) string {
|
||||||
|
return core.JoinPath(home, ".claude", "brain.key")
|
||||||
|
}
|
||||||
|
|
||||||
|
func readBrainKeyFile(path string) (string, error) {
|
||||||
|
info, err := coreio.Local.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if brainKeyModeInsecure(info.Mode().Perm()) {
|
||||||
|
return "", core.E("brain.client", "brain.key has insecure permissions, expected 0600", nil)
|
||||||
|
}
|
||||||
|
data, err := coreio.Local.Read(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return core.Trim(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeBrainKeyFile(path, apiKey string) error {
|
||||||
|
if err := coreio.Local.WriteMode(path, core.Trim(apiKey)+"\n", brainKeyFileMode); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := os.Chmod(path, brainKeyFileMode); err != nil {
|
||||||
|
return core.E("brain.client", "chmod brain.key", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func brainKeyModeInsecure(mode fs.FileMode) bool {
|
||||||
|
return mode.Perm()&^brainKeyFileMode != 0
|
||||||
|
}
|
||||||
595
pkg/mcp/brain/client/client_test.go
Normal file
595
pkg/mcp/brain/client/client_test.go
Normal file
|
|
@ -0,0 +1,595 @@
|
||||||
|
// SPDX-License-Identifier: EUPL-1.2
|
||||||
|
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClientRemember_Good_SendsOrgAndAuth(t *testing.T) {
|
||||||
|
var gotBody map[string]any
|
||||||
|
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
t.Fatalf("expected POST, got %s", r.Method)
|
||||||
|
}
|
||||||
|
if r.URL.Path != "/v1/brain/remember" {
|
||||||
|
t.Fatalf("expected /v1/brain/remember, got %s", r.URL.Path)
|
||||||
|
}
|
||||||
|
if r.Header.Get("Authorization") != "Bearer test-key" {
|
||||||
|
t.Fatalf("expected bearer token, got %q", r.Header.Get("Authorization"))
|
||||||
|
}
|
||||||
|
gotBody = readRequestBody(t, r)
|
||||||
|
writeJSON(t, w, http.StatusOK, map[string]any{"id": "mem-1"})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
c := New(Options{
|
||||||
|
URL: server.URL,
|
||||||
|
Key: "test-key",
|
||||||
|
Org: "core",
|
||||||
|
AgentID: "codex",
|
||||||
|
HTTPClient: server.Client(),
|
||||||
|
MaxAttempts: 1,
|
||||||
|
})
|
||||||
|
result, err := c.Remember(context.Background(), RememberInput{
|
||||||
|
Content: "remember org",
|
||||||
|
Type: "decision",
|
||||||
|
Project: "mcp",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Remember failed: %v", err)
|
||||||
|
}
|
||||||
|
if result["id"] != "mem-1" {
|
||||||
|
t.Fatalf("expected id mem-1, got %v", result["id"])
|
||||||
|
}
|
||||||
|
if gotBody["org"] != "core" {
|
||||||
|
t.Fatalf("expected org=core, got %v", gotBody["org"])
|
||||||
|
}
|
||||||
|
if gotBody["project"] != "mcp" {
|
||||||
|
t.Fatalf("expected project=mcp, got %v", gotBody["project"])
|
||||||
|
}
|
||||||
|
if gotBody["agent_id"] != "codex" {
|
||||||
|
t.Fatalf("expected agent_id=codex, got %v", gotBody["agent_id"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientList_Good_SendsOrgURLParam(t *testing.T) {
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
t.Fatalf("expected GET, got %s", r.Method)
|
||||||
|
}
|
||||||
|
if r.URL.Path != "/v1/brain/list" {
|
||||||
|
t.Fatalf("expected /v1/brain/list, got %s", r.URL.Path)
|
||||||
|
}
|
||||||
|
if got := r.URL.Query().Get("org"); got != "core" {
|
||||||
|
t.Fatalf("expected org=core, got %q", got)
|
||||||
|
}
|
||||||
|
if got := r.URL.Query().Get("project"); got != "mcp" {
|
||||||
|
t.Fatalf("expected project=mcp, got %q", got)
|
||||||
|
}
|
||||||
|
if got := r.URL.Query().Get("limit"); got != "50" {
|
||||||
|
t.Fatalf("expected default limit=50, got %q", got)
|
||||||
|
}
|
||||||
|
writeJSON(t, w, http.StatusOK, map[string]any{"memories": []any{}})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
c := New(Options{URL: server.URL, Key: "test-key", Org: "core", HTTPClient: server.Client(), MaxAttempts: 1})
|
||||||
|
if _, err := c.List(context.Background(), ListInput{Project: "mcp"}); err != nil {
|
||||||
|
t.Fatalf("List failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientCall_Good_BuildsRequestAgainstAPIURL(t *testing.T) {
|
||||||
|
gotHost := ""
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
t.Fatalf("expected POST, got %s", r.Method)
|
||||||
|
}
|
||||||
|
if r.URL.Path != "/v1/brain/remember" {
|
||||||
|
t.Fatalf("expected /v1/brain/remember, got %s", r.URL.Path)
|
||||||
|
}
|
||||||
|
gotHost = r.Host
|
||||||
|
writeJSON(t, w, http.StatusOK, map[string]any{"id": "mem-1"})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
c := New(Options{
|
||||||
|
URL: server.URL,
|
||||||
|
Key: "test-key",
|
||||||
|
HTTPClient: server.Client(),
|
||||||
|
MaxAttempts: 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
result, err := c.Call(context.Background(), http.MethodPost, "/v1/brain/remember", map[string]any{"content": "safe"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Call failed: %v", err)
|
||||||
|
}
|
||||||
|
if result["id"] != "mem-1" {
|
||||||
|
t.Fatalf("expected id mem-1, got %v", result["id"])
|
||||||
|
}
|
||||||
|
if gotHost != strings.TrimPrefix(server.URL, "https://") {
|
||||||
|
t.Fatalf("expected host %s, got %s", strings.TrimPrefix(server.URL, "https://"), gotHost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientCall_Bad_RejectsAbsoluteRequestURL(t *testing.T) {
|
||||||
|
for _, requestPath := range []string{"http://attacker.com/leak", "https://attacker.com/leak"} {
|
||||||
|
t.Run(requestPath, func(t *testing.T) {
|
||||||
|
calls := 0
|
||||||
|
c := New(Options{
|
||||||
|
URL: "https://brain.test",
|
||||||
|
Key: "test-key",
|
||||||
|
HTTPClient: &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||||
|
calls++
|
||||||
|
return nil, core.E("test", "unexpected HTTP request", nil)
|
||||||
|
})},
|
||||||
|
MaxAttempts: 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := c.Call(context.Background(), http.MethodPost, requestPath, map[string]any{"content": "leak"})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected absolute URL error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "absolute request URL rejected") {
|
||||||
|
t.Fatalf("expected absolute URL rejection, got %v", err)
|
||||||
|
}
|
||||||
|
if calls != 0 {
|
||||||
|
t.Fatalf("expected no HTTP requests, got %d", calls)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientNew_Bad_RejectsHTTPAPIURLWithoutInsecureEnv(t *testing.T) {
|
||||||
|
t.Setenv(insecureBrainEnv, "")
|
||||||
|
|
||||||
|
c := New(Options{URL: "http://internal/", Key: "test-key"})
|
||||||
|
if c.configErr == nil {
|
||||||
|
t.Fatal("expected insecure HTTP API URL to be rejected")
|
||||||
|
}
|
||||||
|
if !strings.Contains(c.configErr.Error(), "API URL must use https unless CORE_BRAIN_INSECURE=true") {
|
||||||
|
t.Fatalf("expected insecure API URL error, got %v", c.configErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientNew_Good_AllowsHTTPAPIURLWithInsecureEnv(t *testing.T) {
|
||||||
|
t.Setenv(insecureBrainEnv, "true")
|
||||||
|
|
||||||
|
c := New(Options{URL: "http://internal/", Key: "test-key"})
|
||||||
|
if c.configErr != nil {
|
||||||
|
t.Fatalf("expected insecure HTTP API URL to be allowed, got %v", c.configErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientCall_Good_Retries503ThenSucceeds(t *testing.T) {
|
||||||
|
attempts := 0
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attempts++
|
||||||
|
if attempts == 1 {
|
||||||
|
writeJSON(t, w, http.StatusServiceUnavailable, map[string]any{"error": "down"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(t, w, http.StatusOK, map[string]any{"memories": []any{}})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
c := New(Options{
|
||||||
|
URL: server.URL,
|
||||||
|
Key: "test-key",
|
||||||
|
HTTPClient: server.Client(),
|
||||||
|
MaxAttempts: 3,
|
||||||
|
BaseDelay: time.Nanosecond,
|
||||||
|
})
|
||||||
|
if _, err := c.Recall(context.Background(), RecallInput{Query: "retry"}); err != nil {
|
||||||
|
t.Fatalf("Recall failed after retry: %v", err)
|
||||||
|
}
|
||||||
|
if attempts != 2 {
|
||||||
|
t.Fatalf("expected 2 attempts, got %d", attempts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientCall_Good_Retries408ThenSucceeds(t *testing.T) {
|
||||||
|
attempts := 0
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attempts++
|
||||||
|
if attempts == 1 {
|
||||||
|
writeJSON(t, w, http.StatusRequestTimeout, map[string]any{"error": "timeout"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(t, w, http.StatusOK, map[string]any{"memories": []any{}})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
c := New(Options{
|
||||||
|
URL: server.URL,
|
||||||
|
Key: "test-key",
|
||||||
|
HTTPClient: server.Client(),
|
||||||
|
MaxAttempts: 3,
|
||||||
|
BaseDelay: time.Nanosecond,
|
||||||
|
})
|
||||||
|
if _, err := c.Recall(context.Background(), RecallInput{Query: "retry"}); err != nil {
|
||||||
|
t.Fatalf("Recall failed after retry: %v", err)
|
||||||
|
}
|
||||||
|
if attempts != 2 {
|
||||||
|
t.Fatalf("expected 2 attempts, got %d", attempts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientCall_Good_Retries429ThenSucceeds(t *testing.T) {
|
||||||
|
attempts := 0
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attempts++
|
||||||
|
if attempts == 1 {
|
||||||
|
writeJSON(t, w, http.StatusTooManyRequests, map[string]any{"error": "rate limited"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(t, w, http.StatusOK, map[string]any{"memories": []any{}})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
c := New(Options{
|
||||||
|
URL: server.URL,
|
||||||
|
Key: "test-key",
|
||||||
|
HTTPClient: server.Client(),
|
||||||
|
MaxAttempts: 3,
|
||||||
|
BaseDelay: time.Nanosecond,
|
||||||
|
})
|
||||||
|
if _, err := c.Recall(context.Background(), RecallInput{Query: "retry"}); err != nil {
|
||||||
|
t.Fatalf("Recall failed after retry: %v", err)
|
||||||
|
}
|
||||||
|
if attempts != 2 {
|
||||||
|
t.Fatalf("expected 2 attempts, got %d", attempts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientCall_Good_Retries429UsingRetryAfterSeconds(t *testing.T) {
|
||||||
|
attempts := 0
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attempts++
|
||||||
|
if attempts == 1 {
|
||||||
|
w.Header().Set("Retry-After", "2")
|
||||||
|
writeJSON(t, w, http.StatusTooManyRequests, map[string]any{"error": "rate limited"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(t, w, http.StatusOK, map[string]any{"memories": []any{}})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
c := New(Options{
|
||||||
|
URL: server.URL,
|
||||||
|
Key: "test-key",
|
||||||
|
HTTPClient: server.Client(),
|
||||||
|
MaxAttempts: 3,
|
||||||
|
BaseDelay: time.Nanosecond,
|
||||||
|
})
|
||||||
|
sleeps := []time.Duration{}
|
||||||
|
c.sleepFunc = func(ctx context.Context, delay time.Duration) error {
|
||||||
|
sleeps = append(sleeps, delay)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := c.Recall(context.Background(), RecallInput{Query: "retry"}); err != nil {
|
||||||
|
t.Fatalf("Recall failed after retry: %v", err)
|
||||||
|
}
|
||||||
|
if attempts != 2 {
|
||||||
|
t.Fatalf("expected 2 attempts, got %d", attempts)
|
||||||
|
}
|
||||||
|
if len(sleeps) != 1 {
|
||||||
|
t.Fatalf("expected one retry sleep, got %d", len(sleeps))
|
||||||
|
}
|
||||||
|
if sleeps[0] != 2*time.Second {
|
||||||
|
t.Fatalf("expected Retry-After sleep of 2s, got %v", sleeps[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientSleep_Good_AppliesJitterAcrossClients(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
c1 := New(Options{URL: "https://brain.test", Key: "test-key", BaseDelay: 10 * time.Second})
|
||||||
|
c2 := New(Options{URL: "https://brain.test", Key: "test-key", BaseDelay: 10 * time.Second})
|
||||||
|
|
||||||
|
var delay1 time.Duration
|
||||||
|
var delay2 time.Duration
|
||||||
|
c1.sleepFunc = func(ctx context.Context, delay time.Duration) error {
|
||||||
|
delay1 = delay
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
c2.sleepFunc = func(ctx context.Context, delay time.Duration) error {
|
||||||
|
delay2 = delay
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
if err := c1.sleep(ctx, 3); err != nil {
|
||||||
|
t.Fatalf("first client sleep failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := c2.sleep(ctx, 3); err != nil {
|
||||||
|
t.Fatalf("second client sleep failed: %v", err)
|
||||||
|
}
|
||||||
|
if delay1 < 0 || delay1 > maxBackoffDelay {
|
||||||
|
t.Fatalf("first client delay out of range: %v", delay1)
|
||||||
|
}
|
||||||
|
if delay2 < 0 || delay2 > maxBackoffDelay {
|
||||||
|
t.Fatalf("second client delay out of range: %v", delay2)
|
||||||
|
}
|
||||||
|
if delay1 != delay2 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Fatalf("expected jitter to produce different delays for two clients, both got %v", delay1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJitteredBackoffDelay_Good_CapsHighAttempt(t *testing.T) {
|
||||||
|
if limit := backoffDelayLimit(defaultBaseDelay, 20); limit != maxBackoffDelay {
|
||||||
|
t.Fatalf("expected high-attempt backoff limit %v, got %v", maxBackoffDelay, limit)
|
||||||
|
}
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
if delay := jitteredBackoffDelay(defaultBaseDelay, 20); delay < 0 || delay > maxBackoffDelay {
|
||||||
|
t.Fatalf("expected high-attempt jitter <= %v, got %v", maxBackoffDelay, delay)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJitteredBackoffDelay_Good_UsesFullJitterRange(t *testing.T) {
|
||||||
|
limit := 800 * time.Millisecond
|
||||||
|
if got := backoffDelayLimit(100*time.Millisecond, 3); got != limit {
|
||||||
|
t.Fatalf("expected attempt 3 backoff limit %v, got %v", limit, got)
|
||||||
|
}
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
if delay := jitteredBackoffDelay(100*time.Millisecond, 3); delay < 0 || delay > limit {
|
||||||
|
t.Fatalf("expected jitter in [0, %v], got %v", limit, delay)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientCall_Good_Retries429WithPastRetryAfterDateWithoutNegativeSleep(t *testing.T) {
|
||||||
|
attempts := 0
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attempts++
|
||||||
|
if attempts == 1 {
|
||||||
|
w.Header().Set("Retry-After", "Wed, 21 Oct 2015 07:28:00 GMT")
|
||||||
|
writeJSON(t, w, http.StatusTooManyRequests, map[string]any{"error": "rate limited"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(t, w, http.StatusOK, map[string]any{"memories": []any{}})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
c := New(Options{
|
||||||
|
URL: server.URL,
|
||||||
|
Key: "test-key",
|
||||||
|
HTTPClient: server.Client(),
|
||||||
|
MaxAttempts: 3,
|
||||||
|
BaseDelay: time.Nanosecond,
|
||||||
|
})
|
||||||
|
sleeps := []time.Duration{}
|
||||||
|
c.sleepFunc = func(ctx context.Context, delay time.Duration) error {
|
||||||
|
sleeps = append(sleeps, delay)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := c.Recall(context.Background(), RecallInput{Query: "retry"}); err != nil {
|
||||||
|
t.Fatalf("Recall failed after retry: %v", err)
|
||||||
|
}
|
||||||
|
if attempts != 2 {
|
||||||
|
t.Fatalf("expected 2 attempts, got %d", attempts)
|
||||||
|
}
|
||||||
|
if len(sleeps) != 1 {
|
||||||
|
t.Fatalf("expected one retry sleep, got %d", len(sleeps))
|
||||||
|
}
|
||||||
|
if sleeps[0] != 0 {
|
||||||
|
t.Fatalf("expected past Retry-After date to sleep zero, got %v", sleeps[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientCall_Good_CapsRetryAfterDelay(t *testing.T) {
|
||||||
|
attempts := 0
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attempts++
|
||||||
|
if attempts == 1 {
|
||||||
|
w.Header().Set("Retry-After", "9999")
|
||||||
|
writeJSON(t, w, http.StatusServiceUnavailable, map[string]any{"error": "down"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(t, w, http.StatusOK, map[string]any{"memories": []any{}})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
c := New(Options{
|
||||||
|
URL: server.URL,
|
||||||
|
Key: "test-key",
|
||||||
|
HTTPClient: server.Client(),
|
||||||
|
MaxAttempts: 3,
|
||||||
|
BaseDelay: time.Nanosecond,
|
||||||
|
})
|
||||||
|
sleeps := []time.Duration{}
|
||||||
|
c.sleepFunc = func(ctx context.Context, delay time.Duration) error {
|
||||||
|
sleeps = append(sleeps, delay)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := c.Recall(context.Background(), RecallInput{Query: "retry"}); err != nil {
|
||||||
|
t.Fatalf("Recall failed after retry: %v", err)
|
||||||
|
}
|
||||||
|
if attempts != 2 {
|
||||||
|
t.Fatalf("expected 2 attempts, got %d", attempts)
|
||||||
|
}
|
||||||
|
if len(sleeps) != 1 {
|
||||||
|
t.Fatalf("expected one retry sleep, got %d", len(sleeps))
|
||||||
|
}
|
||||||
|
if sleeps[0] != maxRetryAfterDelay {
|
||||||
|
t.Fatalf("expected capped Retry-After sleep of %v, got %v", maxRetryAfterDelay, sleeps[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientCall_Bad_DoesNotRetry400(t *testing.T) {
|
||||||
|
attempts := 0
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attempts++
|
||||||
|
writeJSON(t, w, http.StatusBadRequest, map[string]any{"error": "bad request"})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
c := New(Options{
|
||||||
|
URL: server.URL,
|
||||||
|
Key: "test-key",
|
||||||
|
HTTPClient: server.Client(),
|
||||||
|
MaxAttempts: 3,
|
||||||
|
BaseDelay: time.Nanosecond,
|
||||||
|
})
|
||||||
|
if _, err := c.Recall(context.Background(), RecallInput{Query: "bad"}); err == nil {
|
||||||
|
t.Fatal("expected 400 error")
|
||||||
|
}
|
||||||
|
if attempts != 1 {
|
||||||
|
t.Fatalf("expected one attempt for 400, got %d", attempts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientCall_Bad_Continuous503OpensCircuit(t *testing.T) {
|
||||||
|
attempts := 0
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attempts++
|
||||||
|
writeJSON(t, w, http.StatusServiceUnavailable, map[string]any{"error": "down"})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
breaker := NewCircuitBreaker(CircuitBreakerOptions{
|
||||||
|
FailureThreshold: 3,
|
||||||
|
SuccessThreshold: 1,
|
||||||
|
Cooldown: time.Hour,
|
||||||
|
})
|
||||||
|
c := New(Options{
|
||||||
|
URL: server.URL,
|
||||||
|
Key: "test-key",
|
||||||
|
HTTPClient: server.Client(),
|
||||||
|
MaxAttempts: 3,
|
||||||
|
BaseDelay: time.Nanosecond,
|
||||||
|
CircuitBreaker: breaker,
|
||||||
|
})
|
||||||
|
|
||||||
|
if _, err := c.Recall(context.Background(), RecallInput{Query: "down"}); err == nil {
|
||||||
|
t.Fatal("expected 503 error")
|
||||||
|
}
|
||||||
|
if breaker.State() != CircuitOpen {
|
||||||
|
t.Fatalf("expected circuit open, got %s", breaker.State())
|
||||||
|
}
|
||||||
|
if _, err := c.Recall(context.Background(), RecallInput{Query: "down"}); !core.Is(err, ErrCircuitOpen) {
|
||||||
|
t.Fatalf("expected ErrCircuitOpen, got %v", err)
|
||||||
|
}
|
||||||
|
if attempts != 3 {
|
||||||
|
t.Fatalf("expected no network attempt after circuit open, got %d attempts", attempts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientCall_Bad_ContextCancellation(t *testing.T) {
|
||||||
|
attempts := 0
|
||||||
|
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attempts++
|
||||||
|
writeJSON(t, w, http.StatusOK, map[string]any{"ok": true})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
c := New(Options{URL: server.URL, Key: "test-key", HTTPClient: server.Client(), MaxAttempts: 3})
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if _, err := c.Recall(ctx, RecallInput{Query: "cancelled"}); !core.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("expected context.Canceled, got %v", err)
|
||||||
|
}
|
||||||
|
if attempts != 0 {
|
||||||
|
t.Fatalf("expected cancelled request to avoid network, got %d attempts", attempts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteBrainKey_Good_Uses0600(t *testing.T) {
|
||||||
|
home := t.TempDir()
|
||||||
|
path := filepath.Join(home, ".claude", "brain.key")
|
||||||
|
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||||
|
t.Fatalf("create fixture dir: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(path, []byte("old-key\n"), 0o644); err != nil {
|
||||||
|
t.Fatalf("write fixture: %v", err)
|
||||||
|
}
|
||||||
|
t.Setenv("HOME", home)
|
||||||
|
|
||||||
|
if err := WriteBrainKey("test-key"); err != nil {
|
||||||
|
t.Fatalf("WriteBrainKey failed: %v", err)
|
||||||
|
}
|
||||||
|
info, err := os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("stat brain key: %v", err)
|
||||||
|
}
|
||||||
|
if got := info.Mode().Perm(); got != brainKeyFileMode {
|
||||||
|
t.Fatalf("expected brain.key mode %v, got %v", brainKeyFileMode, got)
|
||||||
|
}
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read brain key: %v", err)
|
||||||
|
}
|
||||||
|
if got := string(data); got != "test-key\n" {
|
||||||
|
t.Fatalf("expected written key, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrainKeyFile_Bad_RejectsInsecurePermissions(t *testing.T) {
|
||||||
|
path := filepath.Join(t.TempDir(), "brain.key")
|
||||||
|
if err := os.WriteFile(path, []byte("test-key\n"), brainKeyFileMode); err != nil {
|
||||||
|
t.Fatalf("write fixture: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.Chmod(path, 0o644); err != nil {
|
||||||
|
t.Fatalf("chmod fixture: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := readBrainKeyFile(path); err == nil {
|
||||||
|
t.Fatal("expected insecure permissions error")
|
||||||
|
} else if !strings.Contains(err.Error(), "brain.key has insecure permissions, expected 0600") {
|
||||||
|
t.Fatalf("expected insecure permissions error, got %v", err)
|
||||||
|
}
|
||||||
|
info, err := os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("stat brain key: %v", err)
|
||||||
|
}
|
||||||
|
if got := info.Mode().Perm(); got != 0o644 {
|
||||||
|
t.Fatalf("read should not chmod brain.key, got mode %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func readRequestBody(t *testing.T, r *http.Request) map[string]any {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
readResult := core.ReadAll(r.Body)
|
||||||
|
if !readResult.OK {
|
||||||
|
t.Fatalf("failed to read body: %v", readResult.Value)
|
||||||
|
}
|
||||||
|
body := map[string]any{}
|
||||||
|
if decodeResult := core.JSONUnmarshalString(readResult.Value.(string), &body); !decodeResult.OK {
|
||||||
|
t.Fatalf("failed to decode body: %v", decodeResult.Value)
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeJSON(t *testing.T, w http.ResponseWriter, status int, payload any) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
if _, err := w.Write([]byte(core.JSONMarshalString(payload))); err != nil {
|
||||||
|
t.Fatalf("failed to write response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (fn roundTripFunc) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||||
|
return fn(request)
|
||||||
|
}
|
||||||
|
|
@ -3,20 +3,12 @@
|
||||||
package brain
|
package brain
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
goio "io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
coremcp "dappco.re/go/mcp/pkg/mcp"
|
coremcp "dappco.re/go/mcp/pkg/mcp"
|
||||||
coreio "forge.lthn.ai/core/go-io"
|
brainclient "dappco.re/go/mcp/pkg/mcp/brain/client"
|
||||||
coreerr "forge.lthn.ai/core/go-log"
|
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -29,9 +21,7 @@ type channelSender func(ctx context.Context, channel string, data any)
|
||||||
// Unlike Subsystem (which uses the IDE WebSocket bridge), this calls the
|
// Unlike Subsystem (which uses the IDE WebSocket bridge), this calls the
|
||||||
// Laravel API directly — suitable for standalone core-mcp usage.
|
// Laravel API directly — suitable for standalone core-mcp usage.
|
||||||
type DirectSubsystem struct {
|
type DirectSubsystem struct {
|
||||||
apiURL string
|
apiClient *brainclient.Client
|
||||||
apiKey string
|
|
||||||
client *http.Client
|
|
||||||
onChannel channelSender
|
onChannel channelSender
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -58,23 +48,17 @@ func (s *DirectSubsystem) OnChannel(fn func(ctx context.Context, channel string,
|
||||||
// Reads CORE_BRAIN_URL and CORE_BRAIN_KEY from environment, or falls back
|
// Reads CORE_BRAIN_URL and CORE_BRAIN_KEY from environment, or falls back
|
||||||
// to ~/.claude/brain.key for the API key.
|
// to ~/.claude/brain.key for the API key.
|
||||||
func NewDirect() *DirectSubsystem {
|
func NewDirect() *DirectSubsystem {
|
||||||
apiURL := os.Getenv("CORE_BRAIN_URL")
|
return NewDirectWithClient(brainclient.NewFromEnvironment())
|
||||||
if apiURL == "" {
|
}
|
||||||
apiURL = "https://api.lthn.sh"
|
|
||||||
}
|
|
||||||
|
|
||||||
apiKey := os.Getenv("CORE_BRAIN_KEY")
|
// NewDirectWithClient creates a direct brain subsystem using the shared client.
|
||||||
if apiKey == "" {
|
//
|
||||||
if data, err := coreio.Local.Read(os.ExpandEnv("$HOME/.claude/brain.key")); err == nil {
|
// brain := NewDirectWithClient(client.New(client.Options{URL: "http://127.0.0.1:8080", Key: "test"}))
|
||||||
apiKey = strings.TrimSpace(data)
|
func NewDirectWithClient(apiClient *brainclient.Client) *DirectSubsystem {
|
||||||
}
|
if apiClient == nil {
|
||||||
}
|
apiClient = brainclient.NewFromEnvironment()
|
||||||
|
|
||||||
return &DirectSubsystem{
|
|
||||||
apiURL: apiURL,
|
|
||||||
apiKey: apiKey,
|
|
||||||
client: &http.Client{Timeout: 30 * time.Second},
|
|
||||||
}
|
}
|
||||||
|
return &DirectSubsystem{apiClient: apiClient}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Name implements mcp.Subsystem.
|
// Name implements mcp.Subsystem.
|
||||||
|
|
@ -100,7 +84,7 @@ func (s *DirectSubsystem) RegisterTools(svc *coremcp.Service) {
|
||||||
|
|
||||||
coremcp.AddToolRecorded(svc, server, "brain", &mcp.Tool{
|
coremcp.AddToolRecorded(svc, server, "brain", &mcp.Tool{
|
||||||
Name: "brain_list",
|
Name: "brain_list",
|
||||||
Description: "List memories in OpenBrain with optional filtering by project, type, and agent.",
|
Description: "List memories in OpenBrain with optional filtering by org, project, type, and agent.",
|
||||||
}, s.list)
|
}, s.list)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -108,57 +92,19 @@ func (s *DirectSubsystem) RegisterTools(svc *coremcp.Service) {
|
||||||
func (s *DirectSubsystem) Shutdown(_ context.Context) error { return nil }
|
func (s *DirectSubsystem) Shutdown(_ context.Context) error { return nil }
|
||||||
|
|
||||||
func (s *DirectSubsystem) apiCall(ctx context.Context, method, path string, body any) (map[string]any, error) {
|
func (s *DirectSubsystem) apiCall(ctx context.Context, method, path string, body any) (map[string]any, error) {
|
||||||
if s.apiKey == "" {
|
return s.client().Call(ctx, method, path, body)
|
||||||
return nil, coreerr.E("brain.apiCall", "no API key (set CORE_BRAIN_KEY or create ~/.claude/brain.key)", nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
var reqBody goio.Reader
|
|
||||||
if body != nil {
|
|
||||||
data, err := json.Marshal(body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, coreerr.E("brain.apiCall", "marshal request", err)
|
|
||||||
}
|
|
||||||
reqBody = bytes.NewReader(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, method, s.apiURL+path, reqBody)
|
|
||||||
if err != nil {
|
|
||||||
return nil, coreerr.E("brain.apiCall", "create request", err)
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Accept", "application/json")
|
|
||||||
req.Header.Set("Authorization", "Bearer "+s.apiKey)
|
|
||||||
|
|
||||||
resp, err := s.client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, coreerr.E("brain.apiCall", "API call failed", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
respData, err := goio.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, coreerr.E("brain.apiCall", "read response", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode >= 400 {
|
|
||||||
return nil, coreerr.E("brain.apiCall", "API returned "+string(respData), nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
var result map[string]any
|
|
||||||
if err := json.Unmarshal(respData, &result); err != nil {
|
|
||||||
return nil, coreerr.E("brain.apiCall", "parse response", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DirectSubsystem) remember(ctx context.Context, _ *mcp.CallToolRequest, input RememberInput) (*mcp.CallToolResult, RememberOutput, error) {
|
func (s *DirectSubsystem) remember(ctx context.Context, _ *mcp.CallToolRequest, input RememberInput) (*mcp.CallToolResult, RememberOutput, error) {
|
||||||
result, err := s.apiCall(ctx, "POST", "/v1/brain/remember", map[string]any{
|
result, err := s.client().Remember(ctx, brainclient.RememberInput{
|
||||||
"content": input.Content,
|
Content: input.Content,
|
||||||
"type": input.Type,
|
Type: input.Type,
|
||||||
"tags": input.Tags,
|
Tags: input.Tags,
|
||||||
"project": input.Project,
|
Org: input.Org,
|
||||||
"agent_id": "cladius",
|
Project: input.Project,
|
||||||
|
Confidence: input.Confidence,
|
||||||
|
Supersedes: input.Supersedes,
|
||||||
|
ExpiresIn: input.ExpiresIn,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, RememberOutput{}, err
|
return nil, RememberOutput{}, err
|
||||||
|
|
@ -168,6 +114,7 @@ func (s *DirectSubsystem) remember(ctx context.Context, _ *mcp.CallToolRequest,
|
||||||
if s.onChannel != nil {
|
if s.onChannel != nil {
|
||||||
s.onChannel(ctx, coremcp.ChannelBrainRememberDone, map[string]any{
|
s.onChannel(ctx, coremcp.ChannelBrainRememberDone, map[string]any{
|
||||||
"id": id,
|
"id": id,
|
||||||
|
"org": input.Org,
|
||||||
"type": input.Type,
|
"type": input.Type,
|
||||||
"project": input.Project,
|
"project": input.Project,
|
||||||
})
|
})
|
||||||
|
|
@ -180,54 +127,26 @@ func (s *DirectSubsystem) remember(ctx context.Context, _ *mcp.CallToolRequest,
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DirectSubsystem) recall(ctx context.Context, _ *mcp.CallToolRequest, input RecallInput) (*mcp.CallToolResult, RecallOutput, error) {
|
func (s *DirectSubsystem) recall(ctx context.Context, _ *mcp.CallToolRequest, input RecallInput) (*mcp.CallToolResult, RecallOutput, error) {
|
||||||
body := map[string]any{
|
result, err := s.client().Recall(ctx, brainclient.RecallInput{
|
||||||
"query": input.Query,
|
Query: input.Query,
|
||||||
"top_k": input.TopK,
|
TopK: input.TopK,
|
||||||
"agent_id": "cladius",
|
Org: input.Filter.Org,
|
||||||
}
|
Project: input.Filter.Project,
|
||||||
if input.Filter.Project != "" {
|
Type: input.Filter.Type,
|
||||||
body["project"] = input.Filter.Project
|
AgentID: input.Filter.AgentID,
|
||||||
}
|
MinConfidence: input.Filter.MinConfidence,
|
||||||
if input.Filter.Type != nil {
|
})
|
||||||
body["type"] = input.Filter.Type
|
|
||||||
}
|
|
||||||
if input.TopK == 0 {
|
|
||||||
body["top_k"] = 10
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := s.apiCall(ctx, "POST", "/v1/brain/recall", body)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, RecallOutput{}, err
|
return nil, RecallOutput{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var memories []Memory
|
memories := memoriesFromResult(result)
|
||||||
if mems, ok := result["memories"].([]any); ok {
|
|
||||||
for _, m := range mems {
|
|
||||||
if mm, ok := m.(map[string]any); ok {
|
|
||||||
mem := Memory{
|
|
||||||
Content: fmt.Sprintf("%v", mm["content"]),
|
|
||||||
Type: fmt.Sprintf("%v", mm["type"]),
|
|
||||||
Project: fmt.Sprintf("%v", mm["project"]),
|
|
||||||
AgentID: fmt.Sprintf("%v", mm["agent_id"]),
|
|
||||||
CreatedAt: fmt.Sprintf("%v", mm["created_at"]),
|
|
||||||
}
|
|
||||||
if id, ok := mm["id"].(string); ok {
|
|
||||||
mem.ID = id
|
|
||||||
}
|
|
||||||
if score, ok := mm["score"].(float64); ok {
|
|
||||||
mem.Confidence = score
|
|
||||||
}
|
|
||||||
if source, ok := mm["source"].(string); ok {
|
|
||||||
mem.Tags = append(mem.Tags, "source:"+source)
|
|
||||||
}
|
|
||||||
memories = append(memories, mem)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.onChannel != nil {
|
if s.onChannel != nil {
|
||||||
s.onChannel(ctx, coremcp.ChannelBrainRecallDone, map[string]any{
|
s.onChannel(ctx, coremcp.ChannelBrainRecallDone, map[string]any{
|
||||||
"query": input.Query,
|
"query": input.Query,
|
||||||
|
"org": input.Filter.Org,
|
||||||
|
"project": input.Filter.Project,
|
||||||
"count": len(memories),
|
"count": len(memories),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -239,7 +158,7 @@ func (s *DirectSubsystem) recall(ctx context.Context, _ *mcp.CallToolRequest, in
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DirectSubsystem) forget(ctx context.Context, _ *mcp.CallToolRequest, input ForgetInput) (*mcp.CallToolResult, ForgetOutput, error) {
|
func (s *DirectSubsystem) forget(ctx context.Context, _ *mcp.CallToolRequest, input ForgetInput) (*mcp.CallToolResult, ForgetOutput, error) {
|
||||||
_, err := s.apiCall(ctx, "DELETE", "/v1/brain/forget/"+input.ID, nil)
|
_, err := s.client().Forget(ctx, brainclient.ForgetInput{ID: input.ID, Reason: input.Reason})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ForgetOutput{}, err
|
return nil, ForgetOutput{}, err
|
||||||
}
|
}
|
||||||
|
|
@ -263,51 +182,22 @@ func (s *DirectSubsystem) list(ctx context.Context, _ *mcp.CallToolRequest, inpu
|
||||||
if limit == 0 {
|
if limit == 0 {
|
||||||
limit = 50
|
limit = 50
|
||||||
}
|
}
|
||||||
|
result, err := s.client().List(ctx, brainclient.ListInput{
|
||||||
values := url.Values{}
|
Org: input.Org,
|
||||||
if input.Project != "" {
|
Project: input.Project,
|
||||||
values.Set("project", input.Project)
|
Type: input.Type,
|
||||||
}
|
AgentID: input.AgentID,
|
||||||
if input.Type != "" {
|
Limit: limit,
|
||||||
values.Set("type", input.Type)
|
})
|
||||||
}
|
|
||||||
if input.AgentID != "" {
|
|
||||||
values.Set("agent_id", input.AgentID)
|
|
||||||
}
|
|
||||||
values.Set("limit", fmt.Sprintf("%d", limit))
|
|
||||||
|
|
||||||
result, err := s.apiCall(ctx, http.MethodGet, "/v1/brain/list?"+values.Encode(), nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ListOutput{}, err
|
return nil, ListOutput{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var memories []Memory
|
memories := memoriesFromResult(result)
|
||||||
if mems, ok := result["memories"].([]any); ok {
|
|
||||||
for _, m := range mems {
|
|
||||||
if mm, ok := m.(map[string]any); ok {
|
|
||||||
mem := Memory{
|
|
||||||
Content: fmt.Sprintf("%v", mm["content"]),
|
|
||||||
Type: fmt.Sprintf("%v", mm["type"]),
|
|
||||||
Project: fmt.Sprintf("%v", mm["project"]),
|
|
||||||
AgentID: fmt.Sprintf("%v", mm["agent_id"]),
|
|
||||||
CreatedAt: fmt.Sprintf("%v", mm["created_at"]),
|
|
||||||
}
|
|
||||||
if id, ok := mm["id"].(string); ok {
|
|
||||||
mem.ID = id
|
|
||||||
}
|
|
||||||
if score, ok := mm["score"].(float64); ok {
|
|
||||||
mem.Confidence = score
|
|
||||||
}
|
|
||||||
if source, ok := mm["source"].(string); ok {
|
|
||||||
mem.Tags = append(mem.Tags, "source:"+source)
|
|
||||||
}
|
|
||||||
memories = append(memories, mem)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.onChannel != nil {
|
if s.onChannel != nil {
|
||||||
s.onChannel(ctx, coremcp.ChannelBrainListDone, map[string]any{
|
s.onChannel(ctx, coremcp.ChannelBrainListDone, map[string]any{
|
||||||
|
"org": input.Org,
|
||||||
"project": input.Project,
|
"project": input.Project,
|
||||||
"type": input.Type,
|
"type": input.Type,
|
||||||
"agent_id": input.AgentID,
|
"agent_id": input.AgentID,
|
||||||
|
|
@ -321,3 +211,57 @@ func (s *DirectSubsystem) list(ctx context.Context, _ *mcp.CallToolRequest, inpu
|
||||||
Memories: memories,
|
Memories: memories,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DirectSubsystem) client() *brainclient.Client {
|
||||||
|
if s.apiClient == nil {
|
||||||
|
s.apiClient = brainclient.NewFromEnvironment()
|
||||||
|
}
|
||||||
|
return s.apiClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// memoriesFromResult extracts Memory entries from an API response map.
|
||||||
|
func memoriesFromResult(result map[string]any) []Memory {
|
||||||
|
var memories []Memory
|
||||||
|
mems, ok := result["memories"].([]any)
|
||||||
|
if !ok {
|
||||||
|
return memories
|
||||||
|
}
|
||||||
|
for _, m := range mems {
|
||||||
|
mm, ok := m.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
mem := Memory{
|
||||||
|
Content: stringFromMap(mm, "content"),
|
||||||
|
Type: stringFromMap(mm, "type"),
|
||||||
|
Org: stringFromMap(mm, "org"),
|
||||||
|
Project: stringFromMap(mm, "project"),
|
||||||
|
AgentID: stringFromMap(mm, "agent_id"),
|
||||||
|
CreatedAt: stringFromMap(mm, "created_at"),
|
||||||
|
}
|
||||||
|
if id, ok := mm["id"].(string); ok {
|
||||||
|
mem.ID = id
|
||||||
|
}
|
||||||
|
if score, ok := mm["score"].(float64); ok {
|
||||||
|
mem.Confidence = score
|
||||||
|
}
|
||||||
|
if source, ok := mm["source"].(string); ok {
|
||||||
|
mem.Tags = append(mem.Tags, "source:"+source)
|
||||||
|
}
|
||||||
|
memories = append(memories, mem)
|
||||||
|
}
|
||||||
|
return memories
|
||||||
|
}
|
||||||
|
|
||||||
|
// stringFromMap extracts a string value from a map, returning "" if missing or wrong type.
|
||||||
|
func stringFromMap(m map[string]any, key string) string {
|
||||||
|
v, ok := m[key]
|
||||||
|
if !ok || v == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
s, ok := v.(string)
|
||||||
|
if !ok {
|
||||||
|
return core.Sprintf("%v", v)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -8,14 +8,21 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
brainclient "dappco.re/go/mcp/pkg/mcp/brain/client"
|
||||||
)
|
)
|
||||||
|
|
||||||
// newTestDirect creates a DirectSubsystem pointing at a test server.
|
// newTestDirect creates a DirectSubsystem pointing at a test server.
|
||||||
func newTestDirect(url string) *DirectSubsystem {
|
func newTestDirect(url string) *DirectSubsystem {
|
||||||
return &DirectSubsystem{
|
return &DirectSubsystem{
|
||||||
apiURL: url,
|
apiClient: brainclient.New(brainclient.Options{
|
||||||
apiKey: "test-key",
|
URL: url,
|
||||||
client: http.DefaultClient,
|
Key: "test-key",
|
||||||
|
HTTPClient: http.DefaultClient,
|
||||||
|
MaxAttempts: 1,
|
||||||
|
BaseDelay: time.Nanosecond,
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -84,7 +91,12 @@ func TestApiCall_Good_GetNilBody(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestApiCall_Bad_NoApiKey(t *testing.T) {
|
func TestApiCall_Bad_NoApiKey(t *testing.T) {
|
||||||
s := &DirectSubsystem{apiKey: "", client: http.DefaultClient}
|
s := &DirectSubsystem{apiClient: brainclient.New(brainclient.Options{
|
||||||
|
URL: "http://example.test",
|
||||||
|
Key: "",
|
||||||
|
HTTPClient: http.DefaultClient,
|
||||||
|
MaxAttempts: 1,
|
||||||
|
})}
|
||||||
_, err := s.apiCall(context.Background(), "GET", "/test", nil)
|
_, err := s.apiCall(context.Background(), "GET", "/test", nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error when apiKey is empty")
|
t.Error("expected error when apiKey is empty")
|
||||||
|
|
@ -121,9 +133,12 @@ func TestApiCall_Bad_InvalidJson(t *testing.T) {
|
||||||
|
|
||||||
func TestApiCall_Bad_Unreachable(t *testing.T) {
|
func TestApiCall_Bad_Unreachable(t *testing.T) {
|
||||||
s := &DirectSubsystem{
|
s := &DirectSubsystem{
|
||||||
apiURL: "http://127.0.0.1:1", // nothing listening
|
apiClient: brainclient.New(brainclient.Options{
|
||||||
apiKey: "key",
|
URL: "http://127.0.0.1:1", // nothing listening
|
||||||
client: http.DefaultClient,
|
Key: "key",
|
||||||
|
HTTPClient: http.DefaultClient,
|
||||||
|
MaxAttempts: 1,
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
_, err := s.apiCall(context.Background(), "GET", "/test", nil)
|
_, err := s.apiCall(context.Background(), "GET", "/test", nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|
@ -143,6 +158,9 @@ func TestDirectRemember_Good(t *testing.T) {
|
||||||
if body["agent_id"] != "cladius" {
|
if body["agent_id"] != "cladius" {
|
||||||
t.Errorf("expected agent_id=cladius, got %v", body["agent_id"])
|
t.Errorf("expected agent_id=cladius, got %v", body["agent_id"])
|
||||||
}
|
}
|
||||||
|
if body["org"] != "core" {
|
||||||
|
t.Errorf("expected org=core, got %v", body["org"])
|
||||||
|
}
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
json.NewEncoder(w).Encode(map[string]any{"id": "mem-456"})
|
json.NewEncoder(w).Encode(map[string]any{"id": "mem-456"})
|
||||||
}))
|
}))
|
||||||
|
|
@ -152,6 +170,7 @@ func TestDirectRemember_Good(t *testing.T) {
|
||||||
_, out, err := s.remember(context.Background(), nil, RememberInput{
|
_, out, err := s.remember(context.Background(), nil, RememberInput{
|
||||||
Content: "test memory",
|
Content: "test memory",
|
||||||
Type: "observation",
|
Type: "observation",
|
||||||
|
Org: "core",
|
||||||
Project: "test-project",
|
Project: "test-project",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -188,6 +207,9 @@ func TestDirectRecall_Good(t *testing.T) {
|
||||||
if body["query"] != "scoring algorithm" {
|
if body["query"] != "scoring algorithm" {
|
||||||
t.Errorf("unexpected query: %v", body["query"])
|
t.Errorf("unexpected query: %v", body["query"])
|
||||||
}
|
}
|
||||||
|
if body["org"] != "core" {
|
||||||
|
t.Errorf("expected org=core, got %v", body["org"])
|
||||||
|
}
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
json.NewEncoder(w).Encode(map[string]any{
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
"memories": []any{
|
"memories": []any{
|
||||||
|
|
@ -195,6 +217,7 @@ func TestDirectRecall_Good(t *testing.T) {
|
||||||
"id": "mem-1",
|
"id": "mem-1",
|
||||||
"content": "scoring uses weighted average",
|
"content": "scoring uses weighted average",
|
||||||
"type": "architecture",
|
"type": "architecture",
|
||||||
|
"org": "core",
|
||||||
"project": "eaas",
|
"project": "eaas",
|
||||||
"agent_id": "virgil",
|
"agent_id": "virgil",
|
||||||
"score": 0.92,
|
"score": 0.92,
|
||||||
|
|
@ -209,7 +232,7 @@ func TestDirectRecall_Good(t *testing.T) {
|
||||||
_, out, err := s.recall(context.Background(), nil, RecallInput{
|
_, out, err := s.recall(context.Background(), nil, RecallInput{
|
||||||
Query: "scoring algorithm",
|
Query: "scoring algorithm",
|
||||||
TopK: 5,
|
TopK: 5,
|
||||||
Filter: RecallFilter{Project: "eaas"},
|
Filter: RecallFilter{Org: "core", Project: "eaas"},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("recall failed: %v", err)
|
t.Fatalf("recall failed: %v", err)
|
||||||
|
|
@ -220,6 +243,9 @@ func TestDirectRecall_Good(t *testing.T) {
|
||||||
if out.Memories[0].ID != "mem-1" {
|
if out.Memories[0].ID != "mem-1" {
|
||||||
t.Errorf("expected id=mem-1, got %q", out.Memories[0].ID)
|
t.Errorf("expected id=mem-1, got %q", out.Memories[0].ID)
|
||||||
}
|
}
|
||||||
|
if out.Memories[0].Org != "core" {
|
||||||
|
t.Errorf("expected org=core, got %q", out.Memories[0].Org)
|
||||||
|
}
|
||||||
if out.Memories[0].Confidence != 0.92 {
|
if out.Memories[0].Confidence != 0.92 {
|
||||||
t.Errorf("expected score=0.92, got %f", out.Memories[0].Confidence)
|
t.Errorf("expected score=0.92, got %f", out.Memories[0].Confidence)
|
||||||
}
|
}
|
||||||
|
|
@ -356,6 +382,9 @@ func TestDirectList_Good(t *testing.T) {
|
||||||
if got := r.URL.Query().Get("project"); got != "eaas" {
|
if got := r.URL.Query().Get("project"); got != "eaas" {
|
||||||
t.Errorf("expected project=eaas, got %q", got)
|
t.Errorf("expected project=eaas, got %q", got)
|
||||||
}
|
}
|
||||||
|
if got := r.URL.Query().Get("org"); got != "core" {
|
||||||
|
t.Errorf("expected org=core, got %q", got)
|
||||||
|
}
|
||||||
if got := r.URL.Query().Get("type"); got != "decision" {
|
if got := r.URL.Query().Get("type"); got != "decision" {
|
||||||
t.Errorf("expected type=decision, got %q", got)
|
t.Errorf("expected type=decision, got %q", got)
|
||||||
}
|
}
|
||||||
|
|
@ -372,6 +401,7 @@ func TestDirectList_Good(t *testing.T) {
|
||||||
"id": "mem-1",
|
"id": "mem-1",
|
||||||
"content": "use qdrant",
|
"content": "use qdrant",
|
||||||
"type": "decision",
|
"type": "decision",
|
||||||
|
"org": "core",
|
||||||
"project": "eaas",
|
"project": "eaas",
|
||||||
"agent_id": "virgil",
|
"agent_id": "virgil",
|
||||||
"score": 0.88,
|
"score": 0.88,
|
||||||
|
|
@ -384,6 +414,7 @@ func TestDirectList_Good(t *testing.T) {
|
||||||
|
|
||||||
s := newTestDirect(srv.URL)
|
s := newTestDirect(srv.URL)
|
||||||
_, out, err := s.list(context.Background(), nil, ListInput{
|
_, out, err := s.list(context.Background(), nil, ListInput{
|
||||||
|
Org: "core",
|
||||||
Project: "eaas",
|
Project: "eaas",
|
||||||
Type: "decision",
|
Type: "decision",
|
||||||
AgentID: "virgil",
|
AgentID: "virgil",
|
||||||
|
|
@ -401,6 +432,9 @@ func TestDirectList_Good(t *testing.T) {
|
||||||
if out.Memories[0].Confidence != 0.88 {
|
if out.Memories[0].Confidence != 0.88 {
|
||||||
t.Errorf("expected score=0.88, got %f", out.Memories[0].Confidence)
|
t.Errorf("expected score=0.88, got %f", out.Memories[0].Confidence)
|
||||||
}
|
}
|
||||||
|
if out.Memories[0].Org != "core" {
|
||||||
|
t.Errorf("expected org=core, got %q", out.Memories[0].Org)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDirectList_Good_EmitsAgentIDChannelPayload(t *testing.T) {
|
func TestDirectList_Good_EmitsAgentIDChannelPayload(t *testing.T) {
|
||||||
|
|
@ -422,6 +456,7 @@ func TestDirectList_Good_EmitsAgentIDChannelPayload(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
_, out, err := s.list(context.Background(), nil, ListInput{
|
_, out, err := s.list(context.Background(), nil, ListInput{
|
||||||
|
Org: "core",
|
||||||
Project: "eaas",
|
Project: "eaas",
|
||||||
Type: "decision",
|
Type: "decision",
|
||||||
AgentID: "virgil",
|
AgentID: "virgil",
|
||||||
|
|
@ -445,6 +480,9 @@ func TestDirectList_Good_EmitsAgentIDChannelPayload(t *testing.T) {
|
||||||
if gotPayload["project"] != "eaas" {
|
if gotPayload["project"] != "eaas" {
|
||||||
t.Fatalf("expected project=eaas, got %v", gotPayload["project"])
|
t.Fatalf("expected project=eaas, got %v", gotPayload["project"])
|
||||||
}
|
}
|
||||||
|
if gotPayload["org"] != "core" {
|
||||||
|
t.Fatalf("expected org=core, got %v", gotPayload["org"])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDirectList_Good_DefaultLimit(t *testing.T) {
|
func TestDirectList_Good_DefaultLimit(t *testing.T) {
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,11 @@ package brain
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"dappco.re/go/api"
|
||||||
|
"dappco.re/go/core/api/pkg/provider"
|
||||||
|
"dappco.re/go/ws"
|
||||||
coremcp "dappco.re/go/mcp/pkg/mcp"
|
coremcp "dappco.re/go/mcp/pkg/mcp"
|
||||||
"dappco.re/go/mcp/pkg/mcp/ide"
|
"dappco.re/go/mcp/pkg/mcp/ide"
|
||||||
"forge.lthn.ai/core/api"
|
|
||||||
"forge.lthn.ai/core/api/pkg/provider"
|
|
||||||
"forge.lthn.ai/core/go-ws"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -91,6 +91,7 @@ func (p *BrainProvider) Describe() []api.RouteDescription {
|
||||||
"content": map[string]any{"type": "string"},
|
"content": map[string]any{"type": "string"},
|
||||||
"type": map[string]any{"type": "string"},
|
"type": map[string]any{"type": "string"},
|
||||||
"tags": map[string]any{"type": "array", "items": map[string]any{"type": "string"}},
|
"tags": map[string]any{"type": "array", "items": map[string]any{"type": "string"}},
|
||||||
|
"org": map[string]any{"type": "string"},
|
||||||
"project": map[string]any{"type": "string"},
|
"project": map[string]any{"type": "string"},
|
||||||
"confidence": map[string]any{"type": "number"},
|
"confidence": map[string]any{"type": "number"},
|
||||||
},
|
},
|
||||||
|
|
@ -119,6 +120,7 @@ func (p *BrainProvider) Describe() []api.RouteDescription {
|
||||||
"filter": map[string]any{
|
"filter": map[string]any{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": map[string]any{
|
"properties": map[string]any{
|
||||||
|
"org": map[string]any{"type": "string"},
|
||||||
"project": map[string]any{"type": "string"},
|
"project": map[string]any{"type": "string"},
|
||||||
"type": map[string]any{"type": "string"},
|
"type": map[string]any{"type": "string"},
|
||||||
},
|
},
|
||||||
|
|
@ -161,7 +163,7 @@ func (p *BrainProvider) Describe() []api.RouteDescription {
|
||||||
Method: "GET",
|
Method: "GET",
|
||||||
Path: "/list",
|
Path: "/list",
|
||||||
Summary: "List memories",
|
Summary: "List memories",
|
||||||
Description: "List memories with optional filtering by project, type, and agent.",
|
Description: "List memories with optional filtering by org, project, type, and agent.",
|
||||||
Tags: []string{"brain"},
|
Tags: []string{"brain"},
|
||||||
Response: map[string]any{
|
Response: map[string]any{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
|
@ -208,6 +210,7 @@ func (p *BrainProvider) remember(c *gin.Context) {
|
||||||
"content": input.Content,
|
"content": input.Content,
|
||||||
"type": input.Type,
|
"type": input.Type,
|
||||||
"tags": input.Tags,
|
"tags": input.Tags,
|
||||||
|
"org": input.Org,
|
||||||
"project": input.Project,
|
"project": input.Project,
|
||||||
"confidence": input.Confidence,
|
"confidence": input.Confidence,
|
||||||
"supersedes": input.Supersedes,
|
"supersedes": input.Supersedes,
|
||||||
|
|
@ -220,6 +223,7 @@ func (p *BrainProvider) remember(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
p.emitEvent(coremcp.ChannelBrainRememberDone, map[string]any{
|
p.emitEvent(coremcp.ChannelBrainRememberDone, map[string]any{
|
||||||
|
"org": input.Org,
|
||||||
"type": input.Type,
|
"type": input.Type,
|
||||||
"project": input.Project,
|
"project": input.Project,
|
||||||
})
|
})
|
||||||
|
|
@ -299,6 +303,7 @@ func (p *BrainProvider) list(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
project := c.Query("project")
|
project := c.Query("project")
|
||||||
|
org := c.Query("org")
|
||||||
typ := c.Query("type")
|
typ := c.Query("type")
|
||||||
agentID := c.Query("agent_id")
|
agentID := c.Query("agent_id")
|
||||||
limit := c.Query("limit")
|
limit := c.Query("limit")
|
||||||
|
|
@ -306,6 +311,7 @@ func (p *BrainProvider) list(c *gin.Context) {
|
||||||
err := p.bridge.Send(ide.BridgeMessage{
|
err := p.bridge.Send(ide.BridgeMessage{
|
||||||
Type: "brain_list",
|
Type: "brain_list",
|
||||||
Data: map[string]any{
|
Data: map[string]any{
|
||||||
|
"org": org,
|
||||||
"project": project,
|
"project": project,
|
||||||
"type": typ,
|
"type": typ,
|
||||||
"agent_id": agentID,
|
"agent_id": agentID,
|
||||||
|
|
@ -318,6 +324,7 @@ func (p *BrainProvider) list(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
p.emitEvent(coremcp.ChannelBrainListDone, map[string]any{
|
p.emitEvent(coremcp.ChannelBrainListDone, map[string]any{
|
||||||
|
"org": org,
|
||||||
"project": project,
|
"project": project,
|
||||||
"type": typ,
|
"type": typ,
|
||||||
"agent_id": agentID,
|
"agent_id": agentID,
|
||||||
|
|
@ -354,14 +361,14 @@ func (p *BrainProvider) emitEvent(channel string, data any) {
|
||||||
func (p *BrainProvider) handleBridgeMessage(msg ide.BridgeMessage) {
|
func (p *BrainProvider) handleBridgeMessage(msg ide.BridgeMessage) {
|
||||||
switch msg.Type {
|
switch msg.Type {
|
||||||
case "brain_remember":
|
case "brain_remember":
|
||||||
p.emitEvent(coremcp.ChannelBrainRememberDone, bridgePayload(msg.Data, "type", "project"))
|
p.emitEvent(coremcp.ChannelBrainRememberDone, bridgePayload(msg.Data, "org", "type", "project"))
|
||||||
case "brain_recall":
|
case "brain_recall":
|
||||||
payload := bridgePayload(msg.Data, "query", "project", "type", "agent_id")
|
payload := bridgePayload(msg.Data, "query", "org", "project", "type", "agent_id")
|
||||||
payload["count"] = bridgeCount(msg.Data)
|
payload["count"] = bridgeCount(msg.Data)
|
||||||
p.emitEvent(coremcp.ChannelBrainRecallDone, payload)
|
p.emitEvent(coremcp.ChannelBrainRecallDone, payload)
|
||||||
case "brain_forget":
|
case "brain_forget":
|
||||||
p.emitEvent(coremcp.ChannelBrainForgetDone, bridgePayload(msg.Data, "id", "reason"))
|
p.emitEvent(coremcp.ChannelBrainForgetDone, bridgePayload(msg.Data, "id", "reason"))
|
||||||
case "brain_list":
|
case "brain_list":
|
||||||
p.emitEvent(coremcp.ChannelBrainListDone, bridgePayload(msg.Data, "project", "type", "agent_id", "limit"))
|
p.emitEvent(coremcp.ChannelBrainListDone, bridgePayload(msg.Data, "org", "project", "type", "agent_id", "limit"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,13 +5,16 @@ package brain
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"time"
|
"time"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
coreerr "dappco.re/go/log"
|
||||||
coremcp "dappco.re/go/mcp/pkg/mcp"
|
coremcp "dappco.re/go/mcp/pkg/mcp"
|
||||||
"dappco.re/go/mcp/pkg/mcp/ide"
|
"dappco.re/go/mcp/pkg/mcp/ide"
|
||||||
coreerr "forge.lthn.ai/core/go-log"
|
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const brainOrgMaxLength = 128
|
||||||
|
|
||||||
// emitChannel pushes a brain event through the shared notifier.
|
// emitChannel pushes a brain event through the shared notifier.
|
||||||
func (s *Subsystem) emitChannel(ctx context.Context, channel string, data any) {
|
func (s *Subsystem) emitChannel(ctx context.Context, channel string, data any) {
|
||||||
if s.notifier != nil {
|
if s.notifier != nil {
|
||||||
|
|
@ -23,11 +26,12 @@ func (s *Subsystem) emitChannel(ctx context.Context, channel string, data any) {
|
||||||
|
|
||||||
// RememberInput is the input for brain_remember.
|
// RememberInput is the input for brain_remember.
|
||||||
//
|
//
|
||||||
// input := RememberInput{Content: "Use Qdrant for vector search", Type: "decision"}
|
// input := RememberInput{Content: "Use Qdrant for vector search", Type: "decision", Org: "core"}
|
||||||
type RememberInput struct {
|
type RememberInput struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Tags []string `json:"tags,omitempty"`
|
Tags []string `json:"tags,omitempty"`
|
||||||
|
Org string `json:"org,omitempty"`
|
||||||
Project string `json:"project,omitempty"`
|
Project string `json:"project,omitempty"`
|
||||||
Confidence float64 `json:"confidence,omitempty"`
|
Confidence float64 `json:"confidence,omitempty"`
|
||||||
Supersedes string `json:"supersedes,omitempty"`
|
Supersedes string `json:"supersedes,omitempty"`
|
||||||
|
|
@ -54,8 +58,9 @@ type RecallInput struct {
|
||||||
|
|
||||||
// RecallFilter holds optional filter criteria for brain_recall.
|
// RecallFilter holds optional filter criteria for brain_recall.
|
||||||
//
|
//
|
||||||
// filter := RecallFilter{Project: "core/mcp", MinConfidence: 0.5}
|
// filter := RecallFilter{Org: "core", Project: "core/mcp", MinConfidence: 0.5}
|
||||||
type RecallFilter struct {
|
type RecallFilter struct {
|
||||||
|
Org string `json:"org,omitempty"`
|
||||||
Project string `json:"project,omitempty"`
|
Project string `json:"project,omitempty"`
|
||||||
Type any `json:"type,omitempty"`
|
Type any `json:"type,omitempty"`
|
||||||
AgentID string `json:"agent_id,omitempty"`
|
AgentID string `json:"agent_id,omitempty"`
|
||||||
|
|
@ -80,6 +85,7 @@ type Memory struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Tags []string `json:"tags,omitempty"`
|
Tags []string `json:"tags,omitempty"`
|
||||||
|
Org string `json:"org,omitempty"`
|
||||||
Project string `json:"project,omitempty"`
|
Project string `json:"project,omitempty"`
|
||||||
Confidence float64 `json:"confidence"`
|
Confidence float64 `json:"confidence"`
|
||||||
SupersedesID string `json:"supersedes_id,omitempty"`
|
SupersedesID string `json:"supersedes_id,omitempty"`
|
||||||
|
|
@ -107,8 +113,9 @@ type ForgetOutput struct {
|
||||||
|
|
||||||
// ListInput is the input for brain_list.
|
// ListInput is the input for brain_list.
|
||||||
//
|
//
|
||||||
// input := ListInput{Project: "core/mcp", Limit: 50}
|
// input := ListInput{Org: "core", Project: "core/mcp", Limit: 50}
|
||||||
type ListInput struct {
|
type ListInput struct {
|
||||||
|
Org string `json:"org,omitempty"`
|
||||||
Project string `json:"project,omitempty"`
|
Project string `json:"project,omitempty"`
|
||||||
Type string `json:"type,omitempty"`
|
Type string `json:"type,omitempty"`
|
||||||
AgentID string `json:"agent_id,omitempty"`
|
AgentID string `json:"agent_id,omitempty"`
|
||||||
|
|
@ -124,6 +131,25 @@ type ListOutput struct {
|
||||||
Memories []Memory `json:"memories"`
|
Memories []Memory `json:"memories"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateBrainOrg(org string) error {
|
||||||
|
if utf8.RuneCountInString(org) > brainOrgMaxLength {
|
||||||
|
return coreerr.E("brain.validate", "org exceeds maximum length of 128 characters", nil)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateRememberInput(input RememberInput) error {
|
||||||
|
return validateBrainOrg(input.Org)
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateRecallInput(input RecallInput) error {
|
||||||
|
return validateBrainOrg(input.Filter.Org)
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateListInput(input ListInput) error {
|
||||||
|
return validateBrainOrg(input.Org)
|
||||||
|
}
|
||||||
|
|
||||||
// -- Tool registration --------------------------------------------------------
|
// -- Tool registration --------------------------------------------------------
|
||||||
|
|
||||||
func (s *Subsystem) registerBrainTools(svc *coremcp.Service) {
|
func (s *Subsystem) registerBrainTools(svc *coremcp.Service) {
|
||||||
|
|
@ -145,13 +171,16 @@ func (s *Subsystem) registerBrainTools(svc *coremcp.Service) {
|
||||||
|
|
||||||
coremcp.AddToolRecorded(svc, server, "brain", &mcp.Tool{
|
coremcp.AddToolRecorded(svc, server, "brain", &mcp.Tool{
|
||||||
Name: "brain_list",
|
Name: "brain_list",
|
||||||
Description: "List memories in the shared OpenBrain knowledge store. Supports filtering by project, type, and agent. No vector search -- use brain_recall for semantic queries.",
|
Description: "List memories in the shared OpenBrain knowledge store. Supports filtering by org, project, type, and agent. No vector search -- use brain_recall for semantic queries.",
|
||||||
}, s.brainList)
|
}, s.brainList)
|
||||||
}
|
}
|
||||||
|
|
||||||
// -- Tool handlers ------------------------------------------------------------
|
// -- Tool handlers ------------------------------------------------------------
|
||||||
|
|
||||||
func (s *Subsystem) brainRemember(ctx context.Context, _ *mcp.CallToolRequest, input RememberInput) (*mcp.CallToolResult, RememberOutput, error) {
|
func (s *Subsystem) brainRemember(ctx context.Context, _ *mcp.CallToolRequest, input RememberInput) (*mcp.CallToolResult, RememberOutput, error) {
|
||||||
|
if err := validateRememberInput(input); err != nil {
|
||||||
|
return nil, RememberOutput{}, err
|
||||||
|
}
|
||||||
if s.bridge == nil {
|
if s.bridge == nil {
|
||||||
return nil, RememberOutput{}, errBridgeNotAvailable
|
return nil, RememberOutput{}, errBridgeNotAvailable
|
||||||
}
|
}
|
||||||
|
|
@ -162,6 +191,7 @@ func (s *Subsystem) brainRemember(ctx context.Context, _ *mcp.CallToolRequest, i
|
||||||
"content": input.Content,
|
"content": input.Content,
|
||||||
"type": input.Type,
|
"type": input.Type,
|
||||||
"tags": input.Tags,
|
"tags": input.Tags,
|
||||||
|
"org": input.Org,
|
||||||
"project": input.Project,
|
"project": input.Project,
|
||||||
"confidence": input.Confidence,
|
"confidence": input.Confidence,
|
||||||
"supersedes": input.Supersedes,
|
"supersedes": input.Supersedes,
|
||||||
|
|
@ -173,6 +203,7 @@ func (s *Subsystem) brainRemember(ctx context.Context, _ *mcp.CallToolRequest, i
|
||||||
}
|
}
|
||||||
|
|
||||||
s.emitChannel(ctx, coremcp.ChannelBrainRememberDone, map[string]any{
|
s.emitChannel(ctx, coremcp.ChannelBrainRememberDone, map[string]any{
|
||||||
|
"org": input.Org,
|
||||||
"type": input.Type,
|
"type": input.Type,
|
||||||
"project": input.Project,
|
"project": input.Project,
|
||||||
})
|
})
|
||||||
|
|
@ -184,6 +215,9 @@ func (s *Subsystem) brainRemember(ctx context.Context, _ *mcp.CallToolRequest, i
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Subsystem) brainRecall(ctx context.Context, _ *mcp.CallToolRequest, input RecallInput) (*mcp.CallToolResult, RecallOutput, error) {
|
func (s *Subsystem) brainRecall(ctx context.Context, _ *mcp.CallToolRequest, input RecallInput) (*mcp.CallToolResult, RecallOutput, error) {
|
||||||
|
if err := validateRecallInput(input); err != nil {
|
||||||
|
return nil, RecallOutput{}, err
|
||||||
|
}
|
||||||
if s.bridge == nil {
|
if s.bridge == nil {
|
||||||
return nil, RecallOutput{}, errBridgeNotAvailable
|
return nil, RecallOutput{}, errBridgeNotAvailable
|
||||||
}
|
}
|
||||||
|
|
@ -234,6 +268,9 @@ func (s *Subsystem) brainForget(ctx context.Context, _ *mcp.CallToolRequest, inp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Subsystem) brainList(ctx context.Context, _ *mcp.CallToolRequest, input ListInput) (*mcp.CallToolResult, ListOutput, error) {
|
func (s *Subsystem) brainList(ctx context.Context, _ *mcp.CallToolRequest, input ListInput) (*mcp.CallToolResult, ListOutput, error) {
|
||||||
|
if err := validateListInput(input); err != nil {
|
||||||
|
return nil, ListOutput{}, err
|
||||||
|
}
|
||||||
if s.bridge == nil {
|
if s.bridge == nil {
|
||||||
return nil, ListOutput{}, errBridgeNotAvailable
|
return nil, ListOutput{}, errBridgeNotAvailable
|
||||||
}
|
}
|
||||||
|
|
@ -245,6 +282,7 @@ func (s *Subsystem) brainList(ctx context.Context, _ *mcp.CallToolRequest, input
|
||||||
err := s.bridge.Send(ide.BridgeMessage{
|
err := s.bridge.Send(ide.BridgeMessage{
|
||||||
Type: "brain_list",
|
Type: "brain_list",
|
||||||
Data: map[string]any{
|
Data: map[string]any{
|
||||||
|
"org": input.Org,
|
||||||
"project": input.Project,
|
"project": input.Project,
|
||||||
"type": input.Type,
|
"type": input.Type,
|
||||||
"agent_id": input.AgentID,
|
"agent_id": input.AgentID,
|
||||||
|
|
@ -256,6 +294,7 @@ func (s *Subsystem) brainList(ctx context.Context, _ *mcp.CallToolRequest, input
|
||||||
}
|
}
|
||||||
|
|
||||||
s.emitChannel(ctx, coremcp.ChannelBrainListDone, map[string]any{
|
s.emitChannel(ctx, coremcp.ChannelBrainListDone, map[string]any{
|
||||||
|
"org": input.Org,
|
||||||
"project": input.Project,
|
"project": input.Project,
|
||||||
"type": input.Type,
|
"type": input.Type,
|
||||||
"agent_id": input.AgentID,
|
"agent_id": input.AgentID,
|
||||||
|
|
|
||||||
166
pkg/mcp/brain/tools_test.go
Normal file
166
pkg/mcp/brain/tools_test.go
Normal file
|
|
@ -0,0 +1,166 @@
|
||||||
|
// SPDX-License-Identifier: EUPL-1.2
|
||||||
|
|
||||||
|
package brain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"dappco.re/go/mcp/pkg/mcp/ide"
|
||||||
|
"dappco.re/go/ws"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
var brainToolTestUpgrader = websocket.Upgrader{
|
||||||
|
CheckOrigin: func(_ *http.Request) bool { return true },
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConnectedBrainToolSubsystem(t *testing.T) (*Subsystem, <-chan ide.BridgeMessage) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
messages := make(chan ide.BridgeMessage, 8)
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := brainToolTestUpgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("upgrade error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
for {
|
||||||
|
var msg ide.BridgeMessage
|
||||||
|
if err := conn.ReadJSON(&msg); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
messages <- msg
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
hub := ws.NewHub()
|
||||||
|
go hub.Run(ctx)
|
||||||
|
|
||||||
|
cfg := ide.DefaultConfig()
|
||||||
|
cfg.LaravelWSURL = "ws" + strings.TrimPrefix(srv.URL, "http")
|
||||||
|
cfg.ReconnectInterval = 10 * time.Millisecond
|
||||||
|
cfg.MaxReconnectInterval = 10 * time.Millisecond
|
||||||
|
|
||||||
|
bridge := ide.NewBridge(hub, cfg)
|
||||||
|
bridge.Start(ctx)
|
||||||
|
waitBrainToolBridgeConnected(t, bridge)
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
bridge.Shutdown()
|
||||||
|
cancel()
|
||||||
|
srv.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
return New(bridge), messages
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitBrainToolBridgeConnected(t *testing.T, bridge *ide.Bridge) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
deadline := time.Now().Add(2 * time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if bridge.Connected() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
t.Fatal("bridge did not connect within timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
func readBrainToolBridgeMessage(t *testing.T, messages <-chan ide.BridgeMessage) ide.BridgeMessage {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case msg := <-messages:
|
||||||
|
return msg
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for bridge message")
|
||||||
|
return ide.BridgeMessage{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertBrainOrgValidationError(t *testing.T, err error) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected org validation error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "org exceeds maximum length of 128 characters") {
|
||||||
|
t.Fatalf("expected org length error, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrainRemember_Good_OrgLengthBoundary(t *testing.T) {
|
||||||
|
sub, messages := newConnectedBrainToolSubsystem(t)
|
||||||
|
|
||||||
|
for _, tc := range []struct {
|
||||||
|
name string
|
||||||
|
org string
|
||||||
|
}{
|
||||||
|
{name: "non_empty", org: "core"},
|
||||||
|
{name: "empty", org: ""},
|
||||||
|
{name: "boundary", org: strings.Repeat("a", brainOrgMaxLength)},
|
||||||
|
} {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
_, out, err := sub.brainRemember(context.Background(), nil, RememberInput{
|
||||||
|
Content: "test memory",
|
||||||
|
Type: "observation",
|
||||||
|
Org: tc.org,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("brainRemember failed: %v", err)
|
||||||
|
}
|
||||||
|
if !out.Success {
|
||||||
|
t.Fatal("expected success=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := readBrainToolBridgeMessage(t, messages)
|
||||||
|
if msg.Type != "brain_remember" {
|
||||||
|
t.Fatalf("expected brain_remember message, got %q", msg.Type)
|
||||||
|
}
|
||||||
|
data, ok := msg.Data.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected bridge data map, got %T", msg.Data)
|
||||||
|
}
|
||||||
|
if data["org"] != tc.org {
|
||||||
|
t.Fatalf("expected org %q, got %v", tc.org, data["org"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrainRemember_Bad_OrgTooLong(t *testing.T) {
|
||||||
|
sub := New(nil)
|
||||||
|
|
||||||
|
_, _, err := sub.brainRemember(context.Background(), nil, RememberInput{
|
||||||
|
Content: "test memory",
|
||||||
|
Type: "observation",
|
||||||
|
Org: strings.Repeat("a", brainOrgMaxLength+1),
|
||||||
|
})
|
||||||
|
|
||||||
|
assertBrainOrgValidationError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrainOrgValidation_Bad_RecallAndListRejectBeforeBridge(t *testing.T) {
|
||||||
|
sub := New(nil)
|
||||||
|
tooLong := strings.Repeat("a", brainOrgMaxLength+1)
|
||||||
|
|
||||||
|
_, _, err := sub.brainRecall(context.Background(), nil, RecallInput{
|
||||||
|
Query: "test",
|
||||||
|
Filter: RecallFilter{Org: tooLong},
|
||||||
|
})
|
||||||
|
assertBrainOrgValidationError(t, err)
|
||||||
|
|
||||||
|
_, _, err = sub.brainList(context.Background(), nil, ListInput{
|
||||||
|
Org: tooLong,
|
||||||
|
})
|
||||||
|
assertBrainOrgValidationError(t, err)
|
||||||
|
}
|
||||||
|
|
@ -3,13 +3,12 @@
|
||||||
package mcp
|
package mcp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
core "dappco.re/go/core"
|
core "dappco.re/go/core"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
api "forge.lthn.ai/core/api"
|
api "dappco.re/go/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
// maxBodySize is the maximum request body size accepted by bridged tool endpoints.
|
// maxBodySize is the maximum request body size accepted by bridged tool endpoints.
|
||||||
|
|
@ -48,7 +47,7 @@ func BridgeToAPI(svc *Service, bridge *api.ToolBridge) {
|
||||||
if !r.OK {
|
if !r.OK {
|
||||||
if err, ok := r.Value.(error); ok {
|
if err, ok := r.Value.(error); ok {
|
||||||
var maxBytesErr *http.MaxBytesError
|
var maxBytesErr *http.MaxBytesError
|
||||||
if errors.As(err, &maxBytesErr) || core.Contains(err.Error(), "request body too large") {
|
if core.As(err, &maxBytesErr) || core.Contains(err.Error(), "request body too large") {
|
||||||
c.JSON(http.StatusRequestEntityTooLarge, api.Fail("request_too_large", "Request body exceeds 10 MB limit"))
|
c.JSON(http.StatusRequestEntityTooLarge, api.Fail("request_too_large", "Request body exceeds 10 MB limit"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -63,7 +62,7 @@ func BridgeToAPI(svc *Service, bridge *api.ToolBridge) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Body present + error = likely bad input (malformed JSON).
|
// Body present + error = likely bad input (malformed JSON).
|
||||||
// No body + error = tool execution failure.
|
// No body + error = tool execution failure.
|
||||||
if errors.Is(err, errInvalidRESTInput) {
|
if core.Is(err, errInvalidRESTInput) {
|
||||||
c.JSON(http.StatusBadRequest, api.Fail("invalid_input", "Malformed JSON in request body"))
|
c.JSON(http.StatusBadRequest, api.Fail("invalid_input", "Malformed JSON in request body"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ import (
|
||||||
"dappco.re/go/mcp/pkg/mcp/agentic"
|
"dappco.re/go/mcp/pkg/mcp/agentic"
|
||||||
"dappco.re/go/mcp/pkg/mcp/brain"
|
"dappco.re/go/mcp/pkg/mcp/brain"
|
||||||
"dappco.re/go/mcp/pkg/mcp/ide"
|
"dappco.re/go/mcp/pkg/mcp/ide"
|
||||||
api "forge.lthn.ai/core/api"
|
api "dappco.re/go/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
@ -81,13 +81,16 @@ func TestBridgeToAPI_Good_DescribableGroup(t *testing.T) {
|
||||||
var dg api.DescribableGroup = bridge
|
var dg api.DescribableGroup = bridge
|
||||||
descs := dg.Describe()
|
descs := dg.Describe()
|
||||||
|
|
||||||
if len(descs) != len(svc.Tools()) {
|
// ToolBridge.Describe prepends a GET entry describing the tool listing
|
||||||
t.Fatalf("expected %d descriptions, got %d", len(svc.Tools()), len(descs))
|
// endpoint, so the expected count is svc.Tools() + 1.
|
||||||
|
wantDescs := len(svc.Tools()) + 1
|
||||||
|
if len(descs) != wantDescs {
|
||||||
|
t.Fatalf("expected %d descriptions, got %d", wantDescs, len(descs))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, d := range descs {
|
for _, d := range descs {
|
||||||
if d.Method != "POST" {
|
if d.Method != "POST" && d.Method != "GET" {
|
||||||
t.Errorf("expected Method=POST for %s, got %q", d.Path, d.Method)
|
t.Errorf("expected Method=POST or GET for %s, got %q", d.Path, d.Method)
|
||||||
}
|
}
|
||||||
if d.Summary == "" {
|
if d.Summary == "" {
|
||||||
t.Errorf("expected non-empty Summary for %s", d.Path)
|
t.Errorf("expected non-empty Summary for %s", d.Path)
|
||||||
|
|
@ -250,7 +253,7 @@ func TestBridgeToAPI_Good_EndToEnd(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify a tool endpoint is reachable through the engine.
|
// Verify a tool endpoint is reachable through the engine.
|
||||||
resp2, err := http.Post(srv.URL+"/tools/lang_list", "application/json", nil)
|
resp2, err := http.Post(srv.URL+"/tools/lang_list", "application/json", strings.NewReader("{}"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("lang_list request failed: %v", err)
|
t.Fatalf("lang_list request failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,8 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
coreerr "forge.lthn.ai/core/go-log"
|
coreerr "dappco.re/go/log"
|
||||||
"forge.lthn.ai/core/go-ws"
|
"dappco.re/go/ws"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-ws"
|
"dappco.re/go/ws"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,14 +4,13 @@ package ide
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
core "dappco.re/go/core"
|
core "dappco.re/go/core"
|
||||||
coremcp "dappco.re/go/mcp/pkg/mcp"
|
coremcp "dappco.re/go/mcp/pkg/mcp"
|
||||||
coreerr "forge.lthn.ai/core/go-log"
|
coreerr "dappco.re/go/log"
|
||||||
"forge.lthn.ai/core/go-ws"
|
"dappco.re/go/ws"
|
||||||
)
|
)
|
||||||
|
|
||||||
// errBridgeNotAvailable is returned when a tool requires the Laravel bridge
|
// errBridgeNotAvailable is returned when a tool requires the Laravel bridge
|
||||||
|
|
@ -556,7 +555,7 @@ func stringFromAny(v any) string {
|
||||||
switch value := v.(type) {
|
switch value := v.(type) {
|
||||||
case string:
|
case string:
|
||||||
return value
|
return value
|
||||||
case fmt.Stringer:
|
case interface{ String() string }:
|
||||||
return value.String()
|
return value.String()
|
||||||
default:
|
default:
|
||||||
return ""
|
return ""
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
coremcp "dappco.re/go/mcp/pkg/mcp"
|
coremcp "dappco.re/go/mcp/pkg/mcp"
|
||||||
coreerr "forge.lthn.ai/core/go-log"
|
coreerr "dappco.re/go/log"
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ package ide
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
coremcp "dappco.re/go/mcp/pkg/mcp"
|
coremcp "dappco.re/go/mcp/pkg/mcp"
|
||||||
|
|
@ -86,6 +87,46 @@ type DashboardMetricsOutput struct {
|
||||||
Metrics DashboardMetrics `json:"metrics"`
|
Metrics DashboardMetrics `json:"metrics"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DashboardStateInput is the input for ide_dashboard_state.
|
||||||
|
//
|
||||||
|
// input := DashboardStateInput{}
|
||||||
|
type DashboardStateInput struct{}
|
||||||
|
|
||||||
|
// DashboardStateOutput is the output for ide_dashboard_state.
|
||||||
|
//
|
||||||
|
// // out.State["theme"] == "dark"
|
||||||
|
type DashboardStateOutput struct {
|
||||||
|
State map[string]any `json:"state"` // arbitrary key/value map
|
||||||
|
UpdatedAt time.Time `json:"updatedAt"` // when the state last changed
|
||||||
|
}
|
||||||
|
|
||||||
|
// DashboardUpdateInput is the input for ide_dashboard_update.
|
||||||
|
//
|
||||||
|
// input := DashboardUpdateInput{
|
||||||
|
// State: map[string]any{"theme": "light", "sidebar": true},
|
||||||
|
// Replace: false,
|
||||||
|
// }
|
||||||
|
type DashboardUpdateInput struct {
|
||||||
|
State map[string]any `json:"state"` // partial or full state
|
||||||
|
Replace bool `json:"replace,omitempty"` // true to overwrite, false to merge (default)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DashboardUpdateOutput is the output for ide_dashboard_update.
|
||||||
|
//
|
||||||
|
// // out.State reflects the merged/replaced state
|
||||||
|
type DashboardUpdateOutput struct {
|
||||||
|
State map[string]any `json:"state"` // merged state after the update
|
||||||
|
UpdatedAt time.Time `json:"updatedAt"` // when the state was applied
|
||||||
|
}
|
||||||
|
|
||||||
|
// dashboardStateStore holds the mutable dashboard UI state shared between the
|
||||||
|
// IDE frontend and MCP callers. Access is guarded by dashboardStateMu.
|
||||||
|
var (
|
||||||
|
dashboardStateMu sync.RWMutex
|
||||||
|
dashboardStateStore = map[string]any{}
|
||||||
|
dashboardStateUpdated time.Time
|
||||||
|
)
|
||||||
|
|
||||||
func (s *Subsystem) registerDashboardTools(svc *coremcp.Service) {
|
func (s *Subsystem) registerDashboardTools(svc *coremcp.Service) {
|
||||||
server := svc.Server()
|
server := svc.Server()
|
||||||
coremcp.AddToolRecorded(svc, server, "ide", &mcp.Tool{
|
coremcp.AddToolRecorded(svc, server, "ide", &mcp.Tool{
|
||||||
|
|
@ -102,6 +143,16 @@ func (s *Subsystem) registerDashboardTools(svc *coremcp.Service) {
|
||||||
Name: "ide_dashboard_metrics",
|
Name: "ide_dashboard_metrics",
|
||||||
Description: "Get aggregate build and agent metrics for a time period",
|
Description: "Get aggregate build and agent metrics for a time period",
|
||||||
}, s.dashboardMetrics)
|
}, s.dashboardMetrics)
|
||||||
|
|
||||||
|
coremcp.AddToolRecorded(svc, server, "ide", &mcp.Tool{
|
||||||
|
Name: "ide_dashboard_state",
|
||||||
|
Description: "Get the current dashboard UI state (arbitrary key/value map shared with the IDE).",
|
||||||
|
}, s.dashboardState)
|
||||||
|
|
||||||
|
coremcp.AddToolRecorded(svc, server, "ide", &mcp.Tool{
|
||||||
|
Name: "ide_dashboard_update",
|
||||||
|
Description: "Update the dashboard UI state. Merges into existing state by default; set replace=true to overwrite.",
|
||||||
|
}, s.dashboardUpdate)
|
||||||
}
|
}
|
||||||
|
|
||||||
// dashboardOverview returns a platform overview with bridge status and
|
// dashboardOverview returns a platform overview with bridge status and
|
||||||
|
|
@ -211,3 +262,79 @@ func (s *Subsystem) dashboardMetrics(_ context.Context, _ *mcp.CallToolRequest,
|
||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// dashboardState returns the current dashboard UI state as a snapshot.
|
||||||
|
//
|
||||||
|
// out := s.dashboardState(ctx, nil, DashboardStateInput{})
|
||||||
|
func (s *Subsystem) dashboardState(_ context.Context, _ *mcp.CallToolRequest, _ DashboardStateInput) (*mcp.CallToolResult, DashboardStateOutput, error) {
|
||||||
|
dashboardStateMu.RLock()
|
||||||
|
defer dashboardStateMu.RUnlock()
|
||||||
|
|
||||||
|
snapshot := make(map[string]any, len(dashboardStateStore))
|
||||||
|
for k, v := range dashboardStateStore {
|
||||||
|
snapshot[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, DashboardStateOutput{
|
||||||
|
State: snapshot,
|
||||||
|
UpdatedAt: dashboardStateUpdated,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// dashboardUpdate merges or replaces the dashboard UI state and emits an
|
||||||
|
// activity event so the IDE can react to the change.
|
||||||
|
//
|
||||||
|
// out := s.dashboardUpdate(ctx, nil, DashboardUpdateInput{State: map[string]any{"theme": "dark"}})
|
||||||
|
func (s *Subsystem) dashboardUpdate(ctx context.Context, _ *mcp.CallToolRequest, input DashboardUpdateInput) (*mcp.CallToolResult, DashboardUpdateOutput, error) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
dashboardStateMu.Lock()
|
||||||
|
if input.Replace || dashboardStateStore == nil {
|
||||||
|
dashboardStateStore = make(map[string]any, len(input.State))
|
||||||
|
}
|
||||||
|
for k, v := range input.State {
|
||||||
|
dashboardStateStore[k] = v
|
||||||
|
}
|
||||||
|
dashboardStateUpdated = now
|
||||||
|
|
||||||
|
snapshot := make(map[string]any, len(dashboardStateStore))
|
||||||
|
for k, v := range dashboardStateStore {
|
||||||
|
snapshot[k] = v
|
||||||
|
}
|
||||||
|
dashboardStateMu.Unlock()
|
||||||
|
|
||||||
|
// Record the change on the activity feed so ide_dashboard_activity
|
||||||
|
// reflects state transitions alongside build/session events.
|
||||||
|
s.recordActivity("dashboard_state", "dashboard state updated")
|
||||||
|
|
||||||
|
// Push the update over the Laravel bridge when available so web clients
|
||||||
|
// stay in sync with desktop tooling.
|
||||||
|
if s.bridge != nil {
|
||||||
|
_ = s.bridge.Send(BridgeMessage{
|
||||||
|
Type: "dashboard_update",
|
||||||
|
Data: snapshot,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Surface the change on the shared MCP notifier so connected sessions
|
||||||
|
// receive a JSON-RPC notification alongside the tool response.
|
||||||
|
if s.notifier != nil {
|
||||||
|
s.notifier.ChannelSend(ctx, "dashboard.state.updated", map[string]any{
|
||||||
|
"state": snapshot,
|
||||||
|
"updatedAt": now,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, DashboardUpdateOutput{
|
||||||
|
State: snapshot,
|
||||||
|
UpdatedAt: now,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// resetDashboardState clears the shared dashboard state. Intended for tests.
|
||||||
|
func resetDashboardState() {
|
||||||
|
dashboardStateMu.Lock()
|
||||||
|
defer dashboardStateMu.Unlock()
|
||||||
|
dashboardStateStore = map[string]any{}
|
||||||
|
dashboardStateUpdated = time.Time{}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
coremcp "dappco.re/go/mcp/pkg/mcp"
|
coremcp "dappco.re/go/mcp/pkg/mcp"
|
||||||
"forge.lthn.ai/core/go-ws"
|
"dappco.re/go/ws"
|
||||||
)
|
)
|
||||||
|
|
||||||
// --- Helpers ---
|
// --- Helpers ---
|
||||||
|
|
@ -949,3 +949,76 @@ func TestChatSend_Good_BridgeMessageType(t *testing.T) {
|
||||||
t.Fatal("timed out waiting for bridge message")
|
t.Fatal("timed out waiting for bridge message")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestToolsDashboard_DashboardState_Good returns an empty state when the
|
||||||
|
// store has not been touched.
|
||||||
|
func TestToolsDashboard_DashboardState_Good(t *testing.T) {
|
||||||
|
t.Cleanup(resetDashboardState)
|
||||||
|
|
||||||
|
sub := newNilBridgeSubsystem()
|
||||||
|
_, out, err := sub.dashboardState(context.Background(), nil, DashboardStateInput{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dashboardState failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(out.State) != 0 {
|
||||||
|
t.Fatalf("expected empty state, got %v", out.State)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolsDashboard_DashboardUpdate_Good merges the supplied state into the
|
||||||
|
// shared store and reflects it back on a subsequent dashboardState call.
|
||||||
|
func TestToolsDashboard_DashboardUpdate_Good(t *testing.T) {
|
||||||
|
t.Cleanup(resetDashboardState)
|
||||||
|
|
||||||
|
sub := newNilBridgeSubsystem()
|
||||||
|
|
||||||
|
_, updateOut, err := sub.dashboardUpdate(context.Background(), nil, DashboardUpdateInput{
|
||||||
|
State: map[string]any{"theme": "dark"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dashboardUpdate failed: %v", err)
|
||||||
|
}
|
||||||
|
if updateOut.State["theme"] != "dark" {
|
||||||
|
t.Fatalf("expected theme 'dark', got %v", updateOut.State["theme"])
|
||||||
|
}
|
||||||
|
|
||||||
|
_, readOut, err := sub.dashboardState(context.Background(), nil, DashboardStateInput{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dashboardState failed: %v", err)
|
||||||
|
}
|
||||||
|
if readOut.State["theme"] != "dark" {
|
||||||
|
t.Fatalf("expected persisted theme 'dark', got %v", readOut.State["theme"])
|
||||||
|
}
|
||||||
|
if readOut.UpdatedAt.IsZero() {
|
||||||
|
t.Fatal("expected non-zero UpdatedAt after update")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolsDashboard_DashboardUpdate_Ugly replaces (not merges) prior state
|
||||||
|
// when Replace=true.
|
||||||
|
func TestToolsDashboard_DashboardUpdate_Ugly(t *testing.T) {
|
||||||
|
t.Cleanup(resetDashboardState)
|
||||||
|
|
||||||
|
sub := newNilBridgeSubsystem()
|
||||||
|
|
||||||
|
_, _, err := sub.dashboardUpdate(context.Background(), nil, DashboardUpdateInput{
|
||||||
|
State: map[string]any{"theme": "dark", "sidebar": true},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("seed dashboardUpdate failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, out, err := sub.dashboardUpdate(context.Background(), nil, DashboardUpdateInput{
|
||||||
|
State: map[string]any{"theme": "light"},
|
||||||
|
Replace: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("replace dashboardUpdate failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, ok := out.State["sidebar"]; ok {
|
||||||
|
t.Fatal("expected sidebar to be removed after replace")
|
||||||
|
}
|
||||||
|
if out.State["theme"] != "light" {
|
||||||
|
t.Fatalf("expected theme 'light', got %v", out.State["theme"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
18
pkg/mcp/ipc.go
Normal file
18
pkg/mcp/ipc.go
Normal file
|
|
@ -0,0 +1,18 @@
|
||||||
|
// SPDX-License-Identifier: EUPL-1.2
|
||||||
|
|
||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Service) handleChannelPushIPC(ctx context.Context, ev ChannelPush) core.Result {
|
||||||
|
if core.Trim(ev.Channel) == "" {
|
||||||
|
return core.Result{Value: core.E("mcp.HandleIPCEvents", "channel is required", nil), OK: false}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.ChannelSend(ctx, ev.Channel, ev.Data)
|
||||||
|
return core.Result{OK: true}
|
||||||
|
}
|
||||||
111
pkg/mcp/ipc_test.go
Normal file
111
pkg/mcp/ipc_test.go
Normal file
|
|
@ -0,0 +1,111 @@
|
||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIPC_HandleIPCEvents_Good(t *testing.T) {
|
||||||
|
svc, err := New(Options{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel, session, clientConn := connectNotificationSession(t, svc)
|
||||||
|
defer cancel()
|
||||||
|
defer session.Close()
|
||||||
|
defer clientConn.Close()
|
||||||
|
|
||||||
|
clientConn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
read := readNotificationMessageUntil(t, clientConn, func(msg map[string]any) bool {
|
||||||
|
return msg["method"] == ChannelNotificationMethod
|
||||||
|
})
|
||||||
|
|
||||||
|
result := svc.HandleIPCEvents(nil, ChannelPush{
|
||||||
|
Channel: "agent.completed",
|
||||||
|
Data: map[string]any{
|
||||||
|
"repo": "core/mcp",
|
||||||
|
"ok": true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if !result.OK {
|
||||||
|
t.Fatalf("HandleIPCEvents() returned non-OK result: %#v", result.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
res := <-read
|
||||||
|
if res.err != nil {
|
||||||
|
t.Fatalf("failed to read channel notification: %v", res.err)
|
||||||
|
}
|
||||||
|
|
||||||
|
params, ok := res.msg["params"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected params object, got %T", res.msg["params"])
|
||||||
|
}
|
||||||
|
if params["channel"] != "agent.completed" {
|
||||||
|
t.Fatalf("expected channel agent.completed, got %#v", params["channel"])
|
||||||
|
}
|
||||||
|
|
||||||
|
payload, ok := params["data"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected data object, got %T", params["data"])
|
||||||
|
}
|
||||||
|
if payload["repo"] != "core/mcp" || payload["ok"] != true {
|
||||||
|
t.Fatalf("unexpected payload: %#v", payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPC_HandleIPCEvents_Bad(t *testing.T) {
|
||||||
|
svc, err := New(Options{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := svc.HandleIPCEvents(nil, ChannelPush{
|
||||||
|
Channel: " \t ",
|
||||||
|
Data: map[string]any{"ok": false},
|
||||||
|
})
|
||||||
|
if result.OK {
|
||||||
|
t.Fatal("expected empty ChannelPush channel to fail")
|
||||||
|
}
|
||||||
|
if _, ok := result.Value.(error); !ok {
|
||||||
|
t.Fatalf("expected error result value, got %T", result.Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPC_HandleIPCEvents_Ugly(t *testing.T) {
|
||||||
|
svc, err := New(Options{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel, session, clientConn := connectNotificationSession(t, svc)
|
||||||
|
defer cancel()
|
||||||
|
defer session.Close()
|
||||||
|
defer clientConn.Close()
|
||||||
|
|
||||||
|
clientConn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
read := readNotificationMessageUntil(t, clientConn, func(msg map[string]any) bool {
|
||||||
|
params, ok := msg["params"].(map[string]any)
|
||||||
|
return msg["method"] == ChannelNotificationMethod && ok && params["channel"] == "agent.edge"
|
||||||
|
})
|
||||||
|
|
||||||
|
result := svc.HandleIPCEvents(nil, ChannelPush{Channel: "agent.edge"})
|
||||||
|
if !result.OK {
|
||||||
|
t.Fatalf("HandleIPCEvents() returned non-OK result: %#v", result.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
res := <-read
|
||||||
|
if res.err != nil {
|
||||||
|
t.Fatalf("failed to read edge notification: %v", res.err)
|
||||||
|
}
|
||||||
|
params, ok := res.msg["params"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected params object, got %T", res.msg["params"])
|
||||||
|
}
|
||||||
|
if _, ok := params["data"]; !ok {
|
||||||
|
t.Fatalf("expected data key for nil ChannelPush data: %#v", params)
|
||||||
|
}
|
||||||
|
if params["data"] != nil {
|
||||||
|
t.Fatalf("expected nil data, got %#v", params["data"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -5,22 +5,20 @@
|
||||||
package mcp
|
package mcp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"iter"
|
"iter"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
core "dappco.re/go/core"
|
core "dappco.re/go/core"
|
||||||
"forge.lthn.ai/core/go-io"
|
"dappco.re/go/io"
|
||||||
"forge.lthn.ai/core/go-log"
|
"dappco.re/go/log"
|
||||||
"forge.lthn.ai/core/go-process"
|
"dappco.re/go/process"
|
||||||
"forge.lthn.ai/core/go-ws"
|
"dappco.re/go/ws"
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -74,7 +72,8 @@ func New(opts Options) (*Service, error) {
|
||||||
|
|
||||||
server := mcp.NewServer(impl, &mcp.ServerOptions{
|
server := mcp.NewServer(impl, &mcp.ServerOptions{
|
||||||
Capabilities: &mcp.ServerCapabilities{
|
Capabilities: &mcp.ServerCapabilities{
|
||||||
Tools: &mcp.ToolCapabilities{ListChanged: true},
|
Resources: &mcp.ResourceCapabilities{ListChanged: false},
|
||||||
|
Tools: &mcp.ToolCapabilities{ListChanged: false},
|
||||||
Logging: &mcp.LoggingCapabilities{},
|
Logging: &mcp.LoggingCapabilities{},
|
||||||
Experimental: channelCapability(),
|
Experimental: channelCapability(),
|
||||||
},
|
},
|
||||||
|
|
@ -245,15 +244,15 @@ func (s *Service) resolveWorkspacePath(path string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.workspaceRoot == "" {
|
if s.workspaceRoot == "" {
|
||||||
return filepath.Clean(path)
|
return core.CleanPath(path, "/")
|
||||||
}
|
}
|
||||||
|
|
||||||
clean := filepath.Clean(string(filepath.Separator) + path)
|
clean := core.CleanPath(string(filepath.Separator)+path, "/")
|
||||||
clean = strings.TrimPrefix(clean, string(filepath.Separator))
|
clean = core.TrimPrefix(clean, string(filepath.Separator))
|
||||||
if clean == "." || clean == "" {
|
if clean == "." || clean == "" {
|
||||||
return s.workspaceRoot
|
return s.workspaceRoot
|
||||||
}
|
}
|
||||||
return filepath.Join(s.workspaceRoot, clean)
|
return core.Path(s.workspaceRoot, clean)
|
||||||
}
|
}
|
||||||
|
|
||||||
// registerTools adds the built-in tool groups to the MCP server.
|
// registerTools adds the built-in tool groups to the MCP server.
|
||||||
|
|
@ -317,6 +316,7 @@ func (s *Service) registerTools(server *mcp.Server) {
|
||||||
s.registerProcessTools(server)
|
s.registerProcessTools(server)
|
||||||
s.registerWebviewTools(server)
|
s.registerWebviewTools(server)
|
||||||
s.registerWSTools(server)
|
s.registerWSTools(server)
|
||||||
|
s.registerWSClientTools(server)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tool input/output types for MCP file operations.
|
// Tool input/output types for MCP file operations.
|
||||||
|
|
@ -543,8 +543,8 @@ func (s *Service) listDirectory(ctx context.Context, req *mcp.CallToolRequest, i
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ListDirectoryOutput{}, log.E("mcp.listDirectory", "failed to list directory", err)
|
return nil, ListDirectoryOutput{}, log.E("mcp.listDirectory", "failed to list directory", err)
|
||||||
}
|
}
|
||||||
sort.Slice(entries, func(i, j int) bool {
|
slices.SortFunc(entries, func(a, b os.DirEntry) int {
|
||||||
return entries[i].Name() < entries[j].Name()
|
return cmp.Compare(a.Name(), b.Name())
|
||||||
})
|
})
|
||||||
result := make([]DirectoryEntry, 0, len(entries))
|
result := make([]DirectoryEntry, 0, len(entries))
|
||||||
for _, e := range entries {
|
for _, e := range entries {
|
||||||
|
|
@ -615,7 +615,7 @@ func (s *Service) fileExists(ctx context.Context, req *mcp.CallToolRequest, inpu
|
||||||
|
|
||||||
info, err := s.medium.Stat(input.Path)
|
info, err := s.medium.Stat(input.Path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if core.Is(err, os.ErrNotExist) {
|
||||||
return nil, FileExistsOutput{Exists: false, IsDir: false, Path: input.Path}, nil
|
return nil, FileExistsOutput{Exists: false, IsDir: false, Path: input.Path}, nil
|
||||||
}
|
}
|
||||||
return nil, FileExistsOutput{}, log.E("mcp.fileExists", "failed to stat path", err)
|
return nil, FileExistsOutput{}, log.E("mcp.fileExists", "failed to stat path", err)
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,11 @@ package mcp
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNew_Good_DefaultWorkspace(t *testing.T) {
|
func TestMcp_New_Good_DefaultWorkspace(t *testing.T) {
|
||||||
cwd, err := os.Getwd()
|
cwd, err := os.Getwd()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to get working directory: %v", err)
|
t.Fatalf("Failed to get working directory: %v", err)
|
||||||
|
|
@ -25,7 +26,7 @@ func TestNew_Good_DefaultWorkspace(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNew_Good_CustomWorkspace(t *testing.T) {
|
func TestMcp_New_Good_CustomWorkspace(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
s, err := New(Options{WorkspaceRoot: tmpDir})
|
s, err := New(Options{WorkspaceRoot: tmpDir})
|
||||||
|
|
@ -41,7 +42,7 @@ func TestNew_Good_CustomWorkspace(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNew_Good_NoRestriction(t *testing.T) {
|
func TestMcp_New_Good_NoRestriction(t *testing.T) {
|
||||||
s, err := New(Options{Unrestricted: true})
|
s, err := New(Options{Unrestricted: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create service: %v", err)
|
t.Fatalf("Failed to create service: %v", err)
|
||||||
|
|
@ -55,7 +56,7 @@ func TestNew_Good_NoRestriction(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNew_Good_RegistersBuiltInTools(t *testing.T) {
|
func TestMcp_New_Good_RegistersBuiltInTools(t *testing.T) {
|
||||||
s, err := New(Options{})
|
s, err := New(Options{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create service: %v", err)
|
t.Fatalf("Failed to create service: %v", err)
|
||||||
|
|
@ -95,7 +96,47 @@ func TestNew_Good_RegistersBuiltInTools(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetSupportedLanguages_Good_IncludesAllDetectedLanguages(t *testing.T) {
|
func TestMcp_New_Bad_NilSubsystemIgnored(t *testing.T) {
|
||||||
|
s, err := New(Options{Subsystems: []Subsystem{nil}})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New failed with nil subsystem: %v", err)
|
||||||
|
}
|
||||||
|
if len(s.Subsystems()) != 0 {
|
||||||
|
t.Fatalf("expected nil subsystem to be ignored, got %d subsystems", len(s.Subsystems()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcp_New_Ugly_ConcurrentConstruction(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
const workers = 8
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errs := make(chan error, workers)
|
||||||
|
for i := 0; i < workers; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
s, err := New(Options{WorkspaceRoot: tmpDir})
|
||||||
|
if err != nil {
|
||||||
|
errs <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.workspaceRoot != tmpDir || s.medium == nil {
|
||||||
|
errs <- os.ErrInvalid
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
close(errs)
|
||||||
|
|
||||||
|
for err := range errs {
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("concurrent New failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcp_GetSupportedLanguages_Good_IncludesAllDetectedLanguages(t *testing.T) {
|
||||||
s, err := New(Options{})
|
s, err := New(Options{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create service: %v", err)
|
t.Fatalf("Failed to create service: %v", err)
|
||||||
|
|
@ -146,7 +187,40 @@ func TestGetSupportedLanguages_Good_IncludesAllDetectedLanguages(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDetectLanguageFromPath_Good_KnownExtensions(t *testing.T) {
|
func TestMcp_GetSupportedLanguages_Bad_IgnoresUnsupportedInputState(t *testing.T) {
|
||||||
|
s := &Service{}
|
||||||
|
|
||||||
|
_, out, err := s.getSupportedLanguages(nil, nil, GetSupportedLanguagesInput{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("getSupportedLanguages failed without initialized service state: %v", err)
|
||||||
|
}
|
||||||
|
if len(out.Languages) == 0 {
|
||||||
|
t.Fatal("expected supported languages to be returned")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcp_GetSupportedLanguages_Ugly_ReturnsIndependentSnapshots(t *testing.T) {
|
||||||
|
s, err := New(Options{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, first, err := s.getSupportedLanguages(nil, nil, GetSupportedLanguagesInput{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("getSupportedLanguages failed: %v", err)
|
||||||
|
}
|
||||||
|
first.Languages[0].ID = "mutated"
|
||||||
|
|
||||||
|
_, second, err := s.getSupportedLanguages(nil, nil, GetSupportedLanguagesInput{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("getSupportedLanguages failed on second call: %v", err)
|
||||||
|
}
|
||||||
|
if second.Languages[0].ID == "mutated" {
|
||||||
|
t.Fatal("expected a fresh supported languages snapshot")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcp_DetectLanguageFromPath_Good_KnownExtensions(t *testing.T) {
|
||||||
cases := map[string]string{
|
cases := map[string]string{
|
||||||
"main.go": "go",
|
"main.go": "go",
|
||||||
"index.tsx": "typescript",
|
"index.tsx": "typescript",
|
||||||
|
|
@ -163,7 +237,30 @@ func TestDetectLanguageFromPath_Good_KnownExtensions(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMedium_Good_ReadWrite(t *testing.T) {
|
func TestMcp_DetectLanguageFromPath_Bad_UnsupportedExtensionDefaultsPlaintext(t *testing.T) {
|
||||||
|
if got := detectLanguageFromPath("archive.unknown"); got != "plaintext" {
|
||||||
|
t.Fatalf("expected unsupported extension to be plaintext, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcp_DetectLanguageFromPath_Ugly_BoundaryPaths(t *testing.T) {
|
||||||
|
cases := map[string]string{
|
||||||
|
"": "plaintext",
|
||||||
|
"Dockerfile": "dockerfile",
|
||||||
|
"nested/Makefile": "plaintext",
|
||||||
|
"nested/file.TSX": "plaintext",
|
||||||
|
"nested/.env": "plaintext",
|
||||||
|
"nested/file.bash": "shell",
|
||||||
|
}
|
||||||
|
|
||||||
|
for path, want := range cases {
|
||||||
|
if got := detectLanguageFromPath(path); got != want {
|
||||||
|
t.Fatalf("detectLanguageFromPath(%q) = %q, want %q", path, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcp_Medium_Good_ReadWrite(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
s, err := New(Options{WorkspaceRoot: tmpDir})
|
s, err := New(Options{WorkspaceRoot: tmpDir})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -193,7 +290,53 @@ func TestMedium_Good_ReadWrite(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMedium_Good_EnsureDir(t *testing.T) {
|
func TestMcp_Medium_Bad_ReadMissingFile(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
s, err := New(Options{WorkspaceRoot: tmpDir})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := s.medium.Read("missing.txt"); err == nil {
|
||||||
|
t.Fatal("expected reading a missing file to fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcp_Medium_Ugly_ConcurrentReadWrite(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
s, err := New(Options{WorkspaceRoot: tmpDir})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const workers = 8
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errs := make(chan error, workers)
|
||||||
|
for i := 0; i < workers; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
path := filepath.Join("concurrent", string(rune('a'+i))+".txt")
|
||||||
|
if err := s.medium.Write(path, "content"); err != nil {
|
||||||
|
errs <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := s.medium.Read(path); err != nil {
|
||||||
|
errs <- err
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
close(errs)
|
||||||
|
|
||||||
|
for err := range errs {
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("concurrent medium access failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcp_Medium_Good_EnsureDir(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
s, err := New(Options{WorkspaceRoot: tmpDir})
|
s, err := New(Options{WorkspaceRoot: tmpDir})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -216,7 +359,36 @@ func TestMedium_Good_EnsureDir(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFileExists_Good_FileAndDirectory(t *testing.T) {
|
func TestMcp_Medium_Bad_EnsureDirOverFile(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
s, err := New(Options{WorkspaceRoot: tmpDir})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create service: %v", err)
|
||||||
|
}
|
||||||
|
if err := s.medium.Write("same", "content"); err != nil {
|
||||||
|
t.Fatalf("Failed to write file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.medium.EnsureDir("same"); err == nil {
|
||||||
|
t.Fatal("expected EnsureDir over an existing file to fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcp_Medium_Ugly_EnsureDirIdempotentNestedBoundary(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
s, err := New(Options{WorkspaceRoot: tmpDir})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
if err := s.medium.EnsureDir("subdir/nested"); err != nil {
|
||||||
|
t.Fatalf("EnsureDir call %d failed: %v", i+1, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcp_FileExists_Good_FileAndDirectory(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
s, err := New(Options{WorkspaceRoot: tmpDir})
|
s, err := New(Options{WorkspaceRoot: tmpDir})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -253,7 +425,31 @@ func TestFileExists_Good_FileAndDirectory(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestListDirectory_Good_ReturnsDocumentedEntryPaths(t *testing.T) {
|
func TestMcp_FileExists_Bad_MissingPath(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
s, err := New(Options{WorkspaceRoot: tmpDir})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, out, err := s.fileExists(nil, nil, FileExistsInput{Path: "missing.txt"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("fileExists(missing) failed: %v", err)
|
||||||
|
}
|
||||||
|
if out.Exists || out.IsDir {
|
||||||
|
t.Fatalf("expected missing path to be reported absent, got %+v", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcp_FileExists_Ugly_NilMedium(t *testing.T) {
|
||||||
|
s := &Service{}
|
||||||
|
|
||||||
|
if _, _, err := s.fileExists(nil, nil, FileExistsInput{Path: "anything"}); err == nil {
|
||||||
|
t.Fatal("expected fileExists to fail when medium is nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcp_ListDirectory_Good_ReturnsDocumentedEntryPaths(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
s, err := New(Options{WorkspaceRoot: tmpDir})
|
s, err := New(Options{WorkspaceRoot: tmpDir})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -281,7 +477,45 @@ func TestListDirectory_Good_ReturnsDocumentedEntryPaths(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMedium_Good_IsFile(t *testing.T) {
|
func TestMcp_ListDirectory_Bad_MissingDirectory(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
s, err := New(Options{WorkspaceRoot: tmpDir})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, _, err := s.listDirectory(nil, nil, ListDirectoryInput{Path: "missing"}); err == nil {
|
||||||
|
t.Fatal("expected listing a missing directory to fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcp_ListDirectory_Ugly_SortsEntries(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
s, err := New(Options{WorkspaceRoot: tmpDir})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create service: %v", err)
|
||||||
|
}
|
||||||
|
for _, name := range []string{"b.txt", "a.txt", "c.txt"} {
|
||||||
|
if err := s.medium.Write(filepath.Join("nested", name), "content"); err != nil {
|
||||||
|
t.Fatalf("Failed to write %s: %v", name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, out, err := s.listDirectory(nil, nil, ListDirectoryInput{Path: "nested"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listDirectory failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(out.Entries) != 3 {
|
||||||
|
t.Fatalf("expected three entries, got %d", len(out.Entries))
|
||||||
|
}
|
||||||
|
for i, want := range []string{"a.txt", "b.txt", "c.txt"} {
|
||||||
|
if out.Entries[i].Name != want {
|
||||||
|
t.Fatalf("entry %d = %q, want %q", i, out.Entries[i].Name, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcp_Medium_Good_IsFile(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
s, err := New(Options{WorkspaceRoot: tmpDir})
|
s, err := New(Options{WorkspaceRoot: tmpDir})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -302,7 +536,34 @@ func TestMedium_Good_IsFile(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolveWorkspacePath_Good(t *testing.T) {
|
func TestMcp_Medium_Bad_IsFileEmptyPath(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
s, err := New(Options{WorkspaceRoot: tmpDir})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.medium.IsFile("") {
|
||||||
|
t.Fatal("empty path should not be a file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcp_Medium_Ugly_IsFileDirectoryBoundary(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
s, err := New(Options{WorkspaceRoot: tmpDir})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create service: %v", err)
|
||||||
|
}
|
||||||
|
if err := s.medium.EnsureDir("nested"); err != nil {
|
||||||
|
t.Fatalf("Failed to create directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.medium.IsFile("nested") {
|
||||||
|
t.Fatal("directory should not be reported as a file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcp_ResolveWorkspacePath_Good(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
s, err := New(Options{WorkspaceRoot: tmpDir})
|
s, err := New(Options{WorkspaceRoot: tmpDir})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -322,7 +583,7 @@ func TestResolveWorkspacePath_Good(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolveWorkspacePath_Good_Unrestricted(t *testing.T) {
|
func TestMcp_ResolveWorkspacePath_Good_Unrestricted(t *testing.T) {
|
||||||
s, err := New(Options{Unrestricted: true})
|
s, err := New(Options{Unrestricted: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create service: %v", err)
|
t.Fatalf("Failed to create service: %v", err)
|
||||||
|
|
@ -336,7 +597,33 @@ func TestResolveWorkspacePath_Good_Unrestricted(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSandboxing_Traversal_Sanitized(t *testing.T) {
|
func TestMcp_ResolveWorkspacePath_Bad_EmptyPath(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
s, err := New(Options{WorkspaceRoot: tmpDir})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := s.resolveWorkspacePath(""); got != "" {
|
||||||
|
t.Fatalf("resolveWorkspacePath(empty) = %q, want empty", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcp_ResolveWorkspacePath_Ugly_TraversalSanitized(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
s, err := New(Options{WorkspaceRoot: tmpDir})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := s.resolveWorkspacePath("../../secret.txt")
|
||||||
|
want := filepath.Join(tmpDir, "secret.txt")
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("resolveWorkspacePath(traversal) = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMcp_Medium_Ugly_TraversalSanitized(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
s, err := New(Options{WorkspaceRoot: tmpDir})
|
s, err := New(Options{WorkspaceRoot: tmpDir})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -356,7 +643,7 @@ func TestSandboxing_Traversal_Sanitized(t *testing.T) {
|
||||||
// should validate inputs before calling Medium.
|
// should validate inputs before calling Medium.
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSandboxing_Symlinks_Blocked(t *testing.T) {
|
func TestMcp_Medium_Ugly_SymlinksBlocked(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
outsideDir := t.TempDir()
|
outsideDir := t.TempDir()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,17 +7,17 @@
|
||||||
package mcp
|
package mcp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"context"
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"iter"
|
"iter"
|
||||||
"os"
|
"os" // Note: required for process stdout; core Fs/Env do not expose a stdio writer.
|
||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -203,7 +203,7 @@ func (s *Service) ChannelSend(ctx context.Context, channel string, data any) {
|
||||||
if s == nil || s.server == nil {
|
if s == nil || s.server == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(channel) == "" {
|
if core.Trim(channel) == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ctx = normalizeNotificationContext(ctx)
|
ctx = normalizeNotificationContext(ctx)
|
||||||
|
|
@ -218,7 +218,7 @@ func (s *Service) ChannelSendToSession(ctx context.Context, session *mcp.ServerS
|
||||||
if s == nil || s.server == nil || session == nil {
|
if s == nil || s.server == nil || session == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(channel) == "" {
|
if core.Trim(channel) == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ctx = normalizeNotificationContext(ctx)
|
ctx = normalizeNotificationContext(ctx)
|
||||||
|
|
@ -275,6 +275,15 @@ func (s *Service) debugNotify(msg string, args ...any) {
|
||||||
s.logger.Debug(msg, args...)
|
s.logger.Debug(msg, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NotifySession sends a raw JSON-RPC notification to a specific MCP session.
|
||||||
|
//
|
||||||
|
// coremcp.NotifySession(ctx, session, "notifications/claude/channel", map[string]any{
|
||||||
|
// "content": "build failed", "meta": map[string]string{"severity": "high"},
|
||||||
|
// })
|
||||||
|
func NotifySession(ctx context.Context, session *mcp.ServerSession, method string, payload any) error {
|
||||||
|
return sendSessionNotification(ctx, session, method, payload)
|
||||||
|
}
|
||||||
|
|
||||||
func sendSessionNotification(ctx context.Context, session *mcp.ServerSession, method string, payload any) error {
|
func sendSessionNotification(ctx context.Context, session *mcp.ServerSession, method string, payload any) error {
|
||||||
if session == nil {
|
if session == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -353,8 +362,8 @@ func snapshotSessions(server *mcp.Server) []*mcp.ServerSession {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Slice(sessions, func(i, j int) bool {
|
slices.SortFunc(sessions, func(a, b *mcp.ServerSession) int {
|
||||||
return sessions[i].ID() < sessions[j].ID()
|
return cmp.Compare(a.ID(), b.ID())
|
||||||
})
|
})
|
||||||
|
|
||||||
return sessions
|
return sessions
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,9 @@ package mcp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
)
|
)
|
||||||
|
|
||||||
type processRuntime struct {
|
type processRuntime struct {
|
||||||
|
|
@ -50,19 +50,20 @@ func (s *Service) forgetProcessRuntime(id string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func isTestProcess(command string, args []string) bool {
|
func isTestProcess(command string, args []string) bool {
|
||||||
base := strings.ToLower(filepath.Base(command))
|
base := core.Lower(core.PathBase(command))
|
||||||
if base == "" {
|
if base == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
switch base {
|
switch base {
|
||||||
case "go":
|
case "go":
|
||||||
return len(args) > 0 && strings.EqualFold(args[0], "test")
|
return len(args) > 0 && core.Lower(args[0]) == "test"
|
||||||
case "cargo":
|
case "cargo":
|
||||||
return len(args) > 0 && strings.EqualFold(args[0], "test")
|
return len(args) > 0 && core.Lower(args[0]) == "test"
|
||||||
case "npm", "pnpm", "yarn", "bun":
|
case "npm", "pnpm", "yarn", "bun":
|
||||||
for _, arg := range args {
|
for _, arg := range args {
|
||||||
if strings.EqualFold(arg, "test") || strings.HasPrefix(strings.ToLower(arg), "test:") {
|
lower := core.Lower(arg)
|
||||||
|
if lower == "test" || core.HasPrefix(lower, "test:") {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
61
pkg/mcp/progress.go
Normal file
61
pkg/mcp/progress.go
Normal file
|
|
@ -0,0 +1,61 @@
|
||||||
|
// SPDX-License-Identifier: EUPL-1.2
|
||||||
|
|
||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProgressTokenFromRequest extracts _meta.progressToken from an MCP tool call.
|
||||||
|
func ProgressTokenFromRequest(req *sdkmcp.CallToolRequest) any {
|
||||||
|
if req == nil || req.Params == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return req.Params.GetProgressToken()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendProgressNotification emits notifications/progress when the caller supplied
|
||||||
|
// _meta.progressToken. Calls without a token or MCP session are no-ops.
|
||||||
|
func SendProgressNotification(ctx context.Context, req *sdkmcp.CallToolRequest, progress float64, total float64, message string) error {
|
||||||
|
token := ProgressTokenFromRequest(req)
|
||||||
|
if req == nil || req.Session == nil || token == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return req.Session.NotifyProgress(ctx, &sdkmcp.ProgressNotificationParams{
|
||||||
|
ProgressToken: token,
|
||||||
|
Progress: progress,
|
||||||
|
Total: total,
|
||||||
|
Message: message,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProgressNotifier caches the request progress token for multi-step tools.
|
||||||
|
type ProgressNotifier struct {
|
||||||
|
ctx context.Context
|
||||||
|
req *sdkmcp.CallToolRequest
|
||||||
|
token any
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProgressNotifier prepares repeated notifications for a single tool call.
|
||||||
|
func NewProgressNotifier(ctx context.Context, req *sdkmcp.CallToolRequest) ProgressNotifier {
|
||||||
|
return ProgressNotifier{
|
||||||
|
ctx: ctx,
|
||||||
|
req: req,
|
||||||
|
token: ProgressTokenFromRequest(req),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send emits a progress notification when the tool call includes a token.
|
||||||
|
func (n ProgressNotifier) Send(progress float64, total float64, message string) error {
|
||||||
|
if n.req == nil || n.req.Session == nil || n.token == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return n.req.Session.NotifyProgress(n.ctx, &sdkmcp.ProgressNotificationParams{
|
||||||
|
ProgressToken: n.token,
|
||||||
|
Progress: progress,
|
||||||
|
Total: total,
|
||||||
|
Message: message,
|
||||||
|
})
|
||||||
|
}
|
||||||
43
pkg/mcp/progress_test.go
Normal file
43
pkg/mcp/progress_test.go
Normal file
|
|
@ -0,0 +1,43 @@
|
||||||
|
// SPDX-License-Identifier: EUPL-1.2
|
||||||
|
|
||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProgressTokenFromRequest_Good_ExtractsMetaToken(t *testing.T) {
|
||||||
|
req := &sdkmcp.CallToolRequest{Params: &sdkmcp.CallToolParamsRaw{}}
|
||||||
|
req.Params.SetProgressToken("dispatch-123")
|
||||||
|
|
||||||
|
if got := ProgressTokenFromRequest(req); got != "dispatch-123" {
|
||||||
|
t.Fatalf("expected progress token dispatch-123, got %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProgressTokenFromRequest_Good_NilSafe(t *testing.T) {
|
||||||
|
if got := ProgressTokenFromRequest(nil); got != nil {
|
||||||
|
t.Fatalf("expected nil token from nil request, got %v", got)
|
||||||
|
}
|
||||||
|
req := &sdkmcp.CallToolRequest{}
|
||||||
|
if got := ProgressTokenFromRequest(req); got != nil {
|
||||||
|
t.Fatalf("expected nil token from request without params, got %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendProgressNotification_Good_NoopsWithoutSession(t *testing.T) {
|
||||||
|
req := &sdkmcp.CallToolRequest{Params: &sdkmcp.CallToolParamsRaw{}}
|
||||||
|
req.Params.SetProgressToken("process-1")
|
||||||
|
|
||||||
|
if err := SendProgressNotification(context.Background(), req, 1, 2, "started"); err != nil {
|
||||||
|
t.Fatalf("expected no-op without session, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
notifier := NewProgressNotifier(context.Background(), req)
|
||||||
|
if err := notifier.Send(2, 2, "done"); err != nil {
|
||||||
|
t.Fatalf("expected no-op notifier without session, got error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -7,8 +7,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
core "dappco.re/go/core"
|
core "dappco.re/go/core"
|
||||||
"forge.lthn.ai/core/go-process"
|
"dappco.re/go/process"
|
||||||
"forge.lthn.ai/core/go-ws"
|
"dappco.re/go/ws"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Register is the service factory for core.WithService.
|
// Register is the service factory for core.WithService.
|
||||||
|
|
@ -98,6 +98,7 @@ func (s *Service) OnStartup(ctx context.Context) core.Result {
|
||||||
// HandleIPCEvents implements Core's IPC handler interface.
|
// HandleIPCEvents implements Core's IPC handler interface.
|
||||||
//
|
//
|
||||||
// c.ACTION(mcp.ChannelPush{Channel: "agent.status", Data: statusMap})
|
// c.ACTION(mcp.ChannelPush{Channel: "agent.status", Data: statusMap})
|
||||||
|
//
|
||||||
// Catches ChannelPush messages from other services and pushes them to Claude Code sessions.
|
// Catches ChannelPush messages from other services and pushes them to Claude Code sessions.
|
||||||
func (s *Service) HandleIPCEvents(c *core.Core, msg core.Message) core.Result {
|
func (s *Service) HandleIPCEvents(c *core.Core, msg core.Message) core.Result {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
@ -109,7 +110,7 @@ func (s *Service) HandleIPCEvents(c *core.Core, msg core.Message) core.Result {
|
||||||
|
|
||||||
switch ev := msg.(type) {
|
switch ev := msg.(type) {
|
||||||
case ChannelPush:
|
case ChannelPush:
|
||||||
s.ChannelSend(ctx, ev.Channel, ev.Data)
|
return s.handleChannelPushIPC(ctx, ev)
|
||||||
case process.ActionProcessStarted:
|
case process.ActionProcessStarted:
|
||||||
startedAt := time.Now()
|
startedAt := time.Now()
|
||||||
s.recordProcessRuntime(ev.ID, processRuntime{
|
s.recordProcessRuntime(ev.ID, processRuntime{
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"dappco.re/go/core"
|
"dappco.re/go/core"
|
||||||
"forge.lthn.ai/core/go-process"
|
"dappco.re/go/process"
|
||||||
"forge.lthn.ai/core/go-ws"
|
"dappco.re/go/ws"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRegister_Good_WiresOptionalServices(t *testing.T) {
|
func TestRegister_Good_WiresOptionalServices(t *testing.T) {
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,43 @@ type ToolRecord struct {
|
||||||
// return nil, ReadFileOutput{Path: "src/main.go"}, nil
|
// return nil, ReadFileOutput{Path: "src/main.go"}, nil
|
||||||
// })
|
// })
|
||||||
func AddToolRecorded[In, Out any](s *Service, server *mcp.Server, group string, t *mcp.Tool, h mcp.ToolHandlerFor[In, Out]) {
|
func AddToolRecorded[In, Out any](s *Service, server *mcp.Server, group string, t *mcp.Tool, h mcp.ToolHandlerFor[In, Out]) {
|
||||||
mcp.AddTool(server, t, h)
|
// Set inputSchema from struct reflection if not already set.
|
||||||
|
// Use server.AddTool (non-generic) to avoid auto-generated outputSchema.
|
||||||
|
// The go-sdk's generic mcp.AddTool generates outputSchema from the Out type,
|
||||||
|
// but Claude Code's protocol (2025-03-26) doesn't support outputSchema.
|
||||||
|
// Removing it reduces tools/list from 214KB to ~74KB.
|
||||||
|
if t.InputSchema == nil {
|
||||||
|
t.InputSchema = structSchema(new(In))
|
||||||
|
if t.InputSchema == nil {
|
||||||
|
t.InputSchema = map[string]any{"type": "object"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Wrap the typed handler into a generic ToolHandler.
|
||||||
|
wrapped := func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
|
var input In
|
||||||
|
if req != nil && len(req.Params.Arguments) > 0 {
|
||||||
|
if r := core.JSONUnmarshal(req.Params.Arguments, &input); !r.OK {
|
||||||
|
if err, ok := r.Value.(error); ok {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := s.authorizeToolAccess(ctx, req, t.Name, input); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result, output, err := h(ctx, req, input)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if result != nil {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
data := core.JSONMarshalString(output)
|
||||||
|
return &mcp.CallToolResult{
|
||||||
|
Content: []mcp.Content{&mcp.TextContent{Text: data}},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
server.AddTool(t, wrapped)
|
||||||
|
|
||||||
restHandler := func(ctx context.Context, body []byte) (any, error) {
|
restHandler := func(ctx context.Context, body []byte) (any, error) {
|
||||||
var input In
|
var input In
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-process"
|
"dappco.re/go/process"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestToolRegistry_Good_RecordsTools(t *testing.T) {
|
func TestToolRegistry_Good_RecordsTools(t *testing.T) {
|
||||||
|
|
@ -71,13 +71,19 @@ func TestToolRegistry_Good_ToolCount(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
tools := svc.Tools()
|
tools := svc.Tools()
|
||||||
// Built-in tools: file_read, file_write, file_delete, file_rename,
|
// Built-in tools (no ProcessService / WSHub / Subsystems):
|
||||||
// file_exists, file_edit, dir_list, dir_create, lang_detect, lang_list,
|
// files (8): file_read, file_write, file_delete, file_rename,
|
||||||
// metrics_record, metrics_query, rag_query, rag_ingest, rag_collections,
|
// file_exists, file_edit, dir_list, dir_create
|
||||||
// webview_connect, webview_disconnect, webview_navigate, webview_click,
|
// language (2): lang_detect, lang_list
|
||||||
// webview_type, webview_query, webview_console, webview_eval,
|
// metrics (2): metrics_record, metrics_query
|
||||||
// webview_screenshot, webview_wait
|
// rag (6): rag_query, rag_search, rag_ingest, rag_index,
|
||||||
const expectedCount = 25
|
// rag_retrieve, rag_collections
|
||||||
|
// webview (12): webview_connect, webview_disconnect, webview_navigate,
|
||||||
|
// webview_click, webview_type, webview_query,
|
||||||
|
// webview_console, webview_eval, webview_screenshot,
|
||||||
|
// webview_wait, webview_render, webview_update
|
||||||
|
// ws (3): ws_connect, ws_send, ws_close
|
||||||
|
const expectedCount = 33
|
||||||
if len(tools) != expectedCount {
|
if len(tools) != expectedCount {
|
||||||
t.Errorf("expected %d tools, got %d", expectedCount, len(tools))
|
t.Errorf("expected %d tools, got %d", expectedCount, len(tools))
|
||||||
for _, tr := range tools {
|
for _, tr := range tools {
|
||||||
|
|
@ -95,8 +101,8 @@ func TestToolRegistry_Good_GroupAssignment(t *testing.T) {
|
||||||
fileTools := []string{"file_read", "file_write", "file_delete", "file_rename", "file_exists", "file_edit", "dir_list", "dir_create"}
|
fileTools := []string{"file_read", "file_write", "file_delete", "file_rename", "file_exists", "file_edit", "dir_list", "dir_create"}
|
||||||
langTools := []string{"lang_detect", "lang_list"}
|
langTools := []string{"lang_detect", "lang_list"}
|
||||||
metricsTools := []string{"metrics_record", "metrics_query"}
|
metricsTools := []string{"metrics_record", "metrics_query"}
|
||||||
ragTools := []string{"rag_query", "rag_ingest", "rag_collections"}
|
ragTools := []string{"rag_query", "rag_search", "rag_ingest", "rag_index", "rag_retrieve", "rag_collections"}
|
||||||
webviewTools := []string{"webview_connect", "webview_disconnect", "webview_navigate", "webview_click", "webview_type", "webview_query", "webview_console", "webview_eval", "webview_screenshot", "webview_wait"}
|
webviewTools := []string{"webview_connect", "webview_disconnect", "webview_navigate", "webview_click", "webview_type", "webview_query", "webview_console", "webview_eval", "webview_screenshot", "webview_wait", "webview_render", "webview_update"}
|
||||||
|
|
||||||
byName := make(map[string]ToolRecord)
|
byName := make(map[string]ToolRecord)
|
||||||
for _, tr := range svc.Tools() {
|
for _, tr := range svc.Tools() {
|
||||||
|
|
@ -157,6 +163,18 @@ func TestToolRegistry_Good_GroupAssignment(t *testing.T) {
|
||||||
t.Errorf("tool %s: expected group 'webview', got %q", name, tr.Group)
|
t.Errorf("tool %s: expected group 'webview', got %q", name, tr.Group)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wsClientTools := []string{"ws_connect", "ws_send", "ws_close"}
|
||||||
|
for _, name := range wsClientTools {
|
||||||
|
tr, ok := byName[name]
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("tool %s not found in registry", name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if tr.Group != "ws" {
|
||||||
|
t.Errorf("tool %s: expected group 'ws', got %q", name, tr.Group)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestToolRegistry_Good_ToolRecordFields(t *testing.T) {
|
func TestToolRegistry_Good_ToolRecordFields(t *testing.T) {
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
core "dappco.re/go/core"
|
core "dappco.re/go/core"
|
||||||
"forge.lthn.ai/core/go-ai/ai"
|
"dappco.re/go/ai/ai"
|
||||||
"forge.lthn.ai/core/go-log"
|
"dappco.re/go/log"
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,8 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-log"
|
"dappco.re/go/log"
|
||||||
"forge.lthn.ai/core/go-process"
|
"dappco.re/go/process"
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -29,6 +29,32 @@ type ProcessStartInput struct {
|
||||||
Env []string `json:"env,omitempty"` // e.g. ["CGO_ENABLED=0"]
|
Env []string `json:"env,omitempty"` // e.g. ["CGO_ENABLED=0"]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProcessRunInput contains parameters for running a command to completion
|
||||||
|
// and returning its captured output.
|
||||||
|
//
|
||||||
|
// input := ProcessRunInput{
|
||||||
|
// Command: "go",
|
||||||
|
// Args: []string{"test", "./..."},
|
||||||
|
// Dir: "/home/user/project",
|
||||||
|
// Env: []string{"CGO_ENABLED=0"},
|
||||||
|
// }
|
||||||
|
type ProcessRunInput struct {
|
||||||
|
Command string `json:"command"` // e.g. "go"
|
||||||
|
Args []string `json:"args,omitempty"` // e.g. ["test", "./..."]
|
||||||
|
Dir string `json:"dir,omitempty"` // e.g. "/home/user/project"
|
||||||
|
Env []string `json:"env,omitempty"` // e.g. ["CGO_ENABLED=0"]
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessRunOutput contains the result of running a process to completion.
|
||||||
|
//
|
||||||
|
// // out.ID == "proc-abc123", out.ExitCode == 0, out.Output == "PASS\n..."
|
||||||
|
type ProcessRunOutput struct {
|
||||||
|
ID string `json:"id"` // e.g. "proc-abc123"
|
||||||
|
ExitCode int `json:"exitCode"` // 0 on success
|
||||||
|
Output string `json:"output"` // combined stdout/stderr
|
||||||
|
Command string `json:"command"` // e.g. "go"
|
||||||
|
}
|
||||||
|
|
||||||
// ProcessStartOutput contains the result of starting a process.
|
// ProcessStartOutput contains the result of starting a process.
|
||||||
//
|
//
|
||||||
// // out.ID == "proc-abc123", out.PID == 54321, out.Command == "go"
|
// // out.ID == "proc-abc123", out.PID == 54321, out.Command == "go"
|
||||||
|
|
@ -146,6 +172,11 @@ func (s *Service) registerProcessTools(server *mcp.Server) bool {
|
||||||
Description: "Start a new external process. Returns process ID for tracking.",
|
Description: "Start a new external process. Returns process ID for tracking.",
|
||||||
}, s.processStart)
|
}, s.processStart)
|
||||||
|
|
||||||
|
addToolRecorded(s, server, "process", &mcp.Tool{
|
||||||
|
Name: "process_run",
|
||||||
|
Description: "Run a command to completion and return the captured output. Blocks until the process exits.",
|
||||||
|
}, s.processRun)
|
||||||
|
|
||||||
addToolRecorded(s, server, "process", &mcp.Tool{
|
addToolRecorded(s, server, "process", &mcp.Tool{
|
||||||
Name: "process_stop",
|
Name: "process_stop",
|
||||||
Description: "Gracefully stop a running process by ID.",
|
Description: "Gracefully stop a running process by ID.",
|
||||||
|
|
@ -224,6 +255,68 @@ func (s *Service) processStart(ctx context.Context, req *mcp.CallToolRequest, in
|
||||||
return nil, output, nil
|
return nil, output, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// processRun handles the process_run tool call.
|
||||||
|
// Executes the command to completion and returns the captured output.
|
||||||
|
func (s *Service) processRun(ctx context.Context, req *mcp.CallToolRequest, input ProcessRunInput) (*mcp.CallToolResult, ProcessRunOutput, error) {
|
||||||
|
if s.processService == nil {
|
||||||
|
return nil, ProcessRunOutput{}, log.E("processRun", "process service unavailable", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
progress := NewProgressNotifier(ctx, req)
|
||||||
|
s.logger.Security("MCP tool execution", "tool", "process_run", "command", input.Command, "args", input.Args, "dir", input.Dir, "user", log.Username())
|
||||||
|
|
||||||
|
if input.Command == "" {
|
||||||
|
return nil, ProcessRunOutput{}, log.E("processRun", "command cannot be empty", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := process.RunOptions{
|
||||||
|
Command: input.Command,
|
||||||
|
Args: input.Args,
|
||||||
|
Dir: s.resolveWorkspacePath(input.Dir),
|
||||||
|
Env: input.Env,
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = progress.Send(0, 2, "starting process")
|
||||||
|
proc, err := s.processService.StartWithOptions(ctx, opts)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("mcp: process run start failed", "command", input.Command, "err", err)
|
||||||
|
return nil, ProcessRunOutput{}, log.E("processRun", "failed to start process", err)
|
||||||
|
}
|
||||||
|
_ = progress.Send(1, 2, "process started")
|
||||||
|
|
||||||
|
info := proc.Info()
|
||||||
|
s.recordProcessRuntime(proc.ID, processRuntime{
|
||||||
|
Command: proc.Command,
|
||||||
|
Args: proc.Args,
|
||||||
|
Dir: info.Dir,
|
||||||
|
StartedAt: proc.StartedAt,
|
||||||
|
})
|
||||||
|
s.ChannelSend(ctx, ChannelProcessStart, map[string]any{
|
||||||
|
"id": proc.ID,
|
||||||
|
"pid": info.PID,
|
||||||
|
"command": proc.Command,
|
||||||
|
"args": proc.Args,
|
||||||
|
"dir": info.Dir,
|
||||||
|
"startedAt": proc.StartedAt,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Wait for completion (context-aware).
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
_ = progress.Send(2, 2, "process cancelled")
|
||||||
|
return nil, ProcessRunOutput{}, log.E("processRun", "cancelled", ctx.Err())
|
||||||
|
case <-proc.Done():
|
||||||
|
}
|
||||||
|
_ = progress.Send(2, 2, "process completed")
|
||||||
|
|
||||||
|
return nil, ProcessRunOutput{
|
||||||
|
ID: proc.ID,
|
||||||
|
ExitCode: proc.ExitCode,
|
||||||
|
Output: proc.Output(),
|
||||||
|
Command: proc.Command,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// processStop handles the process_stop tool call.
|
// processStop handles the process_stop tool call.
|
||||||
func (s *Service) processStop(ctx context.Context, req *mcp.CallToolRequest, input ProcessStopInput) (*mcp.CallToolResult, ProcessStopOutput, error) {
|
func (s *Service) processStop(ctx context.Context, req *mcp.CallToolRequest, input ProcessStopInput) (*mcp.CallToolResult, ProcessStopOutput, error) {
|
||||||
if s.processService == nil {
|
if s.processService == nil {
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"dappco.re/go/core"
|
"dappco.re/go/core"
|
||||||
"forge.lthn.ai/core/go-process"
|
"dappco.re/go/process"
|
||||||
)
|
)
|
||||||
|
|
||||||
// newTestProcessService creates a real process.Service backed by a core.Core for CI tests.
|
// newTestProcessService creates a real process.Service backed by a core.Core for CI tests.
|
||||||
|
|
|
||||||
|
|
@ -301,3 +301,57 @@ func TestRegisterProcessTools_Bad_NilService(t *testing.T) {
|
||||||
t.Error("Expected registerProcessTools to return false when processService is nil")
|
t.Error("Expected registerProcessTools to return false when processService is nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestToolsProcess_ProcessRunInput_Good exercises the process_run input DTO shape.
|
||||||
|
func TestToolsProcess_ProcessRunInput_Good(t *testing.T) {
|
||||||
|
input := ProcessRunInput{
|
||||||
|
Command: "echo",
|
||||||
|
Args: []string{"hello"},
|
||||||
|
Dir: "/tmp",
|
||||||
|
Env: []string{"FOO=bar"},
|
||||||
|
}
|
||||||
|
if input.Command != "echo" {
|
||||||
|
t.Errorf("expected command 'echo', got %q", input.Command)
|
||||||
|
}
|
||||||
|
if len(input.Args) != 1 || input.Args[0] != "hello" {
|
||||||
|
t.Errorf("expected args [hello], got %v", input.Args)
|
||||||
|
}
|
||||||
|
if input.Dir != "/tmp" {
|
||||||
|
t.Errorf("expected dir '/tmp', got %q", input.Dir)
|
||||||
|
}
|
||||||
|
if len(input.Env) != 1 {
|
||||||
|
t.Errorf("expected 1 env, got %d", len(input.Env))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolsProcess_ProcessRunOutput_Good exercises the process_run output DTO shape.
|
||||||
|
func TestToolsProcess_ProcessRunOutput_Good(t *testing.T) {
|
||||||
|
output := ProcessRunOutput{
|
||||||
|
ID: "proc-1",
|
||||||
|
ExitCode: 0,
|
||||||
|
Output: "hello\n",
|
||||||
|
Command: "echo",
|
||||||
|
}
|
||||||
|
if output.ID != "proc-1" {
|
||||||
|
t.Errorf("expected id 'proc-1', got %q", output.ID)
|
||||||
|
}
|
||||||
|
if output.ExitCode != 0 {
|
||||||
|
t.Errorf("expected exit code 0, got %d", output.ExitCode)
|
||||||
|
}
|
||||||
|
if output.Output != "hello\n" {
|
||||||
|
t.Errorf("expected output 'hello\\n', got %q", output.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolsProcess_ProcessRun_Bad rejects calls without a process service.
|
||||||
|
func TestToolsProcess_ProcessRun_Bad(t *testing.T) {
|
||||||
|
svc, err := New(Options{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err = svc.processRun(t.Context(), nil, ProcessRunInput{Command: "echo", Args: []string{"hi"}})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when process service is unavailable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,8 @@ import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
core "dappco.re/go/core"
|
core "dappco.re/go/core"
|
||||||
"forge.lthn.ai/core/go-log"
|
"dappco.re/go/log"
|
||||||
"forge.lthn.ai/core/go-rag"
|
"dappco.re/go/rag"
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -83,6 +83,30 @@ type RAGCollectionsInput struct {
|
||||||
ShowStats bool `json:"show_stats,omitempty"` // true to include point counts and status
|
ShowStats bool `json:"show_stats,omitempty"` // true to include point counts and status
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RAGRetrieveInput contains parameters for retrieving chunks from a specific
|
||||||
|
// document source (rather than running a semantic query).
|
||||||
|
//
|
||||||
|
// input := RAGRetrieveInput{
|
||||||
|
// Source: "docs/services.md",
|
||||||
|
// Collection: "core-docs",
|
||||||
|
// Limit: 20,
|
||||||
|
// }
|
||||||
|
type RAGRetrieveInput struct {
|
||||||
|
Source string `json:"source"` // e.g. "docs/services.md"
|
||||||
|
Collection string `json:"collection,omitempty"` // e.g. "core-docs" (default: "hostuk-docs")
|
||||||
|
Limit int `json:"limit,omitempty"` // e.g. 20 (default: 50)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RAGRetrieveOutput contains document chunks for a specific source.
|
||||||
|
//
|
||||||
|
// // len(out.Chunks) == 12, out.Source == "docs/services.md"
|
||||||
|
type RAGRetrieveOutput struct {
|
||||||
|
Source string `json:"source"` // e.g. "docs/services.md"
|
||||||
|
Collection string `json:"collection"` // collection searched
|
||||||
|
Chunks []RAGQueryResult `json:"chunks"` // chunks for the source, ordered by chunkIndex
|
||||||
|
Count int `json:"count"` // number of chunks returned
|
||||||
|
}
|
||||||
|
|
||||||
// CollectionInfo contains information about a Qdrant collection.
|
// CollectionInfo contains information about a Qdrant collection.
|
||||||
//
|
//
|
||||||
// // ci.Name == "core-docs", ci.PointsCount == 1500, ci.Status == "green"
|
// // ci.Name == "core-docs", ci.PointsCount == 1500, ci.Status == "green"
|
||||||
|
|
@ -106,11 +130,28 @@ func (s *Service) registerRAGTools(server *mcp.Server) {
|
||||||
Description: "Query the RAG vector database for relevant documentation. Returns semantically similar content based on the query.",
|
Description: "Query the RAG vector database for relevant documentation. Returns semantically similar content based on the query.",
|
||||||
}, s.ragQuery)
|
}, s.ragQuery)
|
||||||
|
|
||||||
|
// rag_search is the spec-aligned alias for rag_query.
|
||||||
|
addToolRecorded(s, server, "rag", &mcp.Tool{
|
||||||
|
Name: "rag_search",
|
||||||
|
Description: "Semantic search across documents in the RAG vector database. Returns chunks ranked by similarity.",
|
||||||
|
}, s.ragQuery)
|
||||||
|
|
||||||
addToolRecorded(s, server, "rag", &mcp.Tool{
|
addToolRecorded(s, server, "rag", &mcp.Tool{
|
||||||
Name: "rag_ingest",
|
Name: "rag_ingest",
|
||||||
Description: "Ingest documents into the RAG vector database. Supports both single files and directories.",
|
Description: "Ingest documents into the RAG vector database. Supports both single files and directories.",
|
||||||
}, s.ragIngest)
|
}, s.ragIngest)
|
||||||
|
|
||||||
|
// rag_index is the spec-aligned alias for rag_ingest.
|
||||||
|
addToolRecorded(s, server, "rag", &mcp.Tool{
|
||||||
|
Name: "rag_index",
|
||||||
|
Description: "Index a document or directory into the RAG vector database.",
|
||||||
|
}, s.ragIngest)
|
||||||
|
|
||||||
|
addToolRecorded(s, server, "rag", &mcp.Tool{
|
||||||
|
Name: "rag_retrieve",
|
||||||
|
Description: "Retrieve chunks for a specific document source from the RAG vector database.",
|
||||||
|
}, s.ragRetrieve)
|
||||||
|
|
||||||
addToolRecorded(s, server, "rag", &mcp.Tool{
|
addToolRecorded(s, server, "rag", &mcp.Tool{
|
||||||
Name: "rag_collections",
|
Name: "rag_collections",
|
||||||
Description: "List all available collections in the RAG vector database.",
|
Description: "List all available collections in the RAG vector database.",
|
||||||
|
|
@ -216,6 +257,86 @@ func (s *Service) ragIngest(ctx context.Context, req *mcp.CallToolRequest, input
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ragRetrieve handles the rag_retrieve tool call.
|
||||||
|
// Returns chunks for a specific source path by querying the collection with
|
||||||
|
// the source path as the query text and then filtering results down to the
|
||||||
|
// matching source. This preserves the transport abstraction that the rest of
|
||||||
|
// the RAG tools use while producing the document-scoped view callers expect.
|
||||||
|
func (s *Service) ragRetrieve(ctx context.Context, req *mcp.CallToolRequest, input RAGRetrieveInput) (*mcp.CallToolResult, RAGRetrieveOutput, error) {
|
||||||
|
collection := input.Collection
|
||||||
|
if collection == "" {
|
||||||
|
collection = DefaultRAGCollection
|
||||||
|
}
|
||||||
|
limit := input.Limit
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 50
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Info("MCP tool execution", "tool", "rag_retrieve", "source", input.Source, "collection", collection, "limit", limit, "user", log.Username())
|
||||||
|
|
||||||
|
if input.Source == "" {
|
||||||
|
return nil, RAGRetrieveOutput{}, log.E("ragRetrieve", "source cannot be empty", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the source path as the query text — semantically related chunks
|
||||||
|
// will rank highly, and we then keep only chunks whose Source matches.
|
||||||
|
// Over-fetch by an order of magnitude so document-level limits are met
|
||||||
|
// even when the source appears beyond the top-K of the raw query.
|
||||||
|
overfetch := limit * 10
|
||||||
|
if overfetch < 100 {
|
||||||
|
overfetch = 100
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := rag.QueryDocs(ctx, input.Source, collection, overfetch)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("mcp: rag retrieve query failed", "source", input.Source, "collection", collection, "err", err)
|
||||||
|
return nil, RAGRetrieveOutput{}, log.E("ragRetrieve", "failed to retrieve chunks", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks := make([]RAGQueryResult, 0, limit)
|
||||||
|
for _, r := range results {
|
||||||
|
if r.Source != input.Source {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
chunks = append(chunks, RAGQueryResult{
|
||||||
|
Content: r.Text,
|
||||||
|
Source: r.Source,
|
||||||
|
Section: r.Section,
|
||||||
|
Category: r.Category,
|
||||||
|
ChunkIndex: r.ChunkIndex,
|
||||||
|
Score: r.Score,
|
||||||
|
})
|
||||||
|
if len(chunks) >= limit {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sortChunksByIndex(chunks)
|
||||||
|
|
||||||
|
return nil, RAGRetrieveOutput{
|
||||||
|
Source: input.Source,
|
||||||
|
Collection: collection,
|
||||||
|
Chunks: chunks,
|
||||||
|
Count: len(chunks),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sortChunksByIndex sorts chunks in ascending order of chunk index.
|
||||||
|
// Stable ordering keeps ties by their original position.
|
||||||
|
func sortChunksByIndex(chunks []RAGQueryResult) {
|
||||||
|
if len(chunks) <= 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Insertion sort keeps the code dependency-free and is fast enough
|
||||||
|
// for the small result sets rag_retrieve is designed for.
|
||||||
|
for i := 1; i < len(chunks); i++ {
|
||||||
|
j := i
|
||||||
|
for j > 0 && chunks[j-1].ChunkIndex > chunks[j].ChunkIndex {
|
||||||
|
chunks[j-1], chunks[j] = chunks[j], chunks[j-1]
|
||||||
|
j--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ragCollections handles the rag_collections tool call.
|
// ragCollections handles the rag_collections tool call.
|
||||||
func (s *Service) ragCollections(ctx context.Context, req *mcp.CallToolRequest, input RAGCollectionsInput) (*mcp.CallToolResult, RAGCollectionsOutput, error) {
|
func (s *Service) ragCollections(ctx context.Context, req *mcp.CallToolRequest, input RAGCollectionsInput) (*mcp.CallToolResult, RAGCollectionsOutput, error) {
|
||||||
s.logger.Info("MCP tool execution", "tool", "rag_collections", "show_stats", input.ShowStats, "user", log.Username())
|
s.logger.Info("MCP tool execution", "tool", "rag_collections", "show_stats", input.ShowStats, "user", log.Username())
|
||||||
|
|
|
||||||
|
|
@ -171,3 +171,66 @@ func TestRAGCollectionsInput_ShowStats(t *testing.T) {
|
||||||
t.Error("Expected ShowStats to be true")
|
t.Error("Expected ShowStats to be true")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestToolsRag_RAGRetrieveInput_Good exercises the rag_retrieve DTO defaults.
|
||||||
|
func TestToolsRag_RAGRetrieveInput_Good(t *testing.T) {
|
||||||
|
input := RAGRetrieveInput{
|
||||||
|
Source: "docs/index.md",
|
||||||
|
Collection: "core-docs",
|
||||||
|
Limit: 20,
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Source != "docs/index.md" {
|
||||||
|
t.Errorf("expected source docs/index.md, got %q", input.Source)
|
||||||
|
}
|
||||||
|
if input.Limit != 20 {
|
||||||
|
t.Errorf("expected limit 20, got %d", input.Limit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolsRag_RAGRetrieveOutput_Good exercises the rag_retrieve output shape.
|
||||||
|
func TestToolsRag_RAGRetrieveOutput_Good(t *testing.T) {
|
||||||
|
output := RAGRetrieveOutput{
|
||||||
|
Source: "docs/index.md",
|
||||||
|
Collection: "core-docs",
|
||||||
|
Chunks: []RAGQueryResult{
|
||||||
|
{Content: "first", ChunkIndex: 0},
|
||||||
|
{Content: "second", ChunkIndex: 1},
|
||||||
|
},
|
||||||
|
Count: 2,
|
||||||
|
}
|
||||||
|
if output.Count != 2 {
|
||||||
|
t.Fatalf("expected count 2, got %d", output.Count)
|
||||||
|
}
|
||||||
|
if output.Chunks[1].ChunkIndex != 1 {
|
||||||
|
t.Fatalf("expected chunk 1, got %d", output.Chunks[1].ChunkIndex)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolsRag_SortChunksByIndex_Good verifies sort orders by chunk index ascending.
|
||||||
|
func TestToolsRag_SortChunksByIndex_Good(t *testing.T) {
|
||||||
|
chunks := []RAGQueryResult{
|
||||||
|
{ChunkIndex: 3},
|
||||||
|
{ChunkIndex: 1},
|
||||||
|
{ChunkIndex: 2},
|
||||||
|
}
|
||||||
|
sortChunksByIndex(chunks)
|
||||||
|
for i, want := range []int{1, 2, 3} {
|
||||||
|
if chunks[i].ChunkIndex != want {
|
||||||
|
t.Fatalf("index %d: expected chunk %d, got %d", i, want, chunks[i].ChunkIndex)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolsRag_RagRetrieve_Bad rejects empty source paths.
|
||||||
|
func TestToolsRag_RagRetrieve_Bad(t *testing.T) {
|
||||||
|
svc, err := New(Options{WorkspaceRoot: t.TempDir()})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err = svc.ragRetrieve(t.Context(), nil, RAGRetrieveInput{})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for empty source")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,24 +3,23 @@
|
||||||
package mcp
|
package mcp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
// Note: AX-6 — screenshot normalization needs bytes.NewReader for image.Decode on captured byte slices.
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"image"
|
"image"
|
||||||
"image/jpeg"
|
"image/jpeg"
|
||||||
_ "image/png"
|
_ "image/png"
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
core "dappco.re/go/core"
|
core "dappco.re/go/core"
|
||||||
"forge.lthn.ai/core/go-log"
|
"dappco.re/go/log"
|
||||||
"forge.lthn.ai/core/go-webview"
|
"dappco.re/go/webview"
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// webviewMu protects webviewInstance from concurrent access.
|
// webviewMu protects webviewInstance from concurrent access.
|
||||||
var webviewMu sync.Mutex
|
var webviewMu core.Mutex
|
||||||
|
|
||||||
// webviewInstance holds the current webview connection.
|
// webviewInstance holds the current webview connection.
|
||||||
// This is managed by the MCP service.
|
// This is managed by the MCP service.
|
||||||
|
|
@ -271,6 +270,18 @@ func (s *Service) registerWebviewTools(server *mcp.Server) {
|
||||||
Name: "webview_wait",
|
Name: "webview_wait",
|
||||||
Description: "Wait for an element to appear by CSS selector.",
|
Description: "Wait for an element to appear by CSS selector.",
|
||||||
}, s.webviewWait)
|
}, s.webviewWait)
|
||||||
|
|
||||||
|
// Embedded UI rendering — for pushing HTML/state to connected clients
|
||||||
|
// without requiring a Chrome DevTools connection.
|
||||||
|
addToolRecorded(s, server, "webview", &mcp.Tool{
|
||||||
|
Name: "webview_render",
|
||||||
|
Description: "Render HTML in an embedded webview by ID. Broadcasts to connected clients via the webview.render channel.",
|
||||||
|
}, s.webviewRender)
|
||||||
|
|
||||||
|
addToolRecorded(s, server, "webview", &mcp.Tool{
|
||||||
|
Name: "webview_update",
|
||||||
|
Description: "Update the HTML, title, or state of an embedded webview by ID. Broadcasts to connected clients via the webview.update channel.",
|
||||||
|
}, s.webviewUpdate)
|
||||||
}
|
}
|
||||||
|
|
||||||
// webviewConnect handles the webview_connect tool call.
|
// webviewConnect handles the webview_connect tool call.
|
||||||
|
|
@ -554,7 +565,7 @@ func (s *Service) webviewScreenshot(ctx context.Context, req *mcp.CallToolReques
|
||||||
if format == "" {
|
if format == "" {
|
||||||
format = "png"
|
format = "png"
|
||||||
}
|
}
|
||||||
format = strings.ToLower(format)
|
format = core.Lower(format)
|
||||||
|
|
||||||
data, err := webviewInstance.Screenshot()
|
data, err := webviewInstance.Screenshot()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -586,8 +597,8 @@ func normalizeScreenshotData(data []byte, format string) ([]byte, string, error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
var buf bytes.Buffer
|
buf := core.NewBuffer()
|
||||||
if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: 90}); err != nil {
|
if err := jpeg.Encode(buf, img, &jpeg.Options{Quality: 90}); err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
return buf.Bytes(), "jpeg", nil
|
return buf.Bytes(), "jpeg", nil
|
||||||
|
|
@ -649,7 +660,7 @@ func waitForSelector(ctx context.Context, timeout time.Duration, selector string
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if !strings.Contains(err.Error(), "element not found") {
|
if !core.Contains(err.Error(), "element not found") {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
233
pkg/mcp/tools_webview_embed.go
Normal file
233
pkg/mcp/tools_webview_embed.go
Normal file
|
|
@ -0,0 +1,233 @@
|
||||||
|
// SPDX-License-Identifier: EUPL-1.2
|
||||||
|
|
||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
|
"dappco.re/go/log"
|
||||||
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WebviewRenderInput contains parameters for rendering an embedded
|
||||||
|
// HTML view. The named view is stored and broadcast so connected clients
|
||||||
|
// (Claude Code sessions, CoreGUI windows, HTTP/SSE subscribers) can
|
||||||
|
// display the content.
|
||||||
|
//
|
||||||
|
// input := WebviewRenderInput{
|
||||||
|
// ViewID: "dashboard",
|
||||||
|
// HTML: "<div id='app'>Loading...</div>",
|
||||||
|
// Title: "Agent Dashboard",
|
||||||
|
// Width: 1024,
|
||||||
|
// Height: 768,
|
||||||
|
// State: map[string]any{"theme": "dark"},
|
||||||
|
// }
|
||||||
|
type WebviewRenderInput struct {
|
||||||
|
ViewID string `json:"view_id"` // e.g. "dashboard"
|
||||||
|
HTML string `json:"html"` // rendered markup
|
||||||
|
Title string `json:"title,omitempty"` // e.g. "Agent Dashboard"
|
||||||
|
Width int `json:"width,omitempty"` // preferred width in pixels
|
||||||
|
Height int `json:"height,omitempty"` // preferred height in pixels
|
||||||
|
State map[string]any `json:"state,omitempty"` // initial view state
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebviewRenderOutput reports the result of rendering an embedded view.
|
||||||
|
//
|
||||||
|
// // out.Success == true, out.ViewID == "dashboard"
|
||||||
|
type WebviewRenderOutput struct {
|
||||||
|
Success bool `json:"success"` // true when the view was stored and broadcast
|
||||||
|
ViewID string `json:"view_id"` // echoed view identifier
|
||||||
|
UpdatedAt time.Time `json:"updatedAt"` // when the view was rendered
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebviewUpdateInput contains parameters for updating the state of an
|
||||||
|
// existing embedded view. Callers may provide HTML to replace the markup,
|
||||||
|
// patch fields in the view state, or do both.
|
||||||
|
//
|
||||||
|
// input := WebviewUpdateInput{
|
||||||
|
// ViewID: "dashboard",
|
||||||
|
// HTML: "<div id='app'>Ready</div>",
|
||||||
|
// State: map[string]any{"count": 42},
|
||||||
|
// Merge: true,
|
||||||
|
// }
|
||||||
|
type WebviewUpdateInput struct {
|
||||||
|
ViewID string `json:"view_id"` // e.g. "dashboard"
|
||||||
|
HTML string `json:"html,omitempty"` // replacement markup (optional)
|
||||||
|
Title string `json:"title,omitempty"` // e.g. "Agent Dashboard"
|
||||||
|
State map[string]any `json:"state,omitempty"` // partial state update
|
||||||
|
Merge bool `json:"merge,omitempty"` // merge state (default) or replace when false
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebviewUpdateOutput reports the result of updating an embedded view.
|
||||||
|
//
|
||||||
|
// // out.Success == true, out.ViewID == "dashboard"
|
||||||
|
type WebviewUpdateOutput struct {
|
||||||
|
Success bool `json:"success"` // true when the view was updated and broadcast
|
||||||
|
ViewID string `json:"view_id"` // echoed view identifier
|
||||||
|
UpdatedAt time.Time `json:"updatedAt"` // when the view was last updated
|
||||||
|
}
|
||||||
|
|
||||||
|
// embeddedView captures the live state of a rendered UI view. Instances
|
||||||
|
// are kept per ViewID inside embeddedViewRegistry.
|
||||||
|
type embeddedView struct {
|
||||||
|
ViewID string
|
||||||
|
Title string
|
||||||
|
HTML string
|
||||||
|
Width int
|
||||||
|
Height int
|
||||||
|
State map[string]any
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// embeddedViewRegistry stores the most recent render/update state for each
|
||||||
|
// view so new subscribers can pick up the current UI on connection.
|
||||||
|
// Operations are guarded by embeddedViewMu.
|
||||||
|
var (
|
||||||
|
embeddedViewMu sync.RWMutex
|
||||||
|
embeddedViewRegistry = map[string]*embeddedView{}
|
||||||
|
)
|
||||||
|
|
||||||
|
// ChannelWebviewRender is the channel used to broadcast webview_render events.
|
||||||
|
const ChannelWebviewRender = "webview.render"
|
||||||
|
|
||||||
|
// ChannelWebviewUpdate is the channel used to broadcast webview_update events.
|
||||||
|
const ChannelWebviewUpdate = "webview.update"
|
||||||
|
|
||||||
|
// webviewRender handles the webview_render tool call.
|
||||||
|
func (s *Service) webviewRender(ctx context.Context, req *mcp.CallToolRequest, input WebviewRenderInput) (*mcp.CallToolResult, WebviewRenderOutput, error) {
|
||||||
|
s.logger.Info("MCP tool execution", "tool", "webview_render", "view", input.ViewID, "user", log.Username())
|
||||||
|
|
||||||
|
if core.Trim(input.ViewID) == "" {
|
||||||
|
return nil, WebviewRenderOutput{}, log.E("webviewRender", "view_id is required", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
view := &embeddedView{
|
||||||
|
ViewID: input.ViewID,
|
||||||
|
Title: input.Title,
|
||||||
|
HTML: input.HTML,
|
||||||
|
Width: input.Width,
|
||||||
|
Height: input.Height,
|
||||||
|
State: cloneStateMap(input.State),
|
||||||
|
UpdatedAt: now,
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddedViewMu.Lock()
|
||||||
|
embeddedViewRegistry[input.ViewID] = view
|
||||||
|
embeddedViewMu.Unlock()
|
||||||
|
|
||||||
|
s.ChannelSend(ctx, ChannelWebviewRender, map[string]any{
|
||||||
|
"view_id": view.ViewID,
|
||||||
|
"title": view.Title,
|
||||||
|
"html": view.HTML,
|
||||||
|
"width": view.Width,
|
||||||
|
"height": view.Height,
|
||||||
|
"state": cloneStateMap(view.State),
|
||||||
|
"updatedAt": view.UpdatedAt,
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil, WebviewRenderOutput{
|
||||||
|
Success: true,
|
||||||
|
ViewID: view.ViewID,
|
||||||
|
UpdatedAt: view.UpdatedAt,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// webviewUpdate handles the webview_update tool call.
|
||||||
|
func (s *Service) webviewUpdate(ctx context.Context, req *mcp.CallToolRequest, input WebviewUpdateInput) (*mcp.CallToolResult, WebviewUpdateOutput, error) {
|
||||||
|
s.logger.Info("MCP tool execution", "tool", "webview_update", "view", input.ViewID, "user", log.Username())
|
||||||
|
|
||||||
|
if core.Trim(input.ViewID) == "" {
|
||||||
|
return nil, WebviewUpdateOutput{}, log.E("webviewUpdate", "view_id is required", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
embeddedViewMu.Lock()
|
||||||
|
view, ok := embeddedViewRegistry[input.ViewID]
|
||||||
|
if !ok {
|
||||||
|
// Updating a view that was never rendered creates one lazily so
|
||||||
|
// clients that reconnect mid-session get a consistent snapshot.
|
||||||
|
view = &embeddedView{ViewID: input.ViewID, State: map[string]any{}}
|
||||||
|
embeddedViewRegistry[input.ViewID] = view
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.HTML != "" {
|
||||||
|
view.HTML = input.HTML
|
||||||
|
}
|
||||||
|
if input.Title != "" {
|
||||||
|
view.Title = input.Title
|
||||||
|
}
|
||||||
|
if input.State != nil {
|
||||||
|
merge := input.Merge || len(view.State) == 0
|
||||||
|
if merge {
|
||||||
|
if view.State == nil {
|
||||||
|
view.State = map[string]any{}
|
||||||
|
}
|
||||||
|
for k, v := range input.State {
|
||||||
|
view.State[k] = v
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
view.State = cloneStateMap(input.State)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
view.UpdatedAt = now
|
||||||
|
snapshot := *view
|
||||||
|
snapshot.State = cloneStateMap(view.State)
|
||||||
|
embeddedViewMu.Unlock()
|
||||||
|
|
||||||
|
s.ChannelSend(ctx, ChannelWebviewUpdate, map[string]any{
|
||||||
|
"view_id": snapshot.ViewID,
|
||||||
|
"title": snapshot.Title,
|
||||||
|
"html": snapshot.HTML,
|
||||||
|
"width": snapshot.Width,
|
||||||
|
"height": snapshot.Height,
|
||||||
|
"state": snapshot.State,
|
||||||
|
"updatedAt": snapshot.UpdatedAt,
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil, WebviewUpdateOutput{
|
||||||
|
Success: true,
|
||||||
|
ViewID: snapshot.ViewID,
|
||||||
|
UpdatedAt: snapshot.UpdatedAt,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneStateMap returns a shallow copy of a state map.
|
||||||
|
//
|
||||||
|
// cloned := cloneStateMap(map[string]any{"a": 1}) // cloned["a"] == 1
|
||||||
|
func cloneStateMap(in map[string]any) map[string]any {
|
||||||
|
if in == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(map[string]any, len(in))
|
||||||
|
for k, v := range in {
|
||||||
|
out[k] = v
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupEmbeddedView returns the current snapshot of an embedded view, if any.
|
||||||
|
//
|
||||||
|
// view, ok := lookupEmbeddedView("dashboard")
|
||||||
|
func lookupEmbeddedView(id string) (*embeddedView, bool) {
|
||||||
|
embeddedViewMu.RLock()
|
||||||
|
defer embeddedViewMu.RUnlock()
|
||||||
|
view, ok := embeddedViewRegistry[id]
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
snapshot := *view
|
||||||
|
snapshot.State = cloneStateMap(view.State)
|
||||||
|
return &snapshot, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// resetEmbeddedViews clears the registry. Intended for tests.
|
||||||
|
func resetEmbeddedViews() {
|
||||||
|
embeddedViewMu.Lock()
|
||||||
|
defer embeddedViewMu.Unlock()
|
||||||
|
embeddedViewRegistry = map[string]*embeddedView{}
|
||||||
|
}
|
||||||
137
pkg/mcp/tools_webview_embed_test.go
Normal file
137
pkg/mcp/tools_webview_embed_test.go
Normal file
|
|
@ -0,0 +1,137 @@
|
||||||
|
// SPDX-License-Identifier: EUPL-1.2
|
||||||
|
|
||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestToolsWebviewEmbed_WebviewRender_Good registers a view and verifies the
|
||||||
|
// registry keeps the rendered HTML and state.
|
||||||
|
func TestToolsWebviewEmbed_WebviewRender_Good(t *testing.T) {
|
||||||
|
t.Cleanup(resetEmbeddedViews)
|
||||||
|
|
||||||
|
svc, err := New(Options{WorkspaceRoot: t.TempDir()})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, out, err := svc.webviewRender(context.Background(), nil, WebviewRenderInput{
|
||||||
|
ViewID: "dashboard",
|
||||||
|
HTML: "<p>hello</p>",
|
||||||
|
Title: "Demo",
|
||||||
|
State: map[string]any{"count": 1},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("webviewRender returned error: %v", err)
|
||||||
|
}
|
||||||
|
if !out.Success {
|
||||||
|
t.Fatal("expected Success=true")
|
||||||
|
}
|
||||||
|
if out.ViewID != "dashboard" {
|
||||||
|
t.Fatalf("expected view id 'dashboard', got %q", out.ViewID)
|
||||||
|
}
|
||||||
|
if out.UpdatedAt.IsZero() {
|
||||||
|
t.Fatal("expected non-zero UpdatedAt")
|
||||||
|
}
|
||||||
|
|
||||||
|
view, ok := lookupEmbeddedView("dashboard")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected view to be stored in registry")
|
||||||
|
}
|
||||||
|
if view.HTML != "<p>hello</p>" {
|
||||||
|
t.Fatalf("expected HTML '<p>hello</p>', got %q", view.HTML)
|
||||||
|
}
|
||||||
|
if view.State["count"] != 1 {
|
||||||
|
t.Fatalf("expected state.count=1, got %v", view.State["count"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolsWebviewEmbed_WebviewRender_Bad ensures empty view IDs are rejected.
|
||||||
|
func TestToolsWebviewEmbed_WebviewRender_Bad(t *testing.T) {
|
||||||
|
t.Cleanup(resetEmbeddedViews)
|
||||||
|
|
||||||
|
svc, err := New(Options{WorkspaceRoot: t.TempDir()})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err = svc.webviewRender(context.Background(), nil, WebviewRenderInput{})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for empty view_id")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolsWebviewEmbed_WebviewUpdate_Good merges a state patch into the
|
||||||
|
// previously rendered view.
|
||||||
|
func TestToolsWebviewEmbed_WebviewUpdate_Good(t *testing.T) {
|
||||||
|
t.Cleanup(resetEmbeddedViews)
|
||||||
|
|
||||||
|
svc, err := New(Options{WorkspaceRoot: t.TempDir()})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err = svc.webviewRender(context.Background(), nil, WebviewRenderInput{
|
||||||
|
ViewID: "dashboard",
|
||||||
|
HTML: "<p>hello</p>",
|
||||||
|
State: map[string]any{"count": 1},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("seed render failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, out, err := svc.webviewUpdate(context.Background(), nil, WebviewUpdateInput{
|
||||||
|
ViewID: "dashboard",
|
||||||
|
State: map[string]any{"theme": "dark"},
|
||||||
|
Merge: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("webviewUpdate returned error: %v", err)
|
||||||
|
}
|
||||||
|
if !out.Success {
|
||||||
|
t.Fatal("expected Success=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
view, ok := lookupEmbeddedView("dashboard")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected view to exist after update")
|
||||||
|
}
|
||||||
|
if view.State["count"] != 1 {
|
||||||
|
t.Fatalf("expected count to persist after merge, got %v", view.State["count"])
|
||||||
|
}
|
||||||
|
if view.State["theme"] != "dark" {
|
||||||
|
t.Fatalf("expected theme 'dark' after merge, got %v", view.State["theme"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolsWebviewEmbed_WebviewUpdate_Ugly updates a view that was never
|
||||||
|
// rendered and verifies a fresh registry entry is created.
|
||||||
|
func TestToolsWebviewEmbed_WebviewUpdate_Ugly(t *testing.T) {
|
||||||
|
t.Cleanup(resetEmbeddedViews)
|
||||||
|
|
||||||
|
svc, err := New(Options{WorkspaceRoot: t.TempDir()})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, out, err := svc.webviewUpdate(context.Background(), nil, WebviewUpdateInput{
|
||||||
|
ViewID: "ghost",
|
||||||
|
HTML: "<p>new</p>",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("webviewUpdate returned error: %v", err)
|
||||||
|
}
|
||||||
|
if !out.Success {
|
||||||
|
t.Fatal("expected Success=true for lazy-create update")
|
||||||
|
}
|
||||||
|
|
||||||
|
view, ok := lookupEmbeddedView("ghost")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected ghost view to be created lazily")
|
||||||
|
}
|
||||||
|
if view.HTML != "<p>new</p>" {
|
||||||
|
t.Fatalf("expected HTML '<p>new</p>', got %q", view.HTML)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -11,7 +11,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-webview"
|
"dappco.re/go/webview"
|
||||||
)
|
)
|
||||||
|
|
||||||
// skipIfShort skips webview tests in short mode (go test -short).
|
// skipIfShort skips webview tests in short mode (go test -short).
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
core "dappco.re/go/core"
|
core "dappco.re/go/core"
|
||||||
"forge.lthn.ai/core/go-log"
|
"dappco.re/go/log"
|
||||||
"forge.lthn.ai/core/go-ws"
|
"dappco.re/go/ws"
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
264
pkg/mcp/tools_ws_client.go
Normal file
264
pkg/mcp/tools_ws_client.go
Normal file
|
|
@ -0,0 +1,264 @@
|
||||||
|
// SPDX-License-Identifier: EUPL-1.2
|
||||||
|
|
||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
|
"dappco.re/go/log"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WSConnectInput contains parameters for opening an outbound WebSocket
|
||||||
|
// connection from the MCP server. Each connection is given a stable ID that
|
||||||
|
// subsequent ws_send and ws_close calls use to address it.
|
||||||
|
//
|
||||||
|
// input := WSConnectInput{URL: "wss://example.com/ws", Timeout: 10}
|
||||||
|
type WSConnectInput struct {
|
||||||
|
URL string `json:"url"` // e.g. "wss://example.com/ws"
|
||||||
|
Headers map[string]string `json:"headers,omitempty"` // custom request headers
|
||||||
|
Timeout int `json:"timeout,omitempty"` // handshake timeout in seconds (default: 30)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WSConnectOutput contains the result of opening a WebSocket connection.
|
||||||
|
//
|
||||||
|
// // out.Success == true, out.ID == "ws-0af3…"
|
||||||
|
type WSConnectOutput struct {
|
||||||
|
Success bool `json:"success"` // true when the handshake completed
|
||||||
|
ID string `json:"id"` // e.g. "ws-0af3…"
|
||||||
|
URL string `json:"url"` // the URL that was dialled
|
||||||
|
}
|
||||||
|
|
||||||
|
// WSSendInput contains parameters for sending a message on an open
|
||||||
|
// WebSocket connection.
|
||||||
|
//
|
||||||
|
// input := WSSendInput{ID: "ws-0af3…", Message: "ping"}
|
||||||
|
type WSSendInput struct {
|
||||||
|
ID string `json:"id"` // e.g. "ws-0af3…"
|
||||||
|
Message string `json:"message"` // payload to send
|
||||||
|
Binary bool `json:"binary,omitempty"` // true to send a binary frame (payload is base64 text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WSSendOutput contains the result of sending a message.
|
||||||
|
//
|
||||||
|
// // out.Success == true, out.ID == "ws-0af3…"
|
||||||
|
type WSSendOutput struct {
|
||||||
|
Success bool `json:"success"` // true when the message was written
|
||||||
|
ID string `json:"id"` // e.g. "ws-0af3…"
|
||||||
|
Bytes int `json:"bytes"` // number of bytes written
|
||||||
|
}
|
||||||
|
|
||||||
|
// WSCloseInput contains parameters for closing a WebSocket connection.
|
||||||
|
//
|
||||||
|
// input := WSCloseInput{ID: "ws-0af3…", Reason: "done"}
|
||||||
|
type WSCloseInput struct {
|
||||||
|
ID string `json:"id"` // e.g. "ws-0af3…"
|
||||||
|
Code int `json:"code,omitempty"` // close code (default: 1000 - normal closure)
|
||||||
|
Reason string `json:"reason,omitempty"` // human-readable reason
|
||||||
|
}
|
||||||
|
|
||||||
|
// WSCloseOutput contains the result of closing a WebSocket connection.
|
||||||
|
//
|
||||||
|
// // out.Success == true, out.ID == "ws-0af3…"
|
||||||
|
type WSCloseOutput struct {
|
||||||
|
Success bool `json:"success"` // true when the connection was closed
|
||||||
|
ID string `json:"id"` // e.g. "ws-0af3…"
|
||||||
|
Message string `json:"message,omitempty"` // e.g. "connection closed"
|
||||||
|
}
|
||||||
|
|
||||||
|
// wsClientConn tracks an outbound WebSocket connection tied to a stable ID.
|
||||||
|
type wsClientConn struct {
|
||||||
|
ID string
|
||||||
|
URL string
|
||||||
|
conn *websocket.Conn
|
||||||
|
writeMu sync.Mutex
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// wsClientRegistry holds all live outbound WebSocket connections keyed by ID.
|
||||||
|
// Access is guarded by wsClientMu.
|
||||||
|
var (
|
||||||
|
wsClientMu sync.Mutex
|
||||||
|
wsClientRegistry = map[string]*wsClientConn{}
|
||||||
|
)
|
||||||
|
|
||||||
|
// registerWSClientTools registers the outbound WebSocket client tools.
|
||||||
|
func (s *Service) registerWSClientTools(server *mcp.Server) {
|
||||||
|
addToolRecorded(s, server, "ws", &mcp.Tool{
|
||||||
|
Name: "ws_connect",
|
||||||
|
Description: "Open an outbound WebSocket connection. Returns a connection ID for subsequent ws_send and ws_close calls.",
|
||||||
|
}, s.wsConnect)
|
||||||
|
|
||||||
|
addToolRecorded(s, server, "ws", &mcp.Tool{
|
||||||
|
Name: "ws_send",
|
||||||
|
Description: "Send a text or binary message on an open WebSocket connection identified by ID.",
|
||||||
|
}, s.wsSend)
|
||||||
|
|
||||||
|
addToolRecorded(s, server, "ws", &mcp.Tool{
|
||||||
|
Name: "ws_close",
|
||||||
|
Description: "Close an open WebSocket connection identified by ID.",
|
||||||
|
}, s.wsClose)
|
||||||
|
}
|
||||||
|
|
||||||
|
// wsConnect handles the ws_connect tool call.
|
||||||
|
func (s *Service) wsConnect(ctx context.Context, req *mcp.CallToolRequest, input WSConnectInput) (*mcp.CallToolResult, WSConnectOutput, error) {
|
||||||
|
s.logger.Security("MCP tool execution", "tool", "ws_connect", "url", input.URL, "user", log.Username())
|
||||||
|
|
||||||
|
if core.Trim(input.URL) == "" {
|
||||||
|
return nil, WSConnectOutput{}, log.E("wsConnect", "url is required", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := time.Duration(input.Timeout) * time.Second
|
||||||
|
if timeout <= 0 {
|
||||||
|
timeout = 30 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := websocket.Dialer{
|
||||||
|
HandshakeTimeout: timeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := http.Header{}
|
||||||
|
for k, v := range input.Headers {
|
||||||
|
headers.Set(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
dialCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, _, err := dialer.DialContext(dialCtx, input.URL, headers)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("mcp: ws connect failed", "url", input.URL, "err", err)
|
||||||
|
return nil, WSConnectOutput{}, log.E("wsConnect", "failed to connect", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
id := newWSClientID()
|
||||||
|
client := &wsClientConn{
|
||||||
|
ID: id,
|
||||||
|
URL: input.URL,
|
||||||
|
conn: conn,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
wsClientMu.Lock()
|
||||||
|
wsClientRegistry[id] = client
|
||||||
|
wsClientMu.Unlock()
|
||||||
|
|
||||||
|
return nil, WSConnectOutput{
|
||||||
|
Success: true,
|
||||||
|
ID: id,
|
||||||
|
URL: input.URL,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// wsSend handles the ws_send tool call.
|
||||||
|
func (s *Service) wsSend(ctx context.Context, req *mcp.CallToolRequest, input WSSendInput) (*mcp.CallToolResult, WSSendOutput, error) {
|
||||||
|
s.logger.Info("MCP tool execution", "tool", "ws_send", "id", input.ID, "binary", input.Binary, "user", log.Username())
|
||||||
|
|
||||||
|
if core.Trim(input.ID) == "" {
|
||||||
|
return nil, WSSendOutput{}, log.E("wsSend", "id is required", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
client, ok := getWSClient(input.ID)
|
||||||
|
if !ok {
|
||||||
|
return nil, WSSendOutput{}, log.E("wsSend", "connection not found", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
messageType := websocket.TextMessage
|
||||||
|
if input.Binary {
|
||||||
|
messageType = websocket.BinaryMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
client.writeMu.Lock()
|
||||||
|
err := client.conn.WriteMessage(messageType, []byte(input.Message))
|
||||||
|
client.writeMu.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
log.Error("mcp: ws send failed", "id", input.ID, "err", err)
|
||||||
|
return nil, WSSendOutput{}, log.E("wsSend", "failed to send message", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, WSSendOutput{
|
||||||
|
Success: true,
|
||||||
|
ID: input.ID,
|
||||||
|
Bytes: len(input.Message),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// wsClose handles the ws_close tool call.
|
||||||
|
func (s *Service) wsClose(ctx context.Context, req *mcp.CallToolRequest, input WSCloseInput) (*mcp.CallToolResult, WSCloseOutput, error) {
|
||||||
|
s.logger.Info("MCP tool execution", "tool", "ws_close", "id", input.ID, "user", log.Username())
|
||||||
|
|
||||||
|
if core.Trim(input.ID) == "" {
|
||||||
|
return nil, WSCloseOutput{}, log.E("wsClose", "id is required", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
wsClientMu.Lock()
|
||||||
|
client, ok := wsClientRegistry[input.ID]
|
||||||
|
if ok {
|
||||||
|
delete(wsClientRegistry, input.ID)
|
||||||
|
}
|
||||||
|
wsClientMu.Unlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil, WSCloseOutput{}, log.E("wsClose", "connection not found", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
code := input.Code
|
||||||
|
if code == 0 {
|
||||||
|
code = websocket.CloseNormalClosure
|
||||||
|
}
|
||||||
|
reason := input.Reason
|
||||||
|
if reason == "" {
|
||||||
|
reason = "closed"
|
||||||
|
}
|
||||||
|
|
||||||
|
client.writeMu.Lock()
|
||||||
|
_ = client.conn.WriteControl(
|
||||||
|
websocket.CloseMessage,
|
||||||
|
websocket.FormatCloseMessage(code, reason),
|
||||||
|
time.Now().Add(5*time.Second),
|
||||||
|
)
|
||||||
|
client.writeMu.Unlock()
|
||||||
|
_ = client.conn.Close()
|
||||||
|
|
||||||
|
return nil, WSCloseOutput{
|
||||||
|
Success: true,
|
||||||
|
ID: input.ID,
|
||||||
|
Message: "connection closed",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// newWSClientID returns a fresh identifier for an outbound WebSocket client.
|
||||||
|
//
|
||||||
|
// id := newWSClientID() // "ws-0af3…"
|
||||||
|
func newWSClientID() string {
|
||||||
|
var buf [8]byte
|
||||||
|
_, _ = rand.Read(buf[:])
|
||||||
|
return "ws-" + hex.EncodeToString(buf[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// getWSClient returns a tracked outbound WebSocket client by ID, if any.
|
||||||
|
//
|
||||||
|
// client, ok := getWSClient("ws-0af3…")
|
||||||
|
func getWSClient(id string) (*wsClientConn, bool) {
|
||||||
|
wsClientMu.Lock()
|
||||||
|
defer wsClientMu.Unlock()
|
||||||
|
client, ok := wsClientRegistry[id]
|
||||||
|
return client, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// resetWSClients drops all tracked outbound WebSocket clients. Intended for tests.
|
||||||
|
func resetWSClients() {
|
||||||
|
wsClientMu.Lock()
|
||||||
|
defer wsClientMu.Unlock()
|
||||||
|
for id, client := range wsClientRegistry {
|
||||||
|
_ = client.conn.Close()
|
||||||
|
delete(wsClientRegistry, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
169
pkg/mcp/tools_ws_client_test.go
Normal file
169
pkg/mcp/tools_ws_client_test.go
Normal file
|
|
@ -0,0 +1,169 @@
|
||||||
|
// SPDX-License-Identifier: EUPL-1.2
|
||||||
|
|
||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestToolsWSClient_WSConnect_Good dials a test WebSocket server and verifies
|
||||||
|
// the handshake completes and a client ID is assigned.
|
||||||
|
func TestToolsWSClient_WSConnect_Good(t *testing.T) {
|
||||||
|
t.Cleanup(resetWSClients)
|
||||||
|
|
||||||
|
server := startTestWSServer(t)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
svc, err := New(Options{WorkspaceRoot: t.TempDir()})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, out, err := svc.wsConnect(context.Background(), nil, WSConnectInput{
|
||||||
|
URL: "ws" + strings.TrimPrefix(server.URL, "http") + "/ws",
|
||||||
|
Timeout: 5,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("wsConnect failed: %v", err)
|
||||||
|
}
|
||||||
|
if !out.Success {
|
||||||
|
t.Fatal("expected Success=true")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(out.ID, "ws-") {
|
||||||
|
t.Fatalf("expected ID prefix 'ws-', got %q", out.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err = svc.wsClose(context.Background(), nil, WSCloseInput{ID: out.ID})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("wsClose failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolsWSClient_WSConnect_Bad rejects empty URLs.
|
||||||
|
func TestToolsWSClient_WSConnect_Bad(t *testing.T) {
|
||||||
|
t.Cleanup(resetWSClients)
|
||||||
|
|
||||||
|
svc, err := New(Options{WorkspaceRoot: t.TempDir()})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err = svc.wsConnect(context.Background(), nil, WSConnectInput{})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for empty URL")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolsWSClient_WSSendClose_Good sends a message on an open connection
|
||||||
|
// and then closes it.
|
||||||
|
func TestToolsWSClient_WSSendClose_Good(t *testing.T) {
|
||||||
|
t.Cleanup(resetWSClients)
|
||||||
|
|
||||||
|
server := startTestWSServer(t)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
svc, err := New(Options{WorkspaceRoot: t.TempDir()})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, conn, err := svc.wsConnect(context.Background(), nil, WSConnectInput{
|
||||||
|
URL: "ws" + strings.TrimPrefix(server.URL, "http") + "/ws",
|
||||||
|
Timeout: 5,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("wsConnect failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, sendOut, err := svc.wsSend(context.Background(), nil, WSSendInput{
|
||||||
|
ID: conn.ID,
|
||||||
|
Message: "ping",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("wsSend failed: %v", err)
|
||||||
|
}
|
||||||
|
if !sendOut.Success {
|
||||||
|
t.Fatal("expected Success=true for wsSend")
|
||||||
|
}
|
||||||
|
if sendOut.Bytes != 4 {
|
||||||
|
t.Fatalf("expected 4 bytes written, got %d", sendOut.Bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, closeOut, err := svc.wsClose(context.Background(), nil, WSCloseInput{ID: conn.ID})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("wsClose failed: %v", err)
|
||||||
|
}
|
||||||
|
if !closeOut.Success {
|
||||||
|
t.Fatal("expected Success=true for wsClose")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := getWSClient(conn.ID); ok {
|
||||||
|
t.Fatal("expected connection to be removed after close")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolsWSClient_WSSend_Bad rejects unknown connection IDs.
|
||||||
|
func TestToolsWSClient_WSSend_Bad(t *testing.T) {
|
||||||
|
t.Cleanup(resetWSClients)
|
||||||
|
|
||||||
|
svc, err := New(Options{WorkspaceRoot: t.TempDir()})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err = svc.wsSend(context.Background(), nil, WSSendInput{ID: "ws-missing", Message: "x"})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for unknown connection ID")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestToolsWSClient_WSClose_Bad rejects closes for unknown connection IDs.
|
||||||
|
func TestToolsWSClient_WSClose_Bad(t *testing.T) {
|
||||||
|
t.Cleanup(resetWSClients)
|
||||||
|
|
||||||
|
svc, err := New(Options{WorkspaceRoot: t.TempDir()})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err = svc.wsClose(context.Background(), nil, WSCloseInput{ID: "ws-missing"})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for unknown connection ID")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// startTestWSServer returns an httptest.Server running a minimal echo WebSocket
|
||||||
|
// handler used by the ws_connect/ws_send tests.
|
||||||
|
func startTestWSServer(t *testing.T) *httptest.Server {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
upgrader := websocket.Upgrader{
|
||||||
|
CheckOrigin: func(*http.Request) bool { return true },
|
||||||
|
}
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
for {
|
||||||
|
_, msg, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := conn.WriteMessage(websocket.TextMessage, msg); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return httptest.NewServer(mux)
|
||||||
|
}
|
||||||
|
|
@ -3,7 +3,7 @@ package mcp
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-ws"
|
"dappco.re/go/ws"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestWSToolsRegistered_Good verifies that WebSocket tools are registered when hub is available.
|
// TestWSToolsRegistered_Good verifies that WebSocket tools are registered when hub is available.
|
||||||
|
|
|
||||||
476
pkg/mcp/transformer.go
Normal file
476
pkg/mcp/transformer.go
Normal file
|
|
@ -0,0 +1,476 @@
|
||||||
|
// SPDX-License-Identifier: EUPL-1.2
|
||||||
|
|
||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"mime"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TransformerIn normalises an AI wire protocol request into a unified MCP
|
||||||
|
// request envelope.
|
||||||
|
type TransformerIn interface {
|
||||||
|
Detect(body []byte, contentType, path string) bool
|
||||||
|
Normalise(body []byte) (MCPRequest, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TransformerOut converts an MCP result back into an AI wire protocol response.
|
||||||
|
type TransformerOut interface {
|
||||||
|
Transform(result MCPResult) ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MCPRequest is the gateway's protocol-neutral JSON-RPC request shape.
|
||||||
|
type MCPRequest struct {
|
||||||
|
JSONRPC string `json:"jsonrpc,omitempty"`
|
||||||
|
ID any `json:"id,omitempty"`
|
||||||
|
Method string `json:"method,omitempty"`
|
||||||
|
Params map[string]any `json:"params,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MCPResult is the gateway's protocol-neutral JSON-RPC result shape.
|
||||||
|
type MCPResult struct {
|
||||||
|
JSONRPC string `json:"jsonrpc,omitempty"`
|
||||||
|
ID any `json:"id,omitempty"`
|
||||||
|
Result any `json:"result,omitempty"`
|
||||||
|
Error any `json:"error,omitempty"`
|
||||||
|
Content []MCPContent `json:"content,omitempty"`
|
||||||
|
ToolCalls []MCPToolCall `json:"tool_calls,omitempty"`
|
||||||
|
StopReason string `json:"stop_reason,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MCPContent represents text and tool-use content blocks in the neutral result.
|
||||||
|
type MCPContent struct {
|
||||||
|
Type string `json:"type,omitempty"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
ID string `json:"id,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Input map[string]any `json:"input,omitempty"`
|
||||||
|
Arguments map[string]any `json:"arguments,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MCPToolCall captures a model-requested tool invocation.
|
||||||
|
type MCPToolCall struct {
|
||||||
|
ID string `json:"id,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Arguments map[string]any `json:"arguments,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(#197 follow-up): add Ollama and LiteLLM concrete transformers once the
|
||||||
|
// OpenAI/Anthropic/MCP-native gateway surface has settled.
|
||||||
|
|
||||||
|
// NegotiateTransformer selects the inbound transformer using RFC §9.4 priority:
|
||||||
|
// explicit media type, path, body inspection, then MCP-native fallback. The
|
||||||
|
// honeypot is only selected for malformed or probe-like bodies that no concrete
|
||||||
|
// protocol claims.
|
||||||
|
func NegotiateTransformer(body []byte, contentType, path string) TransformerIn {
|
||||||
|
if headerHasMedia(contentType, "application/openai+json") {
|
||||||
|
return OpenAITransformer{}
|
||||||
|
}
|
||||||
|
if headerHasMedia(contentType, "application/anthropic+json") {
|
||||||
|
return AnthropicTransformer{}
|
||||||
|
}
|
||||||
|
if headerHasMedia(contentType, "application/mcp+json", "application/json-rpc", "application/jsonrpc+json") {
|
||||||
|
return MCPNativeTransformer{}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch normaliseGatewayPath(path) {
|
||||||
|
case "/v1/chat/completions":
|
||||||
|
return OpenAITransformer{}
|
||||||
|
case "/v1/messages":
|
||||||
|
return AnthropicTransformer{}
|
||||||
|
case "/mcp":
|
||||||
|
if (HoneypotTransformer{}).Detect(body, contentType, path) {
|
||||||
|
return HoneypotTransformer{}
|
||||||
|
}
|
||||||
|
return MCPNativeTransformer{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (MCPNativeTransformer{}).Detect(body, "", "") {
|
||||||
|
return MCPNativeTransformer{}
|
||||||
|
}
|
||||||
|
if (OpenAITransformer{}).Detect(body, "", "") {
|
||||||
|
if looksAnthropicBody(body) {
|
||||||
|
return AnthropicTransformer{}
|
||||||
|
}
|
||||||
|
return OpenAITransformer{}
|
||||||
|
}
|
||||||
|
if (AnthropicTransformer{}).Detect(body, "", "") {
|
||||||
|
return AnthropicTransformer{}
|
||||||
|
}
|
||||||
|
if (HoneypotTransformer{}).Detect(body, contentType, path) {
|
||||||
|
return HoneypotTransformer{}
|
||||||
|
}
|
||||||
|
return MCPNativeTransformer{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MCPNativeTransformer is the identity transformer for native MCP JSON-RPC.
|
||||||
|
type MCPNativeTransformer struct{}
|
||||||
|
|
||||||
|
func (MCPNativeTransformer) Detect(body []byte, contentType, path string) bool {
|
||||||
|
if headerHasMedia(contentType, "application/mcp+json", "application/json-rpc", "application/jsonrpc+json") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if normaliseGatewayPath(path) == "/mcp" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
obj, ok := decodeJSONObject(body)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, hasMethod := obj["method"].(string)
|
||||||
|
_, hasResult := obj["result"]
|
||||||
|
_, hasError := obj["error"]
|
||||||
|
return obj["jsonrpc"] == "2.0" && (hasMethod || hasResult || hasError)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (MCPNativeTransformer) Normalise(body []byte) (MCPRequest, error) {
|
||||||
|
var req MCPRequest
|
||||||
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
|
return MCPRequest{}, err
|
||||||
|
}
|
||||||
|
if req.JSONRPC == "" {
|
||||||
|
req.JSONRPC = "2.0"
|
||||||
|
}
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (MCPNativeTransformer) Transform(result MCPResult) ([]byte, error) {
|
||||||
|
if result.JSONRPC == "" {
|
||||||
|
result.JSONRPC = "2.0"
|
||||||
|
}
|
||||||
|
return json.Marshal(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func headerHasMedia(header string, wants ...string) bool {
|
||||||
|
header = strings.TrimSpace(header)
|
||||||
|
if header == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
wantSet := make(map[string]struct{}, len(wants))
|
||||||
|
for _, want := range wants {
|
||||||
|
wantSet[strings.ToLower(strings.TrimSpace(want))] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, part := range strings.Split(header, ",") {
|
||||||
|
media := strings.TrimSpace(part)
|
||||||
|
if parsed, _, err := mime.ParseMediaType(media); err == nil {
|
||||||
|
media = parsed
|
||||||
|
} else if semi := strings.IndexByte(media, ';'); semi >= 0 {
|
||||||
|
media = media[:semi]
|
||||||
|
}
|
||||||
|
media = strings.ToLower(strings.TrimSpace(media))
|
||||||
|
if _, ok := wantSet[media]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func normaliseGatewayPath(path string) string {
|
||||||
|
path = strings.TrimSpace(path)
|
||||||
|
if path == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if i := strings.IndexAny(path, "?#"); i >= 0 {
|
||||||
|
path = path[:i]
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(path, "/") {
|
||||||
|
path = "/" + path
|
||||||
|
}
|
||||||
|
for strings.Contains(path, "//") {
|
||||||
|
path = strings.ReplaceAll(path, "//", "/")
|
||||||
|
}
|
||||||
|
if len(path) > 1 {
|
||||||
|
path = strings.TrimRight(path, "/")
|
||||||
|
}
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeJSONObject(body []byte) (map[string]any, bool) {
|
||||||
|
body = bytes.TrimSpace(body)
|
||||||
|
if len(body) == 0 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
var obj map[string]any
|
||||||
|
if err := json.Unmarshal(body, &obj); err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return obj, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasTopLevelFields(body []byte, fields ...string) bool {
|
||||||
|
obj, ok := decodeJSONObject(body)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, field := range fields {
|
||||||
|
if _, ok := obj[field]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksAnthropicBody(body []byte) bool {
|
||||||
|
obj, ok := decodeJSONObject(body)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if _, ok := obj["system"]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if _, ok := obj["max_tokens"]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if _, ok := obj["anthropic_version"]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
messages, ok := obj["messages"].([]any)
|
||||||
|
if !ok || len(messages) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, raw := range messages {
|
||||||
|
msg, ok := raw.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if role, _ := msg["role"].(string); role == "system" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if blocks, ok := msg["content"].([]any); ok {
|
||||||
|
for _, rawBlock := range blocks {
|
||||||
|
block, ok := rawBlock.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch block["type"] {
|
||||||
|
case "tool_use", "tool_result":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func messagesHaveNoSystemRole(body []byte) bool {
|
||||||
|
obj, ok := decodeJSONObject(body)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
messages, ok := obj["messages"].([]any)
|
||||||
|
if !ok || len(messages) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, raw := range messages {
|
||||||
|
msg, ok := raw.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if role, _ := msg["role"].(string); role == "system" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseRawArgumentObject(raw json.RawMessage) map[string]any {
|
||||||
|
raw = bytes.TrimSpace(raw)
|
||||||
|
if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {
|
||||||
|
return map[string]any{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var encoded string
|
||||||
|
if err := json.Unmarshal(raw, &encoded); err == nil {
|
||||||
|
return parseArgumentString(encoded)
|
||||||
|
}
|
||||||
|
|
||||||
|
var args map[string]any
|
||||||
|
if err := json.Unmarshal(raw, &args); err == nil && args != nil {
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
return map[string]any{"_raw": string(raw)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseArgumentString(s string) map[string]any {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
if s == "" {
|
||||||
|
return map[string]any{}
|
||||||
|
}
|
||||||
|
var args map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(s), &args); err == nil && args != nil {
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
return map[string]any{"_raw": s}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapFromAny(v any) map[string]any {
|
||||||
|
switch typed := v.(type) {
|
||||||
|
case nil:
|
||||||
|
return map[string]any{}
|
||||||
|
case map[string]any:
|
||||||
|
if typed == nil {
|
||||||
|
return map[string]any{}
|
||||||
|
}
|
||||||
|
return typed
|
||||||
|
case json.RawMessage:
|
||||||
|
return parseRawArgumentObject(typed)
|
||||||
|
case string:
|
||||||
|
return parseArgumentString(typed)
|
||||||
|
default:
|
||||||
|
data, err := json.Marshal(typed)
|
||||||
|
if err != nil {
|
||||||
|
return map[string]any{"value": typed}
|
||||||
|
}
|
||||||
|
return parseRawArgumentObject(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractMCPText(result MCPResult) string {
|
||||||
|
var parts []string
|
||||||
|
for _, block := range result.Content {
|
||||||
|
if block.Text != "" && (block.Type == "" || block.Type == "text") {
|
||||||
|
parts = append(parts, block.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parts = append(parts, extractTextFromAny(result.Result)...)
|
||||||
|
return strings.Join(parts, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractTextFromAny(v any) []string {
|
||||||
|
switch typed := v.(type) {
|
||||||
|
case nil:
|
||||||
|
return nil
|
||||||
|
case string:
|
||||||
|
if typed == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []string{typed}
|
||||||
|
case []byte:
|
||||||
|
if len(typed) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []string{string(typed)}
|
||||||
|
case []MCPContent:
|
||||||
|
var out []string
|
||||||
|
for _, block := range typed {
|
||||||
|
if block.Text != "" && (block.Type == "" || block.Type == "text") {
|
||||||
|
out = append(out, block.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
case []any:
|
||||||
|
var out []string
|
||||||
|
for _, item := range typed {
|
||||||
|
out = append(out, extractTextFromAny(item)...)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
case []map[string]any:
|
||||||
|
var out []string
|
||||||
|
for _, item := range typed {
|
||||||
|
out = append(out, extractTextFromAny(item)...)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
case map[string]any:
|
||||||
|
for _, key := range []string{"text", "message", "output"} {
|
||||||
|
if text, ok := typed[key].(string); ok && text != "" {
|
||||||
|
return []string{text}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if content, ok := typed["content"]; ok {
|
||||||
|
return extractTextFromAny(content)
|
||||||
|
}
|
||||||
|
if result, ok := typed["result"]; ok {
|
||||||
|
return extractTextFromAny(result)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
data, err := json.Marshal(typed)
|
||||||
|
if err != nil || len(data) == 0 || bytes.Equal(data, []byte("null")) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []string{string(data)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractMCPToolCalls(result MCPResult) []MCPToolCall {
|
||||||
|
var calls []MCPToolCall
|
||||||
|
calls = append(calls, result.ToolCalls...)
|
||||||
|
for _, block := range result.Content {
|
||||||
|
if block.Type != "tool_use" && block.Name == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
args := block.Input
|
||||||
|
if len(args) == 0 {
|
||||||
|
args = block.Arguments
|
||||||
|
}
|
||||||
|
calls = append(calls, MCPToolCall{ID: block.ID, Name: block.Name, Arguments: args})
|
||||||
|
}
|
||||||
|
calls = append(calls, extractToolCallsFromAny(result.Result)...)
|
||||||
|
return calls
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractToolCallsFromAny(v any) []MCPToolCall {
|
||||||
|
switch typed := v.(type) {
|
||||||
|
case nil:
|
||||||
|
return nil
|
||||||
|
case []MCPToolCall:
|
||||||
|
return typed
|
||||||
|
case []MCPContent:
|
||||||
|
var calls []MCPToolCall
|
||||||
|
for _, block := range typed {
|
||||||
|
if block.Type == "tool_use" || block.Name != "" {
|
||||||
|
args := block.Input
|
||||||
|
if len(args) == 0 {
|
||||||
|
args = block.Arguments
|
||||||
|
}
|
||||||
|
calls = append(calls, MCPToolCall{ID: block.ID, Name: block.Name, Arguments: args})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return calls
|
||||||
|
case []any:
|
||||||
|
var calls []MCPToolCall
|
||||||
|
for _, item := range typed {
|
||||||
|
calls = append(calls, extractToolCallsFromAny(item)...)
|
||||||
|
}
|
||||||
|
return calls
|
||||||
|
case []map[string]any:
|
||||||
|
var calls []MCPToolCall
|
||||||
|
for _, item := range typed {
|
||||||
|
calls = append(calls, extractToolCallsFromAny(item)...)
|
||||||
|
}
|
||||||
|
return calls
|
||||||
|
case map[string]any:
|
||||||
|
for _, key := range []string{"tool_calls", "toolCalls"} {
|
||||||
|
if raw, ok := typed[key]; ok {
|
||||||
|
return extractToolCallsFromAny(raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if raw, ok := typed["content"]; ok {
|
||||||
|
return extractToolCallsFromAny(raw)
|
||||||
|
}
|
||||||
|
name, _ := typed["name"].(string)
|
||||||
|
if name == "" {
|
||||||
|
if fn, ok := typed["function"].(map[string]any); ok {
|
||||||
|
name, _ = fn["name"].(string)
|
||||||
|
args := mapFromAny(fn["arguments"])
|
||||||
|
id, _ := typed["id"].(string)
|
||||||
|
return []MCPToolCall{{ID: id, Name: name, Arguments: args}}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
id, _ := typed["id"].(string)
|
||||||
|
args := mapFromAny(typed["arguments"])
|
||||||
|
if len(args) == 0 {
|
||||||
|
args = mapFromAny(typed["input"])
|
||||||
|
}
|
||||||
|
return []MCPToolCall{{ID: id, Name: name, Arguments: args}}
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
238
pkg/mcp/transformer_anthropic.go
Normal file
238
pkg/mcp/transformer_anthropic.go
Normal file
|
|
@ -0,0 +1,238 @@
|
||||||
|
// SPDX-License-Identifier: EUPL-1.2
|
||||||
|
|
||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AnthropicTransformer maps Anthropic Messages requests and responses.
|
||||||
|
type AnthropicTransformer struct{}
|
||||||
|
|
||||||
|
func (AnthropicTransformer) Detect(body []byte, contentType, path string) bool {
|
||||||
|
if headerHasMedia(contentType, "application/anthropic+json") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if normaliseGatewayPath(path) == "/v1/messages" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if !hasTopLevelFields(body, "model", "messages") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return looksAnthropicBody(body) || messagesHaveNoSystemRole(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (AnthropicTransformer) Normalise(body []byte) (MCPRequest, error) {
|
||||||
|
var req anthropicMessagesRequest
|
||||||
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
|
return MCPRequest{}, err
|
||||||
|
}
|
||||||
|
if req.Model == "" {
|
||||||
|
return MCPRequest{}, fmt.Errorf("anthropic messages request missing model")
|
||||||
|
}
|
||||||
|
if len(req.Messages) == 0 {
|
||||||
|
return MCPRequest{}, fmt.Errorf("anthropic messages request missing messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
params := map[string]any{
|
||||||
|
"source_format": "anthropic",
|
||||||
|
"model": req.Model,
|
||||||
|
"messages": normaliseAnthropicMessages(req.Messages),
|
||||||
|
}
|
||||||
|
if req.System != nil {
|
||||||
|
params["system"] = req.System
|
||||||
|
}
|
||||||
|
if req.MaxTokens != nil {
|
||||||
|
params["max_tokens"] = req.MaxTokens
|
||||||
|
}
|
||||||
|
if req.Temperature != nil {
|
||||||
|
params["temperature"] = req.Temperature
|
||||||
|
}
|
||||||
|
if req.Stream {
|
||||||
|
params["stream"] = req.Stream
|
||||||
|
}
|
||||||
|
if len(req.Tools) > 0 {
|
||||||
|
params["tools"] = normaliseAnthropicTools(req.Tools)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls := anthropicToolUsesFromMessages(req.Messages)
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
call := toolCalls[0]
|
||||||
|
params["name"] = call.Name
|
||||||
|
params["arguments"] = call.Arguments
|
||||||
|
params["tool_calls"] = toolCalls
|
||||||
|
return MCPRequest{JSONRPC: "2.0", Method: "tools/call", Params: params}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return MCPRequest{JSONRPC: "2.0", Method: "sampling/createMessage", Params: params}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (AnthropicTransformer) Transform(result MCPResult) ([]byte, error) {
|
||||||
|
text := extractMCPText(result)
|
||||||
|
toolCalls := extractMCPToolCalls(result)
|
||||||
|
|
||||||
|
content := make([]map[string]any, 0, 1+len(toolCalls))
|
||||||
|
if text != "" {
|
||||||
|
content = append(content, map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": text,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for i, call := range toolCalls {
|
||||||
|
id := call.ID
|
||||||
|
if id == "" {
|
||||||
|
id = fmt.Sprintf("toolu_%d", i)
|
||||||
|
}
|
||||||
|
content = append(content, map[string]any{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": id,
|
||||||
|
"name": call.Name,
|
||||||
|
"input": call.Arguments,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if len(content) == 0 {
|
||||||
|
content = append(content, map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": "",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
stopReason := "end_turn"
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
stopReason = "tool_use"
|
||||||
|
}
|
||||||
|
if result.StopReason != "" {
|
||||||
|
stopReason = result.StopReason
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := map[string]any{
|
||||||
|
"id": anthropicResponseID(result.ID),
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": "mcp-gateway",
|
||||||
|
"content": content,
|
||||||
|
"stop_reason": stopReason,
|
||||||
|
"stop_sequence": nil,
|
||||||
|
}
|
||||||
|
return json.Marshal(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
type anthropicMessagesRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
MaxTokens any `json:"max_tokens,omitempty"`
|
||||||
|
System any `json:"system,omitempty"`
|
||||||
|
Messages []anthropicMessage `json:"messages"`
|
||||||
|
Tools []anthropicTool `json:"tools,omitempty"`
|
||||||
|
Temperature any `json:"temperature,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type anthropicMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content any `json:"content,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type anthropicTool struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
InputSchema any `json:"input_schema,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func normaliseAnthropicMessages(messages []anthropicMessage) []map[string]any {
|
||||||
|
out := make([]map[string]any, 0, len(messages))
|
||||||
|
for _, msg := range messages {
|
||||||
|
item := map[string]any{
|
||||||
|
"role": msg.Role,
|
||||||
|
}
|
||||||
|
if msg.Content != nil {
|
||||||
|
item["content"] = msg.Content
|
||||||
|
}
|
||||||
|
out = append(out, item)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func normaliseAnthropicTools(tools []anthropicTool) []map[string]any {
|
||||||
|
out := make([]map[string]any, 0, len(tools))
|
||||||
|
for _, tool := range tools {
|
||||||
|
out = append(out, map[string]any{
|
||||||
|
"name": tool.Name,
|
||||||
|
"description": tool.Description,
|
||||||
|
"input_schema": tool.InputSchema,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func anthropicToolUsesFromMessages(messages []anthropicMessage) []MCPToolCall {
|
||||||
|
var calls []MCPToolCall
|
||||||
|
for i := len(messages) - 1; i >= 0; i-- {
|
||||||
|
blocks := anthropicContentBlocks(messages[i].Content)
|
||||||
|
for _, block := range blocks {
|
||||||
|
if block.Type != "tool_use" || block.Name == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
calls = append(calls, MCPToolCall{
|
||||||
|
ID: block.ID,
|
||||||
|
Name: block.Name,
|
||||||
|
Arguments: block.Input,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if len(calls) > 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return calls
|
||||||
|
}
|
||||||
|
|
||||||
|
type anthropicContentBlock struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
ID string `json:"id,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Input map[string]any `json:"input,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func anthropicContentBlocks(content any) []anthropicContentBlock {
|
||||||
|
switch typed := content.(type) {
|
||||||
|
case nil:
|
||||||
|
return nil
|
||||||
|
case []anthropicContentBlock:
|
||||||
|
return typed
|
||||||
|
case []any:
|
||||||
|
blocks := make([]anthropicContentBlock, 0, len(typed))
|
||||||
|
for _, item := range typed {
|
||||||
|
data, err := json.Marshal(item)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var block anthropicContentBlock
|
||||||
|
if err := json.Unmarshal(data, &block); err == nil {
|
||||||
|
blocks = append(blocks, block)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return blocks
|
||||||
|
case map[string]any:
|
||||||
|
data, err := json.Marshal(typed)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var block anthropicContentBlock
|
||||||
|
if err := json.Unmarshal(data, &block); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []anthropicContentBlock{block}
|
||||||
|
case string:
|
||||||
|
return []anthropicContentBlock{{Type: "text", Text: typed}}
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func anthropicResponseID(id any) string {
|
||||||
|
if id == nil {
|
||||||
|
return "msg_mcp"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("msg_%v", id)
|
||||||
|
}
|
||||||
112
pkg/mcp/transformer_honeypot.go
Normal file
112
pkg/mcp/transformer_honeypot.go
Normal file
|
|
@ -0,0 +1,112 @@
|
||||||
|
// SPDX-License-Identifier: EUPL-1.2
|
||||||
|
|
||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HoneypotTransformer absorbs malformed or probe-like input and returns a
|
||||||
|
// plausible synthetic response without dispatching to real tools.
|
||||||
|
type HoneypotTransformer struct{}
|
||||||
|
|
||||||
|
func (HoneypotTransformer) Detect(body []byte, contentType, path string) bool {
|
||||||
|
trimmed := bytes.TrimSpace(body)
|
||||||
|
if len(trimmed) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !json.Valid(trimmed) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var obj map[string]any
|
||||||
|
if err := json.Unmarshal(trimmed, &obj); err != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return looksProbeLike(trimmed, contentType, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (HoneypotTransformer) Normalise(body []byte) (MCPRequest, error) {
|
||||||
|
params := map[string]any{
|
||||||
|
"source_format": "honeypot",
|
||||||
|
"raw": honeypotSnippet(body),
|
||||||
|
"malformed": !json.Valid(bytes.TrimSpace(body)),
|
||||||
|
}
|
||||||
|
return MCPRequest{
|
||||||
|
JSONRPC: "2.0",
|
||||||
|
Method: "honeypot/respond",
|
||||||
|
Params: params,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (HoneypotTransformer) Transform(result MCPResult) ([]byte, error) {
|
||||||
|
text := extractMCPText(result)
|
||||||
|
if text == "" {
|
||||||
|
text = "Request received. The gateway is processing the available context and will return compatible MCP output when a valid protocol envelope is provided."
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := map[string]any{
|
||||||
|
"id": honeypotResponseID(result.ID),
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 0,
|
||||||
|
"model": "mcp-gateway",
|
||||||
|
"choices": []map[string]any{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": text,
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"usage": map[string]any{
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return json.Marshal(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func looksProbeLike(body []byte, contentType, path string) bool {
|
||||||
|
haystack := strings.ToLower(strings.Join([]string{
|
||||||
|
string(body),
|
||||||
|
contentType,
|
||||||
|
path,
|
||||||
|
}, "\n"))
|
||||||
|
for _, marker := range []string{
|
||||||
|
"ignore previous",
|
||||||
|
"system prompt",
|
||||||
|
"developer message",
|
||||||
|
"/etc/passwd",
|
||||||
|
"../../",
|
||||||
|
"dump secrets",
|
||||||
|
"jailbreak",
|
||||||
|
"prompt injection",
|
||||||
|
} {
|
||||||
|
if strings.Contains(haystack, marker) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func honeypotSnippet(body []byte) string {
|
||||||
|
s := string(bytes.TrimSpace(body))
|
||||||
|
const max = 4096
|
||||||
|
if len(s) <= max {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[:max]
|
||||||
|
}
|
||||||
|
|
||||||
|
func honeypotResponseID(id any) string {
|
||||||
|
if id == nil {
|
||||||
|
return "chatcmpl-honeypot"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("chatcmpl-honeypot-%v", id)
|
||||||
|
}
|
||||||
247
pkg/mcp/transformer_openai.go
Normal file
247
pkg/mcp/transformer_openai.go
Normal file
|
|
@ -0,0 +1,247 @@
|
||||||
|
// SPDX-License-Identifier: EUPL-1.2
|
||||||
|
|
||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenAITransformer maps OpenAI Chat Completions requests and responses.
|
||||||
|
type OpenAITransformer struct{}
|
||||||
|
|
||||||
|
func (OpenAITransformer) Detect(body []byte, contentType, path string) bool {
|
||||||
|
if headerHasMedia(contentType, "application/openai+json") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if normaliseGatewayPath(path) == "/v1/chat/completions" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return hasTopLevelFields(body, "model", "messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (OpenAITransformer) Normalise(body []byte) (MCPRequest, error) {
|
||||||
|
var req openAIChatCompletionRequest
|
||||||
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
|
return MCPRequest{}, err
|
||||||
|
}
|
||||||
|
if req.Model == "" {
|
||||||
|
return MCPRequest{}, fmt.Errorf("openai chat completion request missing model")
|
||||||
|
}
|
||||||
|
if len(req.Messages) == 0 {
|
||||||
|
return MCPRequest{}, fmt.Errorf("openai chat completion request missing messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
params := map[string]any{
|
||||||
|
"source_format": "openai",
|
||||||
|
"model": req.Model,
|
||||||
|
"messages": normaliseOpenAIMessages(req.Messages),
|
||||||
|
}
|
||||||
|
if len(req.Tools) > 0 {
|
||||||
|
params["tools"] = normaliseOpenAITools(req.Tools)
|
||||||
|
}
|
||||||
|
if req.ToolChoice != nil {
|
||||||
|
params["tool_choice"] = req.ToolChoice
|
||||||
|
}
|
||||||
|
if req.MaxTokens != nil {
|
||||||
|
params["max_tokens"] = req.MaxTokens
|
||||||
|
}
|
||||||
|
if req.MaxCompletionTokens != nil {
|
||||||
|
params["max_completion_tokens"] = req.MaxCompletionTokens
|
||||||
|
}
|
||||||
|
if req.Temperature != nil {
|
||||||
|
params["temperature"] = req.Temperature
|
||||||
|
}
|
||||||
|
if req.Stream {
|
||||||
|
params["stream"] = req.Stream
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls := openAIToolCallsFromMessages(req.Messages)
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
call := toolCalls[0]
|
||||||
|
params["name"] = call.Name
|
||||||
|
params["arguments"] = call.Arguments
|
||||||
|
params["tool_calls"] = toolCalls
|
||||||
|
return MCPRequest{JSONRPC: "2.0", Method: "tools/call", Params: params}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return MCPRequest{JSONRPC: "2.0", Method: "sampling/createMessage", Params: params}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (OpenAITransformer) Transform(result MCPResult) ([]byte, error) {
|
||||||
|
text := extractMCPText(result)
|
||||||
|
toolCalls := extractMCPToolCalls(result)
|
||||||
|
|
||||||
|
message := map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
}
|
||||||
|
if text != "" {
|
||||||
|
message["content"] = text
|
||||||
|
} else if len(toolCalls) > 0 {
|
||||||
|
message["content"] = nil
|
||||||
|
} else {
|
||||||
|
message["content"] = ""
|
||||||
|
}
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
message["tool_calls"] = openAIToolCallsFromMCP(toolCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
finishReason := "stop"
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
}
|
||||||
|
if result.StopReason != "" {
|
||||||
|
finishReason = result.StopReason
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := map[string]any{
|
||||||
|
"id": openAIResponseID(result.ID),
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 0,
|
||||||
|
"model": "mcp-gateway",
|
||||||
|
"choices": []map[string]any{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": message,
|
||||||
|
"finish_reason": finishReason,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return json.Marshal(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIChatCompletionRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []openAIMessage `json:"messages"`
|
||||||
|
Tools []openAITool `json:"tools,omitempty"`
|
||||||
|
ToolChoice any `json:"tool_choice,omitempty"`
|
||||||
|
MaxTokens any `json:"max_tokens,omitempty"`
|
||||||
|
MaxCompletionTokens any `json:"max_completion_tokens,omitempty"`
|
||||||
|
Temperature any `json:"temperature,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content any `json:"content,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
|
ToolCalls []openAIToolCall `json:"tool_calls,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAITool struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function openAIFunctionMetadata `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIFunctionMetadata struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Parameters any `json:"parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIToolCall struct {
|
||||||
|
ID string `json:"id,omitempty"`
|
||||||
|
Type string `json:"type,omitempty"`
|
||||||
|
Function openAIFunctionCall `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIFunctionCall struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments json.RawMessage `json:"arguments,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func normaliseOpenAIMessages(messages []openAIMessage) []map[string]any {
|
||||||
|
out := make([]map[string]any, 0, len(messages))
|
||||||
|
for _, msg := range messages {
|
||||||
|
item := map[string]any{
|
||||||
|
"role": msg.Role,
|
||||||
|
}
|
||||||
|
if msg.Content != nil {
|
||||||
|
item["content"] = msg.Content
|
||||||
|
}
|
||||||
|
if msg.Name != "" {
|
||||||
|
item["name"] = msg.Name
|
||||||
|
}
|
||||||
|
if msg.ToolCallID != "" {
|
||||||
|
item["tool_call_id"] = msg.ToolCallID
|
||||||
|
}
|
||||||
|
if len(msg.ToolCalls) > 0 {
|
||||||
|
item["tool_calls"] = openAIToolCallsFromMessages([]openAIMessage{msg})
|
||||||
|
}
|
||||||
|
out = append(out, item)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func normaliseOpenAITools(tools []openAITool) []map[string]any {
|
||||||
|
out := make([]map[string]any, 0, len(tools))
|
||||||
|
for _, tool := range tools {
|
||||||
|
if tool.Type != "" && tool.Type != "function" {
|
||||||
|
out = append(out, map[string]any{
|
||||||
|
"type": tool.Type,
|
||||||
|
"function": tool.Function,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
item := map[string]any{
|
||||||
|
"name": tool.Function.Name,
|
||||||
|
"description": tool.Function.Description,
|
||||||
|
"input_schema": tool.Function.Parameters,
|
||||||
|
}
|
||||||
|
out = append(out, item)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIToolCallsFromMessages(messages []openAIMessage) []MCPToolCall {
|
||||||
|
var calls []MCPToolCall
|
||||||
|
for i := len(messages) - 1; i >= 0; i-- {
|
||||||
|
msg := messages[i]
|
||||||
|
if len(msg.ToolCalls) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, call := range msg.ToolCalls {
|
||||||
|
if call.Function.Name == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
calls = append(calls, MCPToolCall{
|
||||||
|
ID: call.ID,
|
||||||
|
Name: call.Function.Name,
|
||||||
|
Arguments: parseRawArgumentObject(call.Function.Arguments),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return calls
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIToolCallsFromMCP(calls []MCPToolCall) []map[string]any {
|
||||||
|
out := make([]map[string]any, 0, len(calls))
|
||||||
|
for i, call := range calls {
|
||||||
|
id := call.ID
|
||||||
|
if id == "" {
|
||||||
|
id = fmt.Sprintf("call_%d", i)
|
||||||
|
}
|
||||||
|
args, err := json.Marshal(call.Arguments)
|
||||||
|
if err != nil {
|
||||||
|
args = []byte("{}")
|
||||||
|
}
|
||||||
|
out = append(out, map[string]any{
|
||||||
|
"id": id,
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]any{
|
||||||
|
"name": call.Name,
|
||||||
|
"arguments": string(args),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIResponseID(id any) string {
|
||||||
|
if id == nil {
|
||||||
|
return "chatcmpl-mcp"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("chatcmpl-%v", id)
|
||||||
|
}
|
||||||
191
pkg/mcp/transformer_test.go
Normal file
191
pkg/mcp/transformer_test.go
Normal file
|
|
@ -0,0 +1,191 @@
|
||||||
|
// SPDX-License-Identifier: EUPL-1.2
|
||||||
|
|
||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNegotiate_OpenAI_Good(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"gpt-4o-mini","messages":[{"role":"user","content":"hello"}]}`)
|
||||||
|
|
||||||
|
if _, ok := NegotiateTransformer(body, "", "/v1/chat/completions").(OpenAITransformer); !ok {
|
||||||
|
t.Fatal("expected OpenAITransformer for chat completions path")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNegotiate_Anthropic_Good(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"claude-3-5-sonnet","max_tokens":128,"messages":[{"role":"user","content":"hello"}]}`)
|
||||||
|
|
||||||
|
if _, ok := NegotiateTransformer(body, "", "/v1/messages").(AnthropicTransformer); !ok {
|
||||||
|
t.Fatal("expected AnthropicTransformer for messages path")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNegotiate_MCPNative_Good(t *testing.T) {
|
||||||
|
body := []byte(`{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}`)
|
||||||
|
|
||||||
|
if _, ok := NegotiateTransformer(body, "application/mcp+json", "/mcp").(MCPNativeTransformer); !ok {
|
||||||
|
t.Fatal("expected MCPNativeTransformer for native MCP request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITransformer_Normalise_Good(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": null,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "file_read",
|
||||||
|
"arguments": "{\"path\":\"README.md\"}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
req, err := (OpenAITransformer{}).Normalise(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Normalise failed: %v", err)
|
||||||
|
}
|
||||||
|
if req.JSONRPC != "2.0" {
|
||||||
|
t.Fatalf("expected JSON-RPC 2.0, got %q", req.JSONRPC)
|
||||||
|
}
|
||||||
|
if req.Method != "tools/call" {
|
||||||
|
t.Fatalf("expected tools/call, got %q", req.Method)
|
||||||
|
}
|
||||||
|
if req.Params["source_format"] != "openai" {
|
||||||
|
t.Fatalf("expected source_format openai, got %v", req.Params["source_format"])
|
||||||
|
}
|
||||||
|
if req.Params["model"] != "gpt-4o" {
|
||||||
|
t.Fatalf("expected model to be preserved, got %v", req.Params["model"])
|
||||||
|
}
|
||||||
|
if req.Params["name"] != "file_read" {
|
||||||
|
t.Fatalf("expected tool name file_read, got %v", req.Params["name"])
|
||||||
|
}
|
||||||
|
args, ok := req.Params["arguments"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected argument map, got %T", req.Params["arguments"])
|
||||||
|
}
|
||||||
|
if args["path"] != "README.md" {
|
||||||
|
t.Fatalf("expected README.md path, got %v", args["path"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITransformer_Transform_Good(t *testing.T) {
|
||||||
|
data, err := (OpenAITransformer{}).Transform(MCPResult{
|
||||||
|
ID: 7,
|
||||||
|
Result: map[string]any{
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{"type": "text", "text": "done"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Transform failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]any
|
||||||
|
if err := json.Unmarshal(data, &resp); err != nil {
|
||||||
|
t.Fatalf("response is not JSON: %v", err)
|
||||||
|
}
|
||||||
|
if resp["object"] != "chat.completion" {
|
||||||
|
t.Fatalf("expected chat.completion object, got %v", resp["object"])
|
||||||
|
}
|
||||||
|
choices := resp["choices"].([]any)
|
||||||
|
message := choices[0].(map[string]any)["message"].(map[string]any)
|
||||||
|
if message["content"] != "done" {
|
||||||
|
t.Fatalf("expected content done, got %v", message["content"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnthropicTransformer_Normalise_Good(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"model": "claude-3-5-sonnet",
|
||||||
|
"max_tokens": 256,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "toolu_1",
|
||||||
|
"name": "file_read",
|
||||||
|
"input": {"path":"README.md"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
req, err := (AnthropicTransformer{}).Normalise(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Normalise failed: %v", err)
|
||||||
|
}
|
||||||
|
if req.Method != "tools/call" {
|
||||||
|
t.Fatalf("expected tools/call, got %q", req.Method)
|
||||||
|
}
|
||||||
|
if req.Params["source_format"] != "anthropic" {
|
||||||
|
t.Fatalf("expected source_format anthropic, got %v", req.Params["source_format"])
|
||||||
|
}
|
||||||
|
if req.Params["name"] != "file_read" {
|
||||||
|
t.Fatalf("expected tool name file_read, got %v", req.Params["name"])
|
||||||
|
}
|
||||||
|
args, ok := req.Params["arguments"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected argument map, got %T", req.Params["arguments"])
|
||||||
|
}
|
||||||
|
if args["path"] != "README.md" {
|
||||||
|
t.Fatalf("expected README.md path, got %v", args["path"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnthropicTransformer_Transform_Good(t *testing.T) {
|
||||||
|
data, err := (AnthropicTransformer{}).Transform(MCPResult{
|
||||||
|
ID: "abc",
|
||||||
|
Content: []MCPContent{{Type: "text", Text: "done"}},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Transform failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]any
|
||||||
|
if err := json.Unmarshal(data, &resp); err != nil {
|
||||||
|
t.Fatalf("response is not JSON: %v", err)
|
||||||
|
}
|
||||||
|
if resp["type"] != "message" {
|
||||||
|
t.Fatalf("expected message type, got %v", resp["type"])
|
||||||
|
}
|
||||||
|
content := resp["content"].([]any)
|
||||||
|
first := content[0].(map[string]any)
|
||||||
|
if first["text"] != "done" {
|
||||||
|
t.Fatalf("expected text done, got %v", first["text"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHoneypotTransformer_Detect_FallbackOnGarbage(t *testing.T) {
|
||||||
|
body := []byte(`{not-json`)
|
||||||
|
|
||||||
|
if !(HoneypotTransformer{}).Detect(body, "", "/probe") {
|
||||||
|
t.Fatal("expected honeypot to detect malformed input")
|
||||||
|
}
|
||||||
|
if _, ok := NegotiateTransformer(body, "", "/probe").(HoneypotTransformer); !ok {
|
||||||
|
t.Fatal("expected negotiation to select honeypot for malformed input")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNegotiate_Priority_Ugly(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"claude-3-5-sonnet","max_tokens":128,"messages":[{"role":"user","content":"hello"}]}`)
|
||||||
|
|
||||||
|
if _, ok := NegotiateTransformer(body, "application/openai+json", "/v1/messages").(OpenAITransformer); !ok {
|
||||||
|
t.Fatal("expected explicit OpenAI media type to beat path/body inspection")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -4,14 +4,18 @@ package mcp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/subtle"
|
// Note: AX-6 — HTTP transport boundary needs streaming JSON encode/decode against ResponseWriter and MaxBytesReader.
|
||||||
|
"encoding/json"
|
||||||
|
// Note: AX-6 — structural HTTP transport requires binding an explicit TCP listener.
|
||||||
"net"
|
"net"
|
||||||
|
// Note: AX-6 — structural HTTP transport boundary requires handlers, requests, status codes, and server lifecycle APIs.
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
coreerr "forge.lthn.ai/core/go-log"
|
core "dappco.re/go/core"
|
||||||
|
api "dappco.re/go/api"
|
||||||
|
coreerr "dappco.re/go/log"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -32,12 +36,18 @@ const DefaultHTTPAddr = "127.0.0.1:9101"
|
||||||
// svc.ServeHTTP(ctx, "0.0.0.0:9101")
|
// svc.ServeHTTP(ctx, "0.0.0.0:9101")
|
||||||
//
|
//
|
||||||
// Endpoint /mcp: GET (SSE stream), POST (JSON-RPC), DELETE (terminate session).
|
// Endpoint /mcp: GET (SSE stream), POST (JSON-RPC), DELETE (terminate session).
|
||||||
|
//
|
||||||
|
// Additional endpoints:
|
||||||
|
// - POST /mcp/auth: exchange API token for JWT
|
||||||
|
// - /v1/tools/<tool_name>: auto-mounted REST bridge for MCP tools
|
||||||
|
// - /health: unauthenticated health endpoint
|
||||||
|
// - /.well-known/mcp-servers.json: MCP portal discovery
|
||||||
func (s *Service) ServeHTTP(ctx context.Context, addr string) error {
|
func (s *Service) ServeHTTP(ctx context.Context, addr string) error {
|
||||||
if addr == "" {
|
if addr == "" {
|
||||||
addr = DefaultHTTPAddr
|
addr = DefaultHTTPAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
authToken := os.Getenv("MCP_AUTH_TOKEN")
|
authToken := core.Env("MCP_AUTH_TOKEN")
|
||||||
|
|
||||||
handler := mcp.NewStreamableHTTPHandler(
|
handler := mcp.NewStreamableHTTPHandler(
|
||||||
func(r *http.Request) *mcp.Server {
|
func(r *http.Request) *mcp.Server {
|
||||||
|
|
@ -48,13 +58,25 @@ func (s *Service) ServeHTTP(ctx context.Context, addr string) error {
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
toolBridge := api.NewToolBridge("/v1/tools")
|
||||||
|
BridgeToAPI(s, toolBridge)
|
||||||
|
toolEngine := gin.New()
|
||||||
|
toolBridge.RegisterRoutes(toolEngine.Group("/v1/tools"))
|
||||||
|
toolHandler := withAuth(authToken, toolEngine)
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.Handle("/mcp", withAuth(authToken, handler))
|
mux.Handle("/mcp", withAuth(authToken, handler))
|
||||||
|
mux.Handle("/v1/tools", toolHandler)
|
||||||
|
mux.Handle("/v1/tools/", toolHandler)
|
||||||
|
mux.HandleFunc("/mcp/auth", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
serveMCPAuthExchange(w, r, authToken)
|
||||||
|
})
|
||||||
|
mux.HandleFunc("/.well-known/mcp-servers.json", handleMCPDiscovery)
|
||||||
|
|
||||||
// Health check (no auth)
|
// Health check (no auth)
|
||||||
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.Write([]byte(`{"status":"ok"}`))
|
_ = json.NewEncoder(w).Encode(map[string]any{"status": "ok"})
|
||||||
})
|
})
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", addr)
|
listener, err := net.Listen("tcp", addr)
|
||||||
|
|
@ -72,7 +94,7 @@ func (s *Service) ServeHTTP(ctx context.Context, addr string) error {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
server.Shutdown(shutdownCtx)
|
_ = server.Shutdown(shutdownCtx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||||
|
|
@ -81,31 +103,185 @@ func (s *Service) ServeHTTP(ctx context.Context, addr string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mcpAuthExchangeRequest struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
Workspace string `json:"workspace"`
|
||||||
|
Entitlements []string `json:"entitlements"`
|
||||||
|
Sub string `json:"sub"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type mcpAuthExchangeResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
ExpiresAt int64 `json:"expires_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type mcpDiscoveryResponse struct {
|
||||||
|
Servers []mcpDiscoveryServer `json:"servers"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type mcpDiscoveryServer struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Connection map[string]any `json:"connection"`
|
||||||
|
Capabilities []string `json:"capabilities"`
|
||||||
|
UseWhen []string `json:"use_when"`
|
||||||
|
RelatedServers []string `json:"related_servers"`
|
||||||
|
}
|
||||||
|
|
||||||
// withAuth wraps an http.Handler with Bearer token authentication.
|
// withAuth wraps an http.Handler with Bearer token authentication.
|
||||||
// If token is empty, authentication is disabled for local development.
|
// If token is empty, authentication is disabled for local development.
|
||||||
func withAuth(token string, next http.Handler) http.Handler {
|
func withAuth(token string, next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if strings.TrimSpace(token) == "" {
|
if core.Trim(token) == "" {
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
auth := r.Header.Get("Authorization")
|
claims, err := parseAuthClaims(r.Header.Get("Authorization"), token)
|
||||||
if !strings.HasPrefix(auth, "Bearer ") {
|
if err != nil {
|
||||||
http.Error(w, `{"error":"missing Bearer token"}`, http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
provided := strings.TrimSpace(strings.TrimPrefix(auth, "Bearer "))
|
|
||||||
if len(provided) == 0 {
|
|
||||||
http.Error(w, `{"error":"missing Bearer token"}`, http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if subtle.ConstantTimeCompare([]byte(provided), []byte(token)) != 1 {
|
|
||||||
http.Error(w, `{"error":"invalid token"}`, http.StatusUnauthorized)
|
http.Error(w, `{"error":"invalid token"}`, http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if claims != nil {
|
||||||
|
r = r.WithContext(withAuthClaims(r.Context(), claims))
|
||||||
|
}
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func serveMCPAuthExchange(w http.ResponseWriter, r *http.Request, apiToken string) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
apiToken = core.Trim(apiToken)
|
||||||
|
if apiToken == "" {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
_ = json.NewEncoder(w).Encode(api.Fail("unauthorized", "authentication is not configured"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req mcpAuthExchangeRequest
|
||||||
|
if err := json.NewDecoder(http.MaxBytesReader(w, r.Body, 10<<20)).Decode(&req); err != nil {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
_ = json.NewEncoder(w).Encode(api.Fail("invalid_request", "invalid JSON payload"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
providedToken := core.Trim(extractBearerToken(r.Header.Get("Authorization")))
|
||||||
|
if providedToken == "" {
|
||||||
|
providedToken = core.Trim(req.Token)
|
||||||
|
}
|
||||||
|
if providedToken == "" {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
_ = json.NewEncoder(w).Encode(api.Fail("invalid_request", "missing token"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := parseAuthClaims("Bearer "+providedToken, apiToken); err != nil {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
_ = json.NewEncoder(w).Encode(api.Fail("unauthorized", "invalid API token"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := currentAuthConfig(apiToken)
|
||||||
|
now := time.Now()
|
||||||
|
claims := authClaims{
|
||||||
|
Workspace: core.Trim(req.Workspace),
|
||||||
|
Entitlements: dedupeEntitlements(req.Entitlements),
|
||||||
|
Subject: core.Trim(req.Sub),
|
||||||
|
IssuedAt: now.Unix(),
|
||||||
|
ExpiresAt: now.Unix() + int64(cfg.ttl.Seconds()),
|
||||||
|
}
|
||||||
|
|
||||||
|
minted, err := mintJWTToken(claims, cfg)
|
||||||
|
if err != nil {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
_ = json.NewEncoder(w).Encode(api.Fail("token_error", "failed to mint token"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(mcpAuthExchangeResponse{
|
||||||
|
AccessToken: minted,
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresIn: int64(cfg.ttl.Seconds()),
|
||||||
|
ExpiresAt: claims.ExpiresAt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func dedupeEntitlements(entitlements []string) []string {
|
||||||
|
if len(entitlements) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
seen := make(map[string]struct{}, len(entitlements))
|
||||||
|
out := make([]string, 0, len(entitlements))
|
||||||
|
for _, ent := range entitlements {
|
||||||
|
e := core.Trim(ent)
|
||||||
|
if e == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[e]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[e] = struct{}{}
|
||||||
|
out = append(out, e)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleMCPDiscovery(w http.ResponseWriter, r *http.Request) {
|
||||||
|
resp := mcpDiscoveryResponse{
|
||||||
|
Servers: []mcpDiscoveryServer{
|
||||||
|
{
|
||||||
|
ID: "core-agent",
|
||||||
|
Name: "Core Agent",
|
||||||
|
Description: "Dispatch agents, manage workspaces, search OpenBrain",
|
||||||
|
Connection: map[string]any{
|
||||||
|
"type": "stdio",
|
||||||
|
"command": "core-agent",
|
||||||
|
"args": []string{"mcp"},
|
||||||
|
},
|
||||||
|
Capabilities: []string{"tools", "resources"},
|
||||||
|
UseWhen: []string{
|
||||||
|
"Need to dispatch work to Codex/Claude/Gemini",
|
||||||
|
"Need workspace status",
|
||||||
|
"Need semantic search",
|
||||||
|
},
|
||||||
|
RelatedServers: []string{"core-mcp"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "core-mcp",
|
||||||
|
Name: "Core MCP",
|
||||||
|
Description: "File ops, process and build tools, RAG search, webview, dashboards — the agent-facing MCP framework.",
|
||||||
|
Connection: map[string]any{
|
||||||
|
"type": "stdio",
|
||||||
|
"command": "core-mcp",
|
||||||
|
},
|
||||||
|
Capabilities: []string{"tools", "resources", "logging"},
|
||||||
|
UseWhen: []string{
|
||||||
|
"Need to read/write files inside a workspace",
|
||||||
|
"Need to start or monitor processes",
|
||||||
|
"Need to run RAG queries or index documents",
|
||||||
|
"Need to render or update an embedded dashboard view",
|
||||||
|
},
|
||||||
|
RelatedServers: []string{"core-agent"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
_ = json.NewEncoder(w).Encode(api.Fail("server_error", "failed to encode discovery payload"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-log"
|
"dappco.re/go/log"
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,12 @@ package mcp
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
goio "io"
|
goio "io"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
core "dappco.re/go/core"
|
||||||
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
|
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
)
|
)
|
||||||
|
|
@ -31,7 +31,7 @@ var diagWriter goio.Writer = os.Stderr
|
||||||
func diagPrintf(format string, args ...any) {
|
func diagPrintf(format string, args ...any) {
|
||||||
diagMu.Lock()
|
diagMu.Lock()
|
||||||
defer diagMu.Unlock()
|
defer diagMu.Unlock()
|
||||||
fmt.Fprintf(diagWriter, format, args...)
|
core.Print(diagWriter, format, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// setDiagWriter swaps the diagnostic writer and returns the previous one.
|
// setDiagWriter swaps the diagnostic writer and returns the previous one.
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,8 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"forge.lthn.ai/core/go-io"
|
"dappco.re/go/io"
|
||||||
"forge.lthn.ai/core/go-log"
|
"dappco.re/go/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ServeUnix starts a Unix domain socket server for the MCP service.
|
// ServeUnix starts a Unix domain socket server for the MCP service.
|
||||||
|
|
|
||||||
26
tests/cli/mcp/Taskfile.yaml
Normal file
26
tests/cli/mcp/Taskfile.yaml
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
version: "3"
|
||||||
|
|
||||||
|
tasks:
|
||||||
|
default:
|
||||||
|
deps:
|
||||||
|
- build
|
||||||
|
- vet
|
||||||
|
- test
|
||||||
|
|
||||||
|
build:
|
||||||
|
desc: Compile every package + binary in mcp.
|
||||||
|
dir: ../../..
|
||||||
|
cmds:
|
||||||
|
- GOWORK=off go build ./...
|
||||||
|
|
||||||
|
vet:
|
||||||
|
desc: Run go vet across the module.
|
||||||
|
dir: ../../..
|
||||||
|
cmds:
|
||||||
|
- GOWORK=off go vet ./...
|
||||||
|
|
||||||
|
test:
|
||||||
|
desc: Run unit tests.
|
||||||
|
dir: ../../..
|
||||||
|
cmds:
|
||||||
|
- GOWORK=off go test -count=1 ./...
|
||||||
Loading…
Add table
Reference in a new issue