234 lines
		
	
	
		
			No EOL
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			234 lines
		
	
	
		
			No EOL
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| import psycopg2
 | |
| from psycopg2 import IntegrityError, pool, OperationalError
 | |
| from dotenv import load_dotenv
 | |
| from werkzeug.security import generate_password_hash
 | |
| 
 | |
| load_dotenv()
 | |
| 
 | |
| try:
 | |
|     DB_MIN_CONN = int(os.getenv("DB_MIN_CONN", 1))
 | |
|     DB_MAX_CONN = int(os.getenv("DB_MAX_CONN", 10))
 | |
| 
 | |
|     conn_pool = pool.ThreadedConnectionPool(
 | |
|         minconn=DB_MIN_CONN,
 | |
|         maxconn=DB_MAX_CONN,
 | |
|         host=os.getenv("PG_HOST"),
 | |
|         port=os.getenv("PG_PORT"),
 | |
|         dbname=os.getenv("PG_DATABASE"),
 | |
|         user=os.getenv("PG_USER"),
 | |
|         password=os.getenv("PG_PASSWORD"),
 | |
|         connect_timeout=10,
 | |
|     )
 | |
| except OperationalError:
 | |
|     raise
 | |
| except Exception:
 | |
|     raise
 | |
| 
 | |
| 
 | |
| def get_connection():
 | |
|     """Get a connection from the connection pool."""
 | |
|     return conn_pool.getconn()
 | |
| 
 | |
| 
 | |
| def init_db():
 | |
|     """Initialize database tables with connection pool."""
 | |
|     conn = None
 | |
|     cursor = None
 | |
|     try:
 | |
|         conn = get_connection()
 | |
|         cursor = conn.cursor()
 | |
| 
 | |
|         cursor.execute(
 | |
|             """
 | |
|             CREATE TABLE IF NOT EXISTS subscribers (
 | |
|                 id SERIAL PRIMARY KEY,
 | |
|                 email TEXT UNIQUE NOT NULL
 | |
|             )
 | |
|             """
 | |
|         )
 | |
| 
 | |
|         cursor.execute(
 | |
|             """
 | |
|             CREATE TABLE IF NOT EXISTS admin_users (
 | |
|                 id SERIAL PRIMARY KEY,
 | |
|                 username TEXT UNIQUE NOT NULL,
 | |
|                 password TEXT NOT NULL
 | |
|             )
 | |
|             """
 | |
|         )
 | |
| 
 | |
|         cursor.execute(
 | |
|             """
 | |
|             CREATE TABLE IF NOT EXISTS newsletters (
 | |
|                 id SERIAL PRIMARY KEY,
 | |
|                 subject TEXT NOT NULL,
 | |
|                 body TEXT NOT NULL,
 | |
|                 sent_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
 | |
|             )
 | |
|             """
 | |
|         )
 | |
| 
 | |
|         conn.commit()
 | |
|     except Exception:
 | |
|         if conn:
 | |
|             conn.rollback()
 | |
|         raise
 | |
|     finally:
 | |
|         if cursor:
 | |
|             cursor.close()
 | |
|         if conn:
 | |
|             conn_pool.putconn(conn)
 | |
| 
 | |
| 
 | |
| def get_all_emails():
 | |
|     """Return a list of all subscriber emails."""
 | |
|     conn = None
 | |
|     cursor = None
 | |
|     try:
 | |
|         conn = get_connection()
 | |
|         cursor = conn.cursor()
 | |
|         cursor.execute("SELECT email FROM subscribers")
 | |
|         results = cursor.fetchall()
 | |
|         return [row[0] for row in results]
 | |
|     except Exception:
 | |
|         return []
 | |
|     finally:
 | |
|         if cursor:
 | |
|             cursor.close()
 | |
|         if conn:
 | |
|             conn_pool.putconn(conn)
 | |
| 
 | |
| 
 | |
| def add_email(email):
 | |
|     """Insert an email into the subscribers table."""
 | |
|     conn = None
 | |
|     cursor = None
 | |
|     try:
 | |
|         conn = get_connection()
 | |
|         cursor = conn.cursor()
 | |
|         cursor.execute("INSERT INTO subscribers (email) VALUES (%s)", (email,))
 | |
|         conn.commit()
 | |
|         return True
 | |
|     except IntegrityError:
 | |
|         if conn:
 | |
