diff --git a/files/helpers/comments.py b/files/helpers/comments.py index fc9c79ef1..4be95eced 100644 --- a/files/helpers/comments.py +++ b/files/helpers/comments.py @@ -75,14 +75,15 @@ def update_ancestor_descendant_counts(comment, delta): g.db.add(parent) update_ancestor_descendant_counts(parent, delta) -def bulk_recompute_descendant_counts(predicate = None): +def bulk_recompute_descendant_counts(predicate = None, db=None): """ Recomputes the descendant_count of a large number of comments. The descendant_count of a comment is equal to the number of direct visible child comments plus the sum of the descendant_count of those visible child comments. - :param Callable predicate: If set, only update comments matching this predicate + :param predicate: If set, only update comments matching this predicate + :param db: If set, use this instead of g.db So for example @@ -117,7 +118,8 @@ def bulk_recompute_descendant_counts(predicate = None): AND comments.level = :level_1 """ - max_level_query = g.db.query(func.max(Comment.level)) + db = db if db is not None else g.db + max_level_query = db.query(func.max(Comment.level)) if predicate: max_level_query = predicate(max_level_query) @@ -159,8 +161,8 @@ def bulk_recompute_descendant_counts(predicate = None): ) if predicate: update_statement = predicate(update_statement) - g.db.execute(update_statement) - g.db.commit() + db.execute(update_statement) + db.commit() def comment_on_publish(comment:Comment): """ diff --git a/files/tests/test_child_comment_counts.py b/files/tests/test_child_comment_counts.py index de983fbd4..71980d09a 100644 --- a/files/tests/test_child_comment_counts.py +++ b/files/tests/test_child_comment_counts.py @@ -4,9 +4,10 @@ from . import fixture_comments from . import util from flask import g from files.__main__ import app, db_session -from files.classes import Submission, Comment +from files.classes import Submission, Comment, User from files.helpers.comments import bulk_recompute_descendant_counts import json +import random def assert_comment_visibility(post, comment_body, clients): @@ -171,8 +172,20 @@ def test_bulk_update_descendant_count_quick(accounts, submissions, comments): 4. Delete the comments/posts """ with app.app_context(): - g.db = db_session() - alice_client, alice = accounts.client_and_user_for_account('Alice') + db = db_session() + + lastname = ''.join(random.choice('aio') + random.choice('bfkmprst') for i in range(3)) + alice = User(**{ + "username": f"alice_{lastname}", + "original_username": f"alice_{lastname}", + "admin_level": 0, + "password":"themotteuser", + "email":None, + "ban_evade":0, + "profileurl":"/e/feather.webp" + }) + db.add(alice) + db.commit() posts = [] for i in range(2): post = Submission(**{ @@ -192,8 +205,8 @@ def test_bulk_update_descendant_count_quick(accounts, submissions, comments): 'ghost': False, 'filter_state': 'normal' }) - g.db.add(post) - g.db.commit() + db.add(post) + db.commit() posts.append(post) parent_comment = None top_comment = None @@ -214,15 +227,18 @@ def test_bulk_update_descendant_count_quick(accounts, submissions, comments): if parent_comment is None: top_comment = comment parent_comment = comment - g.db.add(comment) - g.db.commit() + db.add(comment) + db.commit() sorted_comments_0 = sorted(posts[0].comments, key=lambda c: c.id) sorted_comments_1 = sorted(posts[1].comments, key=lambda c: c.id) assert [i+1 for i in range(20)] == [c.level for c in sorted_comments_0] assert [i+1 for i in range(20)] == [c.level for c in sorted_comments_1] assert [0 for i in range(20)] == [c.descendant_count for c in sorted_comments_0] assert [0 for i in range(20)] == [c.descendant_count for c in sorted_comments_1] - bulk_recompute_descendant_counts(lambda q: q.where(Comment.parent_submission == posts[0].id)) + bulk_recompute_descendant_counts( + lambda q: q.where(Comment.parent_submission == posts[0].id), + db + ) sorted_comments_0 = sorted(posts[0].comments, key=lambda c: c.id) sorted_comments_1 = sorted(posts[1].comments, key=lambda c: c.id) assert [i+1 for i in range(20)] == [c.level for c in sorted_comments_0] @@ -231,6 +247,6 @@ def test_bulk_update_descendant_count_quick(accounts, submissions, comments): assert [0 for i in range(20)] == [c.descendant_count for c in sorted_comments_1] for post in posts: for comment in post.comments: - g.db.delete(comment) - g.db.delete(post) - g.db.commit() + db.delete(comment) + db.delete(post) + db.commit() diff --git a/migrations/versions/2023_01_03_07_59_56_1f30a37b08a0_dml_populate_comments_descendant_count_.py b/migrations/versions/2023_01_03_07_59_56_1f30a37b08a0_dml_populate_comments_descendant_count_.py index faee6b549..f9bab2772 100644 --- a/migrations/versions/2023_01_03_07_59_56_1f30a37b08a0_dml_populate_comments_descendant_count_.py +++ b/migrations/versions/2023_01_03_07_59_56_1f30a37b08a0_dml_populate_comments_descendant_count_.py @@ -9,7 +9,6 @@ from alembic import op from sqlalchemy.sql.expression import func, text from sqlalchemy.orm.session import Session from sqlalchemy import update -from flask import g from files.__main__ import db_session from files.classes import Comment @@ -22,18 +21,10 @@ down_revision = 'f8ba0e88ddd1' branch_labels = None depends_on = None -class g_db_set_from_alembic(): - def __enter__(self, *args, **kwargs): - g.db = Session(bind=op.get_bind()) - self.old_db = getattr(g, 'db', None) - def __exit__(self, *args, **kwargs): - g.db = self.old_db - - def upgrade(): - with g_db_set_from_alembic(): - bulk_recompute_descendant_counts() + db =Session(bind=op.get_bind()) + bulk_recompute_descendant_counts(lambda q: q, db) def downgrade(): - with g_db_set_from_alembic(): - g.db.execute(update(Comment).values(descendant_count=0)) + db =Session(bind=op.get_bind()) + db.execute(update(Comment).values(descendant_count=0))