Initial checkin of customizable sqlalchemy-easy-profile.

This commit is contained in:
Ben Rog-Wilhelm 2022-11-10 08:43:22 -06:00 committed by Ben Rog-Wilhelm
parent 937d36de31
commit 6b55cc1f5b
25 changed files with 1698 additions and 0 deletions

View file

@ -0,0 +1,12 @@
# The following names are available as part of the public API for
# ``sqlalchemy-easy-profile``. End users of this package can import
# these names by doing ``from easy_profile import SessionProfiler``,
# for example.
from .middleware import EasyProfileMiddleware
from .profiler import SessionProfiler
from .reporters import StreamReporter
__all__ = ["EasyProfileMiddleware", "SessionProfiler", "StreamReporter"]
__author__ = "Dmitry Vasilishin"
__version__ = "1.2.1"

View file

@ -0,0 +1,63 @@
import re
from .profiler import SessionProfiler
from .reporters import Reporter, StreamReporter
class EasyProfileMiddleware(object):
"""This middleware prints the number of database queries for each HTTP
request and can be applied as a WSGI server middleware.
:param app: WSGI application server
:param sqlalchemy.engine.base.Engine engine: sqlalchemy database engine
:param Reporter reporter: reporter instance
:param list exclude_path: a list of regex patterns for excluding requests
:param int min_time: minimal queries duration to logging
:param int min_query_count: minimal queries count to logging
"""
def __init__(self,
app,
engine=None,
reporter=None,
exclude_path=None,
min_time=0,
min_query_count=1):
if reporter:
if not isinstance(reporter, Reporter):
raise TypeError("reporter must be inherited from 'Reporter'")
self.reporter = reporter
else:
self.reporter = StreamReporter()
self.app = app
self.engine = engine
self.exclude_path = exclude_path or []
self.min_time = min_time
self.min_query_count = min_query_count
def __call__(self, environ, start_response):
profiler = SessionProfiler(self.engine)
path = environ.get("PATH_INFO", "")
if not self._ignore_request(path):
method = environ.get("REQUEST_METHOD")
if method:
path = "{0} {1}".format(method, path)
try:
with profiler:
response = self.app(environ, start_response)
finally:
self._report_stats(path, profiler.stats)
return response
return self.app(environ, start_response)
def _ignore_request(self, path):
"""Check to see if we should ignore the request."""
return any(re.match(pattern, path) for pattern in self.exclude_path)
def _report_stats(self, path, stats):
if (stats["total"] >= self.min_query_count and
stats["duration"] >= self.min_time):
self.reporter.report(path, stats)

View file

