summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorbd <bdunahu@operationnull.com>2025-06-26 20:54:06 -0400
committerbd <bdunahu@operationnull.com>2025-06-26 20:54:06 -0400
commit708b2fc12bc6016c7359ea603c115950bddfefc5 (patch)
tree31b238b8d4d5d4f614f9072dba8b53cd81265e63
parentf685a5a17ddbd9fa6b11640ec1f2fe5a0ec29953 (diff)
base support for multi-threaded programs
-rwxr-xr-xaergia91
1 files changed, 75 insertions, 16 deletions
diff --git a/aergia b/aergia
index 685a13f..5a601c7 100755
--- a/aergia
+++ b/aergia
@@ -43,16 +43,43 @@ Code:
from collections import defaultdict
from types import FrameType
-from typing import cast, List, Tuple
+from typing import cast, List, Optional, Tuple
import argparse
import asyncio
import selectors
import signal
import sys
import time
+import threading
import traceback
+orig_thread_join = threading.Thread.join
+
+
+def thread_join_replacement(
+ self: threading.Thread, timeout: Optional[float] = None
+) -> None:
+ '''
+ We replace threading.Thread.join with this method which always
+ periodically yields.
+ '''
+
+ start_time = time.perf_counter()
+ interval = sys.getswitchinterval()
+ while self.is_alive():
+ orig_thread_join(self, interval)
+ # If a timeout was specified, check to see if it's expired.
+ if timeout is not None:
+ end_time = time.perf_counter()
+ if end_time - start_time >= timeout:
+ return None
+ return None
+
+
+threading.Thread.join = thread_join_replacement
+
+
class ReplacementEpollSelector(selectors.EpollSelector):
'''
Provides a replacement for selectors.PollSelector that
@@ -154,9 +181,7 @@ class Aergia(object):
@staticmethod
def compute_frames_to_record():
'''Collects all stack frames that Aergia actually processes.'''
- frames = []
- if Aergia.is_event_loop_running():
- frames = Aergia.get_current_idle_tasks()
+ frames = Aergia.get_frames_from_runners(Aergia.get_event_loops())
# Process all the frames to remove ones we aren't going to track.
new_frames = []
@@ -186,13 +211,42 @@ class Aergia(object):
return new_frames
@staticmethod
- def get_current_idle_tasks():
- '''Obtains the stack of frames of all currently idle tasks.'''
- curr_task = asyncio.current_task()
+ def get_event_loops():
+ runners = []
+ for t in threading.enumerate():
+ frame = sys._current_frames().get(t.ident)
+ if not frame:
+ continue
+ runner = Aergia.walk_back_until_runner(frame)
+ if runner:
+ runners.append(runner)
+ return runners
+
+ @staticmethod
+ def walk_back_until_runner(frame):
+ while frame:
+ r = Aergia.find_runner_in_locals(frame.f_locals)
+ if r:
+ return r
+ frame = frame.f_back
+ return None
+
+ @staticmethod
+ def find_runner_in_locals(locals_dict):
+ '''Given a dictionary of local variables for a stack frame,
+ retrieves the asyncio runner object, if there is one.'''
+ for val in locals_dict.values():
+ if type(val).__name__ == 'Runner' and \
+ val.__class__.__module__ == 'asyncio.runners':
+ return val
+ return None
+
+ @staticmethod
+ def get_frames_from_runners(runners):
+ '''Given RUNNERS, returns a flat list of tasks.'''
return [
- task.get_coro().cr_frame
- for task in asyncio.all_tasks()
- if task is not curr_task
+ task for runner in runners
+ for task in Aergia.get_idle_task_frames(runner._loop)
]
@staticmethod
@@ -220,12 +274,6 @@ class Aergia(object):
return True
@staticmethod
- def is_event_loop_running() -> bool:
- '''Returns TRUE if there is an exent loop running. This is what
- `asyncio.get_event_loop()' did, before it was deprecated in 3.12'''
- return asyncio.get_event_loop_policy()._local._loop is not None
-
- @staticmethod
def sort_samples(sample_dict):
'''Returns SAMPLE_DICT in descending order by number of samples.'''
return {k: v for k, v in sorted(sample_dict.items(),
@@ -234,6 +282,17 @@ class Aergia(object):
reverse=True)}
@staticmethod
+ def get_idle_task_frames(loop):
+ '''Given an asyncio event loop, returns the list of idle task frames.
+ A task is considered 'idle' if it is not currently executing.'''
+ curr_task = asyncio.current_task(loop)
+ return [
+ task.get_coro().cr_frame
+ for task in asyncio.all_tasks(loop)
+ if task is not curr_task
+ ]
+
+ @staticmethod
def sum_sample(sample):
return sample[0] + sample[1]