[themotte/rDrama#451] Passing in the db connection as an optional param makes things easier

This commit is contained in:
faul_sname 2023-01-13 19:33:56 -08:00
parent 12ca271fe7
commit c0a546d779
3 changed files with 38 additions and 29 deletions

View file

@ -75,14 +75,15 @@ def update_ancestor_descendant_counts(comment, delta):
g.db.add(parent)
update_ancestor_descendant_counts(parent, delta)
def bulk_recompute_descendant_counts(predicate = None):
def bulk_recompute_descendant_counts(predicate = None, db=None):
"""
Recomputes the descendant_count of a large number of comments.
The descendant_count of a comment is equal to the number of direct visible child comments
plus the sum of the descendant_count of those visible child comments.
:param Callable predicate: If set, only update comments matching this predicate
:param predicate: If set, only update comments matching this predicate
:param db: If set, use this instead of g.db
So for example
@ -117,7 +118,8 @@ def bulk_recompute_descendant_counts(predicate = None):
AND comments.level = :level_1
<predicate goes here>
"""
max_level_query = g.db.query(func.max(Comment.level))
db = db if db is not None else g.db
max_level_query = db.query(func.max(Comment.level))
if predicate:
max_level_query = predicate(max_level_query)
@ -159,8 +161,8 @@ def bulk_recompute_descendant_counts(predicate = None):
)
if predicate:
update_statement = predicate(update_statement)
g.db.execute(update_statement)
g.db.commit()
db.execute(update_statement)
db.commit()
def comment_on_publish(comment:Comment):
"""

View file

@ -4,9 +4,10 @@ from . import fixture_comments
from . import util
from flask import g
from files.__main__ import app, db_session
from files.classes import Submission, Comment
from files.classes import Submission, Comment, User
from files.helpers.comments import bulk_recompute_descendant_counts
import json
import random
def assert_comment_visibility(post, comment_body, clients):
@ -171,8 +172,20 @@ def test_bulk_update_descendant_count_quick(accounts, submissions, comments):
4. Delete the comments/posts
"""
with app.app_context():
g.db = db_session()
alice_client, alice = accounts.client_and_user_for_account('Alice')
db = db_session()
lastname = ''.join(random.choice('aio') + random.choice('bfkmprst') for i in range(3))
alice = User(**{
"username": f"alice_{lastname}",
"original_username": f"alice_{lastname}",
"admin_level": 0,
"password":"themotteuser",
"email":None,
"ban_evade":0,
"profileurl":"/e/feather.webp"
})
db.add(alice)
db.commit()
posts = []
for i in range(2):
post = Submission(**{
@ -192,8 +205,8 @@ def test_bulk_update_descendant_count_quick(accounts, submissions, comments):
'ghost': False,
'filter_state': 'normal'
})
g.db.add(post)
g.db.commit()
db.add(post)
db.commit()
posts.append(post)
parent_comment = None
top_comment = None
@ -214,15 +227,18 @@ def test_bulk_update_descendant_count_quick(accounts, submissions, comments):
if parent_comment is None:
top_comment = comment
parent_comment = comment
g.db.add(comment)
g.db.commit()
db.add(comment)
db.commit()
sorted_comments_0 = sorted(posts[0].comments, key=lambda c: c.id)
sorted_comments_1 = sorted(posts[1].comments, key=lambda c: c.id)
assert [i+1 for i in range(20)] == [c.level for c in sorted_comments_0]
assert [i+1 for i in range(20)] == [c.level for c in sorted_comments_1]
assert [0 for i in range(20)] == [c.descendant_count for c in sorted_comments_0]
assert [0 for i in range(20)] == [c.descendant_count for c in sorted_comments_1]
bulk_recompute_descendant_counts(lambda q: q.where(Comment.parent_submission == posts[0].id))
bulk_recompute_descendant_counts(
lambda q: q.where(Comment.parent_submission == posts[0].id),
db
)
sorted_comments_0 = sorted(posts[0].comments, key=lambda c: c.id)
sorted_comments_1 = sorted(posts[1].comments, key=lambda c: c.id)
assert [i+1 for i in range(20)] == [c.level for c in sorted_comments_0]
@ -231,6 +247,6 @@ def test_bulk_update_descendant_count_quick(accounts, submissions, comments):
assert [0 for i in range(20)] == [c.descendant_count for c in sorted_comments_1]
for post in posts:
for comment in post.comments:
g.db.delete(comment)
g.db.delete(post)
g.db.commit()
db.delete(comment)
db.delete(post)
db.commit()

View file

@ -9,7 +9,6 @@ from alembic import op
from sqlalchemy.sql.expression import func, text
from sqlalchemy.orm.session import Session
from sqlalchemy import update
from flask import g
from files.__main__ import db_session
from files.classes import Comment
@ -22,18 +21,10 @@ down_revision = 'f8ba0e88ddd1'
branch_labels = None
depends_on = None
class g_db_set_from_alembic():
def __enter__(self, *args, **kwargs):
g.db = Session(bind=op.get_bind())
self.old_db = getattr(g, 'db', None)
def __exit__(self, *args, **kwargs):
g.db = self.old_db
def upgrade():
with g_db_set_from_alembic():
bulk_recompute_descendant_counts()
db =Session(bind=op.get_bind())
bulk_recompute_descendant_counts(lambda q: q, db)
def downgrade():
with g_db_set_from_alembic():
g.db.execute(update(Comment).values(descendant_count=0))
db =Session(bind=op.get_bind())
db.execute(update(Comment).values(descendant_count=0))