From ac55d9ff0b588b91202ccad72ee71e508e33ad08 Mon Sep 17 00:00:00 2001 From: bd Date: Sat, 19 Jul 2025 22:21:10 -0600 Subject: Reformat repository to allow for new unit tests --- aergia/aergia.py | 332 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 332 insertions(+) create mode 100755 aergia/aergia.py (limited to 'aergia/aergia.py') diff --git a/aergia/aergia.py b/aergia/aergia.py new file mode 100755 index 0000000..b7b4f35 --- /dev/null +++ b/aergia/aergia.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python3 +''' + _/_/ _/ + _/ _/ _/_/ _/ _/_/ _/_/_/ _/_/_/ + _/_/_/_/ _/_/_/_/ _/_/ _/ _/ _/ _/ _/ + _/ _/ _/ _/ _/ _/ _/ _/ _/ + _/ _/ _/_/_/ _/ _/_/_/ _/ _/_/_/ + _/ + _/_/ +Copyright: + + This program is free software: you can redistribute it + and/or modify it under the terms of the GNU General + Public License as published by the Free Software + Foundation, either version 3 of the License, or (at your + option) any later version. + + This program is distributed in the hope that it will be + useful, but WITHOUT ANY WARRANTY; without even the implied + warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the GNU General Public License for more details. + + You should have received a copy of the GNU General Public + License along with this program. If not, see + . + + +Commentary: + + Aergia is a sampling based profiler based off of SCALENE + by Emery Berger (https://github.com/plasma-umass/scalene). + + It is not particularly informative, but unlike SCALENE + or other sampling-based profilers I could find, reports + the wall-time each asyncio await call spends idling. + + The goal behind Aergia is to eventually have these features, + or similar, merged into SCALENE. + + +Code: +''' + +from collections import defaultdict, namedtuple +from typing import Optional +import argparse +import asyncio +import os +import signal +import sys +import threading +import time +import traceback + + +orig_thread_join = threading.Thread.join + + +def thread_join_replacement( + self: threading.Thread, timeout: Optional[float] = None +) -> None: + ''' + We replace threading.Thread.join with this method which always + periodically yields. + ''' + + start_time = time.perf_counter() + interval = sys.getswitchinterval() + while self.is_alive(): + orig_thread_join(self, interval) + # If a timeout was specified, check to see if it's expired. + if timeout is not None: + end_time = time.perf_counter() + if end_time - start_time >= timeout: + return None + return None + + +threading.Thread.join = thread_join_replacement + +# a tuple used as a key in the sample-dict +Sample = namedtuple('Sample', ['file', 'line', 'func']) + + +class Aergia(object): + + # a key-value pair where keys represent frame metadata (see + # Aergia.frame_to_string) and values represent number of times + # sampled. + samples = defaultdict(lambda: 0) + # number of times samples have been collected + total_samples = 0 + # the (ideal) interval between samples + + signal_interval = 0.0 + + @staticmethod + def __init__(signal_interval): + Aergia.signal_interval = signal_interval + + @staticmethod + def start(): + '''Turns on asyncio debug mode and sets up our signals. + + Debug mode must be on by default to avoid losing samples. + Debug mode is required to view the current coroutine being waited on + in `Aergia.get_idle_task_frames'. The TimerHandler object otherwise + does not keep track of a _source_traceback. + ''' + os.environ["PYTHONASYNCIODEBUG"] = "1" + signal.signal(signal.SIGALRM, + Aergia.idle_signal_handler) + signal.setitimer(signal.ITIMER_REAL, + Aergia.signal_interval, + Aergia.signal_interval) + + @staticmethod + def stop(): + '''Stops the profiler.''' + signal.setitimer(signal.ITIMER_REAL, 0) + + @staticmethod + def clear(): + Aergia.total_samples = 0 + Aergia.samples = defaultdict(lambda: 0) + + @staticmethod + def get_samples(): + '''Returns the profiling results.''' + return Aergia.samples + + @staticmethod + def print_samples(): + '''Pretty-print profiling results.''' + if Aergia.total_samples > 0: + print("FILE\tFUNC\tPERC\t(ACTUAL -> SECONDS)") + for key in Aergia.sort_samples(Aergia.samples): + Aergia.print_sample(key) + else: + print("No samples were gathered. If you are using concurrency, " + "this is likely a bug and you may run the profiler again.") + + @staticmethod + def print_sample(key): + '''Pretty-print a single sample.''' + sig_intv = Aergia.signal_interval + value = Aergia.samples[key] + print(f"{Aergia.tuple_to_string(key)} :" + f"\t\t{value * 100 / Aergia.total_samples:.3f}%" + f"\t({value:.3f} ->" + f" {value*sig_intv:.6f} seconds)") + + @staticmethod + def idle_signal_handler(sig, frame): + '''Obtains and records which lines are currently being waited on.''' + keys = Aergia.compute_frames_to_record() + for key in keys: + Aergia.samples[Aergia.frame_to_tuple(key)] += 1 + Aergia.total_samples += 1 + + @staticmethod + def compute_frames_to_record(): + '''Collects all stack frames which are currently being awaited on + during a given timestamp, and + + Note that we do NOT need to walk back up the call-stack to find + which of the user's lines caused the await call. There is NEVER + a previous frame, because idle frames aren't on the call stack! + + Luckily, the event loop and asyncio.all_tasks keeps track of + what is running for us.''' + loops = Aergia.get_event_loops() + frames = Aergia.get_frames_from_loops(loops) + return frames + + @staticmethod + def get_event_loops(): + '''Obtains each thread's event loop by relying on the fact that + if an event loop is active, it's own `run_once' and `run_forever' + will appear in the callstack.''' + loops = [] + for t in threading.enumerate(): + frame = sys._current_frames().get(t.ident) + if not frame: + continue + loops.extend(Aergia.walk_back_until_loops(frame)) + return loops + + @staticmethod + def walk_back_until_loops(frame): + '''Walks back the callstack until all event loops are found.''' + loops = [] + while frame: + loop = Aergia.find_loop_in_locals(frame.f_locals) + if loop and loop not in loops: # Avoid duplicates + loops.append(loop) + frame = frame.f_back + return loops + + @staticmethod + def find_loop_in_locals(locals_dict): + '''Given a dictionary of local variables for a stack frame, + retrieves the asyncio loop object, if there is one.''' + for val in locals_dict.values(): + if isinstance(val, asyncio.AbstractEventLoop): + return val + return None + + @staticmethod + def get_frames_from_loops(loops): + '''Given LOOPS, returns a flat list of frames.''' + return [ + frames for loop in loops + for frames in Aergia.get_idle_task_frames(loop) + ] + + @staticmethod + def frame_to_tuple(frame): + '''Given a frame, constructs a sample key for tallying lines.''' + co = frame.f_code + func_name = co.co_name + line_no = frame.f_lineno + filename = co.co_filename + return Sample(filename, line_no, func_name) + + def tuple_to_string(sample): + '''Given a namedtuple corresponding to a sample key, + pretty-prints a frame as a function/file name and a line number.''' + return sample.file + ':' + str(sample.line) + '\t' + sample.func + + @staticmethod + def sort_samples(sample_dict): + '''Returns SAMPLE_DICT in descending order by number of samples.''' + return {k: v for k, v in sorted(sample_dict.items(), + key=lambda item: item[1], + reverse=True)} + + @staticmethod + def get_idle_task_frames(loop): + '''Given an asyncio event loop, returns the list of idle task frames. + A task is considered 'idle' if it is not currently executing.''' + idle = [] + current = asyncio.current_task(loop) + for task in asyncio.all_tasks(loop): + if task == current: + continue + coro = task.get_coro() + if coro: + f = Aergia.get_deepest_traceable_frame(coro) + if f: + idle.append(f) + return idle + + @staticmethod + def get_deepest_traceable_frame(coro): + if not coro: + return None + curr = coro + lframe = None + while True: + frame = getattr(curr, 'cr_frame', None) + if not frame or not Aergia.should_trace(frame.f_code.co_filename): + return lframe + + lframe = frame + awaited = getattr(curr, 'cr_await', None) + if not awaited or not hasattr(awaited, 'cr_frame'): + return lframe + curr = awaited + + @staticmethod + def should_trace(filename): + '''Returns FALSE if filename is uninteresting to the user.''' + # print(filename) + # FIXME Assume GuixSD. Makes filtering easy + if '/gnu/store' in filename: + return False + if 'site-packages' in filename: + return False + if filename[0] == '<': + return False + # if 'aergia' in filename: + # return False + return True + + @staticmethod + def gettime(): + '''returns the wallclock time''' + return time.process_time() + + +the_globals = { + '__name__': '__main__', + '__doc__': None, + '__package__': None, + '__loader__': globals()['__loader__'], + '__spec__': None, + '__annotations__': {}, + '__builtins__': globals()['__builtins__'], + '__file__': None, + '__cached__': None, +} + + +if __name__ == "__main__": + # parses CLI arguments and facilitates profiler runtime. + # foo + parser = argparse.ArgumentParser( + usage='%(prog)s [args] script [args]' + ) + + parser.add_argument('-i', '--interval', + help='The minimum amount of time inbetween \ + samples in seconds.', + metavar='', + type=float, + default=0.01) + parser.add_argument('script', help='A python script to run.') + parser.add_argument('s_args', nargs=argparse.REMAINDER, + help='python script args') + args = parser.parse_args() + + sys.argv = [args.script] + args.s_args + try: + with open(args.script, 'r', encoding='utf-8') as fp: + code = compile(fp.read(), args.script, "exec") + Aergia(args.interval).start() + exec(code, the_globals) + Aergia.print_samples() + Aergia.stop() + except Exception: + traceback.print_exc() -- cgit v1.2.3