agent/hermes/plugins/openbrain_memory.py

641 lines
21 KiB
Python
Raw Normal View History

# SPDX-License-Identifier: EUPL-1.2
from __future__ import annotations
import atexit
import importlib
import json
import shlex
import socket
import threading
from typing import Any
from urllib.error import HTTPError, URLError
from urllib.parse import urlencode, urlparse
from urllib.request import Request, urlopen
try:
import requests # type: ignore
except ImportError: # pragma: no cover - exercised through fallbacks
requests = None
try:
import httpx # type: ignore
except ImportError: # pragma: no cover - exercised through fallbacks
httpx = None
VALID_MEMORY_TYPES = [
"fact",
"decision",
"observation",
"convention",
"research",
"plan",
"bug",
"architecture",
"documentation",
"service",
"pattern",
"context",
"procedure",
]
class OpenBrainMemoryProvider:
def __init__(
self,
brain_url: str,
api_key: str,
qdrant_url: str,
pg_dsn: str,
workspace_id: int,
org: str | None = None,
) -> None:
self.brain_url = brain_url.rstrip("/")
self.api_key = api_key
self.qdrant_url = qdrant_url.rstrip("/")
self.pg_dsn = pg_dsn
self.workspace_id = workspace_id
self.org = org
self._initialised = False
self._spawn = None
self._pending_writes: list[Any] = []
self._pending_lock = threading.Lock()
def is_available(self) -> bool:
return self._qdrant_reachable() and self._postgres_reachable()
def initialize(self) -> None:
if self._initialised:
return
self._spawn = self._load_core_spawn()
atexit.register(self.on_session_end)
self._initialised = True
def get_tool_schemas(self) -> list[dict]:
return [
{
"name": "brain_remember",
"description": (
"Store a memory in the shared OpenBrain knowledge store. "
"Use this for durable decisions, observations, conventions, "
"research, plans, bugs, or architecture notes."
),
"inputSchema": {
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "The knowledge to remember.",
"maxLength": 50000,
},
"type": {
"type": "string",
"description": "Memory type classification.",
"enum": VALID_MEMORY_TYPES,
},
"tags": {
"type": "array",
"items": {"type": "string"},
"description": "Optional tags for categorisation.",
},
"project": {
"type": "string",
"description": "Optional project scope.",
},
"org": {
"type": "string",
"description": "Optional organisation scope.",
},
"confidence": {
"type": "number",
"description": "Confidence score from 0.0 to 1.0.",
"minimum": 0.0,
"maximum": 1.0,
},
"supersedes": {
"type": "string",
"format": "uuid",
"description": "UUID of an older memory this entry replaces.",
},
"expires_in": {
"type": "integer",
"description": "Hours until the memory expires.",
"minimum": 1,
},
},
"required": ["content", "type"],
},
},
{
"name": "brain_recall",
"description": (
"Semantic search across the shared OpenBrain knowledge store. "
"Returns memories ranked by similarity to the query."
),
"inputSchema": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Natural-language search query.",
"maxLength": 2000,
},
"limit": {
"type": "integer",
"description": "Maximum results to return.",
"minimum": 1,
"maximum": 20,
"default": 5,
},
"top_k": {
"type": "integer",
"description": "Alias for limit used by the Brain API.",
"minimum": 1,
"maximum": 20,
},
"workspace_id": {
"type": "integer",
"description": "Workspace scope for the recall request.",
"minimum": 1,
},
"org": {
"type": "string",
"description": "Optional organisation filter.",
},
"project": {
"type": "string",
"description": "Optional project filter.",
},
"type": {
"description": "Optional memory type filter.",
"oneOf": [
{"type": "string", "enum": VALID_MEMORY_TYPES},
{
"type": "array",
"items": {
"type": "string",
"enum": VALID_MEMORY_TYPES,
},
},
],
},
"keywords": {
"type": "array",
"items": {"type": "string"},
"description": "Keywords that should be present in matching memories.",
},
"boost_keywords": {
"type": "array",
"items": {"type": "string"},
"description": "Keywords that should receive additional ranking weight.",
},
"agent_id": {
"type": "string",
"description": "Optional originating-agent filter.",
},
"min_confidence": {
"type": "number",
"description": "Minimum confidence threshold.",
"minimum": 0.0,
"maximum": 1.0,
},
},
"required": ["query"],
},
},
{
"name": "brain_forget",
"description": "Remove a memory from the shared OpenBrain knowledge store by UUID.",
"inputSchema": {
"type": "object",
"properties": {
"id": {
"type": "string",
"format": "uuid",
"description": "UUID of the memory to remove.",
},
"reason": {
"type": "string",
"description": "Optional reason for forgetting this memory.",
"maxLength": 500,
},
},
"required": ["id"],
},
},
{
"name": "brain_list",
"description": (
"List memories in the shared OpenBrain knowledge store. "
"Supports filtering by project, type, agent, and limit."
),
"inputSchema": {
"type": "object",
"properties": {
"workspace_id": {
"type": "integer",
"description": "Workspace scope for the list request.",
"minimum": 1,
},
"org": {
"type": "string",
"description": "Optional organisation scope.",
},
"project": {
"type": "string",
"description": "Optional project scope.",
},
"type": {
"type": "string",
"description": "Optional memory type filter.",
"enum": VALID_MEMORY_TYPES,
},
"agent_id": {
"type": "string",
"description": "Optional originating-agent filter.",
},
"limit": {
"type": "integer",
"description": "Maximum results to return.",
"minimum": 1,
"maximum": 100,
"default": 20,
},
},
},
},
]
def handle_tool_call(self, name: str, args: dict) -> dict:
self.initialize()
request_args = dict(args or {})
if name == "brain_remember":
payload = self._with_context_defaults(request_args, include_workspace=False)
return self._request_json(
"POST",
self._brain_endpoint("remember"),
json_body=payload,
headers=self._auth_headers(),
)
if name == "brain_recall":
payload = self._with_context_defaults(request_args)
if "limit" in payload and "top_k" not in payload:
payload["top_k"] = payload["limit"]
return self._request_json(
"POST",
self._brain_endpoint("recall"),
json_body=payload,
headers=self._auth_headers(),
)
if name == "brain_forget":
memory_id = request_args.get("id")
if not memory_id:
raise ValueError("brain_forget requires an id")
payload = self._clean_mapping({"reason": request_args.get("reason")})
return self._request_json(
"DELETE",
self._brain_endpoint(f"forget/{memory_id}"),
json_body=payload or None,
headers=self._auth_headers(),
)
if name == "brain_list":
params = self._with_context_defaults(request_args)
return self._request_json(
"GET",
self._brain_endpoint("list"),
params=params,
headers=self._auth_headers(),
)
raise ValueError(f"Unsupported tool call: {name}")
def sync_turn(self, turn: dict) -> None:
self.initialize()
payload = self._build_turn_memory(turn)
if self._dispatch_pending_write(payload):
return
try:
self._request_json(
"POST",
self._brain_endpoint("remember"),
json_body=payload,
headers=self._auth_headers(),
timeout=1.0,
)
except Exception:
return
def system_prompt_block(self) -> str:
return (
"Librarian keeps compact shared memory for the workspace: durable facts, "
"decisions, observations, conventions, plans, bugs, architecture notes, "
"and session context stored with project scope, tags, and confidence."
)
def on_session_end(self) -> None:
with self._pending_lock:
pending = list(self._pending_writes)
self._pending_writes.clear()
for handle in pending:
self._await_pending(handle)
def _qdrant_reachable(self) -> bool:
try:
status = self._request_status("GET", self._qdrant_probe_url(), timeout=1.5)
except Exception:
return False
return 200 <= status < 500
def _postgres_reachable(self) -> bool:
host, port = self._postgres_target()
if not host:
return False
timeout = 1.5
try:
if host.startswith("/"):
socket_path = f"{host.rstrip('/')}/.s.PGSQL.{port}"
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as client:
client.settimeout(timeout)
client.connect(socket_path)
else:
with socket.create_connection((host, port), timeout=timeout):
pass
except OSError:
return False
return True
def _qdrant_probe_url(self) -> str:
if not self.qdrant_url:
return ""
parsed = urlparse(self.qdrant_url)
path = parsed.path or ""
if not path or path == "/":
return f"{self.qdrant_url}/collections"
return self.qdrant_url
def _postgres_target(self) -> tuple[str | None, int]:
parsed = urlparse(self.pg_dsn)
if parsed.scheme in {"postgres", "postgresql"}:
host = parsed.hostname
port = parsed.port or 5432
return host, port
parts: dict[str, str] = {}
for token in shlex.split(self.pg_dsn):
if "=" not in token:
continue
key, value = token.split("=", 1)
parts[key.strip()] = value.strip()
host = parts.get("host") or parts.get("hostaddr") or "localhost"
port_value = parts.get("port", "5432")
try:
port = int(port_value)
except ValueError:
port = 5432
if "," in host:
host = host.split(",", 1)[0]
return host, port
def _load_core_spawn(self):
try:
task_module = importlib.import_module("core.task")
except ImportError:
return None
return getattr(task_module, "spawn", None)
def _dispatch_pending_write(self, payload: dict) -> bool:
if not callable(self._spawn):
return False
handle = None
try:
handle = self._spawn(self._post_turn_memory, payload)
except TypeError:
try:
handle = self._spawn(lambda: self._post_turn_memory(payload))
except Exception:
return False
except Exception:
return False
if handle is not None:
with self._pending_lock:
self._pending_writes.append(handle)
return True
def _await_pending(self, handle: Any) -> None:
waiters = (
("result", (2.0,)),
("join", (2.0,)),
("wait", (2.0,)),
)
for name, args in waiters:
waiter = getattr(handle, name, None)
if callable(waiter):
try:
waiter(*args)
except TypeError:
waiter()
except Exception:
pass
return
def _post_turn_memory(self, payload: dict) -> None:
self._request_json(
"POST",
self._brain_endpoint("remember"),
json_body=payload,
headers=self._auth_headers(),
timeout=1.0,
)
def _build_turn_memory(self, turn: dict) -> dict:
tags = list(turn.get("tags") or [])
if "hermes" not in tags:
tags.append("hermes")
if "session-turn" not in tags:
tags.append("session-turn")
payload = {
"content": json.dumps(turn, sort_keys=True, default=str),
"type": turn.get("type") or "context",
"tags": tags,
"project": turn.get("project"),
"org": turn.get("org", self.org),
"confidence": turn.get("confidence", 0.6),
"workspace_id": turn.get("workspace_id", self.workspace_id),
}
return self._clean_mapping(payload)
def _with_context_defaults(self, values: dict, include_workspace: bool = True) -> dict:
payload = dict(values)
if include_workspace and "workspace_id" not in payload:
payload["workspace_id"] = self.workspace_id
if self.org and "org" not in payload:
payload["org"] = self.org
return self._clean_mapping(payload)
def _clean_mapping(self, values: dict) -> dict:
return {key: value for key, value in values.items() if value is not None and value != ""}
def _brain_endpoint(self, suffix: str) -> str:
return f"{self.brain_url}/v1/brain/{suffix.lstrip('/')}"
def _auth_headers(self) -> dict[str, str]:
return {
"Accept": "application/json",
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
def _request_status(
self,
method: str,
url: str,
*,
params: dict | None = None,
json_body: dict | None = None,
headers: dict | None = None,
timeout: float = 5.0,
) -> int:
status, _ = self._raw_request(
method,
url,
params=params,
json_body=json_body,
headers=headers,
timeout=timeout,
)
return status
def _request_json(
self,
method: str,
url: str,
*,
params: dict | None = None,
json_body: dict | None = None,
headers: dict | None = None,
timeout: float = 5.0,
) -> dict:
status, text = self._raw_request(
method,
url,
params=params,
json_body=json_body,
headers=headers,
timeout=timeout,
)
if not text:
return {"status": status}
try:
payload = json.loads(text)
except json.JSONDecodeError:
return {"status": status, "data": text}
if isinstance(payload, dict):
payload.setdefault("status", status)
return payload
return {"status": status, "data": payload}
def _raw_request(
self,
method: str,
url: str,
*,
params: dict | None = None,
json_body: dict | None = None,
headers: dict | None = None,
timeout: float = 5.0,
) -> tuple[int, str]:
if requests is not None:
response = requests.request(
method,
url,
params=params,
json=json_body,
headers=headers,
timeout=timeout,
)
return response.status_code, response.text
if httpx is not None:
response = httpx.request(
method,
url,
params=params,
json=json_body,
headers=headers,
timeout=timeout,
)
return response.status_code, response.text
return self._urllib_request(
method,
url,
params=params,
json_body=json_body,
headers=headers,
timeout=timeout,
)
def _urllib_request(
self,
method: str,
url: str,
*,
params: dict | None = None,
json_body: dict | None = None,
headers: dict | None = None,
timeout: float = 5.0,
) -> tuple[int, str]:
request_headers = dict(headers or {})
request_url = url
if params:
query_string = urlencode(params, doseq=True)
separator = "&" if "?" in request_url else "?"
request_url = f"{request_url}{separator}{query_string}"
data = None
if json_body is not None:
request_headers.setdefault("Content-Type", "application/json")
data = json.dumps(json_body).encode("utf-8")
request = Request(request_url, data=data, headers=request_headers, method=method)
try:
with urlopen(request, timeout=timeout) as response:
return response.getcode(), response.read().decode("utf-8")
except HTTPError as exc:
return exc.code, exc.read().decode("utf-8")
except URLError as exc:
raise OSError(str(exc)) from exc