import os import psycopg2 from psycopg2 import pool, IntegrityError, OperationalError from dotenv import load_dotenv import logging import threading import time from contextlib import contextmanager load_dotenv() # Global connection pool _connection_pool = None _pool_lock = threading.Lock() _pool_stats = { 'connections_created': 0, 'connections_failed': 0, 'pool_recreated': 0 } logger = logging.getLogger(__name__) class SimpleRobustPool: """Simplified robust connection pool""" def __init__(self, minconn=3, maxconn=15, **kwargs): self.minconn = minconn self.maxconn = maxconn self.kwargs = kwargs self.pool = None self._create_pool() def _create_pool(self): """Create or recreate the connection pool""" try: if self.pool: try: self.pool.closeall() except: pass self.pool = psycopg2.pool.ThreadedConnectionPool( minconn=self.minconn, maxconn=self.maxconn, **self.kwargs ) _pool_stats['pool_recreated'] += 1 logger.info(f"Connection pool created: {self.minconn}-{self.maxconn} connections") except Exception as e: logger.error(f"Failed to create connection pool: {e}") raise def _test_connection(self, conn): """Simple connection test without transaction conflicts""" try: if conn.closed: return False # Simple test that doesn't interfere with transactions conn.poll() return conn.status == psycopg2.extensions.STATUS_READY or conn.status == psycopg2.extensions.STATUS_BEGIN except Exception: return False def getconn(self, retry_count=2): """Get a connection with simplified retry logic""" for attempt in range(retry_count): try: conn = self.pool.getconn() # Simple connection test if not self._test_connection(conn): logger.warning("Got bad connection, discarding and retrying") try: self.pool.putconn(conn, close=True) except: pass continue _pool_stats['connections_created'] += 1 return conn except Exception as e: logger.error(f"Error getting connection (attempt {attempt + 1}): {e}") _pool_stats['connections_failed'] += 1 if attempt == retry_count - 1: # Last attempt failed, try to recreate pool logger.warning("Recreating connection pool due to failures") try: self._create_pool() conn = self.pool.getconn() if self._test_connection(conn): return conn except Exception as recreate_error: logger.error(f"Failed to recreate pool: {recreate_error}") raise # Wait before retry time.sleep(0.5) raise Exception("Failed to get connection after retries") def putconn(self, conn, close=False): """Return connection to pool""" try: # Check if connection should be closed if conn.closed or close: close = True self.pool.putconn(conn, close=close) except Exception as e: logger.error(f"Error returning connection to pool: {e}") try: conn.close() except: pass def closeall(self): """Close all connections""" if self.pool: self.pool.closeall() def get_connection_pool(): """Initialize and return the connection pool""" global _connection_pool with _pool_lock: if _connection_pool is None: try: _connection_pool = SimpleRobustPool( minconn=int(os.getenv('DB_POOL_MIN', 3)), maxconn=int(os.getenv('DB_POOL_MAX', 15)), host=os.getenv("PG_HOST"), port=int(os.getenv("PG_PORT", 5432)), database=os.getenv("PG_DATABASE"), user=os.getenv("PG_USER"), password=os.getenv("PG_PASSWORD"), connect_timeout=int(os.getenv('DB_CONNECT_TIMEOUT', 10)), application_name="rideaware_newsletter" ) logger.info("Database connection pool initialized successfully") except Exception as e: logger.error(f"Error creating connection pool: {e}") raise return _connection_pool @contextmanager def get_db_connection(): """Context manager for database connections""" conn = None try: pool = get_connection_pool() conn = pool.getconn() yield conn except Exception as e: logger.error(f"Database connection error: {e}") if conn: try: conn.rollback() except: pass raise finally: if conn: try: pool = get_connection_pool() pool.putconn(conn) except Exception as e: logger.error(f"Error returning connection: {e}") def get_connection(): """Get a connection from the pool (legacy interface)""" try: pool = get_connection_pool() return pool.getconn() except Exception as e: logger.error(f"Error getting connection from pool: {e}") raise def return_connection(conn): """Return a connection to the pool (legacy interface)""" try: pool = get_connection_pool() pool.putconn(conn) except Exception as e: logger.error(f"Error returning connection to pool: {e}") def close_all_connections(): """Close all connections in the pool""" global _connection_pool with _pool_lock: if _connection_pool: try: _connection_pool.closeall() logger.info("All database connections closed") except Exception as e: logger.error(f"Error closing connections: {e}") finally: _connection_pool = None def get_pool_stats(): """Get connection pool statistics""" return _pool_stats.copy() def column_exists(cursor, table_name, column_name): """Check if a column exists in a table""" cursor.execute(""" SELECT EXISTS ( SELECT 1 FROM information_schema.columns WHERE table_name = %s AND column_name = %s ) """, (table_name, column_name)) return cursor.fetchone()[0] def index_exists(cursor, index_name): """Check if an index exists""" cursor.execute(""" SELECT EXISTS ( SELECT 1 FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE c.relname = %s AND n.nspname = 'public' ) """, (index_name,)) return cursor.fetchone()[0] def init_db(): """Initialize database tables and indexes""" with get_db_connection() as conn: try: with conn.cursor() as cursor: # Create subscribers table cursor.execute(""" CREATE TABLE IF NOT EXISTS subscribers ( id SERIAL PRIMARY KEY, email TEXT UNIQUE NOT NULL ) """) # Add created_at column if it doesn't exist if not column_exists(cursor, 'subscribers', 'created_at'): cursor.execute(""" ALTER TABLE subscribers ADD COLUMN created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP """) logger.info("Added created_at column to subscribers table") # Create newsletters table cursor.execute(""" CREATE TABLE IF NOT EXISTS newsletters( id SERIAL PRIMARY KEY, subject TEXT NOT NULL, body TEXT NOT NULL, sent_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) # Create indexes only if they don't exist indexes = [ ("idx_newsletters_sent_at", "CREATE INDEX IF NOT EXISTS idx_newsletters_sent_at ON newsletters(sent_at DESC)"), ("idx_subscribers_email", "CREATE INDEX IF NOT EXISTS idx_subscribers_email ON subscribers(email)"), ("idx_subscribers_created_at", "CREATE INDEX IF NOT EXISTS idx_subscribers_created_at ON subscribers(created_at DESC)") ] for index_name, create_sql in indexes: cursor.execute(create_sql) logger.info(f"Ensured index {index_name} exists") conn.commit() logger.info("Database tables and indexes initialized successfully") except Exception as e: logger.error(f"Error initializing database: {e}") conn.rollback() raise def add_email(email): """Add email to subscribers with robust connection handling""" with get_db_connection() as conn: try: with conn.cursor() as cursor: cursor.execute("INSERT INTO subscribers (email) VALUES (%s)", (email,)) conn.commit() logger.info(f"Email added successfully: {email}") return True except IntegrityError: # Email already exists conn.rollback() logger.info(f"Email already exists: {email}") return False except Exception as e: conn.rollback() logger.error(f"Error adding email {email}: {e}") return False def remove_email(email): """Remove email from subscribers with robust connection handling""" with get_db_connection() as conn: try: with conn.cursor() as cursor: cursor.execute("DELETE FROM subscribers WHERE email = %s", (email,)) conn.commit() rows_affected = cursor.rowcount if rows_affected > 0: logger.info(f"Email removed successfully: {email}") return True else: logger.info(f"Email not found for removal: {email}") return False except Exception as e: conn.rollback() logger.error(f"Error removing email {email}: {e}") return False def get_subscriber_count(): """Get total number of subscribers""" with get_db_connection() as conn: try: with conn.cursor() as cursor: cursor.execute("SELECT COUNT(*) FROM subscribers") count = cursor.fetchone()[0] return count except Exception as e: logger.error(f"Error getting subscriber count: {e}") return 0 # Cleanup function for graceful shutdown import atexit atexit.register(close_all_connections)