feat(db): add robust psycopg2 pool with health checks, keepalives, and safer caching behavior

This commit is contained in:
Cipher Vance 2025-08-31 13:30:16 -05:00
parent 46093d2b56
commit 0bab1b5dd8

View file

@ -1,172 +1,220 @@
import os
import psycopg2
from psycopg2 import pool, IntegrityError, OperationalError
from dotenv import load_dotenv
import time
import logging
import threading
import time
import atexit
from contextlib import contextmanager
import psycopg2
from psycopg2 import pool, IntegrityError
from psycopg2 import extensions
from dotenv import load_dotenv
load_dotenv()
# Global connection pool
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
_connection_pool = None
_pool_lock = threading.Lock()
_pool_stats = {
'connections_created': 0,
'connections_failed': 0,
'pool_recreated': 0
"connections_checked_out": 0,
"connections_failed": 0,
"pool_recreated": 0,
}
logger = logging.getLogger(__name__)
class SimpleRobustPool:
"""Simplified robust connection pool"""
"""
Thread-safe, robust connection pool wrapper for psycopg2
with connection health checks and automatic recovery.
"""
def __init__(self, minconn=3, maxconn=15, **kwargs):
self.minconn = minconn
self.maxconn = maxconn
self.kwargs = kwargs
self.pool = None
self._create_pool()
def _create_pool(self):
"""Create or recreate the connection pool"""
"""Create or recreate the underlying connection pool."""
try:
if self.pool:
try:
self.pool.closeall()
except:
except Exception:
pass
self.pool = psycopg2.pool.ThreadedConnectionPool(
minconn=self.minconn,
maxconn=self.maxconn,
**self.kwargs
**self.kwargs,
)
_pool_stats["pool_recreated"] += 1
logger.info(
"DB pool created: min=%s max=%s host=%s db=%s",
self.minconn,
self.maxconn,
self.kwargs.get("host"),
self.kwargs.get("database"),
)
_pool_stats['pool_recreated'] += 1
logger.info(f"Connection pool created: {self.minconn}-{self.maxconn} connections")
except Exception as e:
logger.error(f"Failed to create connection pool: {e}")
raise
def _test_connection(self, conn):
"""Simple connection test without transaction conflicts"""
def _healthy(self, conn) -> bool:
"""
Validate a connection with a real round-trip and
make sure it's not in an aborted transaction.
"""
try:
if conn.closed:
return False
# Simple test that doesn't interfere with transactions
conn.poll()
return conn.status == psycopg2.extensions.STATUS_READY or conn.status == psycopg2.extensions.STATUS_BEGIN
txn_status = conn.get_transaction_status()
if txn_status == extensions.TRANSACTION_STATUS_INERROR:
try:
conn.rollback()
except Exception:
return False
if not conn.autocommit:
conn.autocommit = True
with conn.cursor() as cur:
cur.execute("SELECT 1")
cur.fetchone()
return True
except Exception:
return False
def getconn(self, retry_count=2):
"""Get a connection with simplified retry logic"""
for attempt in range(retry_count):
"""
Get a connection, validating it before handing out.
On repeated failures, recreate the pool once.
"""
last_err = None
for attempt in range(retry_count + 1):
try:
conn = self.pool.getconn()
# Simple connection test
if not self._test_connection(conn):
logger.warning("Got bad connection, discarding and retrying")
if not self._healthy(conn):
logger.warning("Discarding unhealthy connection")
try:
self.pool.putconn(conn, close=True)
except:
except Exception:
pass
last_err = last_err or Exception("Unhealthy connection")
time.sleep(0.2)
continue
_pool_stats['connections_created'] += 1
_pool_stats["connections_checked_out"] += 1
return conn
except Exception as e:
logger.error(f"Error getting connection (attempt {attempt + 1}): {e}")
_pool_stats['connections_failed'] += 1
if attempt == retry_count - 1:
# Last attempt failed, try to recreate pool
logger.warning("Recreating connection pool due to failures")
try:
self._create_pool()
conn = self.pool.getconn()
if self._test_connection(conn):
return conn
except Exception as recreate_error:
logger.error(f"Failed to recreate pool: {recreate_error}")
raise
# Wait before retry
time.sleep(0.5)
raise Exception("Failed to get connection after retries")
def putconn(self, conn, close=False):
"""Return connection to pool"""
last_err = e
_pool_stats["connections_failed"] += 1
logger.error(
"Error getting connection (attempt %s/%s): %s",
attempt + 1,
retry_count + 1,
e,
)
time.sleep(0.3)
logger.warning("Recreating DB pool due to repeated failures")
self._create_pool()
try:
# Check if connection should be closed
if conn.closed or close:
conn = self.pool.getconn()
if self._healthy(conn):
_pool_stats["connections_checked_out"] += 1
return conn
self.pool.putconn(conn, close=True)
except Exception as e:
last_err = e
raise last_err or Exception("Failed to get a healthy DB connection")
def putconn(self, conn, close=False):
"""Return a connection to the pool, closing it if it's bad."""
try:
bad = (
conn.closed
or conn.get_transaction_status()
== extensions.TRANSACTION_STATUS_INERROR
)
if bad:
try:
conn.rollback()
except Exception:
pass
close = True
self.pool.putconn(conn, close=close)
except Exception as e:
logger.error(f"Error returning connection to pool: {e}")
try:
conn.close()
except:
except Exception:
pass
def closeall(self):
"""Close all connections"""
if self.pool:
self.pool.closeall()
try:
self.pool.closeall()
except Exception:
pass
def get_connection_pool():
"""Initialize and return the connection pool"""
"""
Initialize and return the global connection pool.
Includes TCP keepalives and statement_timeout to detect dead/stuck sessions.
"""
global _connection_pool
with _pool_lock:
if _connection_pool is None:
try:
_connection_pool = SimpleRobustPool(
minconn=int(os.getenv('DB_POOL_MIN', 3)),
maxconn=int(os.getenv('DB_POOL_MAX', 15)),
host=os.getenv("PG_HOST"),
port=int(os.getenv("PG_PORT", 5432)),
database=os.getenv("PG_DATABASE"),
user=os.getenv("PG_USER"),
password=os.getenv("PG_PASSWORD"),
connect_timeout=int(os.getenv('DB_CONNECT_TIMEOUT', 10)),
application_name="rideaware_newsletter"
)
logger.info("Database connection pool initialized successfully")
except Exception as e:
logger.error(f"Error creating connection pool: {e}")
raise
conn_kwargs = dict(
host=os.getenv("PG_HOST"),
port=int(os.getenv("PG_PORT", 5432)),
database=os.getenv("PG_DATABASE"),
user=os.getenv("PG_USER"),
password=os.getenv("PG_PASSWORD"),
connect_timeout=int(os.getenv("DB_CONNECT_TIMEOUT", 10)),
application_name="rideaware_app",
keepalives=1,
keepalives_idle=int(os.getenv("PG_KEEPALIVE_IDLE", 30)),
keepalives_interval=int(os.getenv("PG_KEEPALIVE_INTERVAL", 10)),
keepalives_count=int(os.getenv("PG_KEEPALIVE_COUNT", 3)),
options="-c statement_timeout={}".format(
int(os.getenv("PG_STATEMENT_TIMEOUT_MS", 5000))
),
)
_connection_pool = SimpleRobustPool(
minconn=int(os.getenv("DB_POOL_MIN", 3)),
maxconn=int(os.getenv("DB_POOL_MAX", 15)),
**conn_kwargs,
)
logger.info("Database connection pool initialized")
return _connection_pool
@contextmanager
def get_db_connection():
"""Context manager for database connections"""
"""
Context manager that yields a healthy connection and
ensures proper cleanup/rollback on errors.
"""
conn = None
try:
pool = get_connection_pool()
conn = pool.getconn()
yield conn
except Exception as e:
logger.error(f"Database connection error: {e}")
if conn:
try:
conn.rollback()
except:
except Exception:
pass
raise
finally:
if conn:
try:
@ -176,16 +224,10 @@ def get_db_connection():
logger.error(f"Error returning connection: {e}")
def get_connection():
"""Get a connection from the pool (legacy interface)"""
try:
pool = get_connection_pool()
return pool.getconn()
except Exception as e:
logger.error(f"Error getting connection from pool: {e}")
raise
pool = get_connection_pool()
return pool.getconn()
def return_connection(conn):
"""Return a connection to the pool (legacy interface)"""
try:
pool = get_connection_pool()
pool.putconn(conn)
@ -193,149 +235,173 @@ def return_connection(conn):
logger.error(f"Error returning connection to pool: {e}")
def close_all_connections():
"""Close all connections in the pool"""
global _connection_pool
with _pool_lock:
if _connection_pool:
try:
_connection_pool.closeall()
logger.info("All database connections closed")
logger.info("All DB connections closed")
except Exception as e:
logger.error(f"Error closing connections: {e}")
finally:
_connection_pool = None
def get_pool_stats():
"""Get connection pool statistics"""
return _pool_stats.copy()
def column_exists(cursor, table_name, column_name):
"""Check if a column exists in a table"""
cursor.execute("""
def column_exists(cursor, table_name, column_name) -> bool:
cursor.execute(
"""
SELECT EXISTS (
SELECT 1
FROM information_schema.columns
SELECT 1
FROM information_schema.columns
WHERE table_name = %s AND column_name = %s
)
""", (table_name, column_name))
return cursor.fetchone()[0]
def index_exists(cursor, index_name):
"""Check if an index exists"""
cursor.execute("""
SELECT EXISTS (
SELECT 1 FROM pg_class c
JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE c.relname = %s AND n.nspname = 'public'
)
""", (index_name,))
return cursor.fetchone()[0]
""",
(table_name, column_name),
)
return bool(cursor.fetchone()[0])
def init_db():
"""Initialize database tables and indexes"""
"""
Initialize tables and indexes idempotently.
"""
with get_db_connection() as conn:
try:
with conn.cursor() as cursor:
# Create subscribers table
cursor.execute("""
with conn.cursor() as cur:
cur.execute(
"""
CREATE TABLE IF NOT EXISTS subscribers (
id SERIAL PRIMARY KEY,
email TEXT UNIQUE NOT NULL
email TEXT UNIQUE NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Add created_at column if it doesn't exist
if not column_exists(cursor, 'subscribers', 'created_at'):
cursor.execute("""
ALTER TABLE subscribers
ADD COLUMN created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
""")
logger.info("Added created_at column to subscribers table")
# Create newsletters table
cursor.execute("""
CREATE TABLE IF NOT EXISTS newsletters(
"""
)
cur.execute(
"""
CREATE TABLE IF NOT EXISTS newsletters (
id SERIAL PRIMARY KEY,
subject TEXT NOT NULL,
body TEXT NOT NULL,
sent_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Create indexes only if they don't exist
indexes = [
("idx_newsletters_sent_at", "CREATE INDEX IF NOT EXISTS idx_newsletters_sent_at ON newsletters(sent_at DESC)"),
("idx_subscribers_email", "CREATE INDEX IF NOT EXISTS idx_subscribers_email ON subscribers(email)"),
("idx_subscribers_created_at", "CREATE INDEX IF NOT EXISTS idx_subscribers_created_at ON subscribers(created_at DESC)")
]
for index_name, create_sql in indexes:
cursor.execute(create_sql)
logger.info(f"Ensured index {index_name} exists")
conn.commit()
logger.info("Database tables and indexes initialized successfully")
"""
)
cur.execute(
"CREATE INDEX IF NOT EXISTS idx_newsletters_sent_at ON newsletters(sent_at DESC)"
)
cur.execute(
"CREATE INDEX IF NOT EXISTS idx_subscribers_email ON subscribers(email)"
)
cur.execute(
"CREATE INDEX IF NOT EXISTS idx_subscribers_created_at ON subscribers(created_at DESC)"
)
conn.commit()
logger.info("DB schema ensured successfully")
except Exception as e:
logger.error(f"Error initializing database: {e}")
logger.error(f"Error initializing DB: {e}")
conn.rollback()
raise
def add_email(email):
"""Add email to subscribers with robust connection handling"""
def add_email(email: str) -> bool:
"""
Insert a subscriber email.
Returns True if inserted, False if duplicate or error.
"""
with get_db_connection() as conn:
try:
with conn.cursor() as cursor:
cursor.execute("INSERT INTO subscribers (email) VALUES (%s)", (email,))
conn.commit()
logger.info(f"Email added successfully: {email}")
return True
with conn.cursor() as cur:
cur.execute(
"INSERT INTO subscribers (email) VALUES (%s)",
(email,),
)
conn.commit()
logger.info("Email added: %s", email)
return True
except IntegrityError:
# Email already exists
conn.rollback()
logger.info(f"Email already exists: {email}")
logger.info("Email already exists: %s", email)
return False
except Exception as e:
conn.rollback()
logger.error(f"Error adding email {email}: {e}")
logger.error("Error adding email %s: %s", email, e)
return False
def remove_email(email):
"""Remove email from subscribers with robust connection handling"""
def remove_email(email: str) -> bool:
"""
Delete a subscriber email.
Returns True if removed, False if not found or error.
"""
with get_db_connection() as conn:
try:
with conn.cursor() as cursor:
cursor.execute("DELETE FROM subscribers WHERE email = %s", (email,))
conn.commit()
rows_affected = cursor.rowcount
if rows_affected > 0:
logger.info(f"Email removed successfully: {email}")
return True
else:
logger.info(f"Email not found for removal: {email}")
return False
with conn.cursor() as cur:
cur.execute(
"DELETE FROM subscribers WHERE email = %s",
(email,),
)
removed = cur.rowcount > 0
conn.commit()
if removed:
logger.info("Email removed: %s", email)
else:
logger.info("Email not found for removal: %s", email)
return removed
except Exception as e:
conn.rollback()
logger.error(f"Error removing email {email}: {e}")
logger.error("Error removing email %s: %s", email, e)
return False
def get_subscriber_count():
"""Get total number of subscribers"""
def get_subscriber_count() -> int:
with get_db_connection() as conn:
try:
with conn.cursor() as cursor:
cursor.execute("SELECT COUNT(*) FROM subscribers")
count = cursor.fetchone()[0]
return count
with conn.cursor() as cur:
cur.execute("SELECT COUNT(*) FROM subscribers")
(count,) = cur.fetchone()
return int(count)
except Exception as e:
logger.error(f"Error getting subscriber count: {e}")
return 0
# Cleanup function for graceful shutdown
import atexit
def fetch_all_newsletters(limit: int = 100):
"""
Return a list of newsletters ordered by sent_at desc.
Each item: dict(id, subject, body, sent_at)
"""
with get_db_connection() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, subject, body, sent_at
FROM newsletters
ORDER BY sent_at DESC
LIMIT %s
""",
(limit,),
)
rows = cur.fetchall()
return [
{"id": r[0], "subject": r[1], "body": r[2], "sent_at": r[3]}
for r in rows
]
def fetch_newsletter_by_id(newsletter_id: int):
with get_db_connection() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, subject, body, sent_at
FROM newsletters
WHERE id = %s
""",
(newsletter_id,),
)
row = cur.fetchone()
if not row:
return None
return {"id": row[0], "subject": row[1], "body": row[2], "sent_at": row[3]}
atexit.register(close_all_connections)