landing/database.py

341 lines
No EOL
12 KiB
Python

import os
import psycopg2
from psycopg2 import pool, IntegrityError, OperationalError
from dotenv import load_dotenv
import logging
import threading
import time
from contextlib import contextmanager
load_dotenv()
# Global connection pool
_connection_pool = None
_pool_lock = threading.Lock()
_pool_stats = {
'connections_created': 0,
'connections_failed': 0,
'pool_recreated': 0
}
logger = logging.getLogger(__name__)
class SimpleRobustPool:
"""Simplified robust connection pool"""
def __init__(self, minconn=3, maxconn=15, **kwargs):
self.minconn = minconn
self.maxconn = maxconn
self.kwargs = kwargs
self.pool = None
self._create_pool()
def _create_pool(self):
"""Create or recreate the connection pool"""
try:
if self.pool:
try:
self.pool.closeall()
except:
pass
self.pool = psycopg2.pool.ThreadedConnectionPool(
minconn=self.minconn,
maxconn=self.maxconn,
**self.kwargs
)
_pool_stats['pool_recreated'] += 1
logger.info(f"Connection pool created: {self.minconn}-{self.maxconn} connections")
except Exception as e:
logger.error(f"Failed to create connection pool: {e}")
raise
def _test_connection(self, conn):
"""Simple connection test without transaction conflicts"""
try:
if conn.closed:
return False
# Simple test that doesn't interfere with transactions
conn.poll()
return conn.status == psycopg2.extensions.STATUS_READY or conn.status == psycopg2.extensions.STATUS_BEGIN
except Exception:
return False
def getconn(self, retry_count=2):
"""Get a connection with simplified retry logic"""
for attempt in range(retry_count):
try:
conn = self.pool.getconn()
# Simple connection test
if not self._test_connection(conn):
logger.warning("Got bad connection, discarding and retrying")
try:
self.pool.putconn(conn, close=True)
except:
pass
continue
_pool_stats['connections_created'] += 1
return conn
except Exception as e:
logger.error(f"Error getting connection (attempt {attempt + 1}): {e}")
_pool_stats['connections_failed'] += 1
if attempt == retry_count - 1:
# Last attempt failed, try to recreate pool
logger.warning("Recreating connection pool due to failures")
try:
self._create_pool()
conn = self.pool.getconn()
if self._test_connection(conn):
return conn
except Exception as recreate_error:
logger.error(f"Failed to recreate pool: {recreate_error}")
raise
# Wait before retry
time.sleep(0.5)
raise Exception("Failed to get connection after retries")
def putconn(self, conn, close=False):
"""Return connection to pool"""
try:
# Check if connection should be closed
if conn.closed or close:
close = True
self.pool.putconn(conn, close=close)
except Exception as e:
logger.error(f"Error returning connection to pool: {e}")
try:
conn.close()
except:
pass
def closeall(self):
"""Close all connections"""
if self.pool:
self.pool.closeall()
def get_connection_pool():
"""Initialize and return the connection pool"""
global _connection_pool
with _pool_lock:
if _connection_pool is None:
try:
_connection_pool = SimpleRobustPool(
minconn=int(os.getenv('DB_POOL_MIN', 3)),
maxconn=int(os.getenv('DB_POOL_MAX', 15)),
host=os.getenv("PG_HOST"),
port=int(os.getenv("PG_PORT", 5432)),
database=os.getenv("PG_DATABASE"),
user=os.getenv("PG_USER"),
password=os.getenv("PG_PASSWORD"),
connect_timeout=int(os.getenv('DB_CONNECT_TIMEOUT', 10)),
application_name="rideaware_newsletter"
)
logger.info("Database connection pool initialized successfully")
except Exception as e:
logger.error(f"Error creating connection pool: {e}")
raise
return _connection_pool
@contextmanager
def get_db_connection():
"""Context manager for database connections"""
conn = None
try:
pool = get_connection_pool()
conn = pool.getconn()
yield conn
except Exception as e:
logger.error(f"Database connection error: {e}")
if conn:
try:
conn.rollback()
except:
pass
raise
finally:
if conn:
try:
pool = get_connection_pool()
pool.putconn(conn)
except Exception as e:
logger.error(f"Error returning connection: {e}")
def get_connection():
"""Get a connection from the pool (legacy interface)"""
try:
pool = get_connection_pool()
return pool.getconn()
except Exception as e:
logger.error(f"Error getting connection from pool: {e}")
raise
def return_connection(conn):
"""Return a connection to the pool (legacy interface)"""
try:
pool = get_connection_pool()
pool.putconn(conn)
except Exception as e:
logger.error(f"Error returning connection to pool: {e}")
def close_all_connections():
"""Close all connections in the pool"""
global _connection_pool
with _pool_lock:
if _connection_pool:
try:
_connection_pool.closeall()
logger.info("All database connections closed")
except Exception as e:
logger.error(f"Error closing connections: {e}")
finally:
_connection_pool = None
def get_pool_stats():
"""Get connection pool statistics"""
return _pool_stats.copy()
def column_exists(cursor, table_name, column_name):
"""Check if a column exists in a table"""
cursor.execute("""
SELECT EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_name = %s AND column_name = %s
)
""", (table_name, column_name))
return cursor.fetchone()[0]
def index_exists(cursor, index_name):
"""Check if an index exists"""
cursor.execute("""
SELECT EXISTS (
SELECT 1 FROM pg_class c
JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE c.relname = %s AND n.nspname = 'public'
)
""", (index_name,))
return cursor.fetchone()[0]
def init_db():
"""Initialize database tables and indexes"""
with get_db_connection() as conn:
try:
with conn.cursor() as cursor:
# Create subscribers table
cursor.execute("""
CREATE TABLE IF NOT EXISTS subscribers (
id SERIAL PRIMARY KEY,
email TEXT UNIQUE NOT NULL
)
""")
# Add created_at column if it doesn't exist
if not column_exists(cursor, 'subscribers', 'created_at'):
cursor.execute("""
ALTER TABLE subscribers
ADD COLUMN created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
""")
logger.info("Added created_at column to subscribers table")
# Create newsletters table
cursor.execute("""
CREATE TABLE IF NOT EXISTS newsletters(
id SERIAL PRIMARY KEY,
subject TEXT NOT NULL,
body TEXT NOT NULL,
sent_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Create indexes only if they don't exist
indexes = [
("idx_newsletters_sent_at", "CREATE INDEX IF NOT EXISTS idx_newsletters_sent_at ON newsletters(sent_at DESC)"),
("idx_subscribers_email", "CREATE INDEX IF NOT EXISTS idx_subscribers_email ON subscribers(email)"),
("idx_subscribers_created_at", "CREATE INDEX IF NOT EXISTS idx_subscribers_created_at ON subscribers(created_at DESC)")
]
for index_name, create_sql in indexes:
cursor.execute(create_sql)
logger.info(f"Ensured index {index_name} exists")
conn.commit()
logger.info("Database tables and indexes initialized successfully")
except Exception as e:
logger.error(f"Error initializing database: {e}")
conn.rollback()
raise
def add_email(email):
"""Add email to subscribers with robust connection handling"""
with get_db_connection() as conn:
try:
with conn.cursor() as cursor:
cursor.execute("INSERT INTO subscribers (email) VALUES (%s)", (email,))
conn.commit()
logger.info(f"Email added successfully: {email}")
return True
except IntegrityError:
# Email already exists
conn.rollback()
logger.info(f"Email already exists: {email}")
return False
except Exception as e:
conn.rollback()
logger.error(f"Error adding email {email}: {e}")
return False
def remove_email(email):
"""Remove email from subscribers with robust connection handling"""
with get_db_connection() as conn:
try:
with conn.cursor() as cursor:
cursor.execute("DELETE FROM subscribers WHERE email = %s", (email,))
conn.commit()
rows_affected = cursor.rowcount
if rows_affected > 0:
logger.info(f"Email removed successfully: {email}")
return True
else:
logger.info(f"Email not found for removal: {email}")
return False
except Exception as e:
conn.rollback()
logger.error(f"Error removing email {email}: {e}")
return False
def get_subscriber_count():
"""Get total number of subscribers"""
with get_db_connection() as conn:
try:
with conn.cursor() as cursor:
cursor.execute("SELECT COUNT(*) FROM subscribers")
count = cursor.fetchone()[0]
return count
except Exception as e:
logger.error(f"Error getting subscriber count: {e}")
return 0
# Cleanup function for graceful shutdown
import atexit
atexit.register(close_all_connections)