diff options
| -rw-r--r-- | aergia/__init__.py (renamed from __init__.py) | 0 | ||||
| -rwxr-xr-x | aergia/aergia.py (renamed from aergia) | 133 | ||||
| -rw-r--r-- | run_tests.py | 14 | ||||
| -rw-r--r-- | t/__init__.py | 0 | ||||
| -rw-r--r-- | t/async-generator-and-comprehension.py | 15 | ||||
| -rw-r--r-- | t/eager_and_scheduled.py | 19 | ||||
| -rw-r--r-- | t/manual/flask.py | 26 | ||||
| -rw-r--r-- | t/random_wait.py | 22 | ||||
| -rw-r--r-- | t/task_groups_and_cancel.py | 35 | ||||
| -rw-r--r-- | t/test_functionality.py | 57 | ||||
| -rw-r--r-- | t/threads.py | 23 | ||||
| -rw-r--r-- | t/utils.py | 34 |
12 files changed, 326 insertions, 52 deletions
diff --git a/__init__.py b/aergia/__init__.py index e69de29..e69de29 100644 --- a/__init__.py +++ b/aergia/__init__.py diff --git a/aergia b/aergia/aergia.py index 3f629ef..b7b4f35 100755 --- a/aergia +++ b/aergia/aergia.py @@ -41,7 +41,7 @@ Commentary: Code: ''' -from collections import defaultdict +from collections import defaultdict, namedtuple from typing import Optional import argparse import asyncio @@ -78,6 +78,9 @@ def thread_join_replacement( threading.Thread.join = thread_join_replacement +# a tuple used as a key in the sample-dict +Sample = namedtuple('Sample', ['file', 'line', 'func']) + class Aergia(object): @@ -88,9 +91,11 @@ class Aergia(object): # number of times samples have been collected total_samples = 0 # the (ideal) interval between samples + signal_interval = 0.0 - def __init__(self, signal_interval): + @staticmethod + def __init__(signal_interval): Aergia.signal_interval = signal_interval @staticmethod @@ -111,8 +116,18 @@ class Aergia(object): @staticmethod def stop(): - Aergia.disable_signals() - Aergia.print_samples() + '''Stops the profiler.''' + signal.setitimer(signal.ITIMER_REAL, 0) + + @staticmethod + def clear(): + Aergia.total_samples = 0 + Aergia.samples = defaultdict(lambda: 0) + + @staticmethod + def get_samples(): + '''Returns the profiling results.''' + return Aergia.samples @staticmethod def print_samples(): @@ -130,20 +145,17 @@ class Aergia(object): '''Pretty-print a single sample.''' sig_intv = Aergia.signal_interval value = Aergia.samples[key] - print(f"{key} :\t\t{value * 100 / Aergia.total_samples:.3f}%" + print(f"{Aergia.tuple_to_string(key)} :" + f"\t\t{value * 100 / Aergia.total_samples:.3f}%" f"\t({value:.3f} ->" f" {value*sig_intv:.6f} seconds)") @staticmethod - def disable_signals(): - signal.setitimer(signal.ITIMER_REAL, 0) - - @staticmethod def idle_signal_handler(sig, frame): '''Obtains and records which lines are currently being waited on.''' keys = Aergia.compute_frames_to_record() for key in keys: - Aergia.samples[Aergia.frame_to_string(key)] += 1 + Aergia.samples[Aergia.frame_to_tuple(key)] += 1 Aergia.total_samples += 1 @staticmethod @@ -159,10 +171,7 @@ class Aergia(object): what is running for us.''' loops = Aergia.get_event_loops() frames = Aergia.get_frames_from_loops(loops) - return [ - f for f in frames - if f is not None and Aergia.should_trace(f.filename) - ] + return frames @staticmethod def get_event_loops(): @@ -206,28 +215,18 @@ class Aergia(object): ] @staticmethod - def frame_to_string(frame): - '''Pretty-prints a frame as a function/file name and a line number. - Additionally used as a key for tallying lines.''' - func_name = frame.name - line_no = frame.lineno - filename = frame.filename - return filename + ':' + str(line_no) + '\t' + func_name - - @staticmethod - def should_trace(filename): - '''Returns FALSE if filename is uninteresting to the user.''' - # return True - # FIXME Assume GuixSD. Makes filtering easy - if '/gnu/store' in filename: - return False - if 'site-packages' in filename: - return False - if filename[0] == '<': - return False - if 'aergia' in filename: - return False - return True + def frame_to_tuple(frame): + '''Given a frame, constructs a sample key for tallying lines.''' + co = frame.f_code + func_name = co.co_name + line_no = frame.f_lineno + filename = co.co_filename + return Sample(filename, line_no, func_name) + + def tuple_to_string(sample): + '''Given a namedtuple corresponding to a sample key, + pretty-prints a frame as a function/file name and a line number.''' + return sample.file + ':' + str(sample.line) + '\t' + sample.func @staticmethod def sort_samples(sample_dict): @@ -241,13 +240,50 @@ class Aergia(object): '''Given an asyncio event loop, returns the list of idle task frames. A task is considered 'idle' if it is not currently executing.''' idle = [] - for th in loop._scheduled: - st = th._source_traceback - if st: - idle += st + current = asyncio.current_task(loop) + for task in asyncio.all_tasks(loop): + if task == current: + continue + coro = task.get_coro() + if coro: + f = Aergia.get_deepest_traceable_frame(coro) + if f: + idle.append(f) return idle @staticmethod + def get_deepest_traceable_frame(coro): + if not coro: + return None + curr = coro + lframe = None + while True: + frame = getattr(curr, 'cr_frame', None) + if not frame or not Aergia.should_trace(frame.f_code.co_filename): + return lframe + + lframe = frame + awaited = getattr(curr, 'cr_await', None) + if not awaited or not hasattr(awaited, 'cr_frame'): + return lframe + curr = awaited + + @staticmethod + def should_trace(filename): + '''Returns FALSE if filename is uninteresting to the user.''' + # print(filename) + # FIXME Assume GuixSD. Makes filtering easy + if '/gnu/store' in filename: + return False + if 'site-packages' in filename: + return False + if filename[0] == '<': + return False + # if 'aergia' in filename: + # return False + return True + + @staticmethod def gettime(): '''returns the wallclock time''' return time.process_time() @@ -266,8 +302,9 @@ the_globals = { } -def parse_arguments(): - '''Parse CLI args''' +if __name__ == "__main__": + # parses CLI arguments and facilitates profiler runtime. + # foo parser = argparse.ArgumentParser( usage='%(prog)s [args] script [args]' ) @@ -281,12 +318,7 @@ def parse_arguments(): parser.add_argument('script', help='A python script to run.') parser.add_argument('s_args', nargs=argparse.REMAINDER, help='python script args') - - return parser.parse_args() - - -def main(): - args = parse_arguments() + args = parser.parse_args() sys.argv = [args.script] + args.s_args try: @@ -294,10 +326,7 @@ def main(): code = compile(fp.read(), args.script, "exec") Aergia(args.interval).start() exec(code, the_globals) + Aergia.print_samples() Aergia.stop() except Exception: traceback.print_exc() - - -if __name__ == "__main__": - main() diff --git a/run_tests.py b/run_tests.py new file mode 100644 index 0000000..ff11fb5 --- /dev/null +++ b/run_tests.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 +import unittest +import sys + +if __name__ == '__main__': + sys.path.append('t/') + sys.path.append('aergia/') + t_loader = unittest.defaultTestLoader + t_runner = unittest.TextTestRunner(verbosity=2) + t = ['test_functionality'] + + t_suite = t_loader.loadTestsFromNames(t) + result = t_runner.run(t_suite) + sys.exit(not result.wasSuccessful()) diff --git a/t/__init__.py b/t/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/t/__init__.py diff --git a/t/async-generator-and-comprehension.py b/t/async-generator-and-comprehension.py new file mode 100644 index 0000000..ce85f45 --- /dev/null +++ b/t/async-generator-and-comprehension.py @@ -0,0 +1,15 @@ +import asyncio + + +async def async_generator(): + for i in range(10): + await asyncio.sleep(0.5) + yield i + + +async def main(): + r = [a async for a in async_generator()] + print(sum(r)) + + +asyncio.run(main()) diff --git a/t/eager_and_scheduled.py b/t/eager_and_scheduled.py new file mode 100644 index 0000000..178a9f4 --- /dev/null +++ b/t/eager_and_scheduled.py @@ -0,0 +1,19 @@ +import asyncio + + +async def foo(): + await asyncio.sleep(1.0) + await baz() + await asyncio.sleep(0.5) + + +async def bar(): + await asyncio.create_subprocess_shell('sleep 1.0') + + +async def baz(): + await asyncio.sleep(1.0) + + +asyncio.run(foo()) +asyncio.run(bar()) diff --git a/t/manual/flask.py b/t/manual/flask.py new file mode 100644 index 0000000..57951b6 --- /dev/null +++ b/t/manual/flask.py @@ -0,0 +1,26 @@ +#!/usr/bin/env -S python3 -m flask --app +import asyncio +import os +import signal +import time +from flask import Flask + +app = Flask(__name__) + +async def query_db(): + await asyncio.sleep(2.0) + return 1 + +@app.route("/") +async def hello_world(): + await asyncio.sleep(10.0) + return "<p>Hello, World!</p>" + +@app.route("/die") +async def die(): + await asyncio.sleep(2.0) + os.kill(os.getpid(), signal.SIGINT) + return "You've killed me!" + +if __name__ == "__main__": + app.run(debug=True) diff --git a/t/random_wait.py b/t/random_wait.py new file mode 100644 index 0000000..2cfc290 --- /dev/null +++ b/t/random_wait.py @@ -0,0 +1,22 @@ +# SuperFastPython.com +# example of waiting for all tasks to complete +from random import random +import asyncio + +total = 0 + + +async def task_coro(arg): + value = random() + total += value + await asyncio.sleep(value) + print(f'>{arg} done in {value}') + + +async def main(): + tasks = [asyncio.create_task(task_coro(i)) for i in range(10)] + done, pending = await asyncio.wait(tasks) + print(f'All done. Total waiting time: {total}') + + +asyncio.run(main()) diff --git a/t/task_groups_and_cancel.py b/t/task_groups_and_cancel.py new file mode 100644 index 0000000..dcc6bbe --- /dev/null +++ b/t/task_groups_and_cancel.py @@ -0,0 +1,35 @@ +import asyncio + + +async def sleep(): + await asyncio.sleep(3) + print('I should never finish!') + return 0 + + +async def work(): + i = 0 + while i < 50: + i += 1 + await asyncio.sleep(0.2) + return 0 + + +async def explode(): + await asyncio.sleep(1.5) + a = 1 / 0 + return a + + +async def main(): + # exploding will bring all other tasks down with it! + try: + async with asyncio.TaskGroup() as tg: + tg.create_task(sleep()) + tg.create_task(work()) + tg.create_task(explode()) + except: + pass + + +asyncio.run(main()) diff --git a/t/test_functionality.py b/t/test_functionality.py new file mode 100644 index 0000000..5253534 --- /dev/null +++ b/t/test_functionality.py @@ -0,0 +1,57 @@ +import utils +import aiohttp +import asyncio + + +class BasicUsage(utils.AergiaUnitTestCase): + + def test_asyncless(self): + def a(): + x = 100 + while x > 0: + x -= 1 + + self.Aergia.start() + a() + self.Aergia.stop() + + samples = self.Aergia.get_samples() + self.assertFalse(samples) + + def test_sequential_tasks(self): + delay = 0.2 + num_times = 5 + + async def b(tot, num): + await asyncio.sleep(delay) + return tot + num + + async def a(): + tot = 0 + for i in range(num_times): + tot = await b(tot, i) + assert tot == 10 + + self.Aergia.start() + asyncio.run(a()) + self.Aergia.stop() + + samples = self.Aergia.get_samples() + + self.assertFuncContains('b', [self.expected_samples(delay * num_times)], + samples) + + def test_simultaneous_tasks(self): + delay = 0.2 + async def b(): await asyncio.sleep(delay) + async def a(): await asyncio.gather(b(), b(), b()) + + self.Aergia.start() + asyncio.run(a()) + self.Aergia.stop() + + samples = self.Aergia.get_samples() + + print(self.expected_samples(delay * 3)) + self.assertFuncContains('b', [self.expected_samples(delay * 3)], + samples) diff --git a/t/threads.py b/t/threads.py new file mode 100644 index 0000000..fa9c7ac --- /dev/null +++ b/t/threads.py @@ -0,0 +1,23 @@ +import asyncio +import threading + + + +async def count(): + print("it's a secret!") + await asyncio.sleep(3) + + +async def main(): + await asyncio.gather(count(), count(), count()) + print("done") + + +def thread_func(): + asyncio.run(main()) + + +if __name__ == "__main__": + x = threading.Thread(target=thread_func) + x.start() + x.join() diff --git a/t/utils.py b/t/utils.py new file mode 100644 index 0000000..fe9db26 --- /dev/null +++ b/t/utils.py @@ -0,0 +1,34 @@ +from aergia.aergia import Aergia +import unittest + + +class AergiaUnitTestCase(unittest.TestCase): + + interval = 0.01 + Aergia = Aergia(interval) + + def setUp(self): + self.Aergia.clear() + + def assertFuncContains(self, func_name, samples_expected, samples): + samples_actual = self.extract_values_by_func(samples, + func_name) + self.assertTrue(len(samples_expected) == len(samples_actual)) + s1 = sorted(samples_expected) + s2 = sorted(samples_actual) + for v1, v2 in zip(s1, s2): + self.assertRoughlyEqual(v1, v2) + + def assertRoughlyEqual(self, v1, v2): + print(f'{v1}, {v2}') + a = abs(v1 - v2) + self.assertTrue(a <= 1) + + def expected_samples(self, total_seconds): + return (total_seconds / self.interval) + + def extract_values_by_func(self, samples, func_name): + return [ + value for key, value in samples.items() + if key.func == func_name + ] |
