diff options
| author | bd <bdunahu@operationnull.com> | 2025-10-06 18:27:09 -0400 |
|---|---|---|
| committer | bd <bdunahu@operationnull.com> | 2025-10-06 18:27:09 -0400 |
| commit | 2edc08465723f444a1ef4108d41bac852f7be88a (patch) | |
| tree | 53f5d1c4eca459c0c9784844b3d8c80bc4b03287 /nemesis/causal_event_loop.py | |
initial commit
Diffstat (limited to 'nemesis/causal_event_loop.py')
| -rw-r--r-- | nemesis/causal_event_loop.py | 309 |
1 files changed, 309 insertions, 0 deletions
diff --git a/nemesis/causal_event_loop.py b/nemesis/causal_event_loop.py new file mode 100644 index 0000000..d01dfd4 --- /dev/null +++ b/nemesis/causal_event_loop.py @@ -0,0 +1,309 @@ +''' +Copyright 2025 bdunahu + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +Commentary: + +Code: +''' + +import asyncio +import collections +from sortedcontainers import SortedList +import heapq +import selectors +import time +import traceback +from asyncio.log import logger +from asyncio import Task, events +from asyncio.base_events import _format_handle + +_MIN_SCHEDULED_TIMER_HANDLES = 100 +_MIN_CANCELLED_TIMER_HANDLES_FRACTION = 0.5 +MAXIMUM_SELECT_TIMEOUT = 24 * 3600 + + +class TimeAwareMixin: + + # the timestamp this callback was registered + register_time = None + # the timestamp this callback completed i/o + io_time = None + # the timestamp this callback was called by the event loop + process_start_time = None + + def __init__(self): + self.register_time = time.monotonic() + + +orig_handle = asyncio.events.Handle + + +def create_subclass(base_class): + class NewSubclass(base_class, TimeAwareMixin): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + TimeAwareMixin.__init__(self) + + return NewSubclass + + +# make all the subclasses inherit from TimeAwareHandle as well +for sc in [orig_handle] + orig_handle.__subclasses__(): + subclass = create_subclass(sc) + setattr(asyncio.events, sc.__name__, subclass) + +class CausalEventLoop(asyncio.SelectorEventLoop): + + # a value between 0 and 1. 0 means no optimization, + # 1 means the target coroutine is optimized away entirely + _speedup = 1.0 + # a list of callbacks which have recently completed + _pause_buffer = [] + # a list of intervals in which the target coroutine has been active + _coro_intervals = SortedList() + # a list of completed callbacks, and their associated queue time + _completed_coros = [] + # the last time we entered the target coro + _time_entered_coro = None + + def __init__(self) -> None: + super().__init__() + + def set_speedup(self, speedup): + self._speedup = speedup + + # reset experiment counters + self._time_entered_coro = None + self._coro_intervals.clear() + self._completed_coros.clear() + + def get_completed_coros(self): + return self._completed_coros + + def get_pause_time(self): + if not self._coro_intervals: + return 0 + + start_interval = self._coro_intervals[0][0] + end_interval = self.time() if self._time_entered_coro else self._coro_intervals[-1][1] + interval = (start_interval, end_interval) + return self._get_pause_time(interval) + + def ping_enter_coro(self): + self._time_entered_coro = self.time() + + def ping_exit_coro(self): + assert isinstance(self._time_entered_coro, float), f"Tried to exit coro before recorded entry!" + self._coro_intervals.add((self._time_entered_coro, self.time())) + self._time_entered_coro = None + + def update_ready(self, can_stall=True): + ''' + Polls the IO selector, schedules resulting callbacks, and schedules + 'call_later' callbacks. + + This function can be called in the middle of an event loop iteration. + + This logic was separated out of `run_once` so that the list of `ready` + tasks may be updated more frequently than once per iteration. + + If SAMPLING is true, the timeout passed to the selector will always be 0. + ''' + curr_time = self.time() + sched_count = len(self._scheduled) + if (sched_count > _MIN_SCHEDULED_TIMER_HANDLES and + self._timer_cancelled_count / sched_count > + _MIN_CANCELLED_TIMER_HANDLES_FRACTION): + # Remove delayed calls that were cancelled if their number + # is too high + new_scheduled = [] + for handle in self._scheduled: + if handle._cancelled: + handle._scheduled = False + else: + new_scheduled.append(handle) + + heapq.heapify(new_scheduled) + self._scheduled = new_scheduled + self._timer_cancelled_count = 0 + else: + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0]._cancelled: + self._timer_cancelled_count -= 1 + handle = heapq.heappop(self._scheduled) + handle._scheduled = False + + timeout = None + # TODO this needs to be rewritten + # We can't miss things placed in timeout either + # if not can_stall or self._ready or self._stopping: + timeout = 0 + # elif self._scheduled: + # # Compute the desired timeout. + # timeout = self._scheduled[0]._when - self.time() + # if timeout > MAXIMUM_SELECT_TIMEOUT: + # timeout = MAXIMUM_SELECT_TIMEOUT + # elif timeout < 0: + # timeout = 0 + + event_list = self._selector.select(timeout) + self._process_events(event_list) + # Needed to break cycles when an exception occurs. + event_list = None + + # Handle 'later' callbacks that are ready. + end_time = self.time() + self._clock_resolution + while self._scheduled: + handle = self._scheduled[0] + if handle._when >= end_time: + break + handle = heapq.heappop(self._scheduled) + handle._scheduled = False + + time_interval = (handle.register_time, curr_time) + time_to_buffer = curr_time + self._get_pause_time(time_interval) + handle.io_time = time_to_buffer + self._ready.append(handle) + + def _run_once(self): + """ + Run one full iteration of the event loop. + + This calls all currently ready callbacks. + """ + self.update_ready() + + current_time = self.time() + to_process = collections.deque([ + handle for handle in self._ready + if handle.io_time < (current_time + self._clock_resolution) + ]) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is thread-safe without using locks. + ntodo = len(to_process) + for i in range(ntodo): + handle = to_process.popleft() + self._ready.remove(handle) + if handle._cancelled: + continue + try: + self._current_handle = handle + + process_start_time = self.time() + handle.process_start_time = process_start_time + + handle._run() + + process_end_time = self.time() + dt = process_end_time - process_start_time + if self._debug and dt >= self.slow_callback_duration: + logger.warning('Executing %s took %.3f seconds', + _format_handle(handle), dt) + + time_interval = (handle.io_time, process_start_time) + pause_time = self._get_pause_time(time_interval) + adjusted_start_time = handle.process_start_time - \ + pause_time + wait_time = adjusted_start_time - handle.io_time + assert wait_time >= -0.0001, f"wait time on {_format_handle(handle)} was found to be {wait_time:.4f}!" + self._completed_coros.append((_format_handle(handle), wait_time)) + except Exception: + traceback.print_exc() + finally: + self._current_handle = None + handle = None # Needed to break cycles when an exception occurs. + + def _process_events(self, event_list): + curr_time = self.time() + for key, mask in event_list: + fileobj, (reader, writer) = key.fileobj, key.data + if mask & selectors.EVENT_READ and reader is not None: + if reader._cancelled: + self._remove_reader(fileobj) + else: + time_interval = (reader.register_time, curr_time) + time_to_buffer = curr_time + \ + self._get_pause_time(time_interval) + reader.io_time = time_to_buffer + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer._cancelled: + self._remove_writer(fileobj) + else: + time_interval = (writer.register_time, curr_time) + time_to_buffer = curr_time + \ + self._get_pause_time(time_interval) + writer.io_time = time_to_buffer + self._add_callback(writer) + + def _call_soon(self, callback, args, context): + handle = events.Handle(callback, args, self, context) + if handle._source_traceback: + del handle._source_traceback[-1] + handle.io_time = self.time() + self._ready.append(handle) + return handle + + def call_soon_threadsafe(self, callback, *args, context=None): + """Like call_soon(), but thread-safe.""" + self._check_closed() + if self._debug: + self._check_callback(callback, 'call_soon_threadsafe') + handle = events._ThreadSafeHandle(callback, args, self, context) + handle.io_time = self.time() + self._ready.append(handle) + if handle._source_traceback: + del handle._source_traceback[-1] + if handle._source_traceback: + del handle._source_traceback[-1] + self._write_to_self() + return handle + + def _get_pause_time(self, cb_interval): + time = 0 + start, end = cb_interval + + for coro_start, coro_end in self._coro_intervals: + if start < coro_end and coro_start < end: + time += self._get_overlap(start, end, coro_start, coro_end) + # coro_intervals are sorted, so by this time all overlap has passed + if end < coro_start: + break + + curr_time = self.time() + if self._time_entered_coro and \ + start < curr_time and self._time_entered_coro < end: + time += self._get_overlap(start, end, self._time_entered_coro, + curr_time) + + return time * self._speedup + + def _get_overlap(self, a_start, a_end, b_start, b_end): + overlap_start = max(a_start, b_start) + overlap_end = min(a_end, b_end) + return overlap_end - overlap_start + + +class CausalEventLoopPolicy(asyncio.DefaultEventLoopPolicy): + def new_event_loop(self): + return CausalEventLoop() + + +asyncio.set_event_loop_policy(CausalEventLoopPolicy()) |
