feat(brain): add direct list support
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
b82d399349
commit
599d0b6298
2 changed files with 149 additions and 2 deletions
|
|
@ -9,6 +9,7 @@ import (
|
|||
"fmt"
|
||||
goio "io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
|
@ -83,6 +84,11 @@ func (s *DirectSubsystem) RegisterTools(server *mcp.Server) {
|
|||
Name: "brain_forget",
|
||||
Description: "Remove a memory from OpenBrain by ID.",
|
||||
}, s.forget)
|
||||
|
||||
mcp.AddTool(server, &mcp.Tool{
|
||||
Name: "brain_list",
|
||||
Description: "List memories in OpenBrain with optional filtering by project, type, and agent.",
|
||||
}, s.list)
|
||||
}
|
||||
|
||||
// Shutdown implements mcp.SubsystemWithShutdown.
|
||||
|
|
@ -231,3 +237,67 @@ func (s *DirectSubsystem) forget(ctx context.Context, _ *mcp.CallToolRequest, in
|
|||
Timestamp: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *DirectSubsystem) list(ctx context.Context, _ *mcp.CallToolRequest, input ListInput) (*mcp.CallToolResult, ListOutput, error) {
|
||||
limit := input.Limit
|
||||
if limit == 0 {
|
||||
limit = 50
|
||||
}
|
||||
|
||||
values := url.Values{}
|
||||
if input.Project != "" {
|
||||
values.Set("project", input.Project)
|
||||
}
|
||||
if input.Type != "" {
|
||||
values.Set("type", input.Type)
|
||||
}
|
||||
if input.AgentID != "" {
|
||||
values.Set("agent_id", input.AgentID)
|
||||
}
|
||||
values.Set("limit", fmt.Sprintf("%d", limit))
|
||||
|
||||
result, err := s.apiCall(ctx, http.MethodGet, "/v1/brain/list?"+values.Encode(), nil)
|
||||
if err != nil {
|
||||
return nil, ListOutput{}, err
|
||||
}
|
||||
|
||||
var memories []Memory
|
||||
if mems, ok := result["memories"].([]any); ok {
|
||||
for _, m := range mems {
|
||||
if mm, ok := m.(map[string]any); ok {
|
||||
mem := Memory{
|
||||
Content: fmt.Sprintf("%v", mm["content"]),
|
||||
Type: fmt.Sprintf("%v", mm["type"]),
|
||||
Project: fmt.Sprintf("%v", mm["project"]),
|
||||
AgentID: fmt.Sprintf("%v", mm["agent_id"]),
|
||||
CreatedAt: fmt.Sprintf("%v", mm["created_at"]),
|
||||
}
|
||||
if id, ok := mm["id"].(string); ok {
|
||||
mem.ID = id
|
||||
}
|
||||
if score, ok := mm["score"].(float64); ok {
|
||||
mem.Confidence = score
|
||||
}
|
||||
if source, ok := mm["source"].(string); ok {
|
||||
mem.Tags = append(mem.Tags, "source:"+source)
|
||||
}
|
||||
memories = append(memories, mem)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if s.onChannel != nil {
|
||||
s.onChannel(ctx, "brain.list.complete", map[string]any{
|
||||
"project": input.Project,
|
||||
"type": input.Type,
|
||||
"agent": input.AgentID,
|
||||
"limit": limit,
|
||||
})
|
||||
}
|
||||
|
||||
return nil, ListOutput{
|
||||
Success: true,
|
||||
Count: len(memories),
|
||||
Memories: memories,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -207,8 +207,8 @@ func TestDirectRecall_Good(t *testing.T) {
|
|||
|
||||
s := newTestDirect(srv.URL)
|
||||
_, out, err := s.recall(context.Background(), nil, RecallInput{
|
||||
Query: "scoring algorithm",
|
||||
TopK: 5,
|
||||
Query: "scoring algorithm",
|
||||
TopK: 5,
|
||||
Filter: RecallFilter{Project: "eaas"},
|
||||
})
|
||||
if err != nil {
|
||||
|
|
@ -303,3 +303,80 @@ func TestDirectForget_Bad_ApiError(t *testing.T) {
|
|||
t.Error("expected error on 404")
|
||||
}
|
||||
}
|
||||
|
||||
// --- list tool tests ---
|
||||
|
||||
func TestDirectList_Good(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
t.Errorf("expected GET, got %s", r.Method)
|
||||
}
|
||||
if got := r.URL.Query().Get("project"); got != "eaas" {
|
||||
t.Errorf("expected project=eaas, got %q", got)
|
||||
}
|
||||
if got := r.URL.Query().Get("type"); got != "decision" {
|
||||
t.Errorf("expected type=decision, got %q", got)
|
||||
}
|
||||
if got := r.URL.Query().Get("agent_id"); got != "virgil" {
|
||||
t.Errorf("expected agent_id=virgil, got %q", got)
|
||||
}
|
||||
if got := r.URL.Query().Get("limit"); got != "20" {
|
||||
t.Errorf("expected limit=20, got %q", got)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"memories": []any{
|
||||
map[string]any{
|
||||
"id": "mem-1",
|
||||
"content": "use qdrant",
|
||||
"type": "decision",
|
||||
"project": "eaas",
|
||||
"agent_id": "virgil",
|
||||
"score": 0.88,
|
||||
"created_at": "2026-03-01T00:00:00Z",
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
s := newTestDirect(srv.URL)
|
||||
_, out, err := s.list(context.Background(), nil, ListInput{
|
||||
Project: "eaas",
|
||||
Type: "decision",
|
||||
AgentID: "virgil",
|
||||
Limit: 20,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("list failed: %v", err)
|
||||
}
|
||||
if !out.Success || out.Count != 1 {
|
||||
t.Fatalf("expected 1 memory, got %+v", out)
|
||||
}
|
||||
if out.Memories[0].ID != "mem-1" {
|
||||
t.Errorf("expected id=mem-1, got %q", out.Memories[0].ID)
|
||||
}
|
||||
if out.Memories[0].Confidence != 0.88 {
|
||||
t.Errorf("expected score=0.88, got %f", out.Memories[0].Confidence)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDirectList_Good_DefaultLimit(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.URL.Query().Get("limit"); got != "50" {
|
||||
t.Errorf("expected limit=50, got %q", got)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]any{"memories": []any{}})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
s := newTestDirect(srv.URL)
|
||||
_, out, err := s.list(context.Background(), nil, ListInput{})
|
||||
if err != nil {
|
||||
t.Fatalf("list failed: %v", err)
|
||||
}
|
||||
if !out.Success || out.Count != 0 {
|
||||
t.Fatalf("expected empty list, got %+v", out)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue