diff --git a/database.py b/database.py index 5a3454a..5ecde96 100644 --- a/database.py +++ b/database.py @@ -1,172 +1,220 @@ import os -import psycopg2 -from psycopg2 import pool, IntegrityError, OperationalError -from dotenv import load_dotenv +import time import logging import threading -import time +import atexit from contextlib import contextmanager +import psycopg2 +from psycopg2 import pool, IntegrityError +from psycopg2 import extensions +from dotenv import load_dotenv load_dotenv() -# Global connection pool +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + _connection_pool = None _pool_lock = threading.Lock() _pool_stats = { - 'connections_created': 0, - 'connections_failed': 0, - 'pool_recreated': 0 + "connections_checked_out": 0, + "connections_failed": 0, + "pool_recreated": 0, } -logger = logging.getLogger(__name__) - class SimpleRobustPool: - """Simplified robust connection pool""" - + """ + Thread-safe, robust connection pool wrapper for psycopg2 + with connection health checks and automatic recovery. + """ + def __init__(self, minconn=3, maxconn=15, **kwargs): self.minconn = minconn self.maxconn = maxconn self.kwargs = kwargs self.pool = None self._create_pool() - + def _create_pool(self): - """Create or recreate the connection pool""" + """Create or recreate the underlying connection pool.""" try: if self.pool: try: self.pool.closeall() - except: + except Exception: pass - + self.pool = psycopg2.pool.ThreadedConnectionPool( minconn=self.minconn, maxconn=self.maxconn, - **self.kwargs + **self.kwargs, + ) + _pool_stats["pool_recreated"] += 1 + logger.info( + "DB pool created: min=%s max=%s host=%s db=%s", + self.minconn, + self.maxconn, + self.kwargs.get("host"), + self.kwargs.get("database"), ) - _pool_stats['pool_recreated'] += 1 - logger.info(f"Connection pool created: {self.minconn}-{self.maxconn} connections") - except Exception as e: logger.error(f"Failed to create connection pool: {e}") raise - - def _test_connection(self, conn): - """Simple connection test without transaction conflicts""" + + def _healthy(self, conn) -> bool: + """ + Validate a connection with a real round-trip and + make sure it's not in an aborted transaction. + """ try: if conn.closed: return False - - # Simple test that doesn't interfere with transactions - conn.poll() - return conn.status == psycopg2.extensions.STATUS_READY or conn.status == psycopg2.extensions.STATUS_BEGIN - + + txn_status = conn.get_transaction_status() + if txn_status == extensions.TRANSACTION_STATUS_INERROR: + try: + conn.rollback() + except Exception: + return False + + if not conn.autocommit: + conn.autocommit = True + + with conn.cursor() as cur: + cur.execute("SELECT 1") + cur.fetchone() + return True except Exception: return False - + def getconn(self, retry_count=2): - """Get a connection with simplified retry logic""" - for attempt in range(retry_count): + """ + Get a connection, validating it before handing out. + On repeated failures, recreate the pool once. + """ + last_err = None + + for attempt in range(retry_count + 1): try: conn = self.pool.getconn() - - # Simple connection test - if not self._test_connection(conn): - logger.warning("Got bad connection, discarding and retrying") + if not self._healthy(conn): + logger.warning("Discarding unhealthy connection") try: self.pool.putconn(conn, close=True) - except: + except Exception: pass + last_err = last_err or Exception("Unhealthy connection") + time.sleep(0.2) continue - - _pool_stats['connections_created'] += 1 + + _pool_stats["connections_checked_out"] += 1 return conn - + except Exception as e: - logger.error(f"Error getting connection (attempt {attempt + 1}): {e}") - _pool_stats['connections_failed'] += 1 - - if attempt == retry_count - 1: - # Last attempt failed, try to recreate pool - logger.warning("Recreating connection pool due to failures") - try: - self._create_pool() - conn = self.pool.getconn() - if self._test_connection(conn): - return conn - except Exception as recreate_error: - logger.error(f"Failed to recreate pool: {recreate_error}") - raise - - # Wait before retry - time.sleep(0.5) - - raise Exception("Failed to get connection after retries") - - def putconn(self, conn, close=False): - """Return connection to pool""" + last_err = e + _pool_stats["connections_failed"] += 1 + logger.error( + "Error getting connection (attempt %s/%s): %s", + attempt + 1, + retry_count + 1, + e, + ) + time.sleep(0.3) + + logger.warning("Recreating DB pool due to repeated failures") + self._create_pool() try: - # Check if connection should be closed - if conn.closed or close: + conn = self.pool.getconn() + if self._healthy(conn): + _pool_stats["connections_checked_out"] += 1 + return conn + self.pool.putconn(conn, close=True) + except Exception as e: + last_err = e + + raise last_err or Exception("Failed to get a healthy DB connection") + + def putconn(self, conn, close=False): + """Return a connection to the pool, closing it if it's bad.""" + try: + bad = ( + conn.closed + or conn.get_transaction_status() + == extensions.TRANSACTION_STATUS_INERROR + ) + if bad: + try: + conn.rollback() + except Exception: + pass close = True - + self.pool.putconn(conn, close=close) - except Exception as e: logger.error(f"Error returning connection to pool: {e}") try: conn.close() - except: + except Exception: pass - + def closeall(self): - """Close all connections""" if self.pool: - self.pool.closeall() + try: + self.pool.closeall() + except Exception: + pass def get_connection_pool(): - """Initialize and return the connection pool""" + """ + Initialize and return the global connection pool. + Includes TCP keepalives and statement_timeout to detect dead/stuck sessions. + """ global _connection_pool with _pool_lock: if _connection_pool is None: - try: - _connection_pool = SimpleRobustPool( - minconn=int(os.getenv('DB_POOL_MIN', 3)), - maxconn=int(os.getenv('DB_POOL_MAX', 15)), - host=os.getenv("PG_HOST"), - port=int(os.getenv("PG_PORT", 5432)), - database=os.getenv("PG_DATABASE"), - user=os.getenv("PG_USER"), - password=os.getenv("PG_PASSWORD"), - connect_timeout=int(os.getenv('DB_CONNECT_TIMEOUT', 10)), - application_name="rideaware_newsletter" - ) - logger.info("Database connection pool initialized successfully") - - except Exception as e: - logger.error(f"Error creating connection pool: {e}") - raise - + conn_kwargs = dict( + host=os.getenv("PG_HOST"), + port=int(os.getenv("PG_PORT", 5432)), + database=os.getenv("PG_DATABASE"), + user=os.getenv("PG_USER"), + password=os.getenv("PG_PASSWORD"), + connect_timeout=int(os.getenv("DB_CONNECT_TIMEOUT", 10)), + application_name="rideaware_app", + keepalives=1, + keepalives_idle=int(os.getenv("PG_KEEPALIVE_IDLE", 30)), + keepalives_interval=int(os.getenv("PG_KEEPALIVE_INTERVAL", 10)), + keepalives_count=int(os.getenv("PG_KEEPALIVE_COUNT", 3)), + options="-c statement_timeout={}".format( + int(os.getenv("PG_STATEMENT_TIMEOUT_MS", 5000)) + ), + ) + _connection_pool = SimpleRobustPool( + minconn=int(os.getenv("DB_POOL_MIN", 3)), + maxconn=int(os.getenv("DB_POOL_MAX", 15)), + **conn_kwargs, + ) + logger.info("Database connection pool initialized") return _connection_pool @contextmanager def get_db_connection(): - """Context manager for database connections""" + """ + Context manager that yields a healthy connection and + ensures proper cleanup/rollback on errors. + """ conn = None try: pool = get_connection_pool() conn = pool.getconn() yield conn - except Exception as e: logger.error(f"Database connection error: {e}") if conn: try: conn.rollback() - except: + except Exception: pass raise - finally: if conn: try: @@ -176,16 +224,10 @@ def get_db_connection(): logger.error(f"Error returning connection: {e}") def get_connection(): - """Get a connection from the pool (legacy interface)""" - try: - pool = get_connection_pool() - return pool.getconn() - except Exception as e: - logger.error(f"Error getting connection from pool: {e}") - raise + pool = get_connection_pool() + return pool.getconn() def return_connection(conn): - """Return a connection to the pool (legacy interface)""" try: pool = get_connection_pool() pool.putconn(conn) @@ -193,149 +235,173 @@ def return_connection(conn): logger.error(f"Error returning connection to pool: {e}") def close_all_connections(): - """Close all connections in the pool""" global _connection_pool with _pool_lock: if _connection_pool: try: _connection_pool.closeall() - logger.info("All database connections closed") + logger.info("All DB connections closed") except Exception as e: logger.error(f"Error closing connections: {e}") finally: _connection_pool = None def get_pool_stats(): - """Get connection pool statistics""" return _pool_stats.copy() -def column_exists(cursor, table_name, column_name): - """Check if a column exists in a table""" - cursor.execute(""" +def column_exists(cursor, table_name, column_name) -> bool: + cursor.execute( + """ SELECT EXISTS ( - SELECT 1 - FROM information_schema.columns + SELECT 1 + FROM information_schema.columns WHERE table_name = %s AND column_name = %s ) - """, (table_name, column_name)) - return cursor.fetchone()[0] - -def index_exists(cursor, index_name): - """Check if an index exists""" - cursor.execute(""" - SELECT EXISTS ( - SELECT 1 FROM pg_class c - JOIN pg_namespace n ON n.oid = c.relnamespace - WHERE c.relname = %s AND n.nspname = 'public' - ) - """, (index_name,)) - return cursor.fetchone()[0] + """, + (table_name, column_name), + ) + return bool(cursor.fetchone()[0]) def init_db(): - """Initialize database tables and indexes""" + """ + Initialize tables and indexes idempotently. + """ with get_db_connection() as conn: try: - with conn.cursor() as cursor: - # Create subscribers table - cursor.execute(""" + with conn.cursor() as cur: + cur.execute( + """ CREATE TABLE IF NOT EXISTS subscribers ( id SERIAL PRIMARY KEY, - email TEXT UNIQUE NOT NULL + email TEXT UNIQUE NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) - """) - - # Add created_at column if it doesn't exist - if not column_exists(cursor, 'subscribers', 'created_at'): - cursor.execute(""" - ALTER TABLE subscribers - ADD COLUMN created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - """) - logger.info("Added created_at column to subscribers table") - - # Create newsletters table - cursor.execute(""" - CREATE TABLE IF NOT EXISTS newsletters( + """ + ) + + cur.execute( + """ + CREATE TABLE IF NOT EXISTS newsletters ( id SERIAL PRIMARY KEY, subject TEXT NOT NULL, body TEXT NOT NULL, sent_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) - """) - - # Create indexes only if they don't exist - indexes = [ - ("idx_newsletters_sent_at", "CREATE INDEX IF NOT EXISTS idx_newsletters_sent_at ON newsletters(sent_at DESC)"), - ("idx_subscribers_email", "CREATE INDEX IF NOT EXISTS idx_subscribers_email ON subscribers(email)"), - ("idx_subscribers_created_at", "CREATE INDEX IF NOT EXISTS idx_subscribers_created_at ON subscribers(created_at DESC)") - ] - - for index_name, create_sql in indexes: - cursor.execute(create_sql) - logger.info(f"Ensured index {index_name} exists") - - conn.commit() - logger.info("Database tables and indexes initialized successfully") - + """ + ) + + cur.execute( + "CREATE INDEX IF NOT EXISTS idx_newsletters_sent_at ON newsletters(sent_at DESC)" + ) + cur.execute( + "CREATE INDEX IF NOT EXISTS idx_subscribers_email ON subscribers(email)" + ) + cur.execute( + "CREATE INDEX IF NOT EXISTS idx_subscribers_created_at ON subscribers(created_at DESC)" + ) + + conn.commit() + logger.info("DB schema ensured successfully") except Exception as e: - logger.error(f"Error initializing database: {e}") + logger.error(f"Error initializing DB: {e}") conn.rollback() raise -def add_email(email): - """Add email to subscribers with robust connection handling""" +def add_email(email: str) -> bool: + """ + Insert a subscriber email. + Returns True if inserted, False if duplicate or error. + """ with get_db_connection() as conn: try: - with conn.cursor() as cursor: - cursor.execute("INSERT INTO subscribers (email) VALUES (%s)", (email,)) - conn.commit() - logger.info(f"Email added successfully: {email}") - return True - + with conn.cursor() as cur: + cur.execute( + "INSERT INTO subscribers (email) VALUES (%s)", + (email,), + ) + conn.commit() + logger.info("Email added: %s", email) + return True except IntegrityError: - # Email already exists conn.rollback() - logger.info(f"Email already exists: {email}") + logger.info("Email already exists: %s", email) return False - except Exception as e: conn.rollback() - logger.error(f"Error adding email {email}: {e}") + logger.error("Error adding email %s: %s", email, e) return False -def remove_email(email): - """Remove email from subscribers with robust connection handling""" +def remove_email(email: str) -> bool: + """ + Delete a subscriber email. + Returns True if removed, False if not found or error. + """ with get_db_connection() as conn: try: - with conn.cursor() as cursor: - cursor.execute("DELETE FROM subscribers WHERE email = %s", (email,)) - conn.commit() - rows_affected = cursor.rowcount - - if rows_affected > 0: - logger.info(f"Email removed successfully: {email}") - return True - else: - logger.info(f"Email not found for removal: {email}") - return False - + with conn.cursor() as cur: + cur.execute( + "DELETE FROM subscribers WHERE email = %s", + (email,), + ) + removed = cur.rowcount > 0 + conn.commit() + if removed: + logger.info("Email removed: %s", email) + else: + logger.info("Email not found for removal: %s", email) + return removed except Exception as e: conn.rollback() - logger.error(f"Error removing email {email}: {e}") + logger.error("Error removing email %s: %s", email, e) return False -def get_subscriber_count(): - """Get total number of subscribers""" +def get_subscriber_count() -> int: with get_db_connection() as conn: try: - with conn.cursor() as cursor: - cursor.execute("SELECT COUNT(*) FROM subscribers") - count = cursor.fetchone()[0] - return count - + with conn.cursor() as cur: + cur.execute("SELECT COUNT(*) FROM subscribers") + (count,) = cur.fetchone() + return int(count) except Exception as e: logger.error(f"Error getting subscriber count: {e}") return 0 -# Cleanup function for graceful shutdown -import atexit +def fetch_all_newsletters(limit: int = 100): + """ + Return a list of newsletters ordered by sent_at desc. + Each item: dict(id, subject, body, sent_at) + """ + with get_db_connection() as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT id, subject, body, sent_at + FROM newsletters + ORDER BY sent_at DESC + LIMIT %s + """, + (limit,), + ) + rows = cur.fetchall() + return [ + {"id": r[0], "subject": r[1], "body": r[2], "sent_at": r[3]} + for r in rows + ] + +def fetch_newsletter_by_id(newsletter_id: int): + with get_db_connection() as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT id, subject, body, sent_at + FROM newsletters + WHERE id = %s + """, + (newsletter_id,), + ) + row = cur.fetchone() + if not row: + return None + return {"id": row[0], "subject": row[1], "body": row[2], "sent_at": row[3]} + atexit.register(close_all_connections) \ No newline at end of file