rDrama/thirdparty/sqlalchemy-easy-profile/easy_profile/profiler.py
2023-04-03 03:30:57 -06:00

195 lines
5.7 KiB
Python

from collections import Counter, namedtuple, OrderedDict, defaultdict
import functools
import inspect
from queue import Queue
import re
import sys
import time
import traceback
import pprint
from sqlalchemy import event
from sqlalchemy.engine.base import Engine
from .reporters import StreamReporter
# Optimize timer function for the platform
if sys.platform == "win32": # pragma: no cover
_timer = time.perf_counter
else:
_timer = time.time
SQL_OPERATORS = ["select", "insert", "update", "delete"]
OPERATOR_REGEX = re.compile("(%s) *." % "|".join(SQL_OPERATORS), re.IGNORECASE)
def _get_object_name(obj):
module = getattr(obj, "__module__", inspect.getmodule(obj).__name__)
if hasattr(obj, "__qualname__"):
name = obj.__qualname__
else:
name = obj.__name__
return module + "." + name
_DebugQuery = namedtuple(
"_DebugQuery", "statement,parameters,start_time,end_time,callstack"
)
class DebugQuery(_DebugQuery):
"""Public implementation of the debug query class"""
@property
def duration(self):
return self.end_time - self.start_time
class SessionProfiler:
"""A session profiler for sqlalchemy queries.
:param Engine engine: sqlalchemy database engine
:attr bool alive: is True if profiling in progress
:attr Queue queries: sqlalchemy queries queue
"""
_before = "before_cursor_execute"
_after = "after_cursor_execute"
def __init__(self, engine=None, stack_callback=None):
if engine is None:
self.engine = Engine
self.db_name = "default"
else:
self.engine = engine
self.db_name = engine.url.database or "undefined"
self.alive = False
self.queries = None
self._stats = None
self._stack_callback = stack_callback or self._stack_callback_default
def __enter__(self):
self.begin()
def __exit__(self, exc_type, exc_val, exc_tb):
self.commit()
def __call__(self, path=None, path_callback=None, reporter=None):
"""Decorate callable object and profile sqlalchemy queries.
If reporter was not defined by default will be used a base
streaming reporter.
:param easy_profile.reporters.Reporter reporter: profiling reporter
:param collections.abc.Callable path_callback: callback for getting
more complex path
"""
if reporter is None:
reporter = StreamReporter()
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if path_callback is not None:
_path = path_callback(func, *args, **kwargs)
else:
_path = path or _get_object_name(func)
self.begin()
try:
result = func(*args, **kwargs)
finally:
self.commit()
reporter.report(_path, self.stats)
return result
return wrapper
return decorator
@property
def stats(self):
if self._stats is None:
self._reset_stats()
return self._stats
def begin(self):
"""Begin profiling session.
:raises AssertionError: When the session is already alive.
"""
if self.alive:
raise AssertionError("Profiling session has already begun")
self.alive = True
self.queries = Queue()
self._reset_stats()
event.listen(self.engine, self._before, self._before_cursor_execute)
event.listen(self.engine, self._after, self._after_cursor_execute)
def commit(self):
"""Commit profiling session.
:raises AssertionError: When the session is not alive.
"""
if not self.alive:
raise AssertionError("Profiling session is already committed")
self.alive = False
self._get_stats()
event.remove(self.engine, self._before, self._before_cursor_execute)
event.remove(self.engine, self._after, self._after_cursor_execute)
def _get_stats(self):
"""Calculate and returns session statistics."""
while not self.queries.empty():
query = self.queries.get()
self._stats["call_stack"].append(query)
match = OPERATOR_REGEX.match(query.statement)
if match:
self._stats[match.group(1).lower()] += 1
self._stats["total"] += 1
self._stats["duration"] += query.duration
duplicates = self._stats["duplicates"].get(query.statement, -1)
self._stats["duplicates"][query.statement] = duplicates + 1
self._stats["callstacks"][query.statement].update({query.callstack: 1})
return self._stats
def _reset_stats(self):
self._stats = OrderedDict()
self._stats["db"] = self.db_name
for operator in SQL_OPERATORS:
self._stats[operator] = 0
self._stats["total"] = 0
self._stats["duration"] = 0
self._stats["call_stack"] = []
self._stats["duplicates"] = Counter()
self._stats["callstacks"] = defaultdict(Counter)
def _before_cursor_execute(self, conn, cursor, statement, parameters,
context, executemany):
context._query_start_time = _timer()
def _after_cursor_execute(self, conn, cursor, statement, parameters,
context, executemany):
self.queries.put(DebugQuery(
statement, parameters, context._query_start_time, _timer(), self._stack_callback()
))
def _stack_callback_default(self):
pprint.pprint(traceback.format_stack())
return ''.join(traceback.format_stack())