feat(db): add robust psycopg2 pool with health checks, keepalives, and safer caching behavior

This commit is contained in:
Cipher Vance 2025-08-31 13:30:16 -05:00
parent 46093d2b56
commit 0bab1b5dd8

View file

@ -1,27 +1,32 @@
import os import os
import psycopg2 import time
from psycopg2 import pool, IntegrityError, OperationalError
from dotenv import load_dotenv
import logging import logging
import threading import threading
import time import atexit
from contextlib import contextmanager from contextlib import contextmanager
import psycopg2
from psycopg2 import pool, IntegrityError
from psycopg2 import extensions
from dotenv import load_dotenv
load_dotenv() load_dotenv()
# Global connection pool logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
_connection_pool = None _connection_pool = None
_pool_lock = threading.Lock() _pool_lock = threading.Lock()
_pool_stats = { _pool_stats = {
'connections_created': 0, "connections_checked_out": 0,
'connections_failed': 0, "connections_failed": 0,
'pool_recreated': 0 "pool_recreated": 0,
} }
logger = logging.getLogger(__name__)
class SimpleRobustPool: class SimpleRobustPool:
"""Simplified robust connection pool""" """
Thread-safe, robust connection pool wrapper for psycopg2
with connection health checks and automatic recovery.
"""
def __init__(self, minconn=3, maxconn=15, **kwargs): def __init__(self, minconn=3, maxconn=15, **kwargs):
self.minconn = minconn self.minconn = minconn
@ -31,142 +36,185 @@ class SimpleRobustPool:
self._create_pool() self._create_pool()
def _create_pool(self): def _create_pool(self):
"""Create or recreate the connection pool""" """Create or recreate the underlying connection pool."""
try: try:
if self.pool: if self.pool:
try: try:
self.pool.closeall() self.pool.closeall()
except: except Exception:
pass pass
self.pool = psycopg2.pool.ThreadedConnectionPool( self.pool = psycopg2.pool.ThreadedConnectionPool(
minconn=self.minconn, minconn=self.minconn,
maxconn=self.maxconn, maxconn=self.maxconn,
**self.kwargs **self.kwargs,
)
_pool_stats["pool_recreated"] += 1
logger.info(
"DB pool created: min=%s max=%s host=%s db=%s",
self.minconn,
self.maxconn,
self.kwargs.get("host"),
self.kwargs.get("database"),
) )
_pool_stats['pool_recreated'] += 1
logger.info(f"Connection pool created: {self.minconn}-{self.maxconn} connections")
except Exception as e: except Exception as e:
logger.error(f"Failed to create connection pool: {e}") logger.error(f"Failed to create connection pool: {e}")
raise raise
def _test_connection(self, conn): def _healthy(self, conn) -> bool:
"""Simple connection test without transaction conflicts""" """
Validate a connection with a real round-trip and
make sure it's not in an aborted transaction.
"""
try: try:
if conn.closed: if conn.closed:
return False return False
# Simple test that doesn't interfere with transactions txn_status = conn.get_transaction_status()
conn.poll() if txn_status == extensions.TRANSACTION_STATUS_INERROR:
return conn.status == psycopg2.extensions.STATUS_READY or conn.status == psycopg2.extensions.STATUS_BEGIN try:
conn.rollback()
except Exception:
return False
if not conn.autocommit:
conn.autocommit = True
with conn.cursor() as cur:
cur.execute("SELECT 1")
cur.fetchone()
return True
except Exception: except Exception:
return False return False
def getconn(self, retry_count=2): def getconn(self, retry_count=2):
"""Get a connection with simplified retry logic""" """
for attempt in range(retry_count): Get a connection, validating it before handing out.
On repeated failures, recreate the pool once.
"""
last_err = None
for attempt in range(retry_count + 1):
try: try:
conn = self.pool.getconn() conn = self.pool.getconn()
if not self._healthy(conn):
# Simple connection test logger.warning("Discarding unhealthy connection")
if not self._test_connection(conn):
logger.warning("Got bad connection, discarding and retrying")
try: try:
self.pool.putconn(conn, close=True) self.pool.putconn(conn, close=True)
except: except Exception:
pass pass
last_err = last_err or Exception("Unhealthy connection")
time.sleep(0.2)
continue continue
_pool_stats['connections_created'] += 1 _pool_stats["connections_checked_out"] += 1
return conn return conn
except Exception as e: except Exception as e:
logger.error(f"Error getting connection (attempt {attempt + 1}): {e}") last_err = e
_pool_stats['connections_failed'] += 1 _pool_stats["connections_failed"] += 1
logger.error(
"Error getting connection (attempt %s/%s): %s",
attempt + 1,
retry_count + 1,
e,
)
time.sleep(0.3)
if attempt == retry_count - 1: logger.warning("Recreating DB pool due to repeated failures")
# Last attempt failed, try to recreate pool
logger.warning("Recreating connection pool due to failures")
try:
self._create_pool() self._create_pool()
try:
conn = self.pool.getconn() conn = self.pool.getconn()
if self._test_connection(conn): if self._healthy(conn):
_pool_stats["connections_checked_out"] += 1
return conn return conn
except Exception as recreate_error: self.pool.putconn(conn, close=True)
logger.error(f"Failed to recreate pool: {recreate_error}") except Exception as e:
raise last_err = e
# Wait before retry raise last_err or Exception("Failed to get a healthy DB connection")
time.sleep(0.5)
raise Exception("Failed to get connection after retries")
def putconn(self, conn, close=False): def putconn(self, conn, close=False):
"""Return connection to pool""" """Return a connection to the pool, closing it if it's bad."""
try: try:
# Check if connection should be closed bad = (
if conn.closed or close: conn.closed
or conn.get_transaction_status()
== extensions.TRANSACTION_STATUS_INERROR
)
if bad:
try:
conn.rollback()
except Exception:
pass
close = True close = True
self.pool.putconn(conn, close=close) self.pool.putconn(conn, close=close)
except Exception as e: except Exception as e:
logger.error(f"Error returning connection to pool: {e}") logger.error(f"Error returning connection to pool: {e}")
try: try:
conn.close() conn.close()
except: except Exception:
pass pass
def closeall(self): def closeall(self):
"""Close all connections"""
if self.pool: if self.pool:
try:
self.pool.closeall() self.pool.closeall()
except Exception:
pass
def get_connection_pool(): def get_connection_pool():
"""Initialize and return the connection pool""" """
Initialize and return the global connection pool.
Includes TCP keepalives and statement_timeout to detect dead/stuck sessions.
"""
global _connection_pool global _connection_pool
with _pool_lock: with _pool_lock:
if _connection_pool is None: if _connection_pool is None:
try: conn_kwargs = dict(
_connection_pool = SimpleRobustPool(
minconn=int(os.getenv('DB_POOL_MIN', 3)),
maxconn=int(os.getenv('DB_POOL_MAX', 15)),
host=os.getenv("PG_HOST"), host=os.getenv("PG_HOST"),
port=int(os.getenv("PG_PORT", 5432)), port=int(os.getenv("PG_PORT", 5432)),
database=os.getenv("PG_DATABASE"), database=os.getenv("PG_DATABASE"),
user=os.getenv("PG_USER"), user=os.getenv("PG_USER"),
password=os.getenv("PG_PASSWORD"), password=os.getenv("PG_PASSWORD"),
connect_timeout=int(os.getenv('DB_CONNECT_TIMEOUT', 10)), connect_timeout=int(os.getenv("DB_CONNECT_TIMEOUT", 10)),
application_name="rideaware_newsletter" application_name="rideaware_app",
keepalives=1,
keepalives_idle=int(os.getenv("PG_KEEPALIVE_IDLE", 30)),
keepalives_interval=int(os.getenv("PG_KEEPALIVE_INTERVAL", 10)),
keepalives_count=int(os.getenv("PG_KEEPALIVE_COUNT", 3)),
options="-c statement_timeout={}".format(
int(os.getenv("PG_STATEMENT_TIMEOUT_MS", 5000))
),
) )
logger.info("Database connection pool initialized successfully") _connection_pool = SimpleRobustPool(
minconn=int(os.getenv("DB_POOL_MIN", 3)),
except Exception as e: maxconn=int(os.getenv("DB_POOL_MAX", 15)),
logger.error(f"Error creating connection pool: {e}") **conn_kwargs,
raise )
logger.info("Database connection pool initialized")
return _connection_pool return _connection_pool
@contextmanager @contextmanager
def get_db_connection(): def get_db_connection():
"""Context manager for database connections""" """
Context manager that yields a healthy connection and
ensures proper cleanup/rollback on errors.
"""
conn = None conn = None
try: try:
pool = get_connection_pool() pool = get_connection_pool()
conn = pool.getconn() conn = pool.getconn()
yield conn yield conn
except Exception as e: except Exception as e:
logger.error(f"Database connection error: {e}") logger.error(f"Database connection error: {e}")
if conn: if conn:
try: try:
conn.rollback() conn.rollback()
except: except Exception:
pass pass
raise raise
finally: finally:
if conn: if conn:
try: try:
@ -176,16 +224,10 @@ def get_db_connection():
logger.error(f"Error returning connection: {e}") logger.error(f"Error returning connection: {e}")
def get_connection(): def get_connection():
"""Get a connection from the pool (legacy interface)"""
try:
pool = get_connection_pool() pool = get_connection_pool()
return pool.getconn() return pool.getconn()
except Exception as e:
logger.error(f"Error getting connection from pool: {e}")
raise
def return_connection(conn): def return_connection(conn):
"""Return a connection to the pool (legacy interface)"""
try: try:
pool = get_connection_pool() pool = get_connection_pool()
pool.putconn(conn) pool.putconn(conn)
@ -193,149 +235,173 @@ def return_connection(conn):
logger.error(f"Error returning connection to pool: {e}") logger.error(f"Error returning connection to pool: {e}")
def close_all_connections(): def close_all_connections():
"""Close all connections in the pool"""
global _connection_pool global _connection_pool
with _pool_lock: with _pool_lock:
if _connection_pool: if _connection_pool:
try: try:
_connection_pool.closeall() _connection_pool.closeall()
logger.info("All database connections closed") logger.info("All DB connections closed")
except Exception as e: except Exception as e:
logger.error(f"Error closing connections: {e}") logger.error(f"Error closing connections: {e}")
finally: finally:
_connection_pool = None _connection_pool = None
def get_pool_stats(): def get_pool_stats():
"""Get connection pool statistics"""
return _pool_stats.copy() return _pool_stats.copy()
def column_exists(cursor, table_name, column_name): def column_exists(cursor, table_name, column_name) -> bool:
"""Check if a column exists in a table""" cursor.execute(
cursor.execute(""" """
SELECT EXISTS ( SELECT EXISTS (
SELECT 1 SELECT 1
FROM information_schema.columns FROM information_schema.columns
WHERE table_name = %s AND column_name = %s WHERE table_name = %s AND column_name = %s
) )
""", (table_name, column_name)) """,
return cursor.fetchone()[0] (table_name, column_name),
def index_exists(cursor, index_name):
"""Check if an index exists"""
cursor.execute("""
SELECT EXISTS (
SELECT 1 FROM pg_class c
JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE c.relname = %s AND n.nspname = 'public'
) )
""", (index_name,)) return bool(cursor.fetchone()[0])
return cursor.fetchone()[0]
def init_db(): def init_db():
"""Initialize database tables and indexes""" """
Initialize tables and indexes idempotently.
"""
with get_db_connection() as conn: with get_db_connection() as conn:
try: try:
with conn.cursor() as cursor: with conn.cursor() as cur:
# Create subscribers table cur.execute(
cursor.execute(""" """
CREATE TABLE IF NOT EXISTS subscribers ( CREATE TABLE IF NOT EXISTS subscribers (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
email TEXT UNIQUE NOT NULL email TEXT UNIQUE NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
) )
""")
# Add created_at column if it doesn't exist cur.execute(
if not column_exists(cursor, 'subscribers', 'created_at'): """
cursor.execute("""
ALTER TABLE subscribers
ADD COLUMN created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
""")
logger.info("Added created_at column to subscribers table")
# Create newsletters table
cursor.execute("""
CREATE TABLE IF NOT EXISTS newsletters ( CREATE TABLE IF NOT EXISTS newsletters (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
subject TEXT NOT NULL, subject TEXT NOT NULL,
body TEXT NOT NULL, body TEXT NOT NULL,
sent_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP sent_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
) )
""") """
)
# Create indexes only if they don't exist cur.execute(
indexes = [ "CREATE INDEX IF NOT EXISTS idx_newsletters_sent_at ON newsletters(sent_at DESC)"
("idx_newsletters_sent_at", "CREATE INDEX IF NOT EXISTS idx_newsletters_sent_at ON newsletters(sent_at DESC)"), )
("idx_subscribers_email", "CREATE INDEX IF NOT EXISTS idx_subscribers_email ON subscribers(email)"), cur.execute(
("idx_subscribers_created_at", "CREATE INDEX IF NOT EXISTS idx_subscribers_created_at ON subscribers(created_at DESC)") "CREATE INDEX IF NOT EXISTS idx_subscribers_email ON subscribers(email)"
] )
cur.execute(
for index_name, create_sql in indexes: "CREATE INDEX IF NOT EXISTS idx_subscribers_created_at ON subscribers(created_at DESC)"
cursor.execute(create_sql) )
logger.info(f"Ensured index {index_name} exists")
conn.commit() conn.commit()
logger.info("Database tables and indexes initialized successfully") logger.info("DB schema ensured successfully")
except Exception as e: except Exception as e:
logger.error(f"Error initializing database: {e}") logger.error(f"Error initializing DB: {e}")
conn.rollback() conn.rollback()
raise raise
def add_email(email): def add_email(email: str) -> bool:
"""Add email to subscribers with robust connection handling""" """
Insert a subscriber email.
Returns True if inserted, False if duplicate or error.
"""
with get_db_connection() as conn: with get_db_connection() as conn:
try: try:
with conn.cursor() as cursor: with conn.cursor() as cur:
cursor.execute("INSERT INTO subscribers (email) VALUES (%s)", (email,)) cur.execute(
"INSERT INTO subscribers (email) VALUES (%s)",
(email,),
)
conn.commit() conn.commit()
logger.info(f"Email added successfully: {email}") logger.info("Email added: %s", email)
return True return True
except IntegrityError: except IntegrityError:
# Email already exists
conn.rollback() conn.rollback()
logger.info(f"Email already exists: {email}") logger.info("Email already exists: %s", email)
return False return False
except Exception as e: except Exception as e:
conn.rollback() conn.rollback()
logger.error(f"Error adding email {email}: {e}") logger.error("Error adding email %s: %s", email, e)
return False return False
def remove_email(email): def remove_email(email: str) -> bool:
"""Remove email from subscribers with robust connection handling""" """
Delete a subscriber email.
Returns True if removed, False if not found or error.
"""
with get_db_connection() as conn: with get_db_connection() as conn:
try: try:
with conn.cursor() as cursor: with conn.cursor() as cur:
cursor.execute("DELETE FROM subscribers WHERE email = %s", (email,)) cur.execute(
"DELETE FROM subscribers WHERE email = %s",
(email,),
)
removed = cur.rowcount > 0
conn.commit() conn.commit()
rows_affected = cursor.rowcount if removed:
logger.info("Email removed: %s", email)
if rows_affected > 0:
logger.info(f"Email removed successfully: {email}")
return True
else: else:
logger.info(f"Email not found for removal: {email}") logger.info("Email not found for removal: %s", email)
return False return removed
except Exception as e: except Exception as e:
conn.rollback() conn.rollback()
logger.error(f"Error removing email {email}: {e}") logger.error("Error removing email %s: %s", email, e)
return False return False
def get_subscriber_count(): def get_subscriber_count() -> int:
"""Get total number of subscribers"""
with get_db_connection() as conn: with get_db_connection() as conn:
try: try:
with conn.cursor() as cursor: with conn.cursor() as cur:
cursor.execute("SELECT COUNT(*) FROM subscribers") cur.execute("SELECT COUNT(*) FROM subscribers")
count = cursor.fetchone()[0] (count,) = cur.fetchone()
return count return int(count)
except Exception as e: except Exception as e:
logger.error(f"Error getting subscriber count: {e}") logger.error(f"Error getting subscriber count: {e}")
return 0 return 0
# Cleanup function for graceful shutdown def fetch_all_newsletters(limit: int = 100):
import atexit """
Return a list of newsletters ordered by sent_at desc.
Each item: dict(id, subject, body, sent_at)
"""
with get_db_connection() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, subject, body, sent_at
FROM newsletters
ORDER BY sent_at DESC
LIMIT %s
""",
(limit,),
)
rows = cur.fetchall()
return [
{"id": r[0], "subject": r[1], "body": r[2], "sent_at": r[3]}
for r in rows
]
def fetch_newsletter_by_id(newsletter_id: int):
with get_db_connection() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, subject, body, sent_at
FROM newsletters
WHERE id = %s
""",
(newsletter_id,),
)
row = cur.fetchone()
if not row:
return None
return {"id": row[0], "subject": row[1], "body": row[2], "sent_at": row[3]}
atexit.register(close_all_connections) atexit.register(close_all_connections)