# SPDX-License-Identifier: EUPL-1.2 from __future__ import annotations import importlib import json import math import shlex import socket import sys from collections.abc import Iterable from urllib.error import HTTPError, URLError from urllib.parse import urlencode, urlparse from urllib.request import Request, urlopen try: import requests # type: ignore except ImportError: # pragma: no cover - exercised through fallbacks requests = None try: import httpx # type: ignore except ImportError: # pragma: no cover - exercised through fallbacks httpx = None try: import networkx as nx # type: ignore except ImportError: # pragma: no cover - exercised through manual fallback nx = None class OpenBrainContextEngine: def __init__( self, brain_url: str, api_key: str, qdrant_url: str, pg_dsn: str, workspace_id: int, org: str | None = None, ) -> None: self.brain_url = brain_url.rstrip("/") self.api_key = api_key self.qdrant_url = qdrant_url.rstrip("/") self.pg_dsn = pg_dsn self.workspace_id = workspace_id self.org = org self._initialised = False self._spawn = None self._similarity_threshold = 0.6 def is_available(self) -> bool: return self._qdrant_reachable() and self._postgres_reachable() def initialize(self) -> None: if self._initialised: return self._spawn = self._load_core_spawn() self._initialised = True def compress( self, turns: list[dict], *, token_budget: int, query_hint: str | None = None, top_k: int = 20, candidate_pool: int = 200, ) -> list[dict]: ordered_turns = list(turns) if len(ordered_turns) <= 2: return ordered_turns self.initialize() budget = max(int(token_budget), 0) total_tokens = self._estimate_total_tokens(ordered_turns) if total_tokens <= budget: return ordered_turns query = self._derive_query(ordered_turns, query_hint) if not query: return self._naive_head_tail(ordered_turns, budget) try: candidates = self._recall_candidates(query, max(int(candidate_pool), 1)) except Exception as exc: self._warn(f"OpenBrain recall failed in compress(); falling back to head+tail truncation: {exc}") return self._naive_head_tail(ordered_turns, budget) if not candidates: self._warn("OpenBrain recall returned no candidates in compress(); falling back to head+tail truncation.") return self._naive_head_tail(ordered_turns, budget) anchor_indices = {0, len(ordered_turns) - 1} nodes = self._build_turn_nodes(ordered_turns) + self._build_candidate_nodes(candidates) centrality, affinity = self._graph_scores(nodes) ranked_turns = self._rank_turn_indices(ordered_turns, centrality, affinity, anchor_indices) rank_selected = self._select_ranked_turns(ranked_turns, anchor_indices, max(int(top_k), 0)) rank_selected = self._trim_selection_to_budget(rank_selected, ranked_turns, ordered_turns, budget, anchor_indices) budget_selected = self._select_budget_turns(ranked_turns, ordered_turns, budget, anchor_indices) if len(budget_selected) > len(rank_selected): selected = budget_selected else: selected = rank_selected return [turn for index, turn in enumerate(ordered_turns) if index in selected] def system_prompt_block(self) -> str: return ( "Librarian/Cartographer compresses chat history with centrality-ranked OpenBrain recall, " "keeping the system anchor and current user turn while dropping low-centrality turns " "when the token budget is tight." ) def _qdrant_reachable(self) -> bool: try: status = self._request_status("GET", self._qdrant_probe_url(), timeout=1.5) except Exception: return False return 200 <= status < 500 def _postgres_reachable(self) -> bool: host, port = self._postgres_target() if not host: return False timeout = 1.5 try: if host.startswith("/"): socket_path = f"{host.rstrip('/')}/.s.PGSQL.{port}" with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as client: client.settimeout(timeout) client.connect(socket_path) else: with socket.create_connection((host, port), timeout=timeout): pass except OSError: return False return True def _qdrant_probe_url(self) -> str: if not self.qdrant_url: return "" parsed = urlparse(self.qdrant_url) path = parsed.path or "" if not path or path == "/": return f"{self.qdrant_url}/collections" return self.qdrant_url def _postgres_target(self) -> tuple[str | None, int]: parsed = urlparse(self.pg_dsn) if parsed.scheme in {"postgres", "postgresql"}: host = parsed.hostname port = parsed.port or 5432 return host, port parts: dict[str, str] = {} for token in shlex.split(self.pg_dsn): if "=" not in token: continue key, value = token.split("=", 1) parts[key.strip()] = value.strip() host = parts.get("host") or parts.get("hostaddr") or "localhost" port_value = parts.get("port", "5432") try: port = int(port_value) except ValueError: port = 5432 if "," in host: host = host.split(",", 1)[0] return host, port def _load_core_spawn(self): try: task_module = importlib.import_module("core.task") except ImportError: return None return getattr(task_module, "spawn", None) def _recall_candidates(self, query: str, candidate_pool: int) -> list[dict]: payload = { "query": query, "top_k": candidate_pool, "filter": self._clean_mapping( { "workspace_id": self.workspace_id, "org": self.org, } ), } response = self._request_json( "POST", self._brain_endpoint("recall"), json_body=payload, headers=self._auth_headers(), ) status = int(response.get("status", 200)) if status >= 400: raise OSError(f"recall returned status {status}") return self._extract_candidates(response) def _extract_candidates(self, response: dict) -> list[dict]: sources: list[object] = [response] data = response.get("data") if isinstance(data, dict): sources.append(data) items: list[dict] = [] for source in sources: if not isinstance(source, dict): continue for key in ("memories", "results", "items", "matches", "data"): value = source.get(key) if isinstance(value, list): items.extend(item for item in value if isinstance(item, dict)) return [self._normalise_candidate(item, index) for index, item in enumerate(items)] def _normalise_candidate(self, item: dict, index: int) -> dict: payload = item.get("payload") if not isinstance(payload, dict): payload = {} candidate_id = item.get("id") or item.get("memory_id") or item.get("uuid") or f"memory:{index}" text = ( payload.get("content") or item.get("content") or payload.get("text") or item.get("text") or "" ) return { "id": str(candidate_id), "text": self._stringify_text(text), "payload": payload, "score": self._float_value(item.get("score"), item.get("similarity"), default=0.0), "vector": self._coerce_vector( item.get("vector") or item.get("embedding") or payload.get("vector") or payload.get("embedding") ), } def _build_turn_nodes(self, turns: list[dict]) -> list[dict]: nodes: list[dict] = [] for index, turn in enumerate(turns): node_id = turn.get("id") or turn.get("uuid") or f"turn:{index}" nodes.append( { "node_id": str(node_id), "turn_index": index, "kind": "turn", "text": self._turn_text(turn), "vector": self._coerce_vector(turn.get("vector") or turn.get("embedding")), "score": 0.0, } ) return nodes def _build_candidate_nodes(self, candidates: list[dict]) -> list[dict]: return [ { "node_id": candidate["id"], "turn_index": None, "kind": "memory", "text": candidate["text"], "vector": candidate["vector"], "score": candidate["score"], } for candidate in candidates ] def _graph_scores(self, nodes: list[dict]) -> tuple[dict[str, float], dict[str, float]]: adjacency = {node["node_id"]: set() for node in nodes} affinity = {node["node_id"]: 0.0 for node in nodes} for left_index in range(len(nodes)): left = nodes[left_index] for right_index in range(left_index + 1, len(nodes)): right = nodes[right_index] similarity = self._node_similarity(left, right) if similarity < self._similarity_threshold: continue adjacency[left["node_id"]].add(right["node_id"]) adjacency[right["node_id"]].add(left["node_id"]) if left["kind"] != right["kind"]: affinity[left["node_id"]] = max(affinity[left["node_id"]], similarity) affinity[right["node_id"]] = max(affinity[right["node_id"]], similarity) centrality = self._degree_centrality(nodes, adjacency) return centrality, affinity def _degree_centrality(self, nodes: list[dict], adjacency: dict[str, set[str]]) -> dict[str, float]: if nx is not None: graph = nx.Graph() for node in nodes: graph.add_node(node["node_id"]) for node_id, edges in adjacency.items(): for edge in edges: graph.add_edge(node_id, edge) return dict(nx.degree_centrality(graph)) normaliser = max(len(nodes) - 1, 1) return {node_id: len(edges) / normaliser for node_id, edges in adjacency.items()} def _rank_turn_indices( self, turns: list[dict], centrality: dict[str, float], affinity: dict[str, float], anchor_indices: set[int], ) -> list[int]: ranked: list[tuple[float, float, float, int]] = [] for index, turn in enumerate(turns): if index in anchor_indices: continue node_id = str(turn.get("id") or turn.get("uuid") or f"turn:{index}") ranked.append( ( -centrality.get(node_id, 0.0), -affinity.get(node_id, 0.0), self._estimate_tokens(turn), index, ) ) ranked.sort() return [index for _, _, _, index in ranked] def _select_ranked_turns(self, ranked_turns: list[int], anchor_indices: set[int], top_k: int) -> set[int]: selected = set(anchor_indices) for index in ranked_turns[:top_k]: selected.add(index) return selected def _select_budget_turns( self, ranked_turns: list[int], turns: list[dict], token_budget: int, anchor_indices: set[int], ) -> set[int]: selected = set(anchor_indices) used_tokens = self._selection_tokens(selected, turns) for index in ranked_turns: turn_tokens = self._estimate_tokens(turns[index]) if used_tokens + turn_tokens > token_budget and selected: continue selected.add(index) used_tokens += turn_tokens return selected def _trim_selection_to_budget( self, selected: set[int], ranked_turns: list[int], turns: list[dict], token_budget: int, anchor_indices: set[int], ) -> set[int]: trimmed = set(selected) if self._selection_tokens(trimmed, turns) <= token_budget: return trimmed removable = [index for index in reversed(ranked_turns) if index in trimmed and index not in anchor_indices] for index in removable: trimmed.remove(index) if self._selection_tokens(trimmed, turns) <= token_budget: break return trimmed def _selection_tokens(self, selected: set[int], turns: list[dict]) -> int: return sum(self._estimate_tokens(turns[index]) for index in selected) def _naive_head_tail(self, turns: list[dict], token_budget: int) -> list[dict]: if not turns: return [] if len(turns) <= 2: return list(turns) selected = {0, len(turns) - 1} used_tokens = self._selection_tokens(selected, turns) for index in range(len(turns) - 2, 0, -1): turn_tokens = self._estimate_tokens(turns[index]) if used_tokens + turn_tokens > token_budget and selected: break selected.add(index) used_tokens += turn_tokens return [turn for index, turn in enumerate(turns) if index in selected] def _derive_query(self, turns: list[dict], query_hint: str | None) -> str: if query_hint and query_hint.strip(): return query_hint.strip() last_text = self._turn_text(turns[-1]) if last_text: return last_text if len(turns) >= 2: previous = self._turn_text(turns[-2]) parts = [part for part in (previous, last_text) if part] return "\n".join(parts) return "" def _estimate_total_tokens(self, turns: list[dict]) -> int: return sum(self._estimate_tokens(turn) for turn in turns) def _estimate_tokens(self, turn: dict) -> int: return len(self._turn_text(turn)) // 4 def _turn_text(self, turn: dict) -> str: content = turn.get("content") if isinstance(content, str): return content if isinstance(content, list): parts = [self._stringify_text(item) for item in content] return " ".join(part for part in parts if part) if content is None: return "" return self._stringify_text(content) def _stringify_text(self, value: object) -> str: if isinstance(value, str): return value if value is None: return "" try: return json.dumps(value, sort_keys=True, default=str) except TypeError: return str(value) def _coerce_vector(self, value: object) -> list[float] | None: if not isinstance(value, Iterable) or isinstance(value, (str, bytes, dict)): return None vector: list[float] = [] for item in value: try: vector.append(float(item)) except (TypeError, ValueError): return None return vector or None def _node_similarity(self, left: dict, right: dict) -> float: left_vector = left.get("vector") right_vector = right.get("vector") if left_vector and right_vector: cosine = self._cosine_similarity(left_vector, right_vector) if cosine is not None: return cosine return self._jaccard_similarity(left.get("text", ""), right.get("text", "")) def _cosine_similarity(self, left: list[float], right: list[float]) -> float | None: if len(left) != len(right) or not left: return None numerator = sum(a * b for a, b in zip(left, right)) left_norm = math.sqrt(sum(a * a for a in left)) right_norm = math.sqrt(sum(b * b for b in right)) if left_norm == 0.0 or right_norm == 0.0: return None return numerator / (left_norm * right_norm) def _jaccard_similarity(self, left_text: str, right_text: str) -> float: left_tokens = self._token_set(left_text) right_tokens = self._token_set(right_text) if not left_tokens or not right_tokens: return 0.0 union = left_tokens | right_tokens if not union: return 0.0 return len(left_tokens & right_tokens) / len(union) def _token_set(self, text: str) -> set[str]: lowered = text.lower() cleaned = "".join(character if character.isalnum() else " " for character in lowered) return {token for token in cleaned.split() if token} def _float_value(self, *values: object, default: float) -> float: for value in values: try: return float(value) except (TypeError, ValueError): continue return default def _warn(self, message: str) -> None: print(message, file=sys.stderr) def _clean_mapping(self, values: dict) -> dict: return {key: value for key, value in values.items() if value is not None and value != ""} def _brain_endpoint(self, suffix: str) -> str: return f"{self.brain_url}/v1/brain/{suffix.lstrip('/')}" def _auth_headers(self) -> dict[str, str]: return { "Accept": "application/json", "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } def _request_status( self, method: str, url: str, *, params: dict | None = None, json_body: dict | None = None, headers: dict | None = None, timeout: float = 5.0, ) -> int: status, _ = self._raw_request( method, url, params=params, json_body=json_body, headers=headers, timeout=timeout, ) return status def _request_json( self, method: str, url: str, *, params: dict | None = None, json_body: dict | None = None, headers: dict | None = None, timeout: float = 5.0, ) -> dict: status, text = self._raw_request( method, url, params=params, json_body=json_body, headers=headers, timeout=timeout, ) if not text: return {"status": status} try: payload = json.loads(text) except json.JSONDecodeError: return {"status": status, "data": text} if isinstance(payload, dict): payload.setdefault("status", status) return payload return {"status": status, "data": payload} def _raw_request( self, method: str, url: str, *, params: dict | None = None, json_body: dict | None = None, headers: dict | None = None, timeout: float = 5.0, ) -> tuple[int, str]: if requests is not None: response = requests.request( method, url, params=params, json=json_body, headers=headers, timeout=timeout, ) return response.status_code, response.text if httpx is not None: response = httpx.request( method, url, params=params, json=json_body, headers=headers, timeout=timeout, ) return response.status_code, response.text return self._urllib_request( method, url, params=params, json_body=json_body, headers=headers, timeout=timeout, ) def _urllib_request( self, method: str, url: str, *, params: dict | None = None, json_body: dict | None = None, headers: dict | None = None, timeout: float = 5.0, ) -> tuple[int, str]: request_headers = dict(headers or {}) request_url = url if params: query_string = urlencode(params, doseq=True) separator = "&" if "?" in request_url else "?" request_url = f"{request_url}{separator}{query_string}" data = None if json_body is not None: request_headers.setdefault("Content-Type", "application/json") data = json.dumps(json_body).encode("utf-8") request = Request(request_url, data=data, headers=request_headers, method=method) try: with urlopen(request, timeout=timeout) as response: return response.getcode(), response.read().decode("utf-8") except HTTPError as exc: return exc.code, exc.read().decode("utf-8") except URLError as exc: raise OSError(str(exc)) from exc