summaryrefslogtreecommitdiff
path: root/aergia
diff options
context:
space:
mode:
authorbd <bdunahu@operationnull.com>2025-08-08 21:25:14 -0400
committerbd <bdunahu@operationnull.com>2025-08-08 21:25:14 -0400
commit019a74e6337e83478998bd7e6faeca635332710e (patch)
tree0c32164f4eb496ff9a0b43997c8bf90c0cdfd415 /aergia
parent25b62280ae5db0544e20151c5561dbafce9cb3dc (diff)
More in-depth type-checking for _search_future (now _trace_down)
Diffstat (limited to 'aergia')
-rwxr-xr-xaergia/aergia.py88
1 files changed, 44 insertions, 44 deletions
diff --git a/aergia/aergia.py b/aergia/aergia.py
index bb8dcbe..a41ef92 100755
--- a/aergia/aergia.py
+++ b/aergia/aergia.py
@@ -44,7 +44,8 @@ Code:
'''
from collections import defaultdict, namedtuple
-from typing import Optional
+from typing import Optional, Tuple
+from types import FrameType
import argparse
import asyncio
import signal
@@ -293,7 +294,7 @@ class Aergia(object):
@staticmethod
def _get_deepest_traceable_frame(coro):
'''Get the deepest frame of coro we care to trace.
- This is possible because each corooutine keeps a reference to the
+ This is possible because each coroutine keeps a reference to the
coroutine it is waiting on.
Note that it cannot be the case that a task is suspended in a frame
@@ -302,55 +303,26 @@ class Aergia(object):
curr = coro
deepest_frame = None
while curr:
- frame = getattr(curr, 'cr_frame', None)
+ frame, curr = Aergia._trace_down(curr)
- if not frame:
- curr = Aergia._search_future(curr)
- if isinstance(curr, AsyncGeneratorType):
- frame = getattr(curr, 'ag_frame', None)
- else:
- break
+ if frame is None:
+ break
if Aergia._should_trace(frame.f_code.co_filename):
deepest_frame = frame
- if isinstance(curr, AsyncGeneratorType):
- curr = getattr(curr, 'ag_await', None)
- else:
- curr = getattr(curr, 'cr_await', None)
-
# if this task is found to point to another task we're profiling,
# then we will get the deepest frame later and should return nothing.
- if isinstance(curr, list) and any(
- Aergia._should_trace_task(task)
- for task in curr
- ):
- return None
-
- return deepest_frame
-
- @staticmethod
- def _search_future(future):
- '''Given an awaitable which is not a coroutine, assume it is a future
- and attempt to find references to tasks or async generators.'''
- awaitable = None
- if not isinstance(future, asyncio.Future):
- # TODO some wrappers like _asyncio.FutureIter,
- # async_generator_asend get caught here, I am not sure if a more
- # robust approach is necessary
-
- # can gc be avoided here?
- refs = gc.get_referents(future)
- if refs:
- awaitable = refs[0]
-
# this is specific to gathering futures, i.e., gather statement.
- # Other cases may need to be added.
- if isinstance(awaitable, asyncio.Future):
- return getattr(awaitable, '_children', [])
+ if curr is not None:
+ tasks = getattr(curr, '_children', [])
+ if any(
+ Aergia._should_trace_task(task)
+ for task in tasks
+ ):
+ return None
- # if this is not AsyncGeneratorType, it is ignored
- return awaitable
+ return deepest_frame
@staticmethod
def _should_trace_task(task):
@@ -373,8 +345,8 @@ class Aergia(object):
# statement.
# if this isn't the case, the associated coroutine will
# be 'waiting' on the coroutine declaration. No! Bad!
- if getattr(coro, 'cr_frame', None) is None or \
- getattr(coro, 'cr_await', None) is None:
+ frame, awaitable = Aergia._trace_down(coro)
+ if frame is None or awaitable is None:
return False
frame = getattr(coro, 'cr_frame', None)
@@ -382,6 +354,34 @@ class Aergia(object):
return Aergia._should_trace(frame.f_code.co_filename)
@staticmethod
+ def _trace_down(awaitable) -> \
+ Tuple[Optional[FrameType], Optional[asyncio.Future]]:
+ ''' Helper for _get_deepest_traceable_frame
+ Given AWAITABLE, returns its associated frame and
+ the future it is waiting on, if any.'''
+ if asyncio.iscoroutine(awaitable) and \
+ hasattr(awaitable, 'cr_await') and \
+ hasattr(awaitable, 'cr_frame'):
+ return getattr(awaitable, 'cr_frame', None), \
+ getattr(awaitable, 'cr_await', None)
+
+ # attempt to obtain an async-generator
+ # can gc be avoided here?
+ refs = gc.get_referents(awaitable)
+ if refs:
+ awaitable = refs[0]
+
+ if isinstance(awaitable, AsyncGeneratorType):
+ return getattr(awaitable, 'ag_frame', None), \
+ getattr(awaitable, 'ag_await', None)
+
+ if isinstance(awaitable, asyncio.Future):
+ # return whatever future we found.
+ return None, awaitable
+
+ return None, None
+
+ @staticmethod
def _should_trace(filename):
'''Returns FALSE if filename is uninteresting to the user.
Don't depend on this. It's good enough for testing.'''