"""The derivation graph: evidence and claims linked by provenance edges. The graph is a DAG by construction: a claim's inputs must already exist when it is added, and claim ids are content-addressed, so cycles cannot form. Responsibilities: * index evidence and claims by id; * maintain reverse ("dependents") edges so that refuting or retracting any node can mechanically enumerate every claim downstream of it (:meth:`DerivationGraph.descendants`); * track exactly one ACTIVE claim per *claim key* (see :func:`mnema.derive.model.claim_key`) so re-derivations supersede rather than duplicate. The graph holds **state**, not policy: status transitions are performed by the engine, which is also responsible for emitting the corresponding signed operations to the log. Rebuilding a graph from a log replay therefore produces an identical structure. """ from __future__ import annotations from collections import defaultdict, deque from typing import Dict, Iterable, List, Optional, Set from mnema.derive.model import Claim, ClaimStatus, EvidenceRecord class UnknownNodeError(KeyError): """Raised when an evidence/claim id is not present in the graph.""" class DerivationGraph: def __init__(self) -> None: self._evidence: Dict[str, EvidenceRecord] = {} self._retracted_evidence: Dict[str, Optional[str]] = {} # id -> reason self._claims: Dict[str, Claim] = {} self._dependents: Dict[str, Set[str]] = defaultdict(set) self._active_by_key: Dict[str, str] = {} # ------------------------------------------------------------------ # # Evidence # ------------------------------------------------------------------ # def add_evidence(self, ev: EvidenceRecord) -> bool: """Add an evidence record. Returns False if the id already exists.""" if ev.evidence_id in self._evidence: return False self._evidence[ev.evidence_id] = ev return True def get_evidence(self, evidence_id: str) -> EvidenceRecord: try: return self._evidence[evidence_id] except KeyError: raise UnknownNodeError(evidence_id) from None def has_evidence(self, evidence_id: str) -> bool: return evidence_id in self._evidence def evidence( self, kind: Optional[str] = None, include_retracted: bool = False ) -> List[EvidenceRecord]: """All (non-retracted) evidence, deterministically ordered.""" out = [ ev for ev in self._evidence.values() if (kind is None or ev.kind == kind) and (include_retracted or ev.evidence_id not in self._retracted_evidence) ] out.sort(key=lambda e: (e.observed_at, e.evidence_id)) return out def retract_evidence(self, evidence_id: str, reason: Optional[str] = None) -> bool: """Mark evidence retracted (it stays in the log/graph for audit). Returns True if this call changed state. """ if evidence_id not in self._evidence: raise UnknownNodeError(evidence_id) if evidence_id in self._retracted_evidence: return False self._retracted_evidence[evidence_id] = reason return True def is_retracted(self, evidence_id: str) -> bool: return evidence_id in self._retracted_evidence # ------------------------------------------------------------------ # # Claims # ------------------------------------------------------------------ # def add_claim(self, claim: Claim) -> None: if claim.claim_id in self._claims: raise ValueError(f"duplicate claim id: {claim.claim_id}") self._claims[claim.claim_id] = claim for inp in claim.inputs: self._dependents[inp].add(claim.claim_id) if claim.status is ClaimStatus.ACTIVE: self._active_by_key[claim.key()] = claim.claim_id def get_claim(self, claim_id: str) -> Claim: try: return self._claims[claim_id] except KeyError: raise UnknownNodeError(claim_id) from None def has_claim(self, claim_id: str) -> bool: return claim_id in self._claims def claims( self, predicate: Optional[str] = None, status: Optional[ClaimStatus] = ClaimStatus.ACTIVE, ) -> List[Claim]: """Claims filtered by predicate and status (``status=None`` for all).""" out = [ c for c in self._claims.values() if (predicate is None or c.predicate == predicate) and (status is None or c.status is status) ] out.sort(key=lambda c: (c.derived_at, c.claim_id)) return out def active_claim_for_key(self, key: str) -> Optional[str]: return self._active_by_key.get(key) def mark_status( self, claim_id: str, status: ClaimStatus, reason: Optional[str] = None ) -> None: claim = self.get_claim(claim_id) claim.status = status claim.status_reason = reason key = claim.key() if status is ClaimStatus.ACTIVE: self._active_by_key[key] = claim_id elif self._active_by_key.get(key) == claim_id: del self._active_by_key[key] # ------------------------------------------------------------------ # # Traversal # ------------------------------------------------------------------ # def dependents_of(self, node_id: str) -> List[str]: """Claim ids that list ``node_id`` directly among their inputs.""" return sorted(self._dependents.get(node_id, ())) def descendants(self, node_id: str) -> List[str]: """All claim ids transitively derived from ``node_id``, BFS order. ``node_id`` may be an evidence id or a claim id; it is not included in the result. BFS order guarantees that when the engine cascades an invalidation, every claim is invalidated before anything derived from it -- the audit log reads causally. """ seen: Set[str] = set() order: List[str] = [] queue: deque[str] = deque(sorted(self._dependents.get(node_id, ()))) while queue: cid = queue.popleft() if cid in seen: continue seen.add(cid) order.append(cid) queue.extend(sorted(self._dependents.get(cid, ()))) return order def inputs_of(self, claim_id: str) -> List[str]: return list(self.get_claim(claim_id).inputs) # ------------------------------------------------------------------ # # Introspection # ------------------------------------------------------------------ # def stats(self) -> dict: by_status: Dict[str, int] = defaultdict(int) for c in self._claims.values(): by_status[c.status.value] += 1 return { "evidence": len(self._evidence), "evidence_retracted": len(self._retracted_evidence), "claims": len(self._claims), "claims_by_status": dict(by_status), "active_keys": len(self._active_by_key), }