|             conn.rollback()
 | |
|         return False
 | |
|     except Exception:
 | |
|         if conn:
 | |
|             conn.rollback()
 | |
|         return False
 | |
|     finally:
 | |
|         if cursor:
 | |
|             cursor.close()
 | |
|         if conn:
 | |
|             conn_pool.putconn(conn)
 | |
| 
 | |
| 
 | |
| def remove_email(email):
 | |
|     """Remove an email from the subscribers table."""
 | |
|     conn = None
 | |
|     cursor = None
 | |
|     try:
 | |
|         conn = get_connection()
 | |
|         cursor = conn.cursor()
 | |
|         cursor.execute("DELETE FROM subscribers WHERE email = %s", (email,))
 | |
|         rowcount = cursor.rowcount
 | |
|         conn.commit()
 | |
|         return rowcount > 0
 | |
|     except Exception:
 | |
|         if conn:
 | |
|             conn.rollback()
 | |
|         return False
 | |
|     finally:
 | |
|         if cursor:
 | |
|             cursor.close()
 | |
|         if conn:
 | |
|             conn_pool.putconn(conn)
 | |
| 
 | |
| 
 | |
| def get_admin(username):
 | |
|     """Retrieve admin credentials for a given username.
 | |
|     Returns a tuple (username, password_hash) if found, otherwise None.
 | |
|     """
 | |
|     conn = None
 | |
|     cursor = None
 | |
|     try:
 | |
|         conn = get_connection()
 | |
|         cursor = conn.cursor()
 | |
|         cursor.execute(
 | |
|             "SELECT username, password FROM admin_users WHERE username = %s",
 | |
|             (username,),
 | |
|         )
 | |
|         return cursor.fetchone()
 | |
|     except Exception:
 | |
|         return None
 | |
|     finally:
 | |
|         if cursor:
 | |
|             cursor.close()
 | |
|         if conn:
 | |
|             conn_pool.putconn(conn)
 | |
| 
 | |
| 
 | |
| def create_default_admin():
 | |
|     """Create a default admin user if one doesn't already exist."""
 | |
|     default_username = os.getenv("ADMIN_USERNAME", "admin")
 | |
|     default_password = os.getenv("ADMIN_PASSWORD", "changeme")
 | |
|     hashed_password = generate_password_hash(default_password, method="pbkdf2:sha256")
 | |
| 
 | |
|     conn = None
 | |
|     cursor = None
 | |
|     try:
 | |
|         conn = get_connection()
 | |
|         cursor = conn.cursor()
 | |
| 
 | |
|         cursor.execute(
 | |
|             "SELECT id FROM admin_users WHERE username = %s", (default_username,)
 | |
|         )
 | |
|         exists = cursor.fetchone()
 | |
|         if exists is None:
 | |
|             cursor.execute(
 | |
|                 "INSERT INTO admin_users (username, password) VALUES (%s, %s)",
 | |
|                 (default_username, hashed_password),
 | |
|             )
 | |
|             conn.commit()
 | |
|     except Exception:
 | |
|         if conn:
 | |
|             conn.rollback()
 | |
|     finally:
 | |
|         if cursor:
 | |
|             cursor.close()
 | |
|         if conn:
 | |
|             conn_pool.putconn(conn)
 | |
| 
 | |
| 
 | |
| def close_pool():
 | |
|     """Close the database connection pool."""
 | |
|     try:
 | |
|         conn_pool.closeall()
 | |
|     except Exception:
 | |
|         pass
 | |
| 
 | |
| 
 | |
| class DatabaseContext:
 | |
|     """Optional context manager for manual transactions."""
 | |
|     def __init__(self):
 | |
|         self.conn = None
 | |
|         self.cursor = None
 | |
| 
 | |
|     def __enter__(self):
 | |
|         self.conn = get_connection()
 | |
|         self.cursor = self.conn.cursor()
 | |
|         return self.cursor
 | |
| 
 | |
|     def __exit__(self, exc_type, exc_val, exc_tb):
 | |
|         try:
 | |
|             if exc_type:
 | |
|                 self.conn.rollback()
 | |
|             else:
 | |
|                 self.conn.commit()
 | |
|         finally:
 | |
|             if self.cursor:
 | |
|                 self.cursor.close()
 | |
|             if self.conn:
 | |
|                 conn_pool.putconn(self.conn) | 
