''' Copyright 2025 bdunahu Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Commentary: Code: ''' import asyncio import collections from sortedcontainers import SortedList import heapq import selectors import sys import time import traceback from copy import copy from asyncio.log import logger from asyncio import Task, events from asyncio.base_events import _format_handle _MIN_SCHEDULED_TIMER_HANDLES = 100 _MIN_CANCELLED_TIMER_HANDLES_FRACTION = 0.5 MAXIMUM_SELECT_TIMEOUT = 24 * 3600 orig_handle = asyncio.events.Handle class TimeAwareMixin: # the timestamp this callback was registered register_time = None # the timestamp this callback completed i/o _when = None # the timestamp this callback entered the pause buffer time_entered_pause_buffer = None def __init__(self): self.register_time = time.monotonic() def __hash__(self): return hash(self._when) def __lt__(self, other): if isinstance(other, orig_handle): return self._when < other._when return NotImplemented def __le__(self, other): if isinstance(other, orig_handle): return self._when < other._when or self.__eq__(other) return NotImplemented def __gt__(self, other): if isinstance(other, orig_handle): return self._when > other._when return NotImplemented def __ge__(self, other): if isinstance(other, orig_handle): return self._when > other._when or self.__eq__(other) return NotImplemented def create_subclass(base_class): class CausalHandle(base_class, TimeAwareMixin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) TimeAwareMixin.__init__(self) return CausalHandle # make all the subclasses inherit from TimeAwareHandle as well for sc in [orig_handle] + orig_handle.__subclasses__(): subclass = create_subclass(sc) setattr(asyncio.events, sc.__name__, subclass) class CausalEventLoop(asyncio.SelectorEventLoop): # a value between 0 and 1. 0 means no optimization, # 1 means the target coroutine is optimized away entirely _speedup = 1.0 # a list of callbacks which have recently completed _pause_buffer = [] # a list of intervals in which the target coroutine has been active _coro_intervals = SortedList() # a list of completed callbacks, and their associated queue time _completed_coros = [] # the last time we entered the target coro _time_entered_coro = None _ready_events = [] # the time this experiment started _start_time = float("inf") def __init__(self) -> None: super().__init__() def set_speedup(self, speedup): # print(self._coro_intervals) self._start_time = self.time() self._speedup = speedup # reset experiment counters self._coro_intervals.clear() self._completed_coros.clear() def get_completed_coros(self): return copy(self._completed_coros) def get_run_time(self): curr_time = self.time() if not self._coro_intervals: return 0 start_interval = self._coro_intervals[0][0] end_interval = curr_time if self._time_entered_coro else self._coro_intervals[-1][1] interval = (start_interval, end_interval) pause_time = self._get_pause_time(interval) return (curr_time - self._start_time) - pause_time def ping_enter_coro(self): self._time_entered_coro = self.time() def ping_exit_coro(self): try: assert isinstance(self._time_entered_coro, float), "Tried to exit coro before recorded entry!" except AssertionError as e: print(f"Assertion failed: {e}") sys.exit(1) self._coro_intervals.add((self._time_entered_coro, self.time())) self._time_entered_coro = None def collect_ready_events(self, timeout=0): event_list = self._selector.select(timeout) self._ready_events.append((event_list, self.time())) def update_ready(self): ''' Polls the IO selector, schedules resulting callbacks, and schedules 'call_later' callbacks. ''' curr_time = self.time() sched_count = len(self._scheduled) if (sched_count > _MIN_SCHEDULED_TIMER_HANDLES and self._timer_cancelled_count / sched_count > _MIN_CANCELLED_TIMER_HANDLES_FRACTION): # Remove delayed calls that were cancelled if their number # is too high new_scheduled = [] for handle in self._scheduled: if handle._cancelled: handle._scheduled = False else: new_scheduled.append(handle) heapq.heapify(new_scheduled) self._scheduled = new_scheduled self._timer_cancelled_count = 0 else: # Remove delayed calls that were cancelled from head of queue. while self._scheduled and self._scheduled[0]._cancelled: self._timer_cancelled_count -= 1 handle = heapq.heappop(self._scheduled) handle._scheduled = False timeout = None if self._ready or self._stopping: timeout = 0 else: curr_time = self.time() if self._scheduled: # Compute the desired timeout. # requires computing our-best guess arrival time handle = self._scheduled[0] timeout = (handle._when + self._get_pause_for_io(handle, curr_time)) \ - curr_time timeout = max(0, min(timeout, MAXIMUM_SELECT_TIMEOUT)) if self._pause_buffer: # pause buffer has an exact arrival time pause_timeout = self._pause_buffer[0]._when - curr_time timeout = min(pause_timeout, timeout) if timeout else pause_timeout timeout = max(0, min(timeout, MAXIMUM_SELECT_TIMEOUT)) timeout = 0 self.collect_ready_events(timeout) for event_list in self._ready_events: self._process_events(*event_list) # Handle 'later' callbacks that are ready. curr_time = self.time() end_time = curr_time + self._clock_resolution while self._scheduled: handle = self._scheduled[0] when = handle._when if when >= end_time: break handle = heapq.heappop(self._scheduled) handle._scheduled = False time_to_buffer = when + self._get_pause_for_io(handle, curr_time) handle._when = time_to_buffer handle.time_entered_pause_buffer = curr_time heapq.heappush(self._pause_buffer, handle) # handle callbacks which can leave pause timeout while self._pause_buffer: # required when handle's _when is modified in place heapq.heapify(self._pause_buffer) handle = self._pause_buffer[0] for other in self._pause_buffer[1:]: assert handle._when <= other._when, f"Heap root {handle} is not smallest" if handle._when >= end_time: break # pop the first item in the list handle = heapq.heappop(self._pause_buffer) # if we paused during buffering, we need to delay again # TODO clean this up # this whole file has 'rounding' timing errors :( handle._when = handle._when + \ self._get_pause_for_pause_time(handle, curr_time) if handle._when >= end_time: handle.time_entered_pause_buffer = curr_time heapq.heappush(self._pause_buffer, handle) else: self._ready.append(handle) def _run_once(self): """ Run one full iteration of the event loop. This calls all currently ready callbacks. """ self.update_ready() # This is the only place where callbacks are actually *called*. # All other places just add them to ready. # Note: We run all currently scheduled callbacks, but not any # callbacks scheduled by callbacks run this time around -- # they will be run the next time (after another I/O poll). # Use an idiom that is thread-safe without using locks. ntodo = len(self._ready) for i in range(ntodo): handle = self._ready.popleft() if handle._cancelled: continue try: self._current_handle = handle process_start_time = self.time() handle._run() process_end_time = self.time() dt = process_end_time - process_start_time if self._debug and dt >= self.slow_callback_duration: logger.warning('Executing %s took %.3f seconds', _format_handle(handle), dt) # do not record coroutines which left I/O during the previous experiment # the time held in the pause buffer would have also been incorrect for # this experiment, but there is nothing we can do about it. if handle._when > self._start_time: time_interval = (handle._when, process_start_time) pause_time = self._get_pause_time(time_interval) adjusted_start_time = process_start_time - pause_time wait_time = adjusted_start_time - handle._when try: assert wait_time >= -0.0001, f"wait time on {_format_handle(handle)} was found to be {wait_time:.4f}!" except AssertionError as e: print(f"Assertion failed: {e}") sys.exit(1) self._completed_coros.append((_format_handle(handle), wait_time)) except Exception: traceback.print_exc() finally: self._current_handle = None handle = None # Needed to break cycles when an exception occurs. def _process_events(self, event_list, time): for key, mask in event_list: fileobj, (reader, writer) = key.fileobj, key.data if mask & selectors.EVENT_READ and reader is not None: if reader._cancelled: self._remove_reader(fileobj) else: time_to_buffer = time + \ self._get_pause_for_io(reader, time) reader._when = time_to_buffer self._add_callback(reader) if mask & selectors.EVENT_WRITE and writer is not None: if writer._cancelled: self._remove_writer(fileobj) else: time_to_buffer = time + \ self._get_pause_for_io(writer, time) writer._when = time_to_buffer self._add_callback(writer) def _call_soon(self, callback, args, context): """Do not add 'callsoon' events to the pause buffer. Add them directly to ready.""" curr_time = self.time() handle = events.Handle(callback, args, self, context) if handle._source_traceback: del handle._source_traceback[-1] if not handle._when: handle._when = curr_time self._ready.append(handle) return handle def call_soon_threadsafe(self, callback, *args, context=None): """Like call_soon(), but thread-safe.""" self._check_closed() if self._debug: self._check_callback(callback, 'call_soon_threadsafe') handle = events._ThreadSafeHandle(callback, args, self, context) self._add_callback(handle) if handle._source_traceback: del handle._source_traceback[-1] if handle._source_traceback: del handle._source_traceback[-1] self._write_to_self() return handle def _add_callback(self, handle): """Add a Handle to _pause_buffer.""" curr_time = self.time() if not handle._when: handle._when = curr_time # required in cases where the event loop reuses the same handle if handle._when < curr_time: handle._when = curr_time if not handle._cancelled and handle not in self._pause_buffer: if self._is_vital(handle): handle._when = curr_time # print(f'{_format_handle(handle)}: {curr_time}') self._ready.append(handle) else: # print(f'DELAYED {_format_handle(handle)}: {curr_time}') handle.time_entered_pause_buffer = curr_time heapq.heappush(self._pause_buffer, handle) def _get_pause_for_io(self, handle, io_time): time_interval = (handle.register_time, io_time) p_time = self._get_pause_time(time_interval) try: assert p_time >= 0, f"calculated pause time on {_format_handle(handle)} was found to be {p_time:.4f}!" except AssertionError as e: print(f"Assertion failed: {e}") sys.exit(1) return p_time def _get_pause_for_pause_time(self, handle, exit_time): time_interval = (handle.time_entered_pause_buffer, exit_time) return self._get_pause_time(time_interval) def _get_pause_time(self, cb_interval): time = 0 start, end = cb_interval for coro_start, coro_end in self._coro_intervals: if start < coro_end and coro_start < end: time += self._get_overlap(start, end, coro_start, coro_end) # coro_intervals are sorted, so by this time all overlap has passed if end < coro_start: break curr_time = self.time() if self._time_entered_coro and \ start < curr_time and self._time_entered_coro < end: time += self._get_overlap(start, end, self._time_entered_coro, curr_time) return time * self._speedup def _get_overlap(self, a_start, a_end, b_start, b_end): overlap_start = max(a_start, b_start) overlap_end = min(a_end, b_end) return overlap_end - overlap_start def _is_vital(self, handle): """ Methods which cannot afford to be paused.""" blacklist = ['_read_from_self', '_read_ready'] cb = handle._callback if isinstance(getattr(cb, '__self__', None), asyncio.tasks.Task): if cb.__self__.get_coro().__name__ in blacklist: return True else: if getattr(cb, '__name__', None) in blacklist: return True return False class CausalEventLoopPolicy(asyncio.DefaultEventLoopPolicy): def new_event_loop(self): return CausalEventLoop() asyncio.set_event_loop_policy(CausalEventLoopPolicy())