summaryrefslogtreecommitdiff
path: root/aergia/aergia.py
diff options
context:
space:
mode:
Diffstat (limited to 'aergia/aergia.py')
-rwxr-xr-xaergia/aergia.py332
1 files changed, 332 insertions, 0 deletions
diff --git a/aergia/aergia.py b/aergia/aergia.py
new file mode 100755
index 0000000..b7b4f35
--- /dev/null
+++ b/aergia/aergia.py
@@ -0,0 +1,332 @@
+#!/usr/bin/env python3
+'''
+ _/_/ _/
+ _/ _/ _/_/ _/ _/_/ _/_/_/ _/_/_/
+ _/_/_/_/ _/_/_/_/ _/_/ _/ _/ _/ _/ _/
+ _/ _/ _/ _/ _/ _/ _/ _/ _/
+ _/ _/ _/_/_/ _/ _/_/_/ _/ _/_/_/
+ _/
+ _/_/
+Copyright:
+
+ This program is free software: you can redistribute it
+ and/or modify it under the terms of the GNU General
+ Public License as published by the Free Software
+ Foundation, either version 3 of the License, or (at your
+ option) any later version.
+
+ This program is distributed in the hope that it will be
+ useful, but WITHOUT ANY WARRANTY; without even the implied
+ warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+ PURPOSE. See the GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public
+ License along with this program. If not, see
+ <https://www.gnu.org/licenses/>.
+
+
+Commentary:
+
+ Aergia is a sampling based profiler based off of SCALENE
+ by Emery Berger (https://github.com/plasma-umass/scalene).
+
+ It is not particularly informative, but unlike SCALENE
+ or other sampling-based profilers I could find, reports
+ the wall-time each asyncio await call spends idling.
+
+ The goal behind Aergia is to eventually have these features,
+ or similar, merged into SCALENE.
+
+
+Code:
+'''
+
+from collections import defaultdict, namedtuple
+from typing import Optional
+import argparse
+import asyncio
+import os
+import signal
+import sys
+import threading
+import time
+import traceback
+
+
+orig_thread_join = threading.Thread.join
+
+
+def thread_join_replacement(
+ self: threading.Thread, timeout: Optional[float] = None
+) -> None:
+ '''
+ We replace threading.Thread.join with this method which always
+ periodically yields.
+ '''
+
+ start_time = time.perf_counter()
+ interval = sys.getswitchinterval()
+ while self.is_alive():
+ orig_thread_join(self, interval)
+ # If a timeout was specified, check to see if it's expired.
+ if timeout is not None:
+ end_time = time.perf_counter()
+ if end_time - start_time >= timeout:
+ return None
+ return None
+
+
+threading.Thread.join = thread_join_replacement
+
+# a tuple used as a key in the sample-dict
+Sample = namedtuple('Sample', ['file', 'line', 'func'])
+
+
+class Aergia(object):
+
+ # a key-value pair where keys represent frame metadata (see
+ # Aergia.frame_to_string) and values represent number of times
+ # sampled.
+ samples = defaultdict(lambda: 0)
+ # number of times samples have been collected
+ total_samples = 0
+ # the (ideal) interval between samples
+
+ signal_interval = 0.0
+
+ @staticmethod
+ def __init__(signal_interval):
+ Aergia.signal_interval = signal_interval
+
+ @staticmethod
+ def start():
+ '''Turns on asyncio debug mode and sets up our signals.
+
+ Debug mode must be on by default to avoid losing samples.
+ Debug mode is required to view the current coroutine being waited on
+ in `Aergia.get_idle_task_frames'. The TimerHandler object otherwise
+ does not keep track of a _source_traceback.
+ '''
+ os.environ["PYTHONASYNCIODEBUG"] = "1"
+ signal.signal(signal.SIGALRM,
+ Aergia.idle_signal_handler)
+ signal.setitimer(signal.ITIMER_REAL,
+ Aergia.signal_interval,
+ Aergia.signal_interval)
+
+ @staticmethod
+ def stop():
+ '''Stops the profiler.'''
+ signal.setitimer(signal.ITIMER_REAL, 0)
+
+ @staticmethod
+ def clear():
+ Aergia.total_samples = 0
+ Aergia.samples = defaultdict(lambda: 0)
+
+ @staticmethod
+ def get_samples():
+ '''Returns the profiling results.'''
+ return Aergia.samples
+
+ @staticmethod
+ def print_samples():
+ '''Pretty-print profiling results.'''
+ if Aergia.total_samples > 0:
+ print("FILE\tFUNC\tPERC\t(ACTUAL -> SECONDS)")
+ for key in Aergia.sort_samples(Aergia.samples):
+ Aergia.print_sample(key)
+ else:
+ print("No samples were gathered. If you are using concurrency, "
+ "this is likely a bug and you may run the profiler again.")
+
+ @staticmethod
+ def print_sample(key):
+ '''Pretty-print a single sample.'''
+ sig_intv = Aergia.signal_interval
+ value = Aergia.samples[key]
+ print(f"{Aergia.tuple_to_string(key)} :"
+ f"\t\t{value * 100 / Aergia.total_samples:.3f}%"
+ f"\t({value:.3f} ->"
+ f" {value*sig_intv:.6f} seconds)")
+
+ @staticmethod
+ def idle_signal_handler(sig, frame):
+ '''Obtains and records which lines are currently being waited on.'''
+ keys = Aergia.compute_frames_to_record()
+ for key in keys:
+ Aergia.samples[Aergia.frame_to_tuple(key)] += 1
+ Aergia.total_samples += 1
+
+ @staticmethod
+ def compute_frames_to_record():
+ '''Collects all stack frames which are currently being awaited on
+ during a given timestamp, and
+
+ Note that we do NOT need to walk back up the call-stack to find
+ which of the user's lines caused the await call. There is NEVER
+ a previous frame, because idle frames aren't on the call stack!
+
+ Luckily, the event loop and asyncio.all_tasks keeps track of
+ what is running for us.'''
+ loops = Aergia.get_event_loops()
+ frames = Aergia.get_frames_from_loops(loops)
+ return frames
+
+ @staticmethod
+ def get_event_loops():
+ '''Obtains each thread's event loop by relying on the fact that
+ if an event loop is active, it's own `run_once' and `run_forever'
+ will appear in the callstack.'''
+ loops = []
+ for t in threading.enumerate():
+ frame = sys._current_frames().get(t.ident)
+ if not frame:
+ continue
+ loops.extend(Aergia.walk_back_until_loops(frame))
+ return loops
+
+ @staticmethod
+ def walk_back_until_loops(frame):
+ '''Walks back the callstack until all event loops are found.'''
+ loops = []
+ while frame:
+ loop = Aergia.find_loop_in_locals(frame.f_locals)
+ if loop and loop not in loops: # Avoid duplicates
+ loops.append(loop)
+ frame = frame.f_back
+ return loops
+
+ @staticmethod
+ def find_loop_in_locals(locals_dict):
+ '''Given a dictionary of local variables for a stack frame,
+ retrieves the asyncio loop object, if there is one.'''
+ for val in locals_dict.values():
+ if isinstance(val, asyncio.AbstractEventLoop):
+ return val
+ return None
+
+ @staticmethod
+ def get_frames_from_loops(loops):
+ '''Given LOOPS, returns a flat list of frames.'''
+ return [
+ frames for loop in loops
+ for frames in Aergia.get_idle_task_frames(loop)
+ ]
+
+ @staticmethod
+ def frame_to_tuple(frame):
+ '''Given a frame, constructs a sample key for tallying lines.'''
+ co = frame.f_code
+ func_name = co.co_name
+ line_no = frame.f_lineno
+ filename = co.co_filename
+ return Sample(filename, line_no, func_name)
+
+ def tuple_to_string(sample):
+ '''Given a namedtuple corresponding to a sample key,
+ pretty-prints a frame as a function/file name and a line number.'''
+ return sample.file + ':' + str(sample.line) + '\t' + sample.func
+
+ @staticmethod
+ def sort_samples(sample_dict):
+ '''Returns SAMPLE_DICT in descending order by number of samples.'''
+ return {k: v for k, v in sorted(sample_dict.items(),
+ key=lambda item: item[1],
+ reverse=True)}
+
+ @staticmethod
+ def get_idle_task_frames(loop):
+ '''Given an asyncio event loop, returns the list of idle task frames.
+ A task is considered 'idle' if it is not currently executing.'''
+ idle = []
+ current = asyncio.current_task(loop)
+ for task in asyncio.all_tasks(loop):
+ if task == current:
+ continue
+ coro = task.get_coro()
+ if coro:
+ f = Aergia.get_deepest_traceable_frame(coro)
+ if f:
+ idle.append(f)
+ return idle
+
+ @staticmethod
+ def get_deepest_traceable_frame(coro):
+ if not coro:
+ return None
+ curr = coro
+ lframe = None
+ while True:
+ frame = getattr(curr, 'cr_frame', None)
+ if not frame or not Aergia.should_trace(frame.f_code.co_filename):
+ return lframe
+
+ lframe = frame
+ awaited = getattr(curr, 'cr_await', None)
+ if not awaited or not hasattr(awaited, 'cr_frame'):
+ return lframe
+ curr = awaited
+
+ @staticmethod
+ def should_trace(filename):
+ '''Returns FALSE if filename is uninteresting to the user.'''
+ # print(filename)
+ # FIXME Assume GuixSD. Makes filtering easy
+ if '/gnu/store' in filename:
+ return False
+ if 'site-packages' in filename:
+ return False
+ if filename[0] == '<':
+ return False
+ # if 'aergia' in filename:
+ # return False
+ return True
+
+ @staticmethod
+ def gettime():
+ '''returns the wallclock time'''
+ return time.process_time()
+
+
+the_globals = {
+ '__name__': '__main__',
+ '__doc__': None,
+ '__package__': None,
+ '__loader__': globals()['__loader__'],
+ '__spec__': None,
+ '__annotations__': {},
+ '__builtins__': globals()['__builtins__'],
+ '__file__': None,
+ '__cached__': None,
+}
+
+
+if __name__ == "__main__":
+ # parses CLI arguments and facilitates profiler runtime.
+ # foo
+ parser = argparse.ArgumentParser(
+ usage='%(prog)s [args] script [args]'
+ )
+
+ parser.add_argument('-i', '--interval',
+ help='The minimum amount of time inbetween \
+ samples in seconds.',
+ metavar='',
+ type=float,
+ default=0.01)
+ parser.add_argument('script', help='A python script to run.')
+ parser.add_argument('s_args', nargs=argparse.REMAINDER,
+ help='python script args')
+ args = parser.parse_args()
+
+ sys.argv = [args.script] + args.s_args
+ try:
+ with open(args.script, 'r', encoding='utf-8') as fp:
+ code = compile(fp.read(), args.script, "exec")
+ Aergia(args.interval).start()
+ exec(code, the_globals)
+ Aergia.print_samples()
+ Aergia.stop()
+ except Exception:
+ traceback.print_exc()