refactor(db): switch to connection pool, remove logging, and tidy helpers
This commit is contained in:
parent
a8589b659f
commit
d9c86aa1bb
1 changed files with 105 additions and 74 deletions
179
database.py
179
database.py
|
|
@ -1,54 +1,53 @@
|
||||||
import os
|
import os
|
||||||
import logging
|
|
||||||
import psycopg2
|
import psycopg2
|
||||||
from psycopg2 import IntegrityError
|
from psycopg2 import IntegrityError, pool, OperationalError
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from werkzeug.security import generate_password_hash
|
from werkzeug.security import generate_password_hash
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# Logging setup
|
try:
|
||||||
logging.basicConfig(
|
DB_MIN_CONN = int(os.getenv("DB_MIN_CONN", 1))
|
||||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
DB_MAX_CONN = int(os.getenv("DB_MAX_CONN", 10))
|
||||||
)
|
|
||||||
logger = logging.getLogger(__name__)
|
conn_pool = pool.ThreadedConnectionPool(
|
||||||
|
minconn=DB_MIN_CONN,
|
||||||
|
maxconn=DB_MAX_CONN,
|
||||||
|
host=os.getenv("PG_HOST"),
|
||||||
|
port=os.getenv("PG_PORT"),
|
||||||
|
dbname=os.getenv("PG_DATABASE"),
|
||||||
|
user=os.getenv("PG_USER"),
|
||||||
|
password=os.getenv("PG_PASSWORD"),
|
||||||
|
connect_timeout=10,
|
||||||
|
)
|
||||||
|
except OperationalError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def get_connection():
|
def get_connection():
|
||||||
"""Return a new connection to the PostgreSQL database."""
|
"""Get a connection from the connection pool."""
|
||||||
try:
|
return conn_pool.getconn()
|
||||||
conn = psycopg2.connect(
|
|
||||||
host=os.getenv("PG_HOST"),
|
|
||||||
port=os.getenv("PG_PORT"),
|
|
||||||
dbname=os.getenv("PG_DATABASE"),
|
|
||||||
user=os.getenv("PG_USER"),
|
|
||||||
password=os.getenv("PG_PASSWORD"),
|
|
||||||
connect_timeout=10,
|
|
||||||
)
|
|
||||||
return conn
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Database connection error: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def init_db():
|
def init_db():
|
||||||
"""Initialize the database tables."""
|
"""Initialize database tables with connection pool."""
|
||||||
conn = None
|
conn = None
|
||||||
|
cursor = None
|
||||||
try:
|
try:
|
||||||
conn = get_connection()
|
conn = get_connection()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# Create subscribers table (if not exists)
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS subscribers (
|
CREATE TABLE IF NOT EXISTS subscribers (
|
||||||
id SERIAL PRIMARY KEY,
|
id SERIAL PRIMARY KEY,
|
||||||
email TEXT UNIQUE NOT NULL
|
email TEXT UNIQUE NOT NULL
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create admin_users table (if not exists)
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS admin_users (
|
CREATE TABLE IF NOT EXISTS admin_users (
|
||||||
|
|
@ -56,94 +55,96 @@ def init_db():
|
||||||
username TEXT UNIQUE NOT NULL,
|
username TEXT UNIQUE NOT NULL,
|
||||||
password TEXT NOT NULL
|
password TEXT NOT NULL
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Newsletter storage
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS newsletters (
|
CREATE TABLE IF NOT EXISTS newsletters (
|
||||||
id SERIAL PRIMARY KEY,
|
id SERIAL PRIMARY KEY,
|
||||||
subject TEXT NOT NULL,
|
subject TEXT NOT NULL,
|
||||||
body TEXT NOT NULL,
|
body TEXT NOT NULL,
|
||||||
sent_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
sent_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logger.info("Database initialized successfully.")
|
except Exception:
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Database initialization error: {e}")
|
|
||||||
if conn:
|
if conn:
|
||||||
conn.rollback() # Rollback if there's an error
|
conn.rollback()
|
||||||
|
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
if conn:
|
if cursor:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
conn.close()
|
if conn:
|
||||||
|
conn_pool.putconn(conn)
|
||||||
|
|
||||||
|
|
||||||
def get_all_emails():
|
def get_all_emails():
|
||||||
"""Return a list of all subscriber emails."""
|
"""Return a list of all subscriber emails."""
|
||||||
|
conn = None
|
||||||
|
cursor = None
|
||||||
try:
|
try:
|
||||||
conn = get_connection()
|
conn = get_connection()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("SELECT email FROM subscribers")
|
cursor.execute("SELECT email FROM subscribers")
|
||||||
results = cursor.fetchall()
|
results = cursor.fetchall()
|
||||||
emails = [row[0] for row in results]
|
return [row[0] for row in results]
|
||||||
logger.debug(f"Retrieved emails: {emails}")
|
except Exception:
|
||||||
return emails
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error retrieving emails: {e}")
|
|
||||||
return []
|
return []
|
||||||
finally:
|
finally:
|
||||||
if conn:
|
if cursor:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
conn.close()
|
if conn:
|
||||||
|
conn_pool.putconn(conn)
|
||||||
|
|
||||||
|
|
||||||
def add_email(email):
|
def add_email(email):
|
||||||
"""Insert an email into the subscribers table."""
|
"""Insert an email into the subscribers table."""
|
||||||
conn = None
|
conn = None
|
||||||
|
cursor = None
|
||||||
try:
|
try:
|
||||||
conn = get_connection()
|
conn = get_connection()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("INSERT INTO subscribers (email) VALUES (%s)", (email,))
|
cursor.execute("INSERT INTO subscribers (email) VALUES (%s)", (email,))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logger.info(f"Email {email} added successfully.")
|
|
||||||
return True
|
return True
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
logger.warning(f"Attempted to add duplicate email: {email}")
|
if conn:
|
||||||
|
conn.rollback()
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error(f"Error adding email {email}: {e}")
|
if conn:
|
||||||
|
conn.rollback()
|
||||||
return False
|
return False
|
||||||
finally:
|
finally:
|
||||||
if conn:
|
if cursor:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
conn.close()
|
if conn:
|
||||||
|
conn_pool.putconn(conn)
|
||||||
|
|
||||||
|
|
||||||
def remove_email(email):
|
def remove_email(email):
|
||||||
"""Remove an email from the subscribers table."""
|
"""Remove an email from the subscribers table."""
|
||||||
conn = None
|
conn = None
|
||||||
|
cursor = None
|
||||||
try:
|
try:
|
||||||
conn = get_connection()
|
conn = get_connection()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute("DELETE FROM subscribers WHERE email = %s", (email,))
|
cursor.execute("DELETE FROM subscribers WHERE email = %s", (email,))
|
||||||
rowcount = cursor.rowcount
|
rowcount = cursor.rowcount
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logger.info(f"Email {email} removed successfully.")
|
|
||||||
return rowcount > 0
|
return rowcount > 0
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error(f"Error removing email {email}: {e}")
|
if conn:
|
||||||
|
conn.rollback()
|
||||||
return False
|
return False
|
||||||
finally:
|
finally:
|
||||||
if conn:
|
if cursor:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
conn.close()
|
if conn:
|
||||||
|
conn_pool.putconn(conn)
|
||||||
|
|
||||||
|
|
||||||
def get_admin(username):
|
def get_admin(username):
|
||||||
|
|
@ -151,6 +152,7 @@ def get_admin(username):
|
||||||
Returns a tuple (username, password_hash) if found, otherwise None.
|
Returns a tuple (username, password_hash) if found, otherwise None.
|
||||||
"""
|
"""
|
||||||
conn = None
|
conn = None
|
||||||
|
cursor = None
|
||||||
try:
|
try:
|
||||||
conn = get_connection()
|
conn = get_connection()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
@ -158,15 +160,14 @@ def get_admin(username):
|
||||||
"SELECT username, password FROM admin_users WHERE username = %s",
|
"SELECT username, password FROM admin_users WHERE username = %s",
|
||||||
(username,),
|
(username,),
|
||||||
)
|
)
|
||||||
result = cursor.fetchone()
|
return cursor.fetchone()
|
||||||
return result # (username, password_hash)
|
except Exception:
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error retrieving admin: {e}")
|
|
||||||
return None
|
return None
|
||||||
finally:
|
finally:
|
||||||
if conn:
|
if cursor:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
conn.close()
|
if conn:
|
||||||
|
conn_pool.putconn(conn)
|
||||||
|
|
||||||
|
|
||||||
def create_default_admin():
|
def create_default_admin():
|
||||||
|
|
@ -174,30 +175,60 @@ def create_default_admin():
|
||||||
default_username = os.getenv("ADMIN_USERNAME", "admin")
|
default_username = os.getenv("ADMIN_USERNAME", "admin")
|
||||||
default_password = os.getenv("ADMIN_PASSWORD", "changeme")
|
default_password = os.getenv("ADMIN_PASSWORD", "changeme")
|
||||||
hashed_password = generate_password_hash(default_password, method="pbkdf2:sha256")
|
hashed_password = generate_password_hash(default_password, method="pbkdf2:sha256")
|
||||||
|
|
||||||
conn = None
|
conn = None
|
||||||
|
cursor = None
|
||||||
try:
|
try:
|
||||||
conn = get_connection()
|
conn = get_connection()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# Check if the admin already exists
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT id FROM admin_users WHERE username = %s", (default_username,)
|
"SELECT id FROM admin_users WHERE username = %s", (default_username,)
|
||||||
)
|
)
|
||||||
if cursor.fetchone() is None:
|
exists = cursor.fetchone()
|
||||||
|
if exists is None:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"INSERT INTO admin_users (username, password) VALUES (%s, %s)",
|
"INSERT INTO admin_users (username, password) VALUES (%s, %s)",
|
||||||
(default_username, hashed_password),
|
(default_username, hashed_password),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
logger.info("Default admin created successfully")
|
except Exception:
|
||||||
else:
|
|
||||||
logger.info("Default admin already exists")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error creating default admin: {e}")
|
|
||||||
if conn:
|
if conn:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
finally:
|
finally:
|
||||||
if conn:
|
if cursor:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
conn.close()
|
if conn:
|
||||||
|
conn_pool.putconn(conn)
|
||||||
|
|
||||||
|
|
||||||
|
def close_pool():
|
||||||
|
"""Close the database connection pool."""
|
||||||
|
try:
|
||||||
|
conn_pool.closeall()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseContext:
|
||||||
|
"""Optional context manager for manual transactions."""
|
||||||
|
def __init__(self):
|
||||||
|
self.conn = None
|
||||||
|
self.cursor = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.conn = get_connection()
|
||||||
|
self.cursor = self.conn.cursor()
|
||||||
|
return self.cursor
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
try:
|
||||||
|
if exc_type:
|
||||||
|
self.conn.rollback()
|
||||||
|
else:
|
||||||
|
self.conn.commit()
|
||||||
|
finally:
|
||||||
|
if self.cursor:
|
||||||
|
self.cursor.close()
|
||||||
|
if self.conn:
|
||||||
|
conn_pool.putconn(self.conn)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue