from collections.abc import Generator from sqlalchemy import create_engine from sqlalchemy.engine import Engine from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker from sqlalchemy.pool import StaticPool from fan_passport.config import get_settings class Base(DeclarativeBase): """Shared declarative base for all ORM models.""" def _sqlite_connect_args(database_url: str) -> dict[str, bool]: if database_url.startswith("sqlite"): return {"check_same_thread": False} return {} def _make_engine(database_url: str) -> Engine: kwargs: dict[str, object] = { "connect_args": _sqlite_connect_args(database_url), "future": True, } if database_url in {"sqlite://", "sqlite:///:memory:", "sqlite+pysqlite:///:memory:"}: kwargs["poolclass"] = StaticPool return create_engine(database_url, **kwargs) _engine: Engine = _make_engine(get_settings().database_url) SessionLocal = sessionmaker( bind=_engine, autoflush=False, autocommit=False, expire_on_commit=False, future=True, ) def configure_database(database_url: str) -> Engine: """Rebind the global session factory. This is used by tests and by app factories that are passed explicit settings. """ global _engine _engine = _make_engine(database_url) SessionLocal.configure(bind=_engine) return _engine def get_engine() -> Engine: return _engine def init_db() -> None: """Create all tables for MVP/local deployments. Production deployments should use a migration tool pointed at Base.metadata. """ import fan_passport.models # noqa: F401 - required so metadata knows all tables Base.metadata.create_all(bind=_engine) def drop_db() -> None: import fan_passport.models # noqa: F401 - required so metadata knows all tables Base.metadata.drop_all(bind=_engine) def get_session() -> Generator[Session, None, None]: session = SessionLocal() try: yield session finally: session.close()