feat(hermes): add openbrain_context.py ContextEngine plugin
Python plugin implementing Hermes ContextEngine backed by OpenBrain. compress() does centrality-ranked retrieval over a candidate pool pulled via brain_recall rather than linear turn truncation. Falls back to naive head+tail truncation when recall is unavailable so the caller never sees a raised exception. Closes tasks.lthn.sh/view.php?id=74 Co-authored-by: Codex <noreply@openai.com> Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
5a851c2f4a
commit
711e2eef72
2 changed files with 806 additions and 0 deletions
661
hermes/plugins/openbrain_context.py
Normal file
661
hermes/plugins/openbrain_context.py
Normal file
|
|
@ -0,0 +1,661 @@
|
|||
# SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import math
|
||||
import shlex
|
||||
import socket
|
||||
import sys
|
||||
from collections.abc import Iterable
|
||||
from urllib.error import HTTPError, URLError
|
||||
from urllib.parse import urlencode, urlparse
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
try:
|
||||
import requests # type: ignore
|
||||
except ImportError: # pragma: no cover - exercised through fallbacks
|
||||
requests = None
|
||||
|
||||
try:
|
||||
import httpx # type: ignore
|
||||
except ImportError: # pragma: no cover - exercised through fallbacks
|
||||
httpx = None
|
||||
|
||||
try:
|
||||
import networkx as nx # type: ignore
|
||||
except ImportError: # pragma: no cover - exercised through manual fallback
|
||||
nx = None
|
||||
|
||||
|
||||
class OpenBrainContextEngine:
|
||||
def __init__(
|
||||
self,
|
||||
brain_url: str,
|
||||
api_key: str,
|
||||
qdrant_url: str,
|
||||
pg_dsn: str,
|
||||
workspace_id: int,
|
||||
org: str | None = None,
|
||||
) -> None:
|
||||
self.brain_url = brain_url.rstrip("/")
|
||||
self.api_key = api_key
|
||||
self.qdrant_url = qdrant_url.rstrip("/")
|
||||
self.pg_dsn = pg_dsn
|
||||
self.workspace_id = workspace_id
|
||||
self.org = org
|
||||
self._initialised = False
|
||||
self._spawn = None
|
||||
self._similarity_threshold = 0.6
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self._qdrant_reachable() and self._postgres_reachable()
|
||||
|
||||
def initialize(self) -> None:
|
||||
if self._initialised:
|
||||
return
|
||||
|
||||
self._spawn = self._load_core_spawn()
|
||||
self._initialised = True
|
||||
|
||||
def compress(
|
||||
self,
|
||||
turns: list[dict],
|
||||
*,
|
||||
token_budget: int,
|
||||
query_hint: str | None = None,
|
||||
top_k: int = 20,
|
||||
candidate_pool: int = 200,
|
||||
) -> list[dict]:
|
||||
ordered_turns = list(turns)
|
||||
if len(ordered_turns) <= 2:
|
||||
return ordered_turns
|
||||
|
||||
self.initialize()
|
||||
|
||||
budget = max(int(token_budget), 0)
|
||||
total_tokens = self._estimate_total_tokens(ordered_turns)
|
||||
if total_tokens <= budget:
|
||||
return ordered_turns
|
||||
|
||||
query = self._derive_query(ordered_turns, query_hint)
|
||||
if not query:
|
||||
return self._naive_head_tail(ordered_turns, budget)
|
||||
|
||||
try:
|
||||
candidates = self._recall_candidates(query, max(int(candidate_pool), 1))
|
||||
except Exception as exc:
|
||||
self._warn(f"OpenBrain recall failed in compress(); falling back to head+tail truncation: {exc}")
|
||||
return self._naive_head_tail(ordered_turns, budget)
|
||||
|
||||
if not candidates:
|
||||
self._warn("OpenBrain recall returned no candidates in compress(); falling back to head+tail truncation.")
|
||||
return self._naive_head_tail(ordered_turns, budget)
|
||||
|
||||
anchor_indices = {0, len(ordered_turns) - 1}
|
||||
nodes = self._build_turn_nodes(ordered_turns) + self._build_candidate_nodes(candidates)
|
||||
centrality, affinity = self._graph_scores(nodes)
|
||||
ranked_turns = self._rank_turn_indices(ordered_turns, centrality, affinity, anchor_indices)
|
||||
|
||||
rank_selected = self._select_ranked_turns(ranked_turns, anchor_indices, max(int(top_k), 0))
|
||||
rank_selected = self._trim_selection_to_budget(rank_selected, ranked_turns, ordered_turns, budget, anchor_indices)
|
||||
budget_selected = self._select_budget_turns(ranked_turns, ordered_turns, budget, anchor_indices)
|
||||
|
||||
if len(budget_selected) > len(rank_selected):
|
||||
selected = budget_selected
|
||||
else:
|
||||
selected = rank_selected
|
||||
|
||||
return [turn for index, turn in enumerate(ordered_turns) if index in selected]
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
return (
|
||||
"Librarian/Cartographer compresses chat history with centrality-ranked OpenBrain recall, "
|
||||
"keeping the system anchor and current user turn while dropping low-centrality turns "
|
||||
"when the token budget is tight."
|
||||
)
|
||||
|
||||
def _qdrant_reachable(self) -> bool:
|
||||
try:
|
||||
status = self._request_status("GET", self._qdrant_probe_url(), timeout=1.5)
|
||||
except Exception:
|
||||
return False
|
||||
return 200 <= status < 500
|
||||
|
||||
def _postgres_reachable(self) -> bool:
|
||||
host, port = self._postgres_target()
|
||||
if not host:
|
||||
return False
|
||||
|
||||
timeout = 1.5
|
||||
|
||||
try:
|
||||
if host.startswith("/"):
|
||||
socket_path = f"{host.rstrip('/')}/.s.PGSQL.{port}"
|
||||
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as client:
|
||||
client.settimeout(timeout)
|
||||
client.connect(socket_path)
|
||||
else:
|
||||
with socket.create_connection((host, port), timeout=timeout):
|
||||
pass
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _qdrant_probe_url(self) -> str:
|
||||
if not self.qdrant_url:
|
||||
return ""
|
||||
|
||||
parsed = urlparse(self.qdrant_url)
|
||||
path = parsed.path or ""
|
||||
|
||||
if not path or path == "/":
|
||||
return f"{self.qdrant_url}/collections"
|
||||
|
||||
return self.qdrant_url
|
||||
|
||||
def _postgres_target(self) -> tuple[str | None, int]:
|
||||
parsed = urlparse(self.pg_dsn)
|
||||
if parsed.scheme in {"postgres", "postgresql"}:
|
||||
host = parsed.hostname
|
||||
port = parsed.port or 5432
|
||||
return host, port
|
||||
|
||||
parts: dict[str, str] = {}
|
||||
for token in shlex.split(self.pg_dsn):
|
||||
if "=" not in token:
|
||||
continue
|
||||
key, value = token.split("=", 1)
|
||||
parts[key.strip()] = value.strip()
|
||||
|
||||
host = parts.get("host") or parts.get("hostaddr") or "localhost"
|
||||
port_value = parts.get("port", "5432")
|
||||
|
||||
try:
|
||||
port = int(port_value)
|
||||
except ValueError:
|
||||
port = 5432
|
||||
|
||||
if "," in host:
|
||||
host = host.split(",", 1)[0]
|
||||
|
||||
return host, port
|
||||
|
||||
def _load_core_spawn(self):
|
||||
try:
|
||||
task_module = importlib.import_module("core.task")
|
||||
except ImportError:
|
||||
return None
|
||||
return getattr(task_module, "spawn", None)
|
||||
|
||||
def _recall_candidates(self, query: str, candidate_pool: int) -> list[dict]:
|
||||
payload = {
|
||||
"query": query,
|
||||
"top_k": candidate_pool,
|
||||
"filter": self._clean_mapping(
|
||||
{
|
||||
"workspace_id": self.workspace_id,
|
||||
"org": self.org,
|
||||
}
|
||||
),
|
||||
}
|
||||
response = self._request_json(
|
||||
"POST",
|
||||
self._brain_endpoint("recall"),
|
||||
json_body=payload,
|
||||
headers=self._auth_headers(),
|
||||
)
|
||||
status = int(response.get("status", 200))
|
||||
if status >= 400:
|
||||
raise OSError(f"recall returned status {status}")
|
||||
return self._extract_candidates(response)
|
||||
|
||||
def _extract_candidates(self, response: dict) -> list[dict]:
|
||||
sources: list[object] = [response]
|
||||
|
||||
data = response.get("data")
|
||||
if isinstance(data, dict):
|
||||
sources.append(data)
|
||||
|
||||
items: list[dict] = []
|
||||
for source in sources:
|
||||
if not isinstance(source, dict):
|
||||
continue
|
||||
|
||||
for key in ("memories", "results", "items", "matches", "data"):
|
||||
value = source.get(key)
|
||||
if isinstance(value, list):
|
||||
items.extend(item for item in value if isinstance(item, dict))
|
||||
|
||||
return [self._normalise_candidate(item, index) for index, item in enumerate(items)]
|
||||
|
||||
def _normalise_candidate(self, item: dict, index: int) -> dict:
|
||||
payload = item.get("payload")
|
||||
if not isinstance(payload, dict):
|
||||
payload = {}
|
||||
|
||||
candidate_id = item.get("id") or item.get("memory_id") or item.get("uuid") or f"memory:{index}"
|
||||
text = (
|
||||
payload.get("content")
|
||||
or item.get("content")
|
||||
or payload.get("text")
|
||||
or item.get("text")
|
||||
or ""
|
||||
)
|
||||
|
||||
return {
|
||||
"id": str(candidate_id),
|
||||
"text": self._stringify_text(text),
|
||||
"payload": payload,
|
||||
"score": self._float_value(item.get("score"), item.get("similarity"), default=0.0),
|
||||
"vector": self._coerce_vector(
|
||||
item.get("vector")
|
||||
or item.get("embedding")
|
||||
or payload.get("vector")
|
||||
or payload.get("embedding")
|
||||
),
|
||||
}
|
||||
|
||||
def _build_turn_nodes(self, turns: list[dict]) -> list[dict]:
|
||||
nodes: list[dict] = []
|
||||
for index, turn in enumerate(turns):
|
||||
node_id = turn.get("id") or turn.get("uuid") or f"turn:{index}"
|
||||
nodes.append(
|
||||
{
|
||||
"node_id": str(node_id),
|
||||
"turn_index": index,
|
||||
"kind": "turn",
|
||||
"text": self._turn_text(turn),
|
||||
"vector": self._coerce_vector(turn.get("vector") or turn.get("embedding")),
|
||||
"score": 0.0,
|
||||
}
|
||||
)
|
||||
return nodes
|
||||
|
||||
def _build_candidate_nodes(self, candidates: list[dict]) -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"node_id": candidate["id"],
|
||||
"turn_index": None,
|
||||
"kind": "memory",
|
||||
"text": candidate["text"],
|
||||
"vector": candidate["vector"],
|
||||
"score": candidate["score"],
|
||||
}
|
||||
for candidate in candidates
|
||||
]
|
||||
|
||||
def _graph_scores(self, nodes: list[dict]) -> tuple[dict[str, float], dict[str, float]]:
|
||||
adjacency = {node["node_id"]: set() for node in nodes}
|
||||
affinity = {node["node_id"]: 0.0 for node in nodes}
|
||||
|
||||
for left_index in range(len(nodes)):
|
||||
left = nodes[left_index]
|
||||
for right_index in range(left_index + 1, len(nodes)):
|
||||
right = nodes[right_index]
|
||||
similarity = self._node_similarity(left, right)
|
||||
if similarity < self._similarity_threshold:
|
||||
continue
|
||||
|
||||
adjacency[left["node_id"]].add(right["node_id"])
|
||||
adjacency[right["node_id"]].add(left["node_id"])
|
||||
|
||||
if left["kind"] != right["kind"]:
|
||||
affinity[left["node_id"]] = max(affinity[left["node_id"]], similarity)
|
||||
affinity[right["node_id"]] = max(affinity[right["node_id"]], similarity)
|
||||
|
||||
centrality = self._degree_centrality(nodes, adjacency)
|
||||
return centrality, affinity
|
||||
|
||||
def _degree_centrality(self, nodes: list[dict], adjacency: dict[str, set[str]]) -> dict[str, float]:
|
||||
if nx is not None:
|
||||
graph = nx.Graph()
|
||||
for node in nodes:
|
||||
graph.add_node(node["node_id"])
|
||||
for node_id, edges in adjacency.items():
|
||||
for edge in edges:
|
||||
graph.add_edge(node_id, edge)
|
||||
return dict(nx.degree_centrality(graph))
|
||||
|
||||
normaliser = max(len(nodes) - 1, 1)
|
||||
return {node_id: len(edges) / normaliser for node_id, edges in adjacency.items()}
|
||||
|
||||
def _rank_turn_indices(
|
||||
self,
|
||||
turns: list[dict],
|
||||
centrality: dict[str, float],
|
||||
affinity: dict[str, float],
|
||||
anchor_indices: set[int],
|
||||
) -> list[int]:
|
||||
ranked: list[tuple[float, float, float, int]] = []
|
||||
|
||||
for index, turn in enumerate(turns):
|
||||
if index in anchor_indices:
|
||||
continue
|
||||
|
||||
node_id = str(turn.get("id") or turn.get("uuid") or f"turn:{index}")
|
||||
ranked.append(
|
||||
(
|
||||
-centrality.get(node_id, 0.0),
|
||||
-affinity.get(node_id, 0.0),
|
||||
self._estimate_tokens(turn),
|
||||
index,
|
||||
)
|
||||
)
|
||||
|
||||
ranked.sort()
|
||||
return [index for _, _, _, index in ranked]
|
||||
|
||||
def _select_ranked_turns(self, ranked_turns: list[int], anchor_indices: set[int], top_k: int) -> set[int]:
|
||||
selected = set(anchor_indices)
|
||||
for index in ranked_turns[:top_k]:
|
||||
selected.add(index)
|
||||
return selected
|
||||
|
||||
def _select_budget_turns(
|
||||
self,
|
||||
ranked_turns: list[int],
|
||||
turns: list[dict],
|
||||
token_budget: int,
|
||||
anchor_indices: set[int],
|
||||
) -> set[int]:
|
||||
selected = set(anchor_indices)
|
||||
used_tokens = self._selection_tokens(selected, turns)
|
||||
|
||||
for index in ranked_turns:
|
||||
turn_tokens = self._estimate_tokens(turns[index])
|
||||
if used_tokens + turn_tokens > token_budget and selected:
|
||||
continue
|
||||
selected.add(index)
|
||||
used_tokens += turn_tokens
|
||||
|
||||
return selected
|
||||
|
||||
def _trim_selection_to_budget(
|
||||
self,
|
||||
selected: set[int],
|
||||
ranked_turns: list[int],
|
||||
turns: list[dict],
|
||||
token_budget: int,
|
||||
anchor_indices: set[int],
|
||||
) -> set[int]:
|
||||
trimmed = set(selected)
|
||||
if self._selection_tokens(trimmed, turns) <= token_budget:
|
||||
return trimmed
|
||||
|
||||
removable = [index for index in reversed(ranked_turns) if index in trimmed and index not in anchor_indices]
|
||||
for index in removable:
|
||||
trimmed.remove(index)
|
||||
if self._selection_tokens(trimmed, turns) <= token_budget:
|
||||
break
|
||||
|
||||
return trimmed
|
||||
|
||||
def _selection_tokens(self, selected: set[int], turns: list[dict]) -> int:
|
||||
return sum(self._estimate_tokens(turns[index]) for index in selected)
|
||||
|
||||
def _naive_head_tail(self, turns: list[dict], token_budget: int) -> list[dict]:
|
||||
if not turns:
|
||||
return []
|
||||
|
||||
if len(turns) <= 2:
|
||||
return list(turns)
|
||||
|
||||
selected = {0, len(turns) - 1}
|
||||
used_tokens = self._selection_tokens(selected, turns)
|
||||
|
||||
for index in range(len(turns) - 2, 0, -1):
|
||||
turn_tokens = self._estimate_tokens(turns[index])
|
||||
if used_tokens + turn_tokens > token_budget and selected:
|
||||
break
|
||||
selected.add(index)
|
||||
used_tokens += turn_tokens
|
||||
|
||||
return [turn for index, turn in enumerate(turns) if index in selected]
|
||||
|
||||
def _derive_query(self, turns: list[dict], query_hint: str | None) -> str:
|
||||
if query_hint and query_hint.strip():
|
||||
return query_hint.strip()
|
||||
|
||||
last_text = self._turn_text(turns[-1])
|
||||
if last_text:
|
||||
return last_text
|
||||
|
||||
if len(turns) >= 2:
|
||||
previous = self._turn_text(turns[-2])
|
||||
parts = [part for part in (previous, last_text) if part]
|
||||
return "\n".join(parts)
|
||||
|
||||
return ""
|
||||
|
||||
def _estimate_total_tokens(self, turns: list[dict]) -> int:
|
||||
return sum(self._estimate_tokens(turn) for turn in turns)
|
||||
|
||||
def _estimate_tokens(self, turn: dict) -> int:
|
||||
return len(self._turn_text(turn)) // 4
|
||||
|
||||
def _turn_text(self, turn: dict) -> str:
|
||||
content = turn.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = [self._stringify_text(item) for item in content]
|
||||
return " ".join(part for part in parts if part)
|
||||
if content is None:
|
||||
return ""
|
||||
return self._stringify_text(content)
|
||||
|
||||
def _stringify_text(self, value: object) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if value is None:
|
||||
return ""
|
||||
try:
|
||||
return json.dumps(value, sort_keys=True, default=str)
|
||||
except TypeError:
|
||||
return str(value)
|
||||
|
||||
def _coerce_vector(self, value: object) -> list[float] | None:
|
||||
if not isinstance(value, Iterable) or isinstance(value, (str, bytes, dict)):
|
||||
return None
|
||||
|
||||
vector: list[float] = []
|
||||
for item in value:
|
||||
try:
|
||||
vector.append(float(item))
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
return vector or None
|
||||
|
||||
def _node_similarity(self, left: dict, right: dict) -> float:
|
||||
left_vector = left.get("vector")
|
||||
right_vector = right.get("vector")
|
||||
if left_vector and right_vector:
|
||||
cosine = self._cosine_similarity(left_vector, right_vector)
|
||||
if cosine is not None:
|
||||
return cosine
|
||||
|
||||
return self._jaccard_similarity(left.get("text", ""), right.get("text", ""))
|
||||
|
||||
def _cosine_similarity(self, left: list[float], right: list[float]) -> float | None:
|
||||
if len(left) != len(right) or not left:
|
||||
return None
|
||||
|
||||
numerator = sum(a * b for a, b in zip(left, right))
|
||||
left_norm = math.sqrt(sum(a * a for a in left))
|
||||
right_norm = math.sqrt(sum(b * b for b in right))
|
||||
if left_norm == 0.0 or right_norm == 0.0:
|
||||
return None
|
||||
|
||||
return numerator / (left_norm * right_norm)
|
||||
|
||||
def _jaccard_similarity(self, left_text: str, right_text: str) -> float:
|
||||
left_tokens = self._token_set(left_text)
|
||||
right_tokens = self._token_set(right_text)
|
||||
if not left_tokens or not right_tokens:
|
||||
return 0.0
|
||||
|
||||
union = left_tokens | right_tokens
|
||||
if not union:
|
||||
return 0.0
|
||||
|
||||
return len(left_tokens & right_tokens) / len(union)
|
||||
|
||||
def _token_set(self, text: str) -> set[str]:
|
||||
lowered = text.lower()
|
||||
cleaned = "".join(character if character.isalnum() else " " for character in lowered)
|
||||
return {token for token in cleaned.split() if token}
|
||||
|
||||
def _float_value(self, *values: object, default: float) -> float:
|
||||
for value in values:
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
return default
|
||||
|
||||
def _warn(self, message: str) -> None:
|
||||
print(message, file=sys.stderr)
|
||||
|
||||
def _clean_mapping(self, values: dict) -> dict:
|
||||
return {key: value for key, value in values.items() if value is not None and value != ""}
|
||||
|
||||
def _brain_endpoint(self, suffix: str) -> str:
|
||||
return f"{self.brain_url}/v1/brain/{suffix.lstrip('/')}"
|
||||
|
||||
def _auth_headers(self) -> dict[str, str]:
|
||||
return {
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def _request_status(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
*,
|
||||
params: dict | None = None,
|
||||
json_body: dict | None = None,
|
||||
headers: dict | None = None,
|
||||
timeout: float = 5.0,
|
||||
) -> int:
|
||||
status, _ = self._raw_request(
|
||||
method,
|
||||
url,
|
||||
params=params,
|
||||
json_body=json_body,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
return status
|
||||
|
||||
def _request_json(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
*,
|
||||
params: dict | None = None,
|
||||
json_body: dict | None = None,
|
||||
headers: dict | None = None,
|
||||
timeout: float = 5.0,
|
||||
) -> dict:
|
||||
status, text = self._raw_request(
|
||||
method,
|
||||
url,
|
||||
params=params,
|
||||
json_body=json_body,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if not text:
|
||||
return {"status": status}
|
||||
|
||||
try:
|
||||
payload = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return {"status": status, "data": text}
|
||||
|
||||
if isinstance(payload, dict):
|
||||
payload.setdefault("status", status)
|
||||
return payload
|
||||
|
||||
return {"status": status, "data": payload}
|
||||
|
||||
def _raw_request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
*,
|
||||
params: dict | None = None,
|
||||
json_body: dict | None = None,
|
||||
headers: dict | None = None,
|
||||
timeout: float = 5.0,
|
||||
) -> tuple[int, str]:
|
||||
if requests is not None:
|
||||
response = requests.request(
|
||||
method,
|
||||
url,
|
||||
params=params,
|
||||
json=json_body,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
return response.status_code, response.text
|
||||
|
||||
if httpx is not None:
|
||||
response = httpx.request(
|
||||
method,
|
||||
url,
|
||||
params=params,
|
||||
json=json_body,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
return response.status_code, response.text
|
||||
|
||||
return self._urllib_request(
|
||||
method,
|
||||
url,
|
||||
params=params,
|
||||
json_body=json_body,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def _urllib_request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
*,
|
||||
params: dict | None = None,
|
||||
json_body: dict | None = None,
|
||||
headers: dict | None = None,
|
||||
timeout: float = 5.0,
|
||||
) -> tuple[int, str]:
|
||||
request_headers = dict(headers or {})
|
||||
request_url = url
|
||||
|
||||
if params:
|
||||
query_string = urlencode(params, doseq=True)
|
||||
separator = "&" if "?" in request_url else "?"
|
||||
request_url = f"{request_url}{separator}{query_string}"
|
||||
|
||||
data = None
|
||||
if json_body is not None:
|
||||
request_headers.setdefault("Content-Type", "application/json")
|
||||
data = json.dumps(json_body).encode("utf-8")
|
||||
|
||||
request = Request(request_url, data=data, headers=request_headers, method=method)
|
||||
|
||||
try:
|
||||
with urlopen(request, timeout=timeout) as response:
|
||||
return response.getcode(), response.read().decode("utf-8")
|
||||
except HTTPError as exc:
|
||||
return exc.code, exc.read().decode("utf-8")
|
||||
except URLError as exc:
|
||||
raise OSError(str(exc)) from exc
|
||||
145
tests/test_openbrain_context.py
Normal file
145
tests/test_openbrain_context.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
# SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
from unittest.mock import patch
|
||||
|
||||
from hermes.plugins.openbrain_context import OpenBrainContextEngine
|
||||
|
||||
|
||||
def make_engine() -> OpenBrainContextEngine:
|
||||
return OpenBrainContextEngine(
|
||||
brain_url="https://brain.example",
|
||||
api_key="test-key",
|
||||
qdrant_url="https://qdrant.example",
|
||||
pg_dsn="postgresql://brain:secret@postgres.example:5432/openbrain",
|
||||
workspace_id=74,
|
||||
org="lthn",
|
||||
)
|
||||
|
||||
|
||||
def make_turns() -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"id": "turn-0",
|
||||
"role": "system",
|
||||
"content": "System context for Hermes and safety rules across the workspace.",
|
||||
},
|
||||
{
|
||||
"id": "turn-1",
|
||||
"role": "assistant",
|
||||
"content": "Old chatter about office snacks, coffee orders, and travel timings.",
|
||||
},
|
||||
{
|
||||
"id": "turn-2",
|
||||
"role": "assistant",
|
||||
"content": "OpenBrain qdrant recall centrality retrieval graph ranking memory compression context.",
|
||||
},
|
||||
{
|
||||
"id": "turn-3",
|
||||
"role": "user",
|
||||
"content": "Another diversion about keyboard colours, umbrellas, and station weather.",
|
||||
},
|
||||
{
|
||||
"id": "turn-4",
|
||||
"role": "assistant",
|
||||
"content": "Context compression should keep qdrant recall centrality retrieval graph ranking context.",
|
||||
},
|
||||
{
|
||||
"id": "turn-5",
|
||||
"role": "user",
|
||||
"content": "Please implement Hermes context compression with qdrant recall centrality retrieval.",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def recall_payload() -> dict:
|
||||
return {
|
||||
"status": 200,
|
||||
"data": {
|
||||
"memories": [
|
||||
{
|
||||
"id": "mem-1",
|
||||
"content": "qdrant recall centrality retrieval graph ranking context compression",
|
||||
"score": 0.97,
|
||||
},
|
||||
{
|
||||
"id": "mem-2",
|
||||
"content": "openbrain qdrant recall centrality retrieval graph ranking memory compression context",
|
||||
"score": 0.95,
|
||||
},
|
||||
{
|
||||
"id": "mem-3",
|
||||
"content": "garden picnic sandwiches clouds trains umbrellas",
|
||||
"score": 0.10,
|
||||
},
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_is_available_happy() -> None:
|
||||
engine = make_engine()
|
||||
|
||||
with patch.object(engine, "_request_status", return_value=200), patch(
|
||||
"hermes.plugins.openbrain_context.socket.create_connection",
|
||||
return_value=nullcontext(object()),
|
||||
):
|
||||
assert engine.is_available() is True
|
||||
|
||||
|
||||
def test_is_available_qdrant_down() -> None:
|
||||
engine = make_engine()
|
||||
|
||||
with patch.object(engine, "_request_status", side_effect=OSError("down")), patch(
|
||||
"hermes.plugins.openbrain_context.socket.create_connection",
|
||||
return_value=nullcontext(object()),
|
||||
):
|
||||
assert engine.is_available() is False
|
||||
|
||||
|
||||
def test_compress_with_short_input_returns_turns_unchanged() -> None:
|
||||
engine = make_engine()
|
||||
turns = [
|
||||
{"role": "system", "content": "System prompt"},
|
||||
{"role": "user", "content": "Current request"},
|
||||
]
|
||||
|
||||
compressed = engine.compress(turns, token_budget=1)
|
||||
|
||||
assert compressed == turns
|
||||
|
||||
|
||||
def test_compress_with_recall_candidates_keeps_central_turns() -> None:
|
||||
engine = make_engine()
|
||||
turns = make_turns()
|
||||
|
||||
with patch.object(engine, "_request_json", return_value=recall_payload()):
|
||||
compressed = engine.compress(turns, token_budget=80, top_k=2, candidate_pool=10)
|
||||
|
||||
assert [turn["id"] for turn in compressed] == ["turn-0", "turn-2", "turn-4", "turn-5"]
|
||||
|
||||
|
||||
def test_compress_preserves_first_and_last_turn_always() -> None:
|
||||
engine = make_engine()
|
||||
turns = make_turns()
|
||||
|
||||
with patch.object(engine, "_request_json", return_value=recall_payload()):
|
||||
compressed = engine.compress(turns, token_budget=36, top_k=1, candidate_pool=10)
|
||||
|
||||
assert compressed[0] == turns[0]
|
||||
assert compressed[-1] == turns[-1]
|
||||
assert turns[0] in compressed
|
||||
assert turns[-1] in compressed
|
||||
|
||||
|
||||
def test_compress_falls_back_gracefully_when_recall_fails(capsys) -> None:
|
||||
engine = make_engine()
|
||||
turns = make_turns()
|
||||
|
||||
with patch.object(engine, "_request_json", side_effect=OSError("down")):
|
||||
compressed = engine.compress(turns, token_budget=59, candidate_pool=10)
|
||||
|
||||
assert [turn["id"] for turn in compressed] == ["turn-0", "turn-4", "turn-5"]
|
||||
assert "falling back to head+tail truncation" in capsys.readouterr().err
|
||||
Loading…
Add table
Reference in a new issue