#!/usr/bin/env python3 """ Database connection and session management for GNSS Guard Server """ import logging from contextlib import contextmanager from typing import Generator from sqlalchemy import create_engine, event from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.pool import QueuePool from config import get_config from models import Base logger = logging.getLogger("gnss_guard.server.database") # Global engine and session factory _engine = None _SessionLocal = None def get_engine(): """Get or create the database engine""" global _engine if _engine is None: config = get_config() # Check if using SQLite (local development) is_sqlite = config.database_url.startswith("sqlite") if is_sqlite: # SQLite-specific settings from sqlalchemy.pool import StaticPool _engine = create_engine( config.database_url, connect_args={"check_same_thread": False}, poolclass=StaticPool, echo=config.debug, ) logger.info(f"SQLite database engine created: {config.database_url}") else: # PostgreSQL with connection pooling _engine = create_engine( config.database_url, poolclass=QueuePool, pool_size=5, max_overflow=10, pool_pre_ping=True, # Verify connections before using echo=config.debug, ) logger.info(f"Database engine created for: {config.database_url.split('@')[-1]}") return _engine def get_session_factory(): """Get or create the session factory""" global _SessionLocal if _SessionLocal is None: engine = get_engine() _SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) return _SessionLocal def init_db(): """Initialize database - create all tables""" engine = get_engine() Base.metadata.create_all(bind=engine) logger.info("Database tables created/verified") def get_db() -> Generator[Session, None, None]: """ Dependency for FastAPI to get database session. Yields a session and ensures it's closed after use. """ SessionLocal = get_session_factory() db = SessionLocal() try: yield db finally: db.close() @contextmanager def get_db_session() -> Generator[Session, None, None]: """ Context manager for database sessions (for use outside FastAPI dependencies). """ SessionLocal = get_session_factory() db = SessionLocal() try: yield db db.commit() except Exception: db.rollback() raise finally: db.close()