diff options
| author | bd <bdunahu@operationnull.com> | 2025-08-08 21:25:14 -0400 |
|---|---|---|
| committer | bd <bdunahu@operationnull.com> | 2025-08-08 21:25:14 -0400 |
| commit | 019a74e6337e83478998bd7e6faeca635332710e (patch) | |
| tree | 0c32164f4eb496ff9a0b43997c8bf90c0cdfd415 /aergia | |
| parent | 25b62280ae5db0544e20151c5561dbafce9cb3dc (diff) | |
More in-depth type-checking for _search_future (now _trace_down)
Diffstat (limited to 'aergia')
| -rwxr-xr-x | aergia/aergia.py | 88 |
1 files changed, 44 insertions, 44 deletions
diff --git a/aergia/aergia.py b/aergia/aergia.py index bb8dcbe..a41ef92 100755 --- a/aergia/aergia.py +++ b/aergia/aergia.py @@ -44,7 +44,8 @@ Code: ''' from collections import defaultdict, namedtuple -from typing import Optional +from typing import Optional, Tuple +from types import FrameType import argparse import asyncio import signal @@ -293,7 +294,7 @@ class Aergia(object): @staticmethod def _get_deepest_traceable_frame(coro): '''Get the deepest frame of coro we care to trace. - This is possible because each corooutine keeps a reference to the + This is possible because each coroutine keeps a reference to the coroutine it is waiting on. Note that it cannot be the case that a task is suspended in a frame @@ -302,55 +303,26 @@ class Aergia(object): curr = coro deepest_frame = None while curr: - frame = getattr(curr, 'cr_frame', None) + frame, curr = Aergia._trace_down(curr) - if not frame: - curr = Aergia._search_future(curr) - if isinstance(curr, AsyncGeneratorType): - frame = getattr(curr, 'ag_frame', None) - else: - break + if frame is None: + break if Aergia._should_trace(frame.f_code.co_filename): deepest_frame = frame - if isinstance(curr, AsyncGeneratorType): - curr = getattr(curr, 'ag_await', None) - else: - curr = getattr(curr, 'cr_await', None) - # if this task is found to point to another task we're profiling, # then we will get the deepest frame later and should return nothing. - if isinstance(curr, list) and any( - Aergia._should_trace_task(task) - for task in curr - ): - return None - - return deepest_frame - - @staticmethod - def _search_future(future): - '''Given an awaitable which is not a coroutine, assume it is a future - and attempt to find references to tasks or async generators.''' - awaitable = None - if not isinstance(future, asyncio.Future): - # TODO some wrappers like _asyncio.FutureIter, - # async_generator_asend get caught here, I am not sure if a more - # robust approach is necessary - - # can gc be avoided here? - refs = gc.get_referents(future) - if refs: - awaitable = refs[0] - # this is specific to gathering futures, i.e., gather statement. - # Other cases may need to be added. - if isinstance(awaitable, asyncio.Future): - return getattr(awaitable, '_children', []) + if curr is not None: + tasks = getattr(curr, '_children', []) + if any( + Aergia._should_trace_task(task) + for task in tasks + ): + return None - # if this is not AsyncGeneratorType, it is ignored - return awaitable + return deepest_frame @staticmethod def _should_trace_task(task): @@ -373,8 +345,8 @@ class Aergia(object): # statement. # if this isn't the case, the associated coroutine will # be 'waiting' on the coroutine declaration. No! Bad! - if getattr(coro, 'cr_frame', None) is None or \ - getattr(coro, 'cr_await', None) is None: + frame, awaitable = Aergia._trace_down(coro) + if frame is None or awaitable is None: return False frame = getattr(coro, 'cr_frame', None) @@ -382,6 +354,34 @@ class Aergia(object): return Aergia._should_trace(frame.f_code.co_filename) @staticmethod + def _trace_down(awaitable) -> \ + Tuple[Optional[FrameType], Optional[asyncio.Future]]: + ''' Helper for _get_deepest_traceable_frame + Given AWAITABLE, returns its associated frame and + the future it is waiting on, if any.''' + if asyncio.iscoroutine(awaitable) and \ + hasattr(awaitable, 'cr_await') and \ + hasattr(awaitable, 'cr_frame'): + return getattr(awaitable, 'cr_frame', None), \ + getattr(awaitable, 'cr_await', None) + + # attempt to obtain an async-generator + # can gc be avoided here? + refs = gc.get_referents(awaitable) + if refs: + awaitable = refs[0] + + if isinstance(awaitable, AsyncGeneratorType): + return getattr(awaitable, 'ag_frame', None), \ + getattr(awaitable, 'ag_await', None) + + if isinstance(awaitable, asyncio.Future): + # return whatever future we found. + return None, awaitable + + return None, None + + @staticmethod def _should_trace(filename): '''Returns FALSE if filename is uninteresting to the user. Don't depend on this. It's good enough for testing.''' |
