''' Copyright 2025 bdunahu Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Commentary: Code: ''' import asyncio import collections from sortedcontainers import SortedList import heapq import selectors import sys import os import time import traceback from copy import copy from asyncio.log import logger from asyncio import Task, events from asyncio.base_events import _format_handle from pathlib import Path from pprint import pformat import logging _MIN_SCHEDULED_TIMER_HANDLES = 100 _MIN_CANCELLED_TIMER_HANDLES_FRACTION = 0.5 MAXIMUM_SELECT_TIMEOUT = 24 * 3600 TRACE_PATH = \ Path(__file__).resolve().parent.parent / "logs" / "trace.log" orig_handle = asyncio.events.Handle class TimeAwareMixin: # the timestamp this callback was registered register_time = None # the timestamp this callback completed i/o _when = None # the timestamp this callback entered the pause buffer time_entered_pause_buffer = None def __init__(self): self.register_time = time.monotonic() def __hash__(self): return hash(self._when) def __lt__(self, other): if isinstance(other, orig_handle): return self._when < other._when return NotImplemented def __le__(self, other): if isinstance(other, orig_handle): return self._when < other._when or self.__eq__(other) return NotImplemented def __gt__(self, other): if isinstance(other, orig_handle): return self._when > other._when return NotImplemented def __ge__(self, other): if isinstance(other, orig_handle): return self._when > other._when or self.__eq__(other) return NotImplemented def create_subclass(base_class): class CausalHandle(base_class, TimeAwareMixin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) TimeAwareMixin.__init__(self) return CausalHandle # make all the subclasses inherit from TimeAwareHandle as well for sc in [orig_handle] + orig_handle.__subclasses__(): subclass = create_subclass(sc) setattr(asyncio.events, sc.__name__, subclass) class CausalEventLoop(asyncio.SelectorEventLoop): _log_level = logging.DEBUG _logger = None # a value between 0 and 1. 0 means no optimization, # 1 means the target coroutine is optimized away entirely _speedup = 0.0 # a list of callbacks which have recently completed _pause_buffer = [] # a list of intervals in which the target coroutine has been active _coro_intervals = SortedList() # a list of completed callbacks, and their associated queue time _completed_coros = [] # the last time we entered the target coro _time_entered_coro = None _ready_events = [] # the time this experiment started _start_time = float("inf") def __init__(self) -> None: super().__init__() self.set_logger() def set_logger(self): if not os.path.exists(TRACE_PATH): with open(TRACE_PATH, 'w') as file: pass self._logger = logging.getLogger(f'LOOP {self._thread_id}') self._logger.setLevel(self._log_level) self._logger.propagate = False file_handler = logging.FileHandler(TRACE_PATH) file_handler.setLevel(self._log_level) formatter = logging.Formatter('%(name)s - %(asctime)s - %(levelname)s --- %(message)s') file_handler.setFormatter(formatter) self._logger.addHandler(file_handler) self._logger.info("═" * 40) self._logger.info("STARTING LOOP") self._logger.info("═" * 40) def set_speedup(self, speedup): # print(self._coro_intervals) self._start_time = self.time() self._speedup = speedup # self._logger.info(f"STARTING EXPERIMENT WITH {self._speedup}") # self._logger.info("═" * 30) # reset experiment counters self._coro_intervals.clear() self._completed_coros.clear() def get_completed_coros(self): return copy(self._completed_coros) def get_run_time(self): curr_time = self.time() if not self._coro_intervals: return 0 start_interval = self._coro_intervals[0][0] end_interval = curr_time if self._time_entered_coro else self._coro_intervals[-1][1] interval = (start_interval, end_interval) pause_time = self._get_pause_time(interval) return (curr_time - self._start_time) - pause_time def ping_enter_coro(self): # self._logger.debug(f"Recording coro ENTER.") self._time_entered_coro = self.time() def ping_exit_coro(self): try: assert isinstance(self._time_entered_coro, float), "Tried to exit coro before recorded entry!" except AssertionError as e: self._logger.critical(f"Tried to exit coro before recorded entry: {e}. Aborting.") sys.exit(1) # self._logger.debug(f"Recording coro EXIT.") self._coro_intervals.add((self._time_entered_coro, self.time())) self._time_entered_coro = None def collect_ready_events(self, timeout=0): event_list = self._selector.select(timeout) if event_list: self._ready_events.append((event_list, self.time())) def update_ready(self): ''' Polls the IO selector, schedules resulting callbacks, and schedules 'call_later' callbacks. ''' curr_time = self.time() sched_count = len(self._scheduled) # two methods to cleanup cancelled callbacks; # avoid the expensive one whenever possible if (sched_count > _MIN_SCHEDULED_TIMER_HANDLES and self._timer_cancelled_count / sched_count > _MIN_CANCELLED_TIMER_HANDLES_FRACTION): # Remove delayed calls that were cancelled if their number # is too high new_scheduled = [] for handle in self._scheduled: if handle._cancelled: handle._scheduled = False self._logger.debug(f"\tSlow cleanup killed {_format_handle(handle)}") else: new_scheduled.append(handle) heapq.heapify(new_scheduled) self._scheduled = new_scheduled self._timer_cancelled_count = 0 else: # Remove delayed calls that were cancelled from head of queue. while self._scheduled and self._scheduled[0]._cancelled: self._timer_cancelled_count -= 1 handle = heapq.heappop(self._scheduled) self._logger.debug(f"\tLazy cleanup killed {_format_handle(handle)}") handle._scheduled = False timeout = None if self._ready or self._stopping: timeout = 0 else: curr_time = self.time() if self._scheduled: # Compute the desired timeout. # requires computing our-best guess arrival time handle = self._scheduled[0] timeout = (handle._when + self._get_pause_for_io(handle, curr_time)) \ - curr_time timeout = max(0, min(timeout, MAXIMUM_SELECT_TIMEOUT)) if self._pause_buffer: # pause buffer has an exact arrival time pause_timeout = self._pause_buffer[0]._when - curr_time timeout = min(pause_timeout, timeout) if timeout else pause_timeout timeout = max(0, min(timeout, MAXIMUM_SELECT_TIMEOUT)) timeout = 0 self._logger.debug(f"HANDLING I/O.") self._logger.info("-" * 20) self.collect_ready_events(timeout) self._logger.debug(f"\tPolled events for {timeout} (waiting={len(self._ready_events)})\n{pformat(self._ready_events, indent=2)}") while len(self._ready_events): event_list = self._ready_events.pop(0) self._process_events(*event_list) self._logger.debug(f"HANDLING SCHEDULED.") self._logger.info("-" * 20) # Handle 'later' callbacks that are ready. curr_time = self.time() end_time = curr_time + self._clock_resolution while self._scheduled: handle = self._scheduled[0] when = handle._when if when >= end_time: break handle = heapq.heappop(self._scheduled) handle._scheduled = False delay = self._get_pause_for_io(handle, curr_time) time_to_buffer = when + delay handle._when = time_to_buffer handle.time_entered_pause_buffer = curr_time # self._ready.append(handle) # do not allow duplicates (FIX?) if handle not in self._pause_buffer: heapq.heappush(self._pause_buffer, handle) self._logger.debug(f"\tscheduled -> pause_buffer for {_format_handle(handle)} (delay={delay})") self._logger.debug(f"HANDLING PAUSED.") self._logger.info("-" * 20) # handle callbacks which can leave pause timeout while self._pause_buffer: # required when handle's _when is modified in place heapq.heapify(self._pause_buffer) handle = self._pause_buffer[0] for other in self._pause_buffer[1:]: assert handle._when <= other._when, f"Heap root {handle} is not smallest" if handle._when >= end_time: self._logger.debug(f"\t{_format_handle(handle)} is not ready to leave. Moving on.") break # pop the first item in the list handle = heapq.heappop(self._pause_buffer) # if we paused during buffering, we need to delay again # TODO clean this up # this whole file has 'rounding' timing errors :( delay = self._get_pause_for_pause_time(handle, curr_time) handle._when = handle._when + delay if handle._when >= end_time: handle.time_entered_pause_buffer = curr_time heapq.heappush(self._pause_buffer, handle) self._logger.debug(f"\tpause_buffer -> pause_buffer for {_format_handle(handle)} (delay={delay})") else: self._ready.append(handle) self._logger.debug(f"\tpause_buffer -> ready for {_format_handle(handle)}") def _run_once(self): """ Run one full iteration of the event loop. This calls all currently ready callbacks. """ self._logger.debug("STARTING ITERATION.") self._logger.info("-" * 40) self.update_ready() # This is the only place where callbacks are actually *called*. # All other places just add them to ready. # Note: We run all currently scheduled callbacks, but not any # callbacks scheduled by callbacks run this time around -- # they will be run the next time (after another I/O poll). # Use an idiom that is thread-safe without using locks. self._logger.info("-" * 40) ntodo = len(self._ready) self._logger.debug(f"RUNNING {ntodo} CALLBACKS.") for i in range(ntodo): handle = self._ready.popleft() if handle._cancelled: continue try: self._current_handle = handle process_start_time = self.time() handle._run() process_end_time = self.time() dt = process_end_time - process_start_time if self._debug and dt >= self.slow_callback_duration: logger.warning('Executing %s took %.3f seconds', _format_handle(handle), dt) self._logger.warning('Executing %s took %.3f seconds', _format_handle(handle), dt) time_interval = (handle._when, process_start_time) pause_time = self._get_pause_time(time_interval) adjusted_start_time = process_start_time - pause_time wait_time = adjusted_start_time - handle._when try: assert wait_time >= -0.0001, f"wait time on {_format_handle(handle)} was found to be {wait_time:.4f}!" except AssertionError as e: self._logger.critical(f'Negative latency on callback {_format_handle(handle)} ({dt}). Aborting.') sys.exit(1) self._completed_coros.append((_format_handle(handle), wait_time)) self._logger.debug(f'\tCompleted {_format_handle(handle)} with latency {dt}') except Exception: traceback.print_exc() finally: self._current_handle = None handle = None # Needed to break cycles when an exception occurs. def _process_events(self, event_list, time): for key, mask in event_list: fileobj, (reader, writer) = key.fileobj, key.data if mask & selectors.EVENT_READ and reader is not None: if reader._cancelled: self._remove_reader(fileobj) self._logger.info(f"\treader {reader} was cancelled.") else: time_to_buffer = time + \ self._get_pause_for_io(reader, time) reader._when = time_to_buffer self._add_callback(reader) if mask & selectors.EVENT_WRITE and writer is not None: if writer._cancelled: self._remove_writer(fileobj) self._logger.info(f"\twriter {writer} was cancelled.") else: time_to_buffer = time + \ self._get_pause_for_io(writer, time) writer._when = time_to_buffer self._add_callback(writer) def _call_soon(self, callback, args, context): """Do not add 'callsoon' events to the pause buffer. Add them directly to ready.""" curr_time = self.time() handle = events.Handle(callback, args, self, context) if handle._source_traceback: del handle._source_traceback[-1] if not handle._when: handle._when = curr_time self._ready.append(handle) self._logger.debug(f"\tio -> ready for {_format_handle(handle)}") return handle def call_soon_threadsafe(self, callback, *args, context=None): """Like call_soon(), but thread-safe.""" self._check_closed() if self._debug: self._check_callback(callback, 'call_soon_threadsafe') handle = events._ThreadSafeHandle(callback, args, self, context) if not handle._when: handle._when = self.time() self._ready.append(handle) self._logger.debug(f"\tio -> ready for {_format_handle(handle)}") if handle._source_traceback: del handle._source_traceback[-1] if handle._source_traceback: del handle._source_traceback[-1] self._write_to_self() return handle def _add_callback(self, handle): """Add a Handle to _pause_buffer.""" curr_time = self.time() if not handle._when: handle._when = curr_time # required in cases where the event loop reuses the same handle if handle._when < curr_time: handle._when = curr_time if not handle._cancelled: if self._is_vital(handle): handle._when = curr_time self._ready.append(handle) self._logger.debug(f"\tio -> ready for VITAL {_format_handle(handle)}") else: handle.time_entered_pause_buffer = curr_time # do not allow duplicates (FIX?) if handle not in self._pause_buffer: heapq.heappush(self._pause_buffer, handle) self._logger.debug(f"\tio -> pause_buffer for {_format_handle(handle)} (delay={handle._when - curr_time})") else: self._logger.warning(f"\t_add_callback called on cancelled handle {_format_handle(handle)}") def _get_pause_for_io(self, handle, io_time): time_interval = (handle.register_time, io_time) p_time = self._get_pause_time(time_interval) try: assert p_time >= 0, f"calculated pause time on {_format_handle(handle)} was found to be {p_time:.4f}!" except AssertionError as e: print(f"Assertion failed: {e}") sys.exit(1) return p_time def _get_pause_for_pause_time(self, handle, exit_time): time_interval = (handle.time_entered_pause_buffer, exit_time) return self._get_pause_time(time_interval) def _get_pause_time(self, cb_interval): time = 0 start, end = cb_interval for coro_start, coro_end in self._coro_intervals: if start < coro_end and coro_start < end: time += self._get_overlap(start, end, coro_start, coro_end) # coro_intervals are sorted, so by this time all overlap has passed if end < coro_start: break curr_time = self.time() if self._time_entered_coro and \ start < curr_time and self._time_entered_coro < end: time += self._get_overlap(start, end, self._time_entered_coro, curr_time) pause_time = time * self._speedup assert pause_time >= 0, f"Delay was found to be less than 0: {pause_time}" return pause_time def _get_overlap(self, a_start, a_end, b_start, b_end): overlap_start = max(a_start, b_start) overlap_end = min(a_end, b_end) assert overlap_end >= overlap_start, f"Bad overlaps: {a_start} {a_end} : {b_start} {b_end} ({overlap_end - overlap_start})" return overlap_end - overlap_start def _is_vital(self, handle): """ Methods which cannot afford to be paused.""" blacklist = ['_read_from_self', '_read_ready', '_accept_connection'] cb = handle._callback if isinstance(getattr(cb, '__self__', None), asyncio.tasks.Task): if cb.__self__.get_coro().__name__ in blacklist: return True else: if getattr(cb, '__name__', None) in blacklist: return True return False # to profile your program, start the event loop with # asyncio.run(coro, loop_factory=causal_loop_factory) def causal_loop_factory(): return CausalEventLoop()