refactor(db): switch to connection pool, remove logging, and tidy helpers

This commit is contained in:
Cipher Vance 2025-08-31 12:11:31 -05:00
parent a8589b659f
commit d9c86aa1bb

View file

@ -1,54 +1,53 @@
import os
import logging
import psycopg2
from psycopg2 import IntegrityError
from psycopg2 import IntegrityError, pool, OperationalError
from dotenv import load_dotenv
from werkzeug.security import generate_password_hash
load_dotenv()
# Logging setup
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
try:
DB_MIN_CONN = int(os.getenv("DB_MIN_CONN", 1))
DB_MAX_CONN = int(os.getenv("DB_MAX_CONN", 10))
conn_pool = pool.ThreadedConnectionPool(
minconn=DB_MIN_CONN,
maxconn=DB_MAX_CONN,
host=os.getenv("PG_HOST"),
port=os.getenv("PG_PORT"),
dbname=os.getenv("PG_DATABASE"),
user=os.getenv("PG_USER"),
password=os.getenv("PG_PASSWORD"),
connect_timeout=10,
)
except OperationalError:
raise
except Exception:
raise
def get_connection():
"""Return a new connection to the PostgreSQL database."""
try:
conn = psycopg2.connect(
host=os.getenv("PG_HOST"),
port=os.getenv("PG_PORT"),
dbname=os.getenv("PG_DATABASE"),
user=os.getenv("PG_USER"),
password=os.getenv("PG_PASSWORD"),
connect_timeout=10,
)
return conn
except Exception as e:
logger.error(f"Database connection error: {e}")
raise
"""Get a connection from the connection pool."""
return conn_pool.getconn()
def init_db():
"""Initialize the database tables."""
"""Initialize database tables with connection pool."""
conn = None
cursor = None
try:
conn = get_connection()
cursor = conn.cursor()
# Create subscribers table (if not exists)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS subscribers (
id SERIAL PRIMARY KEY,
email TEXT UNIQUE NOT NULL
)
"""
"""
)
# Create admin_users table (if not exists)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS admin_users (
@ -56,94 +55,96 @@ def init_db():
username TEXT UNIQUE NOT NULL,
password TEXT NOT NULL
)
"""
"""
)
# Newsletter storage
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS newsletters (
id SERIAL PRIMARY KEY,
subject TEXT NOT NULL,
body TEXT NOT NULL,
sent_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
)
"""
CREATE TABLE IF NOT EXISTS newsletters (
id SERIAL PRIMARY KEY,
subject TEXT NOT NULL,
body TEXT NOT NULL,
sent_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
)
"""
)
conn.commit()
logger.info("Database initialized successfully.")
except Exception as e:
logger.error(f"Database initialization error: {e}")
except Exception:
if conn:
conn.rollback() # Rollback if there's an error
conn.rollback()
raise
finally:
if conn:
if cursor:
cursor.close()
conn.close()
if conn:
conn_pool.putconn(conn)
def get_all_emails():
"""Return a list of all subscriber emails."""
conn = None
cursor = None
try:
conn = get_connection()
cursor = conn.cursor()
cursor.execute("SELECT email FROM subscribers")
results = cursor.fetchall()
emails = [row[0] for row in results]
logger.debug(f"Retrieved emails: {emails}")
return emails
except Exception as e:
logger.error(f"Error retrieving emails: {e}")
return [row[0] for row in results]
except Exception:
return []
finally:
if conn:
if cursor:
cursor.close()
conn.close()
if conn:
conn_pool.putconn(conn)
def add_email(email):
"""Insert an email into the subscribers table."""
conn = None
cursor = None
try:
conn = get_connection()
cursor = conn.cursor()
cursor.execute("INSERT INTO subscribers (email) VALUES (%s)", (email,))
conn.commit()
logger.info(f"Email {email} added successfully.")
return True
except IntegrityError:
logger.warning(f"Attempted to add duplicate email: {email}")
if conn:
conn.rollback()
return False
except Exception as e:
logger.error(f"Error adding email {email}: {e}")
except Exception:
if conn:
conn.rollback()
return False
finally:
if conn:
if cursor:
cursor.close()
conn.close()
if conn:
conn_pool.putconn(conn)
def remove_email(email):
"""Remove an email from the subscribers table."""
conn = None
cursor = None
try:
conn = get_connection()
cursor = conn.cursor()
cursor.execute("DELETE FROM subscribers WHERE email = %s", (email,))
rowcount = cursor.rowcount
conn.commit()
logger.info(f"Email {email} removed successfully.")
return rowcount > 0
except Exception as e:
logger.error(f"Error removing email {email}: {e}")
except Exception:
if conn:
conn.rollback()
return False
finally:
if conn:
if cursor:
cursor.close()
conn.close()
if conn:
conn_pool.putconn(conn)
def get_admin(username):
@ -151,6 +152,7 @@ def get_admin(username):
Returns a tuple (username, password_hash) if found, otherwise None.
"""
conn = None
cursor = None
try:
conn = get_connection()
cursor = conn.cursor()
@ -158,15 +160,14 @@ def get_admin(username):
"SELECT username, password FROM admin_users WHERE username = %s",
(username,),
)
result = cursor.fetchone()
return result # (username, password_hash)
except Exception as e:
logger.error(f"Error retrieving admin: {e}")
return cursor.fetchone()
except Exception:
return None
finally:
if conn:
if cursor:
cursor.close()
conn.close()
if conn:
conn_pool.putconn(conn)
def create_default_admin():
@ -174,30 +175,60 @@ def create_default_admin():
default_username = os.getenv("ADMIN_USERNAME", "admin")
default_password = os.getenv("ADMIN_PASSWORD", "changeme")
hashed_password = generate_password_hash(default_password, method="pbkdf2:sha256")
conn = None
cursor = None
try:
conn = get_connection()
cursor = conn.cursor()
# Check if the admin already exists
cursor.execute(
"SELECT id FROM admin_users WHERE username = %s", (default_username,)
)
if cursor.fetchone() is None:
exists = cursor.fetchone()
if exists is None:
cursor.execute(
"INSERT INTO admin_users (username, password) VALUES (%s, %s)",
(default_username, hashed_password),
)
conn.commit()
logger.info("Default admin created successfully")
else:
logger.info("Default admin already exists")
except Exception as e:
logger.error(f"Error creating default admin: {e}")
except Exception:
if conn:
conn.rollback()
finally:
if conn:
if cursor:
cursor.close()
conn.close()
if conn:
conn_pool.putconn(conn)
def close_pool():
"""Close the database connection pool."""
try:
conn_pool.closeall()
except Exception:
pass
class DatabaseContext:
"""Optional context manager for manual transactions."""
def __init__(self):
self.conn = None
self.cursor = None
def __enter__(self):
self.conn = get_connection()
self.cursor = self.conn.cursor()
return self.cursor
def __exit__(self, exc_type, exc_val, exc_tb):
try:
if exc_type:
self.conn.rollback()
else:
self.conn.commit()
finally:
if self.cursor:
self.cursor.close()
if self.conn:
conn_pool.putconn(self.conn)