From 96f4243713cec0166ec1576eeb2d693ab50c9982 Mon Sep 17 00:00:00 2001 From: Cipher Vance Date: Sun, 31 Aug 2025 09:48:10 -0500 Subject: [PATCH] feat: enhance database reliability, add rate limiting, and improve email compatibility --- database.py | 427 +++++++++++++++++++----------- requirements.txt | 3 +- server.py | 51 +++- templates/confirmation_email.html | 308 +++++++++------------ 4 files changed, 441 insertions(+), 348 deletions(-) diff --git a/database.py b/database.py index d8059fe..5a3454a 100644 --- a/database.py +++ b/database.py @@ -1,64 +1,213 @@ import os import psycopg2 -from psycopg2 import pool, IntegrityError +from psycopg2 import pool, IntegrityError, OperationalError from dotenv import load_dotenv import logging +import threading +import time +from contextlib import contextmanager load_dotenv() # Global connection pool _connection_pool = None +_pool_lock = threading.Lock() +_pool_stats = { + 'connections_created': 0, + 'connections_failed': 0, + 'pool_recreated': 0 +} + +logger = logging.getLogger(__name__) + +class SimpleRobustPool: + """Simplified robust connection pool""" + + def __init__(self, minconn=3, maxconn=15, **kwargs): + self.minconn = minconn + self.maxconn = maxconn + self.kwargs = kwargs + self.pool = None + self._create_pool() + + def _create_pool(self): + """Create or recreate the connection pool""" + try: + if self.pool: + try: + self.pool.closeall() + except: + pass + + self.pool = psycopg2.pool.ThreadedConnectionPool( + minconn=self.minconn, + maxconn=self.maxconn, + **self.kwargs + ) + _pool_stats['pool_recreated'] += 1 + logger.info(f"Connection pool created: {self.minconn}-{self.maxconn} connections") + + except Exception as e: + logger.error(f"Failed to create connection pool: {e}") + raise + + def _test_connection(self, conn): + """Simple connection test without transaction conflicts""" + try: + if conn.closed: + return False + + # Simple test that doesn't interfere with transactions + conn.poll() + return conn.status == psycopg2.extensions.STATUS_READY or conn.status == psycopg2.extensions.STATUS_BEGIN + + except Exception: + return False + + def getconn(self, retry_count=2): + """Get a connection with simplified retry logic""" + for attempt in range(retry_count): + try: + conn = self.pool.getconn() + + # Simple connection test + if not self._test_connection(conn): + logger.warning("Got bad connection, discarding and retrying") + try: + self.pool.putconn(conn, close=True) + except: + pass + continue + + _pool_stats['connections_created'] += 1 + return conn + + except Exception as e: + logger.error(f"Error getting connection (attempt {attempt + 1}): {e}") + _pool_stats['connections_failed'] += 1 + + if attempt == retry_count - 1: + # Last attempt failed, try to recreate pool + logger.warning("Recreating connection pool due to failures") + try: + self._create_pool() + conn = self.pool.getconn() + if self._test_connection(conn): + return conn + except Exception as recreate_error: + logger.error(f"Failed to recreate pool: {recreate_error}") + raise + + # Wait before retry + time.sleep(0.5) + + raise Exception("Failed to get connection after retries") + + def putconn(self, conn, close=False): + """Return connection to pool""" + try: + # Check if connection should be closed + if conn.closed or close: + close = True + + self.pool.putconn(conn, close=close) + + except Exception as e: + logger.error(f"Error returning connection to pool: {e}") + try: + conn.close() + except: + pass + + def closeall(self): + """Close all connections""" + if self.pool: + self.pool.closeall() def get_connection_pool(): """Initialize and return the connection pool""" global _connection_pool - if _connection_pool is None: - try: - _connection_pool = psycopg2.pool.ThreadedConnectionPool( - minconn=2, - maxconn=20, - host=os.getenv("PG_HOST"), - port=os.getenv("PG_PORT", 5432), - dbname=os.getenv("PG_DATABASE"), - user=os.getenv("PG_USER"), - password=os.getenv("PG_PASSWORD"), - connect_timeout=5 - ) - logging.info("Database connection pool created successfully") - except Exception as e: - logging.error(f"Error creating connection pool: {e}") - raise - return _connection_pool + with _pool_lock: + if _connection_pool is None: + try: + _connection_pool = SimpleRobustPool( + minconn=int(os.getenv('DB_POOL_MIN', 3)), + maxconn=int(os.getenv('DB_POOL_MAX', 15)), + host=os.getenv("PG_HOST"), + port=int(os.getenv("PG_PORT", 5432)), + database=os.getenv("PG_DATABASE"), + user=os.getenv("PG_USER"), + password=os.getenv("PG_PASSWORD"), + connect_timeout=int(os.getenv('DB_CONNECT_TIMEOUT', 10)), + application_name="rideaware_newsletter" + ) + logger.info("Database connection pool initialized successfully") + + except Exception as e: + logger.error(f"Error creating connection pool: {e}") + raise + + return _connection_pool -def get_connection(): - """Get a connection from the pool""" +@contextmanager +def get_db_connection(): + """Context manager for database connections""" + conn = None try: pool = get_connection_pool() conn = pool.getconn() - if conn.closed: - # Connection is closed, remove it and get a new one - pool.putconn(conn, close=True) - conn = pool.getconn() - return conn + yield conn + except Exception as e: - logging.error(f"Error getting connection from pool: {e}") + logger.error(f"Database connection error: {e}") + if conn: + try: + conn.rollback() + except: + pass + raise + + finally: + if conn: + try: + pool = get_connection_pool() + pool.putconn(conn) + except Exception as e: + logger.error(f"Error returning connection: {e}") + +def get_connection(): + """Get a connection from the pool (legacy interface)""" + try: + pool = get_connection_pool() + return pool.getconn() + except Exception as e: + logger.error(f"Error getting connection from pool: {e}") raise def return_connection(conn): - """Return a connection to the pool""" + """Return a connection to the pool (legacy interface)""" try: pool = get_connection_pool() pool.putconn(conn) except Exception as e: - logging.error(f"Error returning connection to pool: {e}") + logger.error(f"Error returning connection to pool: {e}") def close_all_connections(): """Close all connections in the pool""" global _connection_pool - if _connection_pool: - _connection_pool.closeall() - _connection_pool = None - logging.info("All database connections closed") + with _pool_lock: + if _connection_pool: + try: + _connection_pool.closeall() + logger.info("All database connections closed") + except Exception as e: + logger.error(f"Error closing connections: {e}") + finally: + _connection_pool = None + +def get_pool_stats(): + """Get connection pool statistics""" + return _pool_stats.copy() def column_exists(cursor, table_name, column_name): """Check if a column exists in a table""" @@ -84,138 +233,108 @@ def index_exists(cursor, index_name): def init_db(): """Initialize database tables and indexes""" - conn = None - try: - conn = get_connection() - cursor = conn.cursor() - - # Create subscribers table - cursor.execute(""" - CREATE TABLE IF NOT EXISTS subscribers ( - id SERIAL PRIMARY KEY, - email TEXT UNIQUE NOT NULL - ) - """) - - # Add created_at column if it doesn't exist - if not column_exists(cursor, 'subscribers', 'created_at'): - cursor.execute(""" - ALTER TABLE subscribers - ADD COLUMN created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - """) - logging.info("Added created_at column to subscribers table") - - # Create newsletters table - cursor.execute(""" - CREATE TABLE IF NOT EXISTS newsletters( - id SERIAL PRIMARY KEY, - subject TEXT NOT NULL, - body TEXT NOT NULL, - sent_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) - """) - - # Create indexes only if they don't exist - if not index_exists(cursor, 'idx_newsletters_sent_at'): - cursor.execute("CREATE INDEX idx_newsletters_sent_at ON newsletters(sent_at DESC)") - logging.info("Created index idx_newsletters_sent_at") - - if not index_exists(cursor, 'idx_subscribers_email'): - cursor.execute("CREATE INDEX idx_subscribers_email ON subscribers(email)") - logging.info("Created index idx_subscribers_email") - - if not index_exists(cursor, 'idx_subscribers_created_at'): - cursor.execute("CREATE INDEX idx_subscribers_created_at ON subscribers(created_at DESC)") - logging.info("Created index idx_subscribers_created_at") - - conn.commit() - cursor.close() - logging.info("Database tables and indexes initialized successfully") - - except Exception as e: - logging.error(f"Error initializing database: {e}") - if conn: + with get_db_connection() as conn: + try: + with conn.cursor() as cursor: + # Create subscribers table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS subscribers ( + id SERIAL PRIMARY KEY, + email TEXT UNIQUE NOT NULL + ) + """) + + # Add created_at column if it doesn't exist + if not column_exists(cursor, 'subscribers', 'created_at'): + cursor.execute(""" + ALTER TABLE subscribers + ADD COLUMN created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + """) + logger.info("Added created_at column to subscribers table") + + # Create newsletters table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS newsletters( + id SERIAL PRIMARY KEY, + subject TEXT NOT NULL, + body TEXT NOT NULL, + sent_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Create indexes only if they don't exist + indexes = [ + ("idx_newsletters_sent_at", "CREATE INDEX IF NOT EXISTS idx_newsletters_sent_at ON newsletters(sent_at DESC)"), + ("idx_subscribers_email", "CREATE INDEX IF NOT EXISTS idx_subscribers_email ON subscribers(email)"), + ("idx_subscribers_created_at", "CREATE INDEX IF NOT EXISTS idx_subscribers_created_at ON subscribers(created_at DESC)") + ] + + for index_name, create_sql in indexes: + cursor.execute(create_sql) + logger.info(f"Ensured index {index_name} exists") + + conn.commit() + logger.info("Database tables and indexes initialized successfully") + + except Exception as e: + logger.error(f"Error initializing database: {e}") conn.rollback() - raise - finally: - if conn: - return_connection(conn) + raise def add_email(email): - """Add email to subscribers with connection pooling""" - conn = None - try: - conn = get_connection() - cursor = conn.cursor() - cursor.execute("INSERT INTO subscribers (email) VALUES (%s)", (email,)) - conn.commit() - cursor.close() - logging.info(f"Email added successfully: {email}") - return True - - except IntegrityError: - # Email already exists - if conn: + """Add email to subscribers with robust connection handling""" + with get_db_connection() as conn: + try: + with conn.cursor() as cursor: + cursor.execute("INSERT INTO subscribers (email) VALUES (%s)", (email,)) + conn.commit() + logger.info(f"Email added successfully: {email}") + return True + + except IntegrityError: + # Email already exists conn.rollback() - logging.info(f"Email already exists: {email}") - return False - - except Exception as e: - if conn: - conn.rollback() - logging.error(f"Error adding email {email}: {e}") - return False - - finally: - if conn: - return_connection(conn) - -def remove_email(email): - """Remove email from subscribers with connection pooling""" - conn = None - try: - conn = get_connection() - cursor = conn.cursor() - cursor.execute("DELETE FROM subscribers WHERE email = %s", (email,)) - conn.commit() - rows_affected = cursor.rowcount - cursor.close() - - if rows_affected > 0: - logging.info(f"Email removed successfully: {email}") - return True - else: - logging.info(f"Email not found for removal: {email}") + logger.info(f"Email already exists: {email}") return False - except Exception as e: - if conn: + except Exception as e: conn.rollback() - logging.error(f"Error removing email {email}: {e}") - return False - - finally: - if conn: - return_connection(conn) + logger.error(f"Error adding email {email}: {e}") + return False + +def remove_email(email): + """Remove email from subscribers with robust connection handling""" + with get_db_connection() as conn: + try: + with conn.cursor() as cursor: + cursor.execute("DELETE FROM subscribers WHERE email = %s", (email,)) + conn.commit() + rows_affected = cursor.rowcount + + if rows_affected > 0: + logger.info(f"Email removed successfully: {email}") + return True + else: + logger.info(f"Email not found for removal: {email}") + return False + + except Exception as e: + conn.rollback() + logger.error(f"Error removing email {email}: {e}") + return False def get_subscriber_count(): """Get total number of subscribers""" - conn = None - try: - conn = get_connection() - cursor = conn.cursor() - cursor.execute("SELECT COUNT(*) FROM subscribers") - count = cursor.fetchone()[0] - cursor.close() - return count - - except Exception as e: - logging.error(f"Error getting subscriber count: {e}") - return 0 - - finally: - if conn: - return_connection(conn) + with get_db_connection() as conn: + try: + with conn.cursor() as cursor: + cursor.execute("SELECT COUNT(*) FROM subscribers") + count = cursor.fetchone()[0] + return count + + except Exception as e: + logger.error(f"Error getting subscriber count: {e}") + return 0 # Cleanup function for graceful shutdown import atexit diff --git a/requirements.txt b/requirements.txt index 68767e9..16fa5b8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ gunicorn -flask python-dotenv psycopg2-binary +Flask +Flask-Limiter \ No newline at end of file diff --git a/server.py b/server.py index d1a3dd0..980c24d 100644 --- a/server.py +++ b/server.py @@ -5,6 +5,8 @@ from threading import Thread import smtplib from email.mime.text import MIMEText from flask import Flask, render_template, request, jsonify, g +from flask_limiter import Limiter +from flask_limiter.util import get_remote_address from dotenv import load_dotenv from database import init_db, get_connection, return_connection, add_email, remove_email @@ -17,8 +19,19 @@ SMTP_PASSWORD = os.getenv('SMTP_PASSWORD') app = Flask(__name__) +# Rate limiting setup +limiter = Limiter( + key_func=get_remote_address, + app=app, + default_limits=["1000 per hour", "100 per minute"], + storage_uri="memory://" +) + # Configure logging -logging.basicConfig(level=logging.INFO) +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s %(levelname)s [%(name)s] %(message)s' +) # Cache configuration _newsletter_cache = {} @@ -132,24 +145,20 @@ def after_request(response): return response -def send_confirmation_email(to_address: str, unsubscribe_link: str): +def send_confirmation_email(to_address: str, html_body: str): """ Sends the HTML confirmation email to `to_address`. - This runs inside its own SMTP_SSL connection with reduced timeout. + html_body is pre-rendered to avoid Flask context issues. """ try: subject = "Thanks for subscribing!" - html_body = render_template( - "confirmation_email.html", - unsubscribe_link=unsubscribe_link - ) msg = MIMEText(html_body, "html", "utf-8") msg["Subject"] = subject msg["From"] = SMTP_USER msg["To"] = to_address - with smtplib.SMTP_SSL(SMTP_SERVER, SMTP_PORT, timeout=5) as server: + with smtplib.SMTP_SSL(SMTP_SERVER, SMTP_PORT, timeout=10) as server: server.login(SMTP_USER, SMTP_PASSWORD) server.sendmail(SMTP_USER, [to_address], msg.as_string()) @@ -158,11 +167,11 @@ def send_confirmation_email(to_address: str, unsubscribe_link: str): except Exception as e: app.logger.error(f"Failed to send email to {to_address}: {e}") -def send_confirmation_async(email, unsubscribe_link): +def send_confirmation_async(email, html_body): """ Wrapper for threading.Thread target. """ - send_confirmation_email(email, unsubscribe_link) + send_confirmation_email(email, html_body) @app.route("/", methods=["GET"]) def index(): @@ -170,6 +179,7 @@ def index(): return render_template("index.html") @app.route("/subscribe", methods=["POST"]) +@limiter.limit("5 per minute") # Strict rate limit for subscriptions def subscribe(): """Subscribe endpoint with optimized database handling""" data = request.get_json() or {} @@ -184,12 +194,17 @@ def subscribe(): try: if add_email(email): + # Render the template in the main thread (with Flask context) unsubscribe_link = f"{request.url_root}unsubscribe?email={email}" + html_body = render_template( + "confirmation_email.html", + unsubscribe_link=unsubscribe_link + ) - # Start email sending in background thread + # Start email sending in background thread with pre-rendered HTML Thread( target=send_confirmation_async, - args=(email, unsubscribe_link), + args=(email, html_body), daemon=True ).start() @@ -202,6 +217,7 @@ def subscribe(): return jsonify(error="Internal server error"), 500 @app.route("/unsubscribe", methods=["GET"]) +@limiter.limit("10 per minute") def unsubscribe(): """Unsubscribe endpoint with optimized database handling""" email = request.args.get("email") @@ -220,6 +236,7 @@ def unsubscribe(): return "Internal server error", 500 @app.route("/newsletters", methods=["GET"]) +@limiter.limit("30 per minute") def newsletters(): """ List all newsletters (newest first) with caching for better performance. @@ -232,6 +249,7 @@ def newsletters(): return "Internal server error", 500 @app.route("/newsletter/", methods=["GET"]) +@limiter.limit("60 per minute") def newsletter_detail(newsletter_id): """ Show a single newsletter by its ID with caching. @@ -248,6 +266,7 @@ def newsletter_detail(newsletter_id): return "Internal server error", 500 @app.route("/admin/clear-cache", methods=["POST"]) +@limiter.limit("5 per minute") def clear_cache(): """Admin endpoint to clear newsletter cache""" try: @@ -258,10 +277,16 @@ def clear_cache(): return jsonify(error="Failed to clear cache"), 500 @app.route("/health", methods=["GET"]) +@limiter.limit("120 per minute") def health_check(): """Health check endpoint for monitoring""" return jsonify(status="healthy", timestamp=time.time()), 200 +# Rate limit error handler +@app.errorhandler(429) +def ratelimit_handler(e): + return jsonify(error="Rate limit exceeded. Please try again later."), 429 + # Error handlers @app.errorhandler(404) def not_found(error): @@ -280,4 +305,4 @@ except Exception as e: raise if __name__ == "__main__": - app.run(host="0.0.0.0", debug=True) + app.run(host="0.0.0.0", debug=True) \ No newline at end of file diff --git a/templates/confirmation_email.html b/templates/confirmation_email.html index 823cc32..8a42ba5 100644 --- a/templates/confirmation_email.html +++ b/templates/confirmation_email.html @@ -13,6 +13,7 @@ + -