summaryrefslogtreecommitdiff
path: root/nemesis
diff options
context:
space:
mode:
Diffstat (limited to 'nemesis')
-rw-r--r--nemesis/causal_event_loop.py41
1 files changed, 22 insertions, 19 deletions
diff --git a/nemesis/causal_event_loop.py b/nemesis/causal_event_loop.py
index 12a4084..9b774d0 100644
--- a/nemesis/causal_event_loop.py
+++ b/nemesis/causal_event_loop.py
@@ -17,6 +17,9 @@ Commentary:
Code:
'''
+from typing import TYPE_CHECKING, Any, Callable, Self
+if TYPE_CHECKING:
+ import contextvars
import asyncio
import collections
@@ -54,7 +57,7 @@ class TimeAwareMixin:
# the timestamp this callback entered the pause buffer
time_entered_pause_buffer = None
- def __init__(self):
+ def __init__(self: Self) -> None:
self.register_time = time.monotonic()
def __hash__(self):
@@ -81,9 +84,9 @@ class TimeAwareMixin:
return NotImplemented
-def create_subclass(base_class):
+def create_subclass(base_class: "type[orig_handle]"):
class CausalHandle(base_class, TimeAwareMixin):
- def __init__(self, *args, **kwargs):
+ def __init__(self: Self, *args: tuple|Any|Callable|None, **kwargs: None) -> None:
super().__init__(*args, **kwargs)
TimeAwareMixin.__init__(self)
@@ -114,11 +117,11 @@ class CausalEventLoop(asyncio.SelectorEventLoop):
# the time this experiment started
_start_time = float("inf")
- def __init__(self) -> None:
+ def __init__(self: Any) -> None:
super().__init__()
self.set_logger()
- def set_logger(self):
+ def set_logger(self: Any) -> None:
if not os.path.exists(TRACE_PATH):
with open(TRACE_PATH, 'w') as file:
pass
@@ -134,7 +137,7 @@ class CausalEventLoop(asyncio.SelectorEventLoop):
self._logger.info("STARTING LOOP")
self._logger.info("═" * 40)
- def set_speedup(self, speedup):
+ def set_speedup(self: Any, speedup: float) -> None:
# print(self._coro_intervals)
self._start_time = self.time()
self._speedup = speedup
@@ -146,10 +149,10 @@ class CausalEventLoop(asyncio.SelectorEventLoop):
self._coro_intervals.clear()
self._completed_coros.clear()
- def get_completed_coros(self):
+ def get_completed_coros(self: Any) -> list[tuple[str, float]]:
return copy(self._completed_coros)
- def get_run_time(self):
+ def get_run_time(self: Any) -> float:
curr_time = self.time()
if not self._coro_intervals:
return 0
@@ -160,11 +163,11 @@ class CausalEventLoop(asyncio.SelectorEventLoop):
pause_time = self._get_pause_time(interval)
return (curr_time - self._start_time) - pause_time
- def ping_enter_coro(self):
+ def ping_enter_coro(self: Any) -> None:
# self._logger.debug(f"Recording coro ENTER.")
self._time_entered_coro = self.time()
- def ping_exit_coro(self):
+ def ping_exit_coro(self: Any) -> None:
try:
assert isinstance(self._time_entered_coro, float), "Tried to exit coro before recorded entry!"
except AssertionError as e:
@@ -174,12 +177,12 @@ class CausalEventLoop(asyncio.SelectorEventLoop):
self._coro_intervals.add((self._time_entered_coro, self.time()))
self._time_entered_coro = None
- def collect_ready_events(self, timeout=0):
+ def collect_ready_events(self: Any, timeout: int=0) -> None:
event_list = self._selector.select(timeout)
if event_list:
self._ready_events.append((event_list, self.time()))
- def update_ready(self):
+ def update_ready(self: Any) -> None:
'''
Polls the IO selector, schedules resulting callbacks, and schedules
'call_later' callbacks.
@@ -292,7 +295,7 @@ class CausalEventLoop(asyncio.SelectorEventLoop):
self._ready.append(handle)
self._logger.debug(f"\tpause_buffer -> ready for {_format_handle(handle)}")
- def _run_once(self):
+ def _run_once(self: Any) -> None:
"""
Run one full iteration of the event loop.
@@ -370,7 +373,7 @@ class CausalEventLoop(asyncio.SelectorEventLoop):
writer._when = time_to_buffer
self._add_callback(writer)
- def _call_soon(self, callback, args, context):
+ def _call_soon(self: Any, callback: Any, args: tuple, context: "contextvars.Context"):
"""Do not add 'callsoon' events to the pause buffer.
Add them directly to ready."""
curr_time = self.time()
@@ -422,7 +425,7 @@ class CausalEventLoop(asyncio.SelectorEventLoop):
else:
self._logger.warning(f"\t_add_callback called on cancelled handle {_format_handle(handle)}")
- def _get_pause_for_io(self, handle, io_time):
+ def _get_pause_for_io(self: Any, handle, io_time: float) -> float:
time_interval = (handle.register_time, io_time)
p_time = self._get_pause_time(time_interval)
@@ -434,11 +437,11 @@ class CausalEventLoop(asyncio.SelectorEventLoop):
return p_time
- def _get_pause_for_pause_time(self, handle, exit_time):
+ def _get_pause_for_pause_time(self: Any, handle, exit_time: float) -> float:
time_interval = (handle.time_entered_pause_buffer, exit_time)
return self._get_pause_time(time_interval)
- def _get_pause_time(self, cb_interval):
+ def _get_pause_time(self: Any, cb_interval: tuple[float, float]) -> float:
time = 0
start, end = cb_interval
@@ -459,7 +462,7 @@ class CausalEventLoop(asyncio.SelectorEventLoop):
assert pause_time >= 0, f"Delay was found to be less than 0: {pause_time}"
return pause_time
- def _get_overlap(self, a_start, a_end, b_start, b_end):
+ def _get_overlap(self: Any, a_start: float, a_end: float, b_start: float, b_end: float) -> float:
overlap_start = max(a_start, b_start)
overlap_end = min(a_end, b_end)
assert overlap_end >= overlap_start, f"Bad overlaps: {a_start} {a_end} : {b_start} {b_end} ({overlap_end - overlap_start})"
@@ -480,5 +483,5 @@ class CausalEventLoop(asyncio.SelectorEventLoop):
# to profile your program, start the event loop with
# asyncio.run(coro, loop_factory=causal_loop_factory)
-def causal_loop_factory():
+def causal_loop_factory() -> Any:
return CausalEventLoop()