"""Shared HTTP, caching, and CSV helpers for the analysis toolkit.""" from __future__ import annotations import csv import hashlib import json import os import sys import time from pathlib import Path from typing import Iterable, Mapping, Sequence import requests from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry USER_AGENT = "fablepool-ha-ecosystem-tools/0.1 (+https://github.com/fablepool)" DEFAULT_CACHE_DIR = Path(__file__).resolve().parent.parent / ".cache" DEFAULT_TTL_SECONDS = 24 * 3600 # one day; analytics/issue data drifts slowly # --------------------------------------------------------------------------- # HTTP session # --------------------------------------------------------------------------- def make_session() -> requests.Session: """Return a requests session with retries on transient server errors.""" session = requests.Session() retry = Retry( total=4, backoff_factor=1.5, status_forcelist=(500, 502, 503, 504), allowed_methods=("GET",), raise_on_status=False, ) adapter = HTTPAdapter(max_retries=retry) session.mount("https://", adapter) session.mount("http://", adapter) session.headers["User-Agent"] = USER_AGENT return session def github_token() -> str | None: """Read a GitHub token from GITHUB_TOKEN or GH_TOKEN (optional).""" return os.environ.get("GITHUB_TOKEN") or os.environ.get("GH_TOKEN") or None def github_headers(token: str | None = None) -> dict[str, str]: headers = { "Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28", } tok = token if token is not None else github_token() if tok: headers["Authorization"] = f"Bearer {tok}" return headers def request_with_backoff( session: requests.Session, url: str, *, params: Mapping[str, str] | None = None, headers: Mapping[str, str] | None = None, timeout: float = 30.0, max_tries: int = 6, ) -> requests.Response: """GET with explicit handling for GitHub primary/secondary rate limits. Transient 5xx retries are already handled by the adapter; here we handle 403/429 responses that carry ``Retry-After`` or ``X-RateLimit-Reset``. """ last_resp: requests.Response | None = None for attempt in range(1, max_tries + 1): resp = session.get(url, params=params, headers=headers, timeout=timeout) last_resp = resp if resp.status_code not in (403, 429): return resp retry_after = resp.headers.get("Retry-After") remaining = resp.headers.get("X-RateLimit-Remaining") reset = resp.headers.get("X-RateLimit-Reset") wait: float | None = None if retry_after is not None: try: wait = float(retry_after) except ValueError: wait = None if wait is None and remaining == "0" and reset is not None: try: wait = max(0.0, float(int(reset)) - time.time()) + 2.0 except ValueError: wait = None if wait is None: # 403 that is not a rate limit (e.g. abuse detection without # headers, or genuine permission error): brief escalating pause. wait = 10.0 * attempt if wait > 900: raise RuntimeError( f"GitHub rate limit on {url} requires waiting {wait:.0f}s; " "set GITHUB_TOKEN for a 5000 req/h budget and re-run " "(the on-disk cache preserves completed work)." ) print( f" rate-limited ({resp.status_code}); sleeping {wait:.0f}s " f"(attempt {attempt}/{max_tries})", file=sys.stderr, ) time.sleep(wait) assert last_resp is not None return last_resp # --------------------------------------------------------------------------- # File cache # --------------------------------------------------------------------------- class FileCache: """Tiny content cache keyed by URL, storing status + body as JSON.""" def __init__(self, directory: Path | str = DEFAULT_CACHE_DIR, ttl_seconds: float = DEFAULT_TTL_SECONDS) -> None: self.directory = Path(directory) self.ttl_seconds = ttl_seconds self.directory.mkdir(parents=True, exist_ok=True) def _path(self, url: str) -> Path: digest = hashlib.sha256(url.encode("utf-8")).hexdigest() return self.directory / f"{digest}.json" def get(self, url: str) -> dict | None: path = self._path(url) if not path.exists(): return None if self.ttl_seconds > 0: age = time.time() - path.stat().st_mtime if age > self.ttl_seconds: return None try: return json.loads(path.read_text(encoding="utf-8")) except (json.JSONDecodeError, OSError): return None def put(self, url: str, status: int, text: str) -> None: envelope = {"url": url, "status": status, "text": text, "fetched_at": time.time()} self._path(url).write_text( json.dumps(envelope), encoding="utf-8" ) def cached_get( session: requests.Session, cache: FileCache, url: str, *, headers: Mapping[str, str] | None = None, params: Mapping[str, str] | None = None, cacheable_statuses: Sequence[int] = (200, 404), ) -> tuple[int, str]: """GET through the cache. Returns (status_code, body_text). Only parameter-free URLs are cached (params bypass the cache) so that search queries with paging never collide. """ if params is None: hit = cache.get(url) if hit is not None: return int(hit["status"]), str(hit["text"]) resp = request_with_backoff(session, url, params=params, headers=headers) if params is None and resp.status_code in cacheable_statuses: cache.put(url, resp.status_code, resp.text) return resp.status_code, resp.text # --------------------------------------------------------------------------- # CSV helpers # --------------------------------------------------------------------------- def write_csv(path: Path | str, fieldnames: Sequence[str], rows: Iterable[Mapping[str, object]]) -> int: """Write rows to CSV, creating parent dirs. Returns row count.""" path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) count = 0 with path.open("w", newline="", encoding="utf-8") as handle: writer = csv.DictWriter(handle, fieldnames=list(fieldnames), extrasaction="ignore") writer.writeheader() for row in rows: writer.writerow(row) count += 1 return count def read_csv(path: Path | str) -> list[dict[str, str]]: with Path(path).open(newline="", encoding="utf-8") as handle: return list(csv.DictReader(handle)) def read_domains_csv(path: Path | str, limit: int = 0) -> list[str]: """Read a list of integration domains from a CSV with a 'domain' column.""" rows = read_csv(path) if not rows: return [] column = None for candidate in ("domain", "integration", "component"): if candidate in rows[0]: column = candidate break if column is None: raise ValueError( f"{path}: expected a 'domain' (or 'integration') column; " f"found {list(rows[0].keys())}" ) seen: set[str] = set() domains: list[str] = [] for row in rows: domain = (row.get(column) or "").strip() if domain and domain not in seen: seen.add(domain) domains.append(domain) if limit and len(domains) >= limit: break return domains