agent/hermes/plugins/openbrain_context.py
Snider 711e2eef72 feat(hermes): add openbrain_context.py ContextEngine plugin
Python plugin implementing Hermes ContextEngine backed by OpenBrain.
compress() does centrality-ranked retrieval over a candidate pool
pulled via brain_recall rather than linear turn truncation. Falls
back to naive head+tail truncation when recall is unavailable so the
caller never sees a raised exception.

Closes tasks.lthn.sh/view.php?id=74
Co-authored-by: Codex <noreply@openai.com>

Co-Authored-By: Virgil <virgil@lethean.io>
2026-04-23 17:29:50 +01:00

661 lines
21 KiB
Python

# SPDX-License-Identifier: EUPL-1.2
from __future__ import annotations
import importlib
import json
import math
import shlex
import socket
import sys
from collections.abc import Iterable
from urllib.error import HTTPError, URLError
from urllib.parse import urlencode, urlparse
from urllib.request import Request, urlopen
try:
import requests # type: ignore
except ImportError: # pragma: no cover - exercised through fallbacks
requests = None
try:
import httpx # type: ignore
except ImportError: # pragma: no cover - exercised through fallbacks
httpx = None
try:
import networkx as nx # type: ignore
except ImportError: # pragma: no cover - exercised through manual fallback
nx = None
class OpenBrainContextEngine:
def __init__(
self,
brain_url: str,
api_key: str,
qdrant_url: str,
pg_dsn: str,
workspace_id: int,
org: str | None = None,
) -> None:
self.brain_url = brain_url.rstrip("/")
self.api_key = api_key
self.qdrant_url = qdrant_url.rstrip("/")
self.pg_dsn = pg_dsn
self.workspace_id = workspace_id
self.org = org
self._initialised = False
self._spawn = None
self._similarity_threshold = 0.6
def is_available(self) -> bool:
return self._qdrant_reachable() and self._postgres_reachable()
def initialize(self) -> None:
if self._initialised:
return
self._spawn = self._load_core_spawn()
self._initialised = True
def compress(
self,
turns: list[dict],
*,
token_budget: int,
query_hint: str | None = None,
top_k: int = 20,
candidate_pool: int = 200,
) -> list[dict]:
ordered_turns = list(turns)
if len(ordered_turns) <= 2:
return ordered_turns
self.initialize()
budget = max(int(token_budget), 0)
total_tokens = self._estimate_total_tokens(ordered_turns)
if total_tokens <= budget:
return ordered_turns
query = self._derive_query(ordered_turns, query_hint)
if not query:
return self._naive_head_tail(ordered_turns, budget)
try:
candidates = self._recall_candidates(query, max(int(candidate_pool), 1))
except Exception as exc:
self._warn(f"OpenBrain recall failed in compress(); falling back to head+tail truncation: {exc}")
return self._naive_head_tail(ordered_turns, budget)
if not candidates:
self._warn("OpenBrain recall returned no candidates in compress(); falling back to head+tail truncation.")
return self._naive_head_tail(ordered_turns, budget)
anchor_indices = {0, len(ordered_turns) - 1}
nodes = self._build_turn_nodes(ordered_turns) + self._build_candidate_nodes(candidates)
centrality, affinity = self._graph_scores(nodes)
ranked_turns = self._rank_turn_indices(ordered_turns, centrality, affinity, anchor_indices)
rank_selected = self._select_ranked_turns(ranked_turns, anchor_indices, max(int(top_k), 0))
rank_selected = self._trim_selection_to_budget(rank_selected, ranked_turns, ordered_turns, budget, anchor_indices)
budget_selected = self._select_budget_turns(ranked_turns, ordered_turns, budget, anchor_indices)
if len(budget_selected) > len(rank_selected):
selected = budget_selected
else:
selected = rank_selected
return [turn for index, turn in enumerate(ordered_turns) if index in selected]
def system_prompt_block(self) -> str:
return (
"Librarian/Cartographer compresses chat history with centrality-ranked OpenBrain recall, "
"keeping the system anchor and current user turn while dropping low-centrality turns "
"when the token budget is tight."
)
def _qdrant_reachable(self) -> bool:
try:
status = self._request_status("GET", self._qdrant_probe_url(), timeout=1.5)
except Exception:
return False
return 200 <= status < 500
def _postgres_reachable(self) -> bool:
host, port = self._postgres_target()
if not host:
return False
timeout = 1.5
try:
if host.startswith("/"):
socket_path = f"{host.rstrip('/')}/.s.PGSQL.{port}"
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as client:
client.settimeout(timeout)
client.connect(socket_path)
else:
with socket.create_connection((host, port), timeout=timeout):
pass
except OSError:
return False
return True
def _qdrant_probe_url(self) -> str:
if not self.qdrant_url:
return ""
parsed = urlparse(self.qdrant_url)
path = parsed.path or ""
if not path or path == "/":
return f"{self.qdrant_url}/collections"
return self.qdrant_url
def _postgres_target(self) -> tuple[str | None, int]:
parsed = urlparse(self.pg_dsn)
if parsed.scheme in {"postgres", "postgresql"}:
host = parsed.hostname
port = parsed.port or 5432
return host, port
parts: dict[str, str] = {}
for token in shlex.split(self.pg_dsn):
if "=" not in token:
continue
key, value = token.split("=", 1)
parts[key.strip()] = value.strip()
host = parts.get("host") or parts.get("hostaddr") or "localhost"
port_value = parts.get("port", "5432")
try:
port = int(port_value)
except ValueError:
port = 5432
if "," in host:
host = host.split(",", 1)[0]
return host, port
def _load_core_spawn(self):
try:
task_module = importlib.import_module("core.task")
except ImportError:
return None
return getattr(task_module, "spawn", None)
def _recall_candidates(self, query: str, candidate_pool: int) -> list[dict]:
payload = {
"query": query,
"top_k": candidate_pool,
"filter": self._clean_mapping(
{
"workspace_id": self.workspace_id,
"org": self.org,
}
),
}
response = self._request_json(
"POST",
self._brain_endpoint("recall"),
json_body=payload,
headers=self._auth_headers(),
)
status = int(response.get("status", 200))
if status >= 400:
raise OSError(f"recall returned status {status}")
return self._extract_candidates(response)
def _extract_candidates(self, response: dict) -> list[dict]:
sources: list[object] = [response]
data = response.get("data")
if isinstance(data, dict):
sources.append(data)
items: list[dict] = []
for source in sources:
if not isinstance(source, dict):
continue
for key in ("memories", "results", "items", "matches", "data"):
value = source.get(key)
if isinstance(value, list):
items.extend(item for item in value if isinstance(item, dict))
return [self._normalise_candidate(item, index) for index, item in enumerate(items)]
def _normalise_candidate(self, item: dict, index: int) -> dict:
payload = item.get("payload")
if not isinstance(payload, dict):
payload = {}
candidate_id = item.get("id") or item.get("memory_id") or item.get("uuid") or f"memory:{index}"
text = (
payload.get("content")
or item.get("content")
or payload.get("text")
or item.get("text")
or ""
)
return {
"id": str(candidate_id),
"text": self._stringify_text(text),
"payload": payload,
"score": self._float_value(item.get("score"), item.get("similarity"), default=0.0),
"vector": self._coerce_vector(
item.get("vector")
or item.get("embedding")
or payload.get("vector")
or payload.get("embedding")
),
}
def _build_turn_nodes(self, turns: list[dict]) -> list[dict]:
nodes: list[dict] = []
for index, turn in enumerate(turns):
node_id = turn.get("id") or turn.get("uuid") or f"turn:{index}"
nodes.append(
{
"node_id": str(node_id),
"turn_index": index,
"kind": "turn",
"text": self._turn_text(turn),
"vector": self._coerce_vector(turn.get("vector") or turn.get("embedding")),
"score": 0.0,
}
)
return nodes
def _build_candidate_nodes(self, candidates: list[dict]) -> list[dict]:
return [
{
"node_id": candidate["id"],
"turn_index": None,
"kind": "memory",
"text": candidate["text"],
"vector": candidate["vector"],
"score": candidate["score"],
}
for candidate in candidates
]
def _graph_scores(self, nodes: list[dict]) -> tuple[dict[str, float], dict[str, float]]:
adjacency = {node["node_id"]: set() for node in nodes}
affinity = {node["node_id"]: 0.0 for node in nodes}
for left_index in range(len(nodes)):
left = nodes[left_index]
for right_index in range(left_index + 1, len(nodes)):
right = nodes[right_index]
similarity = self._node_similarity(left, right)
if similarity < self._similarity_threshold:
continue
adjacency[left["node_id"]].add(right["node_id"])
adjacency[right["node_id"]].add(left["node_id"])
if left["kind"] != right["kind"]:
affinity[left["node_id"]] = max(affinity[left["node_id"]], similarity)
affinity[right["node_id"]] = max(affinity[right["node_id"]], similarity)
centrality = self._degree_centrality(nodes, adjacency)
return centrality, affinity
def _degree_centrality(self, nodes: list[dict], adjacency: dict[str, set[str]]) -> dict[str, float]:
if nx is not None:
graph = nx.Graph()
for node in nodes:
graph.add_node(node["node_id"])
for node_id, edges in adjacency.items():
for edge in edges:
graph.add_edge(node_id, edge)
return dict(nx.degree_centrality(graph))
normaliser = max(len(nodes) - 1, 1)
return {node_id: len(edges) / normaliser for node_id, edges in adjacency.items()}
def _rank_turn_indices(
self,
turns: list[dict],
centrality: dict[str, float],
affinity: dict[str, float],
anchor_indices: set[int],
) -> list[int]:
ranked: list[tuple[float, float, float, int]] = []
for index, turn in enumerate(turns):
if index in anchor_indices:
continue
node_id = str(turn.get("id") or turn.get("uuid") or f"turn:{index}")
ranked.append(
(
-centrality.get(node_id, 0.0),
-affinity.get(node_id, 0.0),
self._estimate_tokens(turn),
index,
)
)
ranked.sort()
return [index for _, _, _, index in ranked]
def _select_ranked_turns(self, ranked_turns: list[int], anchor_indices: set[int], top_k: int) -> set[int]:
selected = set(anchor_indices)
for index in ranked_turns[:top_k]:
selected.add(index)
return selected
def _select_budget_turns(
self,
ranked_turns: list[int],
turns: list[dict],
token_budget: int,
anchor_indices: set[int],
) -> set[int]:
selected = set(anchor_indices)
used_tokens = self._selection_tokens(selected, turns)
for index in ranked_turns:
turn_tokens = self._estimate_tokens(turns[index])
if used_tokens + turn_tokens > token_budget and selected:
continue
selected.add(index)
used_tokens += turn_tokens
return selected
def _trim_selection_to_budget(
self,
selected: set[int],
ranked_turns: list[int],
turns: list[dict],
token_budget: int,
anchor_indices: set[int],
) -> set[int]:
trimmed = set(selected)
if self._selection_tokens(trimmed, turns) <= token_budget:
return trimmed
removable = [index for index in reversed(ranked_turns) if index in trimmed and index not in anchor_indices]
for index in removable:
trimmed.remove(index)
if self._selection_tokens(trimmed, turns) <= token_budget:
break
return trimmed
def _selection_tokens(self, selected: set[int], turns: list[dict]) -> int:
return sum(self._estimate_tokens(turns[index]) for index in selected)
def _naive_head_tail(self, turns: list[dict], token_budget: int) -> list[dict]:
if not turns:
return []
if len(turns) <= 2:
return list(turns)
selected = {0, len(turns) - 1}
used_tokens = self._selection_tokens(selected, turns)
for index in range(len(turns) - 2, 0, -1):
turn_tokens = self._estimate_tokens(turns[index])
if used_tokens + turn_tokens > token_budget and selected:
break
selected.add(index)
used_tokens += turn_tokens
return [turn for index, turn in enumerate(turns) if index in selected]
def _derive_query(self, turns: list[dict], query_hint: str | None) -> str:
if query_hint and query_hint.strip():
return query_hint.strip()
last_text = self._turn_text(turns[-1])
if last_text:
return last_text
if len(turns) >= 2:
previous = self._turn_text(turns[-2])
parts = [part for part in (previous, last_text) if part]
return "\n".join(parts)
return ""
def _estimate_total_tokens(self, turns: list[dict]) -> int:
return sum(self._estimate_tokens(turn) for turn in turns)
def _estimate_tokens(self, turn: dict) -> int:
return len(self._turn_text(turn)) // 4
def _turn_text(self, turn: dict) -> str:
content = turn.get("content")
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [self._stringify_text(item) for item in content]
return " ".join(part for part in parts if part)
if content is None:
return ""
return self._stringify_text(content)
def _stringify_text(self, value: object) -> str:
if isinstance(value, str):
return value
if value is None:
return ""
try:
return json.dumps(value, sort_keys=True, default=str)
except TypeError:
return str(value)
def _coerce_vector(self, value: object) -> list[float] | None:
if not isinstance(value, Iterable) or isinstance(value, (str, bytes, dict)):
return None
vector: list[float] = []
for item in value:
try:
vector.append(float(item))
except (TypeError, ValueError):
return None
return vector or None
def _node_similarity(self, left: dict, right: dict) -> float:
left_vector = left.get("vector")
right_vector = right.get("vector")
if left_vector and right_vector:
cosine = self._cosine_similarity(left_vector, right_vector)
if cosine is not None:
return cosine
return self._jaccard_similarity(left.get("text", ""), right.get("text", ""))
def _cosine_similarity(self, left: list[float], right: list[float]) -> float | None:
if len(left) != len(right) or not left:
return None
numerator = sum(a * b for a, b in zip(left, right))
left_norm = math.sqrt(sum(a * a for a in left))
right_norm = math.sqrt(sum(b * b for b in right))
if left_norm == 0.0 or right_norm == 0.0:
return None
return numerator / (left_norm * right_norm)
def _jaccard_similarity(self, left_text: str, right_text: str) -> float:
left_tokens = self._token_set(left_text)
right_tokens = self._token_set(right_text)
if not left_tokens or not right_tokens:
return 0.0
union = left_tokens | right_tokens
if not union:
return 0.0
return len(left_tokens & right_tokens) / len(union)
def _token_set(self, text: str) -> set[str]:
lowered = text.lower()
cleaned = "".join(character if character.isalnum() else " " for character in lowered)
return {token for token in cleaned.split() if token}
def _float_value(self, *values: object, default: float) -> float:
for value in values:
try:
return float(value)
except (TypeError, ValueError):
continue
return default
def _warn(self, message: str) -> None:
print(message, file=sys.stderr)
def _clean_mapping(self, values: dict) -> dict:
return {key: value for key, value in values.items() if value is not None and value != ""}
def _brain_endpoint(self, suffix: str) -> str:
return f"{self.brain_url}/v1/brain/{suffix.lstrip('/')}"
def _auth_headers(self) -> dict[str, str]:
return {
"Accept": "application/json",
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
def _request_status(
self,
method: str,
url: str,
*,
params: dict | None = None,
json_body: dict | None = None,
headers: dict | None = None,
timeout: float = 5.0,
) -> int:
status, _ = self._raw_request(
method,
url,
params=params,
json_body=json_body,
headers=headers,
timeout=timeout,
)
return status
def _request_json(
self,
method: str,
url: str,
*,
params: dict | None = None,
json_body: dict | None = None,
headers: dict | None = None,
timeout: float = 5.0,
) -> dict:
status, text = self._raw_request(
method,
url,
params=params,
json_body=json_body,
headers=headers,
timeout=timeout,
)
if not text:
return {"status": status}
try:
payload = json.loads(text)
except json.JSONDecodeError:
return {"status": status, "data": text}
if isinstance(payload, dict):
payload.setdefault("status", status)
return payload
return {"status": status, "data": payload}
def _raw_request(
self,
method: str,
url: str,
*,
params: dict | None = None,
json_body: dict | None = None,
headers: dict | None = None,
timeout: float = 5.0,
) -> tuple[int, str]:
if requests is not None:
response = requests.request(
method,
url,
params=params,
json=json_body,
headers=headers,
timeout=timeout,
)
return response.status_code, response.text
if httpx is not None:
response = httpx.request(
method,
url,
params=params,
json=json_body,
headers=headers,
timeout=timeout,
)
return response.status_code, response.text
return self._urllib_request(
method,
url,
params=params,
json_body=json_body,
headers=headers,
timeout=timeout,
)
def _urllib_request(
self,
method: str,
url: str,
*,
params: dict | None = None,
json_body: dict | None = None,
headers: dict | None = None,
timeout: float = 5.0,
) -> tuple[int, str]:
request_headers = dict(headers or {})
request_url = url
if params:
query_string = urlencode(params, doseq=True)
separator = "&" if "?" in request_url else "?"
request_url = f"{request_url}{separator}{query_string}"
data = None
if json_body is not None:
request_headers.setdefault("Content-Type", "application/json")
data = json.dumps(json_body).encode("utf-8")
request = Request(request_url, data=data, headers=request_headers, method=method)
try:
with urlopen(request, timeout=timeout) as response:
return response.getcode(), response.read().decode("utf-8")
except HTTPError as exc:
return exc.code, exc.read().decode("utf-8")
except URLError as exc:
raise OSError(str(exc)) from exc