From 599d0b6298dcb2bb9fc0a10eb37faa973eac5719 Mon Sep 17 00:00:00 2001 From: Virgil Date: Thu, 2 Apr 2026 10:43:54 +0000 Subject: [PATCH] feat(brain): add direct list support Co-Authored-By: Virgil --- pkg/mcp/brain/direct.go | 70 +++++++++++++++++++++++++++++++ pkg/mcp/brain/direct_test.go | 81 +++++++++++++++++++++++++++++++++++- 2 files changed, 149 insertions(+), 2 deletions(-) diff --git a/pkg/mcp/brain/direct.go b/pkg/mcp/brain/direct.go index 3a7115f..b568f0e 100644 --- a/pkg/mcp/brain/direct.go +++ b/pkg/mcp/brain/direct.go @@ -9,6 +9,7 @@ import ( "fmt" goio "io" "net/http" + "net/url" "os" "strings" "time" @@ -83,6 +84,11 @@ func (s *DirectSubsystem) RegisterTools(server *mcp.Server) { Name: "brain_forget", Description: "Remove a memory from OpenBrain by ID.", }, s.forget) + + mcp.AddTool(server, &mcp.Tool{ + Name: "brain_list", + Description: "List memories in OpenBrain with optional filtering by project, type, and agent.", + }, s.list) } // Shutdown implements mcp.SubsystemWithShutdown. @@ -231,3 +237,67 @@ func (s *DirectSubsystem) forget(ctx context.Context, _ *mcp.CallToolRequest, in Timestamp: time.Now(), }, nil } + +func (s *DirectSubsystem) list(ctx context.Context, _ *mcp.CallToolRequest, input ListInput) (*mcp.CallToolResult, ListOutput, error) { + limit := input.Limit + if limit == 0 { + limit = 50 + } + + values := url.Values{} + if input.Project != "" { + values.Set("project", input.Project) + } + if input.Type != "" { + values.Set("type", input.Type) + } + if input.AgentID != "" { + values.Set("agent_id", input.AgentID) + } + values.Set("limit", fmt.Sprintf("%d", limit)) + + result, err := s.apiCall(ctx, http.MethodGet, "/v1/brain/list?"+values.Encode(), nil) + if err != nil { + return nil, ListOutput{}, err + } + + var memories []Memory + if mems, ok := result["memories"].([]any); ok { + for _, m := range mems { + if mm, ok := m.(map[string]any); ok { + mem := Memory{ + Content: fmt.Sprintf("%v", mm["content"]), + Type: fmt.Sprintf("%v", mm["type"]), + Project: fmt.Sprintf("%v", mm["project"]), + AgentID: fmt.Sprintf("%v", mm["agent_id"]), + CreatedAt: fmt.Sprintf("%v", mm["created_at"]), + } + if id, ok := mm["id"].(string); ok { + mem.ID = id + } + if score, ok := mm["score"].(float64); ok { + mem.Confidence = score + } + if source, ok := mm["source"].(string); ok { + mem.Tags = append(mem.Tags, "source:"+source) + } + memories = append(memories, mem) + } + } + } + + if s.onChannel != nil { + s.onChannel(ctx, "brain.list.complete", map[string]any{ + "project": input.Project, + "type": input.Type, + "agent": input.AgentID, + "limit": limit, + }) + } + + return nil, ListOutput{ + Success: true, + Count: len(memories), + Memories: memories, + }, nil +} diff --git a/pkg/mcp/brain/direct_test.go b/pkg/mcp/brain/direct_test.go index db3c5bd..486d43f 100644 --- a/pkg/mcp/brain/direct_test.go +++ b/pkg/mcp/brain/direct_test.go @@ -207,8 +207,8 @@ func TestDirectRecall_Good(t *testing.T) { s := newTestDirect(srv.URL) _, out, err := s.recall(context.Background(), nil, RecallInput{ - Query: "scoring algorithm", - TopK: 5, + Query: "scoring algorithm", + TopK: 5, Filter: RecallFilter{Project: "eaas"}, }) if err != nil { @@ -303,3 +303,80 @@ func TestDirectForget_Bad_ApiError(t *testing.T) { t.Error("expected error on 404") } } + +// --- list tool tests --- + +func TestDirectList_Good(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("expected GET, got %s", r.Method) + } + if got := r.URL.Query().Get("project"); got != "eaas" { + t.Errorf("expected project=eaas, got %q", got) + } + if got := r.URL.Query().Get("type"); got != "decision" { + t.Errorf("expected type=decision, got %q", got) + } + if got := r.URL.Query().Get("agent_id"); got != "virgil" { + t.Errorf("expected agent_id=virgil, got %q", got) + } + if got := r.URL.Query().Get("limit"); got != "20" { + t.Errorf("expected limit=20, got %q", got) + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]any{ + "memories": []any{ + map[string]any{ + "id": "mem-1", + "content": "use qdrant", + "type": "decision", + "project": "eaas", + "agent_id": "virgil", + "score": 0.88, + "created_at": "2026-03-01T00:00:00Z", + }, + }, + }) + })) + defer srv.Close() + + s := newTestDirect(srv.URL) + _, out, err := s.list(context.Background(), nil, ListInput{ + Project: "eaas", + Type: "decision", + AgentID: "virgil", + Limit: 20, + }) + if err != nil { + t.Fatalf("list failed: %v", err) + } + if !out.Success || out.Count != 1 { + t.Fatalf("expected 1 memory, got %+v", out) + } + if out.Memories[0].ID != "mem-1" { + t.Errorf("expected id=mem-1, got %q", out.Memories[0].ID) + } + if out.Memories[0].Confidence != 0.88 { + t.Errorf("expected score=0.88, got %f", out.Memories[0].Confidence) + } +} + +func TestDirectList_Good_DefaultLimit(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.URL.Query().Get("limit"); got != "50" { + t.Errorf("expected limit=50, got %q", got) + } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]any{"memories": []any{}}) + })) + defer srv.Close() + + s := newTestDirect(srv.URL) + _, out, err := s.list(context.Background(), nil, ListInput{}) + if err != nil { + t.Fatalf("list failed: %v", err) + } + if !out.Success || out.Count != 0 { + t.Fatalf("expected empty list, got %+v", out) + } +}