refactor(db): switch to connection pool, remove logging, and tidy helpers
This commit is contained in:
parent
a8589b659f
commit
d9c86aa1bb
1 changed files with 105 additions and 74 deletions
179
database.py
179
database.py
|
|
@ -1,54 +1,53 @@
|
|||
import os
|
||||
import logging
|
||||
import psycopg2
|
||||
from psycopg2 import IntegrityError
|
||||
from psycopg2 import IntegrityError, pool, OperationalError
|
||||
from dotenv import load_dotenv
|
||||
from werkzeug.security import generate_password_hash
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Logging setup
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
try:
|
||||
DB_MIN_CONN = int(os.getenv("DB_MIN_CONN", 1))
|
||||
DB_MAX_CONN = int(os.getenv("DB_MAX_CONN", 10))
|
||||
|
||||
conn_pool = pool.ThreadedConnectionPool(
|
||||
minconn=DB_MIN_CONN,
|
||||
maxconn=DB_MAX_CONN,
|
||||
host=os.getenv("PG_HOST"),
|
||||
port=os.getenv("PG_PORT"),
|
||||
dbname=os.getenv("PG_DATABASE"),
|
||||
user=os.getenv("PG_USER"),
|
||||
password=os.getenv("PG_PASSWORD"),
|
||||
connect_timeout=10,
|
||||
)
|
||||
except OperationalError:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
|
||||
def get_connection():
|
||||
"""Return a new connection to the PostgreSQL database."""
|
||||
try:
|
||||
conn = psycopg2.connect(
|
||||
host=os.getenv("PG_HOST"),
|
||||
port=os.getenv("PG_PORT"),
|
||||
dbname=os.getenv("PG_DATABASE"),
|
||||
user=os.getenv("PG_USER"),
|
||||
password=os.getenv("PG_PASSWORD"),
|
||||
connect_timeout=10,
|
||||
)
|
||||
return conn
|
||||
except Exception as e:
|
||||
logger.error(f"Database connection error: {e}")
|
||||
raise
|
||||
"""Get a connection from the connection pool."""
|
||||
return conn_pool.getconn()
|
||||
|
||||
|
||||
def init_db():
|
||||
"""Initialize the database tables."""
|
||||
"""Initialize database tables with connection pool."""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create subscribers table (if not exists)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS subscribers (
|
||||
id SERIAL PRIMARY KEY,
|
||||
email TEXT UNIQUE NOT NULL
|
||||
)
|
||||
"""
|
||||
"""
|
||||
)
|
||||
|
||||
# Create admin_users table (if not exists)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS admin_users (
|
||||
|
|
@ -56,94 +55,96 @@ def init_db():
|
|||
username TEXT UNIQUE NOT NULL,
|
||||
password TEXT NOT NULL
|
||||
)
|
||||
"""
|
||||
"""
|
||||
)
|
||||
|
||||
# Newsletter storage
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS newsletters (
|
||||
id SERIAL PRIMARY KEY,
|
||||
subject TEXT NOT NULL,
|
||||
body TEXT NOT NULL,
|
||||
sent_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS newsletters (
|
||||
id SERIAL PRIMARY KEY,
|
||||
subject TEXT NOT NULL,
|
||||
body TEXT NOT NULL,
|
||||
sent_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
logger.info("Database initialized successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Database initialization error: {e}")
|
||||
except Exception:
|
||||
if conn:
|
||||
conn.rollback() # Rollback if there's an error
|
||||
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
if conn:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
conn.close()
|
||||
if conn:
|
||||
conn_pool.putconn(conn)
|
||||
|
||||
|
||||
def get_all_emails():
|
||||
"""Return a list of all subscriber emails."""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT email FROM subscribers")
|
||||
results = cursor.fetchall()
|
||||
emails = [row[0] for row in results]
|
||||
logger.debug(f"Retrieved emails: {emails}")
|
||||
return emails
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving emails: {e}")
|
||||
return [row[0] for row in results]
|
||||
except Exception:
|
||||
return []
|
||||
finally:
|
||||
if conn:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
conn.close()
|
||||
if conn:
|
||||
conn_pool.putconn(conn)
|
||||
|
||||
|
||||
def add_email(email):
|
||||
"""Insert an email into the subscribers table."""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("INSERT INTO subscribers (email) VALUES (%s)", (email,))
|
||||
conn.commit()
|
||||
logger.info(f"Email {email} added successfully.")
|
||||
return True
|
||||
except IntegrityError:
|
||||
logger.warning(f"Attempted to add duplicate email: {email}")
|
||||
if conn:
|
||||
conn.rollback()
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding email {email}: {e}")
|
||||
except Exception:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
return False
|
||||
finally:
|
||||
if conn:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
conn.close()
|
||||
if conn:
|
||||
conn_pool.putconn(conn)
|
||||
|
||||
|
||||
def remove_email(email):
|
||||
"""Remove an email from the subscribers table."""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM subscribers WHERE email = %s", (email,))
|
||||
rowcount = cursor.rowcount
|
||||
conn.commit()
|
||||
logger.info(f"Email {email} removed successfully.")
|
||||
return rowcount > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing email {email}: {e}")
|
||||
except Exception:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
return False
|
||||
finally:
|
||||
if conn:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
conn.close()
|
||||
if conn:
|
||||
conn_pool.putconn(conn)
|
||||
|
||||
|
||||
def get_admin(username):
|
||||
|
|
@ -151,6 +152,7 @@ def get_admin(username):
|
|||
Returns a tuple (username, password_hash) if found, otherwise None.
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
|
@ -158,15 +160,14 @@ def get_admin(username):
|
|||
"SELECT username, password FROM admin_users WHERE username = %s",
|
||||
(username,),
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
return result # (username, password_hash)
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving admin: {e}")
|
||||
return cursor.fetchone()
|
||||
except Exception:
|
||||
return None
|
||||
finally:
|
||||
if conn:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
conn.close()
|
||||
if conn:
|
||||
conn_pool.putconn(conn)
|
||||
|
||||
|
||||
def create_default_admin():
|
||||
|
|
@ -174,30 +175,60 @@ def create_default_admin():
|
|||
default_username = os.getenv("ADMIN_USERNAME", "admin")
|
||||
default_password = os.getenv("ADMIN_PASSWORD", "changeme")
|
||||
hashed_password = generate_password_hash(default_password, method="pbkdf2:sha256")
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check if the admin already exists
|
||||
cursor.execute(
|
||||
"SELECT id FROM admin_users WHERE username = %s", (default_username,)
|
||||
)
|
||||
if cursor.fetchone() is None:
|
||||
exists = cursor.fetchone()
|
||||
if exists is None:
|
||||
cursor.execute(
|
||||
"INSERT INTO admin_users (username, password) VALUES (%s, %s)",
|
||||
(default_username, hashed_password),
|
||||
)
|
||||
conn.commit()
|
||||
logger.info("Default admin created successfully")
|
||||
else:
|
||||
logger.info("Default admin already exists")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating default admin: {e}")
|
||||
except Exception:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
finally:
|
||||
if conn:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
conn.close()
|
||||
if conn:
|
||||
conn_pool.putconn(conn)
|
||||
|
||||
|
||||
def close_pool():
|
||||
"""Close the database connection pool."""
|
||||
try:
|
||||
conn_pool.closeall()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseContext:
|
||||
"""Optional context manager for manual transactions."""
|
||||
def __init__(self):
|
||||
self.conn = None
|
||||
self.cursor = None
|
||||
|
||||
def __enter__(self):
|
||||
self.conn = get_connection()
|
||||
self.cursor = self.conn.cursor()
|
||||
return self.cursor
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
try:
|
||||
if exc_type:
|
||||
self.conn.rollback()
|
||||
else:
|
||||
self.conn.commit()
|
||||
finally:
|
||||
if self.cursor:
|
||||
self.cursor.close()
|
||||
if self.conn:
|
||||
conn_pool.putconn(self.conn)
|
||||
Loading…
Add table
Add a link
Reference in a new issue