@ -0,0 +1,186 @@
from collections import Counter, namedtuple, OrderedDict
import functools
import inspect
from queue import Queue
import re
import sys
import time
from sqlalchemy import event
from sqlalchemy.engine.base import Engine
from .reporters import StreamReporter
# Optimize timer function for the platform
if sys.platform == "win32": # pragma: no cover
_timer = time.perf_counter
else:
_timer = time.time
SQL_OPERATORS = ["select", "insert", "update", "delete"]
OPERATOR_REGEX = re.compile("(%s) *." % "|".join(SQL_OPERATORS), re.IGNORECASE)
def _get_object_name(obj):
module = getattr(obj, "__module__", inspect.getmodule(obj).__name__)
if hasattr(obj, "__qualname__"):
name = obj.__qualname__
else:
name = obj.__name__
return module + "." + name
_DebugQuery = namedtuple(
"_DebugQuery", "statement,parameters,start_time,end_time"
)
class DebugQuery(_DebugQuery):
"""Public implementation of the debug query class"""
@property
def duration(self):
return self.end_time - self.start_time
class SessionProfiler:
"""A session profiler for sqlalchemy queries.
:param Engine engine: sqlalchemy database engine
:attr bool alive: is True if profiling in progress
:attr Queue queries: sqlalchemy queries queue
"""
_before = "before_cursor_execute"
_after = "after_cursor_execute"
def __init__(self, engine=None):
if engine is None:
self.engine = Engine
self.db_name = "default"
else:
self.engine = engine
self.db_name = engine.url.database or "undefined"
self.alive = False
self.queries = None
self._stats = None
def __enter__(self):
self.begin()
def __exit__(self, exc_type, exc_val, exc_tb):
self.commit()
def __call__(self, path=None, path_callback=None, reporter=None):
"""Decorate callable object and profile sqlalchemy queries.
If reporter was not defined by default will be used a base
streaming reporter.
:param easy_profile.reporters.Reporter reporter: profiling reporter
:param collections.abc.Callable path_callback: callback for getting
more complex path
"""
if reporter is None:
reporter = StreamReporter()
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if path_callback is not None:
_path = path_callback(func, *args, **kwargs)
else:
_path = path or _get_object_name(func)
self.begin()
try:
result = func(*args, **kwargs)
finally:
self.commit()
reporter.report(_path, self.stats)
return result
return wrapper
return decorator
@property
def stats(self):
if self._stats is None:
self._reset_stats()
return self._stats
def begin(self):
"""Begin profiling session.
:raises AssertionError: When the session is already alive.
"""
if self.alive:
raise AssertionError("Profiling session has already begun")
self.alive = True
self.queries = Queue()
self._reset_stats()
event.listen(self.engine, self._before, self._before_cursor_execute)
event.listen(self.engine, self._after, self._after_cursor_execute)
def commit(self):
"""Commit profiling session.
:raises AssertionError: When the session is not alive.
"""
if not self.alive:
raise AssertionError("Profiling session is already committed")
self.alive = False
self._get_stats()
event.remove(self.engine, self._before, self._before_cursor_execute)
event.remove(self.engine, self._after, self._after_cursor_execute)
def _get_stats(self):
"""Calculate and returns session statistics."""
while not self.queries.empty():
query = self.queries.get()
self._stats["call_stack"].append(query)
match = OPERATOR_REGEX.match(query.statement)
if match:
self._stats[match.group(1).lower()] += 1
self._stats["total"] += 1
self._stats["duration"] += query.duration
duplicates = self._stats["duplicates"].get(query.statement, -1)
self._stats["duplicates"][query.statement] = duplicates + 1
return self._stats
def _reset_stats(self):
self._stats = OrderedDict()
self._stats["db"] = self.db_name
for operator in SQL_OPERATORS:
self._stats[operator] = 0
self._stats["total"] = 0
self._stats["duration"] = 0
self._stats["call_stack"] = []
self._stats["duplicates"] = Counter()
def _before_cursor_execute(self, conn, cursor, statement, parameters,
context, executemany):
context._query_start_time = _timer()
def _after_cursor_execute(self, conn, cursor, statement, parameters,
context, executemany):
self.queries.put(DebugQuery(
statement, parameters, context._query_start_time, _timer()
))

View file

