Port get.py improvements from upstream.

Generally standardizes the get_* helpers:
 - Adds type hinting.
 - Deduplicates block property addition.
 - Respects `graceful` in more contexts.
 - More resilient to invalid user input / less boilerplate necessary
   at call-sites.
This commit is contained in:
TLSM 2022-11-28 12:36:04 -05:00
parent 6b832aba99
commit 9953c5763c
No known key found for this signature in database
GPG key ID: E745A82778055C7E
3 changed files with 193 additions and 160 deletions

View file

@ -1,15 +1,19 @@
from files.classes import * from typing import Iterable, List, Optional, Type, Union
from files.helpers.strings import sql_ilike_clean
from flask import g from flask import g
from sqlalchemy import and_, any_, or_
from files.classes import *
from files.helpers.const import AUTOJANNY_ID
from files.helpers.strings import sql_ilike_clean
def get_id(username, v=None, graceful=False): def get_id(
username:str,
graceful:bool=False) -> Optional[int]:
username = sql_ilike_clean(username) username = sql_ilike_clean(username)
user = g.db.query( user = g.db.query(User.id).filter(
User.id
).filter(
or_( or_(
User.username.ilike(username), User.username.ilike(username),
User.original_username.ilike(username) User.original_username.ilike(username)
@ -17,25 +21,23 @@ def get_id(username, v=None, graceful=False):
).one_or_none() ).one_or_none()
if not user: if not user:
if not graceful: if graceful: return None
abort(404) abort(404)
else:
return None
return user[0] return user[0]
def get_user(username, v=None, graceful=False): def get_user(
username:Optional[str],
if not username: v:Optional[User]=None,
if not graceful: abort(404) graceful:bool=False,
else: return None include_blocks:bool=False) -> Optional[User]:
username = sql_ilike_clean(username) username = sql_ilike_clean(username)
if not username:
if graceful: return None
abort(404)
user = g.db.query( user = g.db.query(User).filter(
User
).filter(
or_( or_(
User.username.ilike(username), User.username.ilike(username),
User.original_username.ilike(username) User.original_username.ilike(username)
@ -43,34 +45,23 @@ def get_user(username, v=None, graceful=False):
).one_or_none() ).one_or_none()
if not user: if not user:
if not graceful: abort(404) if graceful: return None
else: return None abort(404)
if v: if v and include_blocks:
block = g.db.query(UserBlock).filter( user = _add_block_props(user, v)
or_(
and_(
UserBlock.user_id == v.id,
UserBlock.target_id == user.id
),
and_(UserBlock.user_id == user.id,
UserBlock.target_id == v.id
)
)
).first()
user.is_blocking = block and block.user_id == v.id
user.is_blocked = block and block.target_id == v.id
return user return user
def get_users(usernames, v=None, graceful=False):
if not usernames:
if not graceful: abort(404)
else: return []
def get_users(
usernames:Iterable[str],
graceful:bool=False) -> List[User]:
if not usernames: return []
usernames = [ sql_ilike_clean(n) for n in usernames ] usernames = [ sql_ilike_clean(n) for n in usernames ]
if not any(usernames):
if graceful and len(usernames) == 0: return []
abort(404)
users = g.db.query(User).filter( users = g.db.query(User).filter(
or_( or_(
User.username == any_(usernames), User.username == any_(usernames),
@ -78,96 +69,90 @@ def get_users(usernames, v=None, graceful=False):
) )
).all() ).all()
if not users: if len(users) != len(usernames) and not graceful:
if not graceful: abort(404) abort(404)
else: return []
return users return users
def get_account(id, v=None):
try: id = int(id) def get_account(
except: abort(404) id:Union[str,int],
v:Optional[User]=None,
graceful:bool=False,
include_blocks:bool=False) -> Optional[User]:
try:
id = int(id)
except:
if graceful: return None
abort(404)
user = g.db.query(User).filter_by(id = id).one_or_none() user = g.db.get(User, id)
if not user:
if not user: abort(404) if graceful: return None
abort(404)
if v: if v and include_blocks:
block = g.db.query(UserBlock).filter( user = _add_block_props(user, v)
or_(
and_(
UserBlock.user_id == v.id,
UserBlock.target_id == user.id
),
and_(UserBlock.user_id == user.id,
UserBlock.target_id == v.id
)
)
).first()
user.is_blocking = block and block.user_id == v.id
user.is_blocked = block and block.target_id == v.id
return user return user
def get_post(i, v=None, graceful=False): def get_post(
i:Union[str,int],
v:Optional[User]=None,
graceful:bool=False) -> Optional[Submission]:
try: i = int(i) try: i = int(i)
except: except:
if graceful: return None if graceful: return None
else: abort(404) abort(404)
if v: if v:
vt = g.db.query(Vote).filter_by( vt = g.db.query(Vote).filter_by(
user_id=v.id, submission_id=i).subquery() user_id=v.id, submission_id=i).subquery()
blocking = v.blocking.subquery() blocking = v.blocking.subquery()
items = g.db.query( post = g.db.query(
Submission, Submission,
vt.c.vote_type, vt.c.vote_type,
blocking.c.target_id, blocking.c.target_id,
) )
items=items.filter(Submission.id == i post = post.filter(Submission.id == i
).join( ).join(
vt, vt,
vt.c.submission_id == Submission.id, vt.c.submission_id == Submission.id,
isouter=True isouter=True
).join( ).join(
blocking, blocking,
blocking.c.target_id == Submission.author_id, blocking.c.target_id == Submission.author_id,
isouter=True isouter=True
) )
post = post.one_or_none()
items=items.one_or_none()
if not items: if not post:
if graceful: return None if graceful: return None
else: abort(404) else: abort(404)
x = items[0] x = post[0]
x.voted = items[1] or 0 x.voted = post[1] or 0
x.is_blocking = items[2] or 0 x.is_blocking = post[2] or 0
else: else:
items = g.db.query( post = g.db.get(Submission, i)
Submission if not post:
).filter(Submission.id == i).one_or_none()
if not items:
if graceful: return None if graceful: return None
else: abort(404) else: abort(404)
x=items x = post
return x return x
def get_posts(pids, v=None): def get_posts(
pids:Iterable[int],
if not pids: v:Optional[User]=None) -> List[Submission]:
return [] if not pids: return []
if v: if v:
vt = g.db.query(Vote).filter( vt = g.db.query(Vote.vote_type, Vote.submission_id).filter(
Vote.submission_id.in_(pids), Vote.submission_id.in_(pids),
Vote.user_id==v.id Vote.user_id==v.id
).subquery() ).subquery()
@ -183,67 +168,52 @@ def get_posts(pids, v=None):
).filter( ).filter(
Submission.id.in_(pids) Submission.id.in_(pids)
).join( ).join(
vt, vt.c.submission_id==Submission.id, isouter=True vt, vt.c.submission_id == Submission.id, isouter=True
).join( ).join(
blocking, blocking, blocking.c.target_id == Submission.author_id, isouter=True
blocking.c.target_id == Submission.author_id,
isouter=True
).join( ).join(
blocked, blocked, blocked.c.user_id == Submission.author_id, isouter=True
blocked.c.user_id == Submission.author_id, )
isouter=True
).all()
output = [p[0] for p in query]
for i in range(len(output)):
output[i].voted = query[i][1] or 0
output[i].is_blocking = query[i][2] or 0
output[i].is_blocked = query[i][3] or 0
else: else:
output = g.db.query(Submission,).filter(Submission.id.in_(pids)).all() query = g.db.query(Submission).filter(Submission.id.in_(pids))
results = query.all()
if v:
output = [p[0] for p in results]
for i in range(len(output)):
output[i].voted = results[i][1] or 0
output[i].is_blocking = results[i][2] or 0
output[i].is_blocked = results[i][3] or 0
else:
output = results
return sorted(output, key=lambda x: pids.index(x.id)) return sorted(output, key=lambda x: pids.index(x.id))
def get_comment(i, v=None, graceful=False):
def get_comment(
i:Union[str,int],
v:Optional[User]=None,
graceful:bool=False) -> Optional[Comment]:
try: i = int(i) try: i = int(i)
except: except:
if graceful: return None if graceful: return None
abort(404) abort(404)
if not i:
if graceful: return None
else: abort(404)
if v: comment = g.db.get(Comment, i)
if not comment:
if graceful: return None
else: abort(404)
comment=g.db.query(Comment).filter(Comment.id == i).one_or_none() return _add_vote_and_block_props(comment, v, CommentVote)
if not comment and not graceful: abort(404)
block = g.db.query(UserBlock).filter(
or_(
and_(
UserBlock.user_id == v.id,
UserBlock.target_id == comment.author_id
),
and_(
UserBlock.user_id == comment.author_id,
UserBlock.target_id == v.id
)
)
).first()
vts = g.db.query(CommentVote).filter_by(user_id=v.id, comment_id=comment.id)
vt = g.db.query(CommentVote).filter_by(user_id=v.id, comment_id=comment.id).one_or_none()
comment.is_blocking = block and block.user_id == v.id
comment.is_blocked = block and block.target_id == v.id
comment.voted = vt.vote_type if vt else 0
else:
comment = g.db.query(Comment).filter(Comment.id == i).one_or_none()
if not comment and not graceful:abort(404)
return comment
def get_comments(cids, v=None, load_parent=False): def get_comments(
cids:Iterable[int],
v:Optional[User]=None) -> List[Comment]:
if not cids: return [] if not cids: return []
if v: if v:
@ -261,7 +231,8 @@ def get_comments(cids, v=None, load_parent=False):
).filter(Comment.id.in_(cids)) ).filter(Comment.id.in_(cids))
if not (v and (v.shadowbanned or v.admin_level > 1)): if not (v and (v.shadowbanned or v.admin_level > 1)):
comments = comments.join(User, User.id == Comment.author_id).filter(User.shadowbanned == None) comments = comments.join(User, User.id == Comment.author_id) \
.filter(User.shadowbanned == None)
comments = comments.join( comments = comments.join(
votes, votes,
@ -284,21 +255,18 @@ def get_comments(cids, v=None, load_parent=False):
comment.is_blocking = c[2] or 0 comment.is_blocking = c[2] or 0
comment.is_blocked = c[3] or 0 comment.is_blocked = c[3] or 0
output.append(comment) output.append(comment)
else: else:
output = g.db.query(Comment).join(User, User.id == Comment.author_id).filter(User.shadowbanned == None, Comment.id.in_(cids)).all() output = g.db.query(Comment) \
.join(User, User.id == Comment.author_id) \
if load_parent: .filter(User.shadowbanned == None, Comment.id.in_(cids)) \
parents = [x.parent_comment_id for x in output if x.parent_comment_id] .all()
parents = get_comments(parents, v=v)
parents = {x.id: x for x in parents}
for c in output: c.sex = parents.get(c.parent_comment_id)
return sorted(output, key=lambda x: cids.index(x.id)) return sorted(output, key=lambda x: cids.index(x.id))
def get_domain(s): # TODO: This function was concisely inlined into posts.py in upstream.
# Think it involved adding `tldextract` as a dependency.
def get_domain(s:str) -> Optional[BannedDomain]:
parts = s.split(".") parts = s.split(".")
domain_list = set() domain_list = set()
for i in range(len(parts)): for i in range(len(parts)):
@ -308,7 +276,9 @@ def get_domain(s):
domain_list.add(new_domain) domain_list.add(new_domain)
doms = [x for x in g.db.query(BannedDomain).filter(BannedDomain.domain.in_(domain_list)).all()] doms = g.db.query(BannedDomain) \
.filter(BannedDomain.domain.in_(domain_list)).all()
doms = [x for x in doms]
if not doms: if not doms:
return None return None
@ -316,3 +286,70 @@ def get_domain(s):
doms = sorted(doms, key=lambda x: len(x.domain), reverse=True) doms = sorted(doms, key=lambda x: len(x.domain), reverse=True)
return doms[0] return doms[0]
def _add_block_props(
target:Union[Submission, Comment, User],
v:Optional[User]):
if not v: return target
id = None
if any(isinstance(target, cls) for cls in [Submission, Comment]):
id = target.author_id
elif isinstance(target, User):
id = target.id
else:
raise TypeError("add_block_props only supports non-None "
"submissions, comments, and users")
if hasattr(target, 'is_blocking') and hasattr(target, 'is_blocked'):
return target
# users can't block or be blocked by themselves or AutoJanny
if v.id == id or id == AUTOJANNY_ID:
target.is_blocking = False
target.is_blocked = False
return target
block = g.db.query(UserBlock).filter(
or_(
and_(
UserBlock.user_id == v.id,
UserBlock.target_id == id
),
and_(
UserBlock.user_id == id,
UserBlock.target_id == v.id
)
)
).first()
target.is_blocking = block and block.user_id == v.id
target.is_blocked = block and block.target_id == v.id
return target
def _add_vote_props(
target:Union[Submission, Comment],
v:Optional[User],
vote_cls:Union[Type[Vote], Type[CommentVote], None]):
if hasattr(target, 'voted'): return target
vt = g.db.query(vote_cls.vote_type).filter_by(user_id=v.id)
if vote_cls is Vote:
vt = vt.filter_by(submission_id=target.id)
elif vote_cls is CommentVote:
vt = vt.filter_by(comment_id=target.id)
else:
vt = None
if vt: vt = vt.one_or_none()
target.voted = vt.vote_type if vt else 0
return target
def _add_vote_and_block_props(
target:Union[Submission, Comment],
v:Optional[User],
vote_cls:Union[Type[Vote], Type[CommentVote], None]):
if not v: return target
target = _add_block_props(target, v)
return _add_vote_props(target, v, vote_cls)

View file

@ -1,8 +1,9 @@
import typing import typing
# clean strings for searching # clean strings for searching
def sql_ilike_clean(my_str): def sql_ilike_clean(my_str):
if my_str is None:
return None
return my_str.replace(r'\\', '').replace('_', r'\_').replace('%', '').strip() return my_str.replace(r'\\', '').replace('_', r'\_').replace('%', '').strip()
# this will also just return a bool verbatim # this will also just return a bool verbatim

View file

@ -540,7 +540,7 @@ def message2(v, username):
return {"error": "You have been permabanned and cannot send messages; " + \ return {"error": "You have been permabanned and cannot send messages; " + \
"contact modmail if you think this decision was incorrect."}, 403 "contact modmail if you think this decision was incorrect."}, 403
user = get_user(username, v=v) user = get_user(username, v=v, include_blocks=True)
if hasattr(user, 'is_blocking') and user.is_blocking: return {"error": "You're blocking this user."}, 403 if hasattr(user, 'is_blocking') and user.is_blocking: return {"error": "You're blocking this user."}, 403
if v.admin_level <= 1 and hasattr(user, 'is_blocked') and user.is_blocked: if v.admin_level <= 1 and hasattr(user, 'is_blocked') and user.is_blocked:
@ -772,9 +772,7 @@ def visitors(v):
@app.get("/@<username>") @app.get("/@<username>")
@auth_desired @auth_desired
def u_username(username, v=None): def u_username(username, v=None):
u = get_user(username, v=v, include_blocks=True)
u = get_user(username, v=v)
if username != u.username: if username != u.username:
@ -858,8 +856,7 @@ def u_username(username, v=None):
@app.get("/@<username>/comments") @app.get("/@<username>/comments")
@auth_desired @auth_desired
def u_username_comments(username, v=None): def u_username_comments(username, v=None):
user = get_user(username, v=v, include_blocks=True)
user = get_user(username, v=v)
if username != user.username: return redirect(f'/@{user.username}/comments') if username != user.username: return redirect(f'/@{user.username}/comments')
@ -945,8 +942,7 @@ def u_username_comments(username, v=None):
@app.get("/@<username>/info") @app.get("/@<username>/info")
@auth_required @auth_required
def u_username_info(username, v=None): def u_username_info(username, v=None):
user = get_user(username, v=v, include_blocks=True)
user=get_user(username, v=v)
if hasattr(user, 'is_blocking') and user.is_blocking: if hasattr(user, 'is_blocking') and user.is_blocking:
return {"error": "You're blocking this user."}, 401 return {"error": "You're blocking this user."}, 401
@ -958,8 +954,7 @@ def u_username_info(username, v=None):
@app.get("/<id>/info") @app.get("/<id>/info")
@auth_required @auth_required
def u_user_id_info(id, v=None): def u_user_id_info(id, v=None):
user = get_account(id, v=v, include_blocks=True)
user=get_account(id, v=v)
if hasattr(user, 'is_blocking') and user.is_blocking: if hasattr(user, 'is_blocking') and user.is_blocking:
return {"error": "You're blocking this user."}, 401 return {"error": "You're blocking this user."}, 401