feat(brain): add direct list support

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-04-02 10:43:54 +00:00
parent b82d399349
commit 599d0b6298
2 changed files with 149 additions and 2 deletions

View file

@ -9,6 +9,7 @@ import (
"fmt"
goio "io"
"net/http"
"net/url"
"os"
"strings"
"time"
@ -83,6 +84,11 @@ func (s *DirectSubsystem) RegisterTools(server *mcp.Server) {
Name: "brain_forget",
Description: "Remove a memory from OpenBrain by ID.",
}, s.forget)
mcp.AddTool(server, &mcp.Tool{
Name: "brain_list",
Description: "List memories in OpenBrain with optional filtering by project, type, and agent.",
}, s.list)
}
// Shutdown implements mcp.SubsystemWithShutdown.
@ -231,3 +237,67 @@ func (s *DirectSubsystem) forget(ctx context.Context, _ *mcp.CallToolRequest, in
Timestamp: time.Now(),
}, nil
}
func (s *DirectSubsystem) list(ctx context.Context, _ *mcp.CallToolRequest, input ListInput) (*mcp.CallToolResult, ListOutput, error) {
limit := input.Limit
if limit == 0 {
limit = 50
}
values := url.Values{}
if input.Project != "" {
values.Set("project", input.Project)
}
if input.Type != "" {
values.Set("type", input.Type)
}
if input.AgentID != "" {
values.Set("agent_id", input.AgentID)
}
values.Set("limit", fmt.Sprintf("%d", limit))
result, err := s.apiCall(ctx, http.MethodGet, "/v1/brain/list?"+values.Encode(), nil)
if err != nil {
return nil, ListOutput{}, err
}
var memories []Memory
if mems, ok := result["memories"].([]any); ok {
for _, m := range mems {
if mm, ok := m.(map[string]any); ok {
mem := Memory{
Content: fmt.Sprintf("%v", mm["content"]),
Type: fmt.Sprintf("%v", mm["type"]),
Project: fmt.Sprintf("%v", mm["project"]),
AgentID: fmt.Sprintf("%v", mm["agent_id"]),
CreatedAt: fmt.Sprintf("%v", mm["created_at"]),
}
if id, ok := mm["id"].(string); ok {
mem.ID = id
}
if score, ok := mm["score"].(float64); ok {
mem.Confidence = score
}
if source, ok := mm["source"].(string); ok {
mem.Tags = append(mem.Tags, "source:"+source)
}
memories = append(memories, mem)
}
}
}
if s.onChannel != nil {
s.onChannel(ctx, "brain.list.complete", map[string]any{
"project": input.Project,
"type": input.Type,
"agent": input.AgentID,
"limit": limit,
})
}
return nil, ListOutput{
Success: true,
Count: len(memories),
Memories: memories,
}, nil
}

View file

@ -207,8 +207,8 @@ func TestDirectRecall_Good(t *testing.T) {
s := newTestDirect(srv.URL)
_, out, err := s.recall(context.Background(), nil, RecallInput{
Query: "scoring algorithm",
TopK: 5,
Query: "scoring algorithm",
TopK: 5,
Filter: RecallFilter{Project: "eaas"},
})
if err != nil {
@ -303,3 +303,80 @@ func TestDirectForget_Bad_ApiError(t *testing.T) {
t.Error("expected error on 404")
}
}
// --- list tool tests ---
func TestDirectList_Good(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
t.Errorf("expected GET, got %s", r.Method)
}
if got := r.URL.Query().Get("project"); got != "eaas" {
t.Errorf("expected project=eaas, got %q", got)
}
if got := r.URL.Query().Get("type"); got != "decision" {
t.Errorf("expected type=decision, got %q", got)
}
if got := r.URL.Query().Get("agent_id"); got != "virgil" {
t.Errorf("expected agent_id=virgil, got %q", got)
}
if got := r.URL.Query().Get("limit"); got != "20" {
t.Errorf("expected limit=20, got %q", got)
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]any{
"memories": []any{
map[string]any{
"id": "mem-1",
"content": "use qdrant",
"type": "decision",
"project": "eaas",
"agent_id": "virgil",
"score": 0.88,
"created_at": "2026-03-01T00:00:00Z",
},
},
})
}))
defer srv.Close()
s := newTestDirect(srv.URL)
_, out, err := s.list(context.Background(), nil, ListInput{
Project: "eaas",
Type: "decision",
AgentID: "virgil",
Limit: 20,
})
if err != nil {
t.Fatalf("list failed: %v", err)
}
if !out.Success || out.Count != 1 {
t.Fatalf("expected 1 memory, got %+v", out)
}
if out.Memories[0].ID != "mem-1" {
t.Errorf("expected id=mem-1, got %q", out.Memories[0].ID)
}
if out.Memories[0].Confidence != 0.88 {
t.Errorf("expected score=0.88, got %f", out.Memories[0].Confidence)
}
}
func TestDirectList_Good_DefaultLimit(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.URL.Query().Get("limit"); got != "50" {
t.Errorf("expected limit=50, got %q", got)
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]any{"memories": []any{}})
}))
defer srv.Close()
s := newTestDirect(srv.URL)
_, out, err := s.list(context.Background(), nil, ListInput{})
if err != nil {
t.Fatalf("list failed: %v", err)
}
if !out.Success || out.Count != 0 {
t.Fatalf("expected empty list, got %+v", out)
}
}