@ -0,0 +1,161 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
import sys
import sqlparse
from .termcolors import colorize
def shorten(text, length, placeholder="..."):
"""Truncate the given text to fit in the given length.
:param str text: string for truncate
:param int length: max length of string
:param str placeholder: append to the end of truncated text
:return: truncated string
"""
if len(text) > length:
return text[:length - len(placeholder)] + placeholder
return text
class Reporter(ABC):
"""Abstract class for profiler reporters."""
@abstractmethod
def report(self, path, stats):
"""Reports profiling statistic to a stream.
:param str path: where profiling occurred
:param dict stats: profiling statistics
"""
class StreamReporter(Reporter):
"""A base reporter for streaming to a file. By default reports
will be written to ``sys.stdout``.
:param int medium: a medium threshold count
:param int high: a high threshold count
:param file: output destination (stdout by default)
:param bool colorized: set True if output should be colorized
:param int display_duplicates: how much sql duplicates will be displayed
"""
_display_names = OrderedDict([
("Database", "db"),
("SELECT", "select"),
("INSERT", "insert"),
("UPDATE", "update"),
("DELETE", "delete"),
("Totals", "total"),
("Duplicates", "duplicates_count"),
])
def __init__(self,
medium=50,
high=100,
file=sys.stdout,
colorized=True,
display_duplicates=5):
if medium >= high:
raise ValueError("Medium must be less than high")
self._medium = medium
self._high = high
self._file = file
self._colorized = colorized
self._display_duplicates = display_duplicates or 0
def report(self, path, stats):
duplicates = stats["duplicates"]
stats["duplicates_count"] = sum(duplicates.values())
stats["db"] = shorten(stats["db"], 10)
output = self._colorize("\n{0}\n".format(path), ["bold"], fg="blue")
output += self.stats_table(stats)
total = stats["total"]
duration = float(stats["duration"])
summary = "Total queries: {0} in {1:.3}s".format(total, duration)
output += self._info_line("\n{0}\n".format(summary), total)
# Display duplicated sql statements.
#
# Get top counters were value greater than 1 and write to
# a stream. It will be skipped if `display_duplicates` was
# set to `0` or `None`.
most_common = duplicates.most_common(self._display_duplicates)
for statement, count in most_common:
if count < 1:
continue
# Wrap SQL statement and returning a list of wrapped lines
statement = sqlparse.format(
statement, reindent=True, keyword_case="upper"
)
text = "\nRepeated {0} times:\n{1}\n".format(count + 1, statement)
output += self._info_line(text, count)
self._file.write(output)
def stats_table(self, stats, sep="|"):
"""Formats profiling statistics as table.
:param dict stats: profiling statistics
:param str sep: columns separator character
:return: formatted table
:rtype: str
"""
line = sep + "{}" + sep + "\n"
h_names = [n.center(len(n) + 2) for n in self._display_names]
breakline = line.format(sep.join("-" * len(n) for n in h_names))
# Creates table and writes a header
output = ""
output += breakline
output += line.format(sep.join(h_names))
output += breakline
# Formats and writes row values in order by display_names.
#
# Row with values can be colorized for better perception. It's
# can be activated/deactivated through `colorized` parameter.
values = []
for name, key in self._display_names.items():
value = stats[key]
size = len(name) + 2
values.append(str(value).center(size))
row = line.format(sep.join(values))
output += self._info_line(row, stats["total"])
output += breakline
return output
def _info_line(self, line, total):
"""Returns colorized text according threshold.
:param str line: text which should be colorized
:param int total: threshold count
:return: colorized text
"""
if total > self._high:
return self._colorize(line, ["bold"], fg="red")
elif total > self._medium:
return self._colorize(line, ["bold"], fg="yellow")
return self._colorize(line, ["bold"], fg="green")
def _colorize(self, text, opts=(), fg=None, bg=None):
if not self._colorized:
return text
return colorize(text, opts, fg=fg, bg=bg)

View file

@ -0,0 +1,68 @@
ansi_colors = {
"black": 30,
"red": 31,
"green": 32,
"yellow": 33,
"blue": 34,
"magenta": 35,
"cyan": 36,
"white": 37,
"bright_black": 90,
"bright_red": 91,
"bright_green": 92,
"bright_yellow": 93,
"bright_blue": 94,
"bright_magenta": 95,
"bright_cyan": 96,
"bright_white": 97,
}
ansi_reset = "\033[0m"
ansi_options = {
"bold": 1,
"underscore": 4,
"blink": 5,
"reverse": 7,
"conceal": 8,
}
def colorize(text, opts=(), fg=None, bg=None):
"""Colorize text enclosed in ANSI graphics codes.
Depends on the keyword arguments 'fg' and 'bg', and the contents of
the opts tuple/list.
Valid colors:
'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'
Valid options:
'bold', 'underscore', 'blink', 'reverse', 'conceal'
'noreset' - string will not be terminated with the reset code
:param str text: your text
:param tuple opts: text options
:param str fg: foreground color name
:param str bg: background color name
:return: colorized text
"""
codes = []
if len(opts) == 1 and opts[0] == "reset":
return ansi_reset
if fg and fg in ansi_colors:
codes.append("\033[{0}m".format(ansi_colors[fg]))
elif bg and bg in ansi_colors:
codes.append("\033[{0}m".format(ansi_colors[bg] + 10))
for opt in opts:
if opt in ansi_options:
codes.append("\033[{0}m".format(ansi_options[opt]))
if "noreset" not in opts:
text += ansi_reset
return "".join(codes) + text