diff --git a/database.py b/database.py index b90e0ef..e5a85fb 100644 --- a/database.py +++ b/database.py @@ -1,54 +1,53 @@ import os -import logging import psycopg2 -from psycopg2 import IntegrityError +from psycopg2 import IntegrityError, pool, OperationalError from dotenv import load_dotenv from werkzeug.security import generate_password_hash load_dotenv() -# Logging setup -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" -) -logger = logging.getLogger(__name__) +try: + DB_MIN_CONN = int(os.getenv("DB_MIN_CONN", 1)) + DB_MAX_CONN = int(os.getenv("DB_MAX_CONN", 10)) + + conn_pool = pool.ThreadedConnectionPool( + minconn=DB_MIN_CONN, + maxconn=DB_MAX_CONN, + host=os.getenv("PG_HOST"), + port=os.getenv("PG_PORT"), + dbname=os.getenv("PG_DATABASE"), + user=os.getenv("PG_USER"), + password=os.getenv("PG_PASSWORD"), + connect_timeout=10, + ) +except OperationalError: + raise +except Exception: + raise def get_connection(): - """Return a new connection to the PostgreSQL database.""" - try: - conn = psycopg2.connect( - host=os.getenv("PG_HOST"), - port=os.getenv("PG_PORT"), - dbname=os.getenv("PG_DATABASE"), - user=os.getenv("PG_USER"), - password=os.getenv("PG_PASSWORD"), - connect_timeout=10, - ) - return conn - except Exception as e: - logger.error(f"Database connection error: {e}") - raise + """Get a connection from the connection pool.""" + return conn_pool.getconn() def init_db(): - """Initialize the database tables.""" + """Initialize database tables with connection pool.""" conn = None + cursor = None try: conn = get_connection() cursor = conn.cursor() - # Create subscribers table (if not exists) cursor.execute( """ CREATE TABLE IF NOT EXISTS subscribers ( id SERIAL PRIMARY KEY, email TEXT UNIQUE NOT NULL ) - """ + """ ) - # Create admin_users table (if not exists) cursor.execute( """ CREATE TABLE IF NOT EXISTS admin_users ( @@ -56,94 +55,96 @@ def init_db(): username TEXT UNIQUE NOT NULL, password TEXT NOT NULL ) - """ + """ ) - # Newsletter storage cursor.execute( """ - CREATE TABLE IF NOT EXISTS newsletters ( - id SERIAL PRIMARY KEY, - subject TEXT NOT NULL, - body TEXT NOT NULL, - sent_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP - ) - """ + CREATE TABLE IF NOT EXISTS newsletters ( + id SERIAL PRIMARY KEY, + subject TEXT NOT NULL, + body TEXT NOT NULL, + sent_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP + ) + """ ) conn.commit() - logger.info("Database initialized successfully.") - except Exception as e: - logger.error(f"Database initialization error: {e}") + except Exception: if conn: - conn.rollback() # Rollback if there's an error - + conn.rollback() raise finally: - if conn: + if cursor: cursor.close() - conn.close() + if conn: + conn_pool.putconn(conn) def get_all_emails(): """Return a list of all subscriber emails.""" + conn = None + cursor = None try: conn = get_connection() cursor = conn.cursor() cursor.execute("SELECT email FROM subscribers") results = cursor.fetchall() - emails = [row[0] for row in results] - logger.debug(f"Retrieved emails: {emails}") - return emails - except Exception as e: - logger.error(f"Error retrieving emails: {e}") + return [row[0] for row in results] + except Exception: return [] finally: - if conn: + if cursor: cursor.close() - conn.close() + if conn: + conn_pool.putconn(conn) def add_email(email): """Insert an email into the subscribers table.""" conn = None + cursor = None try: conn = get_connection() cursor = conn.cursor() cursor.execute("INSERT INTO subscribers (email) VALUES (%s)", (email,)) conn.commit() - logger.info(f"Email {email} added successfully.") return True except IntegrityError: - logger.warning(f"Attempted to add duplicate email: {email}") + if conn: + conn.rollback() return False - except Exception as e: - logger.error(f"Error adding email {email}: {e}") + except Exception: + if conn: + conn.rollback() return False finally: - if conn: + if cursor: cursor.close() - conn.close() + if conn: + conn_pool.putconn(conn) def remove_email(email): """Remove an email from the subscribers table.""" conn = None + cursor = None try: conn = get_connection() cursor = conn.cursor() cursor.execute("DELETE FROM subscribers WHERE email = %s", (email,)) rowcount = cursor.rowcount conn.commit() - logger.info(f"Email {email} removed successfully.") return rowcount > 0 - except Exception as e: - logger.error(f"Error removing email {email}: {e}") + except Exception: + if conn: + conn.rollback() return False finally: - if conn: + if cursor: cursor.close() - conn.close() + if conn: + conn_pool.putconn(conn) def get_admin(username): @@ -151,6 +152,7 @@ def get_admin(username): Returns a tuple (username, password_hash) if found, otherwise None. """ conn = None + cursor = None try: conn = get_connection() cursor = conn.cursor() @@ -158,15 +160,14 @@ def get_admin(username): "SELECT username, password FROM admin_users WHERE username = %s", (username,), ) - result = cursor.fetchone() - return result # (username, password_hash) - except Exception as e: - logger.error(f"Error retrieving admin: {e}") + return cursor.fetchone() + except Exception: return None finally: - if conn: + if cursor: cursor.close() - conn.close() + if conn: + conn_pool.putconn(conn) def create_default_admin(): @@ -174,30 +175,60 @@ def create_default_admin(): default_username = os.getenv("ADMIN_USERNAME", "admin") default_password = os.getenv("ADMIN_PASSWORD", "changeme") hashed_password = generate_password_hash(default_password, method="pbkdf2:sha256") + conn = None + cursor = None try: conn = get_connection() cursor = conn.cursor() - # Check if the admin already exists cursor.execute( "SELECT id FROM admin_users WHERE username = %s", (default_username,) ) - if cursor.fetchone() is None: + exists = cursor.fetchone() + if exists is None: cursor.execute( "INSERT INTO admin_users (username, password) VALUES (%s, %s)", (default_username, hashed_password), ) conn.commit() - logger.info("Default admin created successfully") - else: - logger.info("Default admin already exists") - except Exception as e: - logger.error(f"Error creating default admin: {e}") + except Exception: if conn: conn.rollback() finally: - if conn: + if cursor: cursor.close() - conn.close() + if conn: + conn_pool.putconn(conn) + +def close_pool(): + """Close the database connection pool.""" + try: + conn_pool.closeall() + except Exception: + pass + + +class DatabaseContext: + """Optional context manager for manual transactions.""" + def __init__(self): + self.conn = None + self.cursor = None + + def __enter__(self): + self.conn = get_connection() + self.cursor = self.conn.cursor() + return self.cursor + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + if exc_type: + self.conn.rollback() + else: + self.conn.commit() + finally: + if self.cursor: + self.cursor.close() + if self.conn: + conn_pool.putconn(self.conn) \ No newline at end of file