import os import time import logging import threading import atexit from contextlib import contextmanager import psycopg2 from psycopg2 import pool, IntegrityError from psycopg2 import extensions from dotenv import load_dotenv load_dotenv() logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) _connection_pool = None _pool_lock = threading.Lock() _pool_stats = { "connections_checked_out": 0, "connections_failed": 0, "pool_recreated": 0, } class SimpleRobustPool: """ Thread-safe, robust connection pool wrapper for psycopg2 with connection health checks and automatic recovery. """ def __init__(self, minconn=3, maxconn=15, **kwargs): self.minconn = minconn self.maxconn = maxconn self.kwargs = kwargs self.pool = None self._create_pool() def _create_pool(self): """Create or recreate the underlying connection pool.""" try: if self.pool: try: self.pool.closeall() except Exception: pass self.pool = psycopg2.pool.ThreadedConnectionPool( minconn=self.minconn, maxconn=self.maxconn, **self.kwargs, ) _pool_stats["pool_recreated"] += 1 logger.info( "DB pool created: min=%s max=%s host=%s db=%s", self.minconn, self.maxconn, self.kwargs.get("host"), self.kwargs.get("database"), ) except Exception as e: logger.error(f"Failed to create connection pool: {e}") raise def _healthy(self, conn) -> bool: """ Validate a connection with a real round-trip and make sure it's not in an aborted transaction. """ try: if conn.closed: return False txn_status = conn.get_transaction_status() if txn_status == extensions.TRANSACTION_STATUS_INERROR: try: conn.rollback() except Exception: return False if not conn.autocommit: conn.autocommit = True with conn.cursor() as cur: cur.execute("SELECT 1") cur.fetchone() return True except Exception: return False def getconn(self, retry_count=2): """ Get a connection, validating it before handing out. On repeated failures, recreate the pool once. """ last_err = None for attempt in range(retry_count + 1): try: conn = self.pool.getconn() if not self._healthy(conn): logger.warning("Discarding unhealthy connection") try: self.pool.putconn(conn, close=True) except Exception: pass last_err = last_err or Exception("Unhealthy connection") time.sleep(0.2) continue _pool_stats["connections_checked_out"] += 1 return conn except Exception as e: last_err = e _pool_stats["connections_failed"] += 1 logger.error( "Error getting connection (attempt %s/%s): %s", attempt + 1, retry_count + 1, e, ) time.sleep(0.3) logger.warning("Recreating DB pool due to repeated failures") self._create_pool() try: conn = self.pool.getconn() if self._healthy(conn): _pool_stats["connections_checked_out"] += 1 return conn self.pool.putconn(conn, close=True) except Exception as e: last_err = e raise last_err or Exception("Failed to get a healthy DB connection") def putconn(self, conn, close=False): """Return a connection to the pool, closing it if it's bad.""" try: bad = ( conn.closed or conn.get_transaction_status() == extensions.TRANSACTION_STATUS_INERROR ) if bad: try: conn.rollback() except Exception: pass close = True self.pool.putconn(conn, close=close) except Exception as e: logger.error(f"Error returning connection to pool: {e}") try: conn.close() except Exception: pass def closeall(self): if self.pool: try: self.pool.closeall() except Exception: pass def get_connection_pool(): """ Initialize and return the global connection pool. Includes TCP keepalives and statement_timeout to detect dead/stuck sessions. """ global _connection_pool with _pool_lock: if _connection_pool is None: conn_kwargs = dict( host=os.getenv("PG_HOST"), port=int(os.getenv("PG_PORT", 5432)), database=os.getenv("PG_DATABASE"), user=os.getenv("PG_USER"), password=os.getenv("PG_PASSWORD"), connect_timeout=int(os.getenv("DB_CONNECT_TIMEOUT", 10)), application_name="rideaware_app", keepalives=1, keepalives_idle=int(os.getenv("PG_KEEPALIVE_IDLE", 30)), keepalives_interval=int(os.getenv("PG_KEEPALIVE_INTERVAL", 10)), keepalives_count=int(os.getenv("PG_KEEPALIVE_COUNT", 3)), options="-c statement_timeout={}".format( int(os.getenv("PG_STATEMENT_TIMEOUT_MS", 5000)) ), ) _connection_pool = SimpleRobustPool( minconn=int(os.getenv("DB_POOL_MIN", 3)), maxconn=int(os.getenv("DB_POOL_MAX", 15)), **conn_kwargs, ) logger.info("Database connection pool initialized") return _connection_pool @contextmanager def get_db_connection(): """ Context manager that yields a healthy connection and ensures proper cleanup/rollback on errors. """ conn = None try: pool = get_connection_pool() conn = pool.getconn() yield conn except Exception as e: logger.error(f"Database connection error: {e}") if conn: try: conn.rollback() except Exception: pass raise finally: if conn: try: pool = get_connection_pool() pool.putconn(conn) except Exception as e: logger.error(f"Error returning connection: {e}") def get_connection(): pool = get_connection_pool() return pool.getconn() def return_connection(conn): try: pool = get_connection_pool() pool.putconn(conn) except Exception as e: logger.error(f"Error returning connection to pool: {e}") def close_all_connections(): global _connection_pool with _pool_lock: if _connection_pool: try: _connection_pool.closeall() logger.info("All DB connections closed") except Exception as e: logger.error(f"Error closing connections: {e}") finally: _connection_pool = None def get_pool_stats(): return _pool_stats.copy() def column_exists(cursor, table_name, column_name) -> bool: cursor.execute( """ SELECT EXISTS ( SELECT 1 FROM information_schema.columns WHERE table_name = %s AND column_name = %s ) """, (table_name, column_name), ) return bool(cursor.fetchone()[0]) def init_db(): """ Initialize tables and indexes idempotently. """ with get_db_connection() as conn: try: with conn.cursor() as cur: cur.execute( """ CREATE TABLE IF NOT EXISTS subscribers ( id SERIAL PRIMARY KEY, email TEXT UNIQUE NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """ ) cur.execute( """ CREATE TABLE IF NOT EXISTS newsletters ( id SERIAL PRIMARY KEY, subject TEXT NOT NULL, body TEXT NOT NULL, sent_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """ ) cur.execute( "CREATE INDEX IF NOT EXISTS idx_newsletters_sent_at ON newsletters(sent_at DESC)" ) cur.execute( "CREATE INDEX IF NOT EXISTS idx_subscribers_email ON subscribers(email)" ) cur.execute( "CREATE INDEX IF NOT EXISTS idx_subscribers_created_at ON subscribers(created_at DESC)" ) conn.commit() logger.info("DB schema ensured successfully") except Exception as e: logger.error(f"Error initializing DB: {e}") conn.rollback() raise def add_email(email: str) -> bool: """ Insert a subscriber email. Returns True if inserted, False if duplicate or error. """ with get_db_connection() as conn: try: with conn.cursor() as cur: cur.execute( "INSERT INTO subscribers (email) VALUES (%s)", (email,), ) conn.commit() logger.info("Email added: %s", email) return True except IntegrityError: conn.rollback() logger.info("Email already exists: %s", email) return False except Exception as e: conn.rollback() logger.error("Error adding email %s: %s", email, e) return False def remove_email(email: str) -> bool: """ Delete a subscriber email. Returns True if removed, False if not found or error. """ with get_db_connection() as conn: try: with conn.cursor() as cur: cur.execute( "DELETE FROM subscribers WHERE email = %s", (email,), ) removed = cur.rowcount > 0 conn.commit() if removed: logger.info("Email removed: %s", email) else: logger.info("Email not found for removal: %s", email) return removed except Exception as e: conn.rollback() logger.error("Error removing email %s: %s", email, e) return False def get_subscriber_count() -> int: with get_db_connection() as conn: try: with conn.cursor() as cur: cur.execute("SELECT COUNT(*) FROM subscribers") (count,) = cur.fetchone() return int(count) except Exception as e: logger.error(f"Error getting subscriber count: {e}") return 0 def fetch_all_newsletters(limit: int = 100): """ Return a list of newsletters ordered by sent_at desc. Each item: dict(id, subject, body, sent_at) """ with get_db_connection() as conn: with conn.cursor() as cur: cur.execute( """ SELECT id, subject, body, sent_at FROM newsletters ORDER BY sent_at DESC LIMIT %s """, (limit,), ) rows = cur.fetchall() return [ {"id": r[0], "subject": r[1], "body": r[2], "sent_at": r[3]} for r in rows ] def fetch_newsletter_by_id(newsletter_id: int): with get_db_connection() as conn: with conn.cursor() as cur: cur.execute( """ SELECT id, subject, body, sent_at FROM newsletters WHERE id = %s """, (newsletter_id,), ) row = cur.fetchone() if not row: return None return {"id": row[0], "subject": row[1], "body": row[2], "sent_at": row[3]} atexit.register(close_all_connections)