landing/database.py

407 lines
No EOL
12 KiB
Python

import os
import time
import logging
import threading
import atexit
from contextlib import contextmanager
import psycopg2
from psycopg2 import pool, IntegrityError
from psycopg2 import extensions
from dotenv import load_dotenv
load_dotenv()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
_connection_pool = None
_pool_lock = threading.Lock()
_pool_stats = {
"connections_checked_out": 0,
"connections_failed": 0,
"pool_recreated": 0,
}
class SimpleRobustPool:
"""
Thread-safe, robust connection pool wrapper for psycopg2
with connection health checks and automatic recovery.
"""
def __init__(self, minconn=3, maxconn=15, **kwargs):
self.minconn = minconn
self.maxconn = maxconn
self.kwargs = kwargs
self.pool = None
self._create_pool()
def _create_pool(self):
"""Create or recreate the underlying connection pool."""
try:
if self.pool:
try:
self.pool.closeall()
except Exception:
pass
self.pool = psycopg2.pool.ThreadedConnectionPool(
minconn=self.minconn,
maxconn=self.maxconn,
**self.kwargs,
)
_pool_stats["pool_recreated"] += 1
logger.info(
"DB pool created: min=%s max=%s host=%s db=%s",
self.minconn,
self.maxconn,
self.kwargs.get("host"),
self.kwargs.get("database"),
)
except Exception as e:
logger.error(f"Failed to create connection pool: {e}")
raise
def _healthy(self, conn) -> bool:
"""
Validate a connection with a real round-trip and
make sure it's not in an aborted transaction.
"""
try:
if conn.closed:
return False
txn_status = conn.get_transaction_status()
if txn_status == extensions.TRANSACTION_STATUS_INERROR:
try:
conn.rollback()
except Exception:
return False
if not conn.autocommit:
conn.autocommit = True
with conn.cursor() as cur:
cur.execute("SELECT 1")
cur.fetchone()
return True
except Exception:
return False
def getconn(self, retry_count=2):
"""
Get a connection, validating it before handing out.
On repeated failures, recreate the pool once.
"""
last_err = None
for attempt in range(retry_count + 1):
try:
conn = self.pool.getconn()
if not self._healthy(conn):
logger.warning("Discarding unhealthy connection")
try:
self.pool.putconn(conn, close=True)
except Exception:
pass
last_err = last_err or Exception("Unhealthy connection")
time.sleep(0.2)
continue
_pool_stats["connections_checked_out"] += 1
return conn
except Exception as e:
last_err = e
_pool_stats["connections_failed"] += 1
logger.error(
"Error getting connection (attempt %s/%s): %s",
attempt + 1,
retry_count + 1,
e,
)
time.sleep(0.3)
logger.warning("Recreating DB pool due to repeated failures")
self._create_pool()
try:
conn = self.pool.getconn()
if self._healthy(conn):
_pool_stats["connections_checked_out"] += 1
return conn
self.pool.putconn(conn, close=True)
except Exception as e:
last_err = e
raise last_err or Exception("Failed to get a healthy DB connection")
def putconn(self, conn, close=False):
"""Return a connection to the pool, closing it if it's bad."""
try:
bad = (
conn.closed
or conn.get_transaction_status()
== extensions.TRANSACTION_STATUS_INERROR
)
if bad:
try:
conn.rollback()
except Exception:
pass
close = True
self.pool.putconn(conn, close=close)
except Exception as e:
logger.error(f"Error returning connection to pool: {e}")
try:
conn.close()
except Exception:
pass
def closeall(self):
if self.pool:
try:
self.pool.closeall()
except Exception:
pass
def get_connection_pool():
"""
Initialize and return the global connection pool.
Includes TCP keepalives and statement_timeout to detect dead/stuck sessions.
"""
global _connection_pool
with _pool_lock:
if _connection_pool is None:
conn_kwargs = dict(
host=os.getenv("PG_HOST"),
port=int(os.getenv("PG_PORT", 5432)),
database=os.getenv("PG_DATABASE"),
user=os.getenv("PG_USER"),
password=os.getenv("PG_PASSWORD"),
connect_timeout=int(os.getenv("DB_CONNECT_TIMEOUT", 10)),
application_name="rideaware_app",
keepalives=1,
keepalives_idle=int(os.getenv("PG_KEEPALIVE_IDLE", 30)),
keepalives_interval=int(os.getenv("PG_KEEPALIVE_INTERVAL", 10)),
keepalives_count=int(os.getenv("PG_KEEPALIVE_COUNT", 3)),
options="-c statement_timeout={}".format(
int(os.getenv("PG_STATEMENT_TIMEOUT_MS", 5000))
),
)
_connection_pool = SimpleRobustPool(
minconn=int(os.getenv("DB_POOL_MIN", 3)),
maxconn=int(os.getenv("DB_POOL_MAX", 15)),
**conn_kwargs,
)
logger.info("Database connection pool initialized")
return _connection_pool
@contextmanager
def get_db_connection():
"""
Context manager that yields a healthy connection and
ensures proper cleanup/rollback on errors.
"""
conn = None
try:
pool = get_connection_pool()
conn = pool.getconn()
yield conn
except Exception as e:
logger.error(f"Database connection error: {e}")
if conn:
try:
conn.rollback()
except Exception:
pass
raise
finally:
if conn:
try:
pool = get_connection_pool()
pool.putconn(conn)
except Exception as e:
logger.error(f"Error returning connection: {e}")
def get_connection():
pool = get_connection_pool()
return pool.getconn()
def return_connection(conn):
try:
pool = get_connection_pool()
pool.putconn(conn)
except Exception as e:
logger.error(f"Error returning connection to pool: {e}")
def close_all_connections():
global _connection_pool
with _pool_lock:
if _connection_pool:
try:
_connection_pool.closeall()
logger.info("All DB connections closed")
except Exception as e:
logger.error(f"Error closing connections: {e}")
finally:
_connection_pool = None
def get_pool_stats():
return _pool_stats.copy()
def column_exists(cursor, table_name, column_name) -> bool:
cursor.execute(
"""
SELECT EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_name = %s AND column_name = %s
)
""",
(table_name, column_name),
)
return bool(cursor.fetchone()[0])
def init_db():
"""
Initialize tables and indexes idempotently.
"""
with get_db_connection() as conn:
try:
with conn.cursor() as cur:
cur.execute(
"""
CREATE TABLE IF NOT EXISTS subscribers (
id SERIAL PRIMARY KEY,
email TEXT UNIQUE NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
cur.execute(
"""
CREATE TABLE IF NOT EXISTS newsletters (
id SERIAL PRIMARY KEY,
subject TEXT NOT NULL,
body TEXT NOT NULL,
sent_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
cur.execute(
"CREATE INDEX IF NOT EXISTS idx_newsletters_sent_at ON newsletters(sent_at DESC)"
)
cur.execute(
"CREATE INDEX IF NOT EXISTS idx_subscribers_email ON subscribers(email)"
)
cur.execute(
"CREATE INDEX IF NOT EXISTS idx_subscribers_created_at ON subscribers(created_at DESC)"
)
conn.commit()
logger.info("DB schema ensured successfully")
except Exception as e:
logger.error(f"Error initializing DB: {e}")
conn.rollback()
raise
def add_email(email: str) -> bool:
"""
Insert a subscriber email.
Returns True if inserted, False if duplicate or error.
"""
with get_db_connection() as conn:
try:
with conn.cursor() as cur:
cur.execute(
"INSERT INTO subscribers (email) VALUES (%s)",
(email,),
)
conn.commit()
logger.info("Email added: %s", email)
return True
except IntegrityError:
conn.rollback()
logger.info("Email already exists: %s", email)
return False
except Exception as e:
conn.rollback()
logger.error("Error adding email %s: %s", email, e)
return False
def remove_email(email: str) -> bool:
"""
Delete a subscriber email.
Returns True if removed, False if not found or error.
"""
with get_db_connection() as conn:
try:
with conn.cursor() as cur:
cur.execute(
"DELETE FROM subscribers WHERE email = %s",
(email,),
)
removed = cur.rowcount > 0
conn.commit()
if removed:
logger.info("Email removed: %s", email)
else:
logger.info("Email not found for removal: %s", email)
return removed
except Exception as e:
conn.rollback()
logger.error("Error removing email %s: %s", email, e)
return False
def get_subscriber_count() -> int:
with get_db_connection() as conn:
try:
with conn.cursor() as cur:
cur.execute("SELECT COUNT(*) FROM subscribers")
(count,) = cur.fetchone()
return int(count)
except Exception as e:
logger.error(f"Error getting subscriber count: {e}")
return 0
def fetch_all_newsletters(limit: int = 100):
"""
Return a list of newsletters ordered by sent_at desc.
Each item: dict(id, subject, body, sent_at)
"""
with get_db_connection() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, subject, body, sent_at
FROM newsletters
ORDER BY sent_at DESC
LIMIT %s
""",
(limit,),
)
rows = cur.fetchall()
return [
{"id": r[0], "subject": r[1], "body": r[2], "sent_at": r[3]}
for r in rows
]
def fetch_newsletter_by_id(newsletter_id: int):
with get_db_connection() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT id, subject, body, sent_at
FROM newsletters
WHERE id = %s
""",
(newsletter_id,),
)
row = cur.fetchone()
if not row:
return None
return {"id": row[0], "subject": row[1], "body": row[2], "sent_at": row[3]}
atexit.register(close_all_connections)