diff --git a/files/__main__.py b/files/__main__.py index f0e1a3be2..8f22f00c4 100644 --- a/files/__main__.py +++ b/files/__main__.py @@ -194,11 +194,12 @@ limiter = flask_limiter.Limiter( # ...and then after that we can load the database. engine: Engine = create_engine(DATABASE_URL) -db_session: scoped_session = scoped_session(sessionmaker( +db_session_factory: sessionmaker = sessionmaker( bind=engine, autoflush=False, future=True, -)) +) +db_session: scoped_session = scoped_session(db_session_factory) # now that we've that, let's add the cache, compression, and mail extensions to our app... diff --git a/files/commands/cron.py b/files/commands/cron.py index a905e88ee..4f672f332 100644 --- a/files/commands/cron.py +++ b/files/commands/cron.py @@ -4,9 +4,9 @@ import time from datetime import datetime, timezone from typing import Final -from sqlalchemy.orm import scoped_session, Session +from sqlalchemy.orm import sessionmaker, Session -from files.__main__ import app, db_session +from files.__main__ import app, db_session_factory from files.classes.cron.tasks import (DayOfWeek, RepeatableTask, RepeatableTaskRun, ScheduledTaskState) @@ -41,7 +41,7 @@ def cron_app_worker(): logging.info("Starting scheduler worker process") while True: try: - _run_tasks(db_session) + _run_tasks(db_session_factory) except Exception as e: logging.exception( "An unhandled exception occurred while running tasks", @@ -77,7 +77,7 @@ def _acquire_lock_exclusive(db: Session, table: str): raise -def _run_tasks(db_session_factory: scoped_session): +def _run_tasks(db_session_factory: sessionmaker): ''' Runs tasks, attempting to guarantee that a task is ran once and only once. This uses postgres to lock the table containing our tasks at key points in @@ -116,10 +116,11 @@ def _run_tasks(db_session_factory: scoped_session): task.run_time_last = now task.run_state_enum = ScheduledTaskState.RUNNING + # This *must* happen before we start doing db queries, including sqlalchemy db queries + db.begin() task_debug_identifier = f"(ID {task.id}:{task.label})" logging.info(f"Running task {task_debug_identifier}") - db.begin() run: RepeatableTaskRun = task.run(db, task.run_time_last_or_created_utc) if run.exception: