summaryrefslogtreecommitdiff
path: root/nemesis/causal_event_loop.py
diff options
context:
space:
mode:
Diffstat (limited to 'nemesis/causal_event_loop.py')
-rw-r--r--nemesis/causal_event_loop.py309
1 files changed, 309 insertions, 0 deletions
diff --git a/nemesis/causal_event_loop.py b/nemesis/causal_event_loop.py
new file mode 100644
index 0000000..d01dfd4
--- /dev/null
+++ b/nemesis/causal_event_loop.py
@@ -0,0 +1,309 @@
+'''
+Copyright 2025 bdunahu
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+Commentary:
+
+Code:
+'''
+
+import asyncio
+import collections
+from sortedcontainers import SortedList
+import heapq
+import selectors
+import time
+import traceback
+from asyncio.log import logger
+from asyncio import Task, events
+from asyncio.base_events import _format_handle
+
+_MIN_SCHEDULED_TIMER_HANDLES = 100
+_MIN_CANCELLED_TIMER_HANDLES_FRACTION = 0.5
+MAXIMUM_SELECT_TIMEOUT = 24 * 3600
+
+
+class TimeAwareMixin:
+
+ # the timestamp this callback was registered
+ register_time = None
+ # the timestamp this callback completed i/o
+ io_time = None
+ # the timestamp this callback was called by the event loop
+ process_start_time = None
+
+ def __init__(self):
+ self.register_time = time.monotonic()
+
+
+orig_handle = asyncio.events.Handle
+
+
+def create_subclass(base_class):
+ class NewSubclass(base_class, TimeAwareMixin):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ TimeAwareMixin.__init__(self)
+
+ return NewSubclass
+
+
+# make all the subclasses inherit from TimeAwareHandle as well
+for sc in [orig_handle] + orig_handle.__subclasses__():
+ subclass = create_subclass(sc)
+ setattr(asyncio.events, sc.__name__, subclass)
+
+class CausalEventLoop(asyncio.SelectorEventLoop):
+
+ # a value between 0 and 1. 0 means no optimization,
+ # 1 means the target coroutine is optimized away entirely
+ _speedup = 1.0
+ # a list of callbacks which have recently completed
+ _pause_buffer = []
+ # a list of intervals in which the target coroutine has been active
+ _coro_intervals = SortedList()
+ # a list of completed callbacks, and their associated queue time
+ _completed_coros = []
+ # the last time we entered the target coro
+ _time_entered_coro = None
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def set_speedup(self, speedup):
+ self._speedup = speedup
+
+ # reset experiment counters
+ self._time_entered_coro = None
+ self._coro_intervals.clear()
+ self._completed_coros.clear()
+
+ def get_completed_coros(self):
+ return self._completed_coros
+
+ def get_pause_time(self):
+ if not self._coro_intervals:
+ return 0
+
+ start_interval = self._coro_intervals[0][0]
+ end_interval = self.time() if self._time_entered_coro else self._coro_intervals[-1][1]
+ interval = (start_interval, end_interval)
+ return self._get_pause_time(interval)
+
+ def ping_enter_coro(self):
+ self._time_entered_coro = self.time()
+
+ def ping_exit_coro(self):
+ assert isinstance(self._time_entered_coro, float), f"Tried to exit coro before recorded entry!"
+ self._coro_intervals.add((self._time_entered_coro, self.time()))
+ self._time_entered_coro = None
+
+ def update_ready(self, can_stall=True):
+ '''
+ Polls the IO selector, schedules resulting callbacks, and schedules
+ 'call_later' callbacks.
+
+ This function can be called in the middle of an event loop iteration.
+
+ This logic was separated out of `run_once` so that the list of `ready`
+ tasks may be updated more frequently than once per iteration.
+
+ If SAMPLING is true, the timeout passed to the selector will always be 0.
+ '''
+ curr_time = self.time()
+ sched_count = len(self._scheduled)
+ if (sched_count > _MIN_SCHEDULED_TIMER_HANDLES and
+ self._timer_cancelled_count / sched_count >
+ _MIN_CANCELLED_TIMER_HANDLES_FRACTION):
+ # Remove delayed calls that were cancelled if their number
+ # is too high
+ new_scheduled = []
+ for handle in self._scheduled:
+ if handle._cancelled:
+ handle._scheduled = False
+ else:
+ new_scheduled.append(handle)
+
+ heapq.heapify(new_scheduled)
+ self._scheduled = new_scheduled
+ self._timer_cancelled_count = 0
+ else:
+ # Remove delayed calls that were cancelled from head of queue.
+ while self._scheduled and self._scheduled[0]._cancelled:
+ self._timer_cancelled_count -= 1
+ handle = heapq.heappop(self._scheduled)
+ handle._scheduled = False
+
+ timeout = None
+ # TODO this needs to be rewritten
+ # We can't miss things placed in timeout either
+ # if not can_stall or self._ready or self._stopping:
+ timeout = 0
+ # elif self._scheduled:
+ # # Compute the desired timeout.
+ # timeout = self._scheduled[0]._when - self.time()
+ # if timeout > MAXIMUM_SELECT_TIMEOUT:
+ # timeout = MAXIMUM_SELECT_TIMEOUT
+ # elif timeout < 0:
+ # timeout = 0
+
+ event_list = self._selector.select(timeout)
+ self._process_events(event_list)
+ # Needed to break cycles when an exception occurs.
+ event_list = None
+
+ # Handle 'later' callbacks that are ready.
+ end_time = self.time() + self._clock_resolution
+ while self._scheduled:
+ handle = self._scheduled[0]
+ if handle._when >= end_time:
+ break
+ handle = heapq.heappop(self._scheduled)
+ handle._scheduled = False
+
+ time_interval = (handle.register_time, curr_time)
+ time_to_buffer = curr_time + self._get_pause_time(time_interval)
+ handle.io_time = time_to_buffer
+ self._ready.append(handle)
+
+ def _run_once(self):
+ """
+ Run one full iteration of the event loop.
+
+ This calls all currently ready callbacks.
+ """
+ self.update_ready()
+
+ current_time = self.time()
+ to_process = collections.deque([
+ handle for handle in self._ready
+ if handle.io_time < (current_time + self._clock_resolution)
+ ])
+
+ # This is the only place where callbacks are actually *called*.
+ # All other places just add them to ready.
+ # Note: We run all currently scheduled callbacks, but not any
+ # callbacks scheduled by callbacks run this time around --
+ # they will be run the next time (after another I/O poll).
+ # Use an idiom that is thread-safe without using locks.
+ ntodo = len(to_process)
+ for i in range(ntodo):
+ handle = to_process.popleft()
+ self._ready.remove(handle)
+ if handle._cancelled:
+ continue
+ try:
+ self._current_handle = handle
+
+ process_start_time = self.time()
+ handle.process_start_time = process_start_time
+
+ handle._run()
+
+ process_end_time = self.time()
+ dt = process_end_time - process_start_time
+ if self._debug and dt >= self.slow_callback_duration:
+ logger.warning('Executing %s took %.3f seconds',
+ _format_handle(handle), dt)
+
+ time_interval = (handle.io_time, process_start_time)
+ pause_time = self._get_pause_time(time_interval)
+ adjusted_start_time = handle.process_start_time - \
+ pause_time
+ wait_time = adjusted_start_time - handle.io_time
+ assert wait_time >= -0.0001, f"wait time on {_format_handle(handle)} was found to be {wait_time:.4f}!"
+ self._completed_coros.append((_format_handle(handle), wait_time))
+ except Exception:
+ traceback.print_exc()
+ finally:
+ self._current_handle = None
+ handle = None # Needed to break cycles when an exception occurs.
+
+ def _process_events(self, event_list):
+ curr_time = self.time()
+ for key, mask in event_list:
+ fileobj, (reader, writer) = key.fileobj, key.data
+ if mask & selectors.EVENT_READ and reader is not None:
+ if reader._cancelled:
+ self._remove_reader(fileobj)
+ else:
+ time_interval = (reader.register_time, curr_time)
+ time_to_buffer = curr_time + \
+ self._get_pause_time(time_interval)
+ reader.io_time = time_to_buffer
+ self._add_callback(reader)
+ if mask & selectors.EVENT_WRITE and writer is not None:
+ if writer._cancelled:
+ self._remove_writer(fileobj)
+ else:
+ time_interval = (writer.register_time, curr_time)
+ time_to_buffer = curr_time + \
+ self._get_pause_time(time_interval)
+ writer.io_time = time_to_buffer
+ self._add_callback(writer)
+
+ def _call_soon(self, callback, args, context):
+ handle = events.Handle(callback, args, self, context)
+ if handle._source_traceback:
+ del handle._source_traceback[-1]
+ handle.io_time = self.time()
+ self._ready.append(handle)
+ return handle
+
+ def call_soon_threadsafe(self, callback, *args, context=None):
+ """Like call_soon(), but thread-safe."""
+ self._check_closed()
+ if self._debug:
+ self._check_callback(callback, 'call_soon_threadsafe')
+ handle = events._ThreadSafeHandle(callback, args, self, context)
+ handle.io_time = self.time()
+ self._ready.append(handle)
+ if handle._source_traceback:
+ del handle._source_traceback[-1]
+ if handle._source_traceback:
+ del handle._source_traceback[-1]
+ self._write_to_self()
+ return handle
+
+ def _get_pause_time(self, cb_interval):
+ time = 0
+ start, end = cb_interval
+
+ for coro_start, coro_end in self._coro_intervals:
+ if start < coro_end and coro_start < end:
+ time += self._get_overlap(start, end, coro_start, coro_end)
+ # coro_intervals are sorted, so by this time all overlap has passed
+ if end < coro_start:
+ break
+
+ curr_time = self.time()
+ if self._time_entered_coro and \
+ start < curr_time and self._time_entered_coro < end:
+ time += self._get_overlap(start, end, self._time_entered_coro,
+ curr_time)
+
+ return time * self._speedup
+
+ def _get_overlap(self, a_start, a_end, b_start, b_end):
+ overlap_start = max(a_start, b_start)
+ overlap_end = min(a_end, b_end)
+ return overlap_end - overlap_start
+
+
+class CausalEventLoopPolicy(asyncio.DefaultEventLoopPolicy):
+ def new_event_loop(self):
+ return CausalEventLoop()
+
+
+asyncio.set_event_loop_policy(CausalEventLoopPolicy())