''' Copyright: Copyright © 2025 bdunahu Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Commentary: Code: ''' from typing import TYPE_CHECKING, Any, Callable, Self if TYPE_CHECKING: import contextvars import line_profiler import asyncio import collections from sortedcontainers import SortedList import heapq import selectors import sys import os import time import traceback from copy import copy from asyncio.log import logger from asyncio import Task, events from asyncio.base_events import _format_handle from pathlib import Path from pprint import pformat import logging _MIN_SCHEDULED_TIMER_HANDLES = 100 _MIN_CANCELLED_TIMER_HANDLES_FRACTION = 0.5 MAXIMUM_SELECT_TIMEOUT = 24 * 3600 TRACE_PATH = \ Path(__file__).resolve().parent.parent / "logs" / "trace.log" orig_handle = asyncio.events.Handle class TimeAwareMixin: # the timestamp this callback was registered register_time = None # the timestamp this callback ideally runs _when = None def __init__(self: Self) -> None: self.register_time = time.monotonic() def __hash__(self): return hash(self._when) def __lt__(self, other): if isinstance(other, orig_handle): return self._when < other._when return NotImplemented def __le__(self, other): if isinstance(other, orig_handle): return self._when < other._when or self.__eq__(other) return NotImplemented def __gt__(self, other): if isinstance(other, orig_handle): return self._when > other._when return NotImplemented def __ge__(self, other): if isinstance(other, orig_handle): return self._when > other._when or self.__eq__(other) return NotImplemented def create_subclass(base_class: "type[orig_handle]"): class CausalHandle(base_class, TimeAwareMixin): def __init__(self: Self, *args: tuple|Any|Callable|None, **kwargs: None) -> None: super().__init__(*args, **kwargs) TimeAwareMixin.__init__(self) return CausalHandle # make all the subclasses inherit from TimeAwareHandle as well for sc in [orig_handle] + orig_handle.__subclasses__(): subclass = create_subclass(sc) setattr(asyncio.events, sc.__name__, subclass) class CausalEventLoop(asyncio.SelectorEventLoop): _log_level = logging.DEBUG _logger = None # a value between 0 and 1. 0 means no optimization, # 1 means the target coroutine is optimized away entirely _speedup = 0.0 # a list of intervals in which the target coroutine has been active _coro_intervals = SortedList() # a list of completed callbacks, and their associated queue time _completed_coros = [] # the last time we entered the target coro _time_entered_coro = None _ready_events = [] # the time this experiment started _start_time = float("inf") def __init__(self: Any) -> None: super().__init__() self.set_logger() def set_logger(self: Any) -> None: if not os.path.exists(TRACE_PATH): with open(TRACE_PATH, 'w') as file: pass # self._logger = logging.getLogger(f'LOOP {self._thread_id}') # self._logger.setLevel(self._log_level) # self._logger.propagate = False file_handler = logging.FileHandler(TRACE_PATH) file_handler.setLevel(self._log_level) formatter = logging.Formatter('%(name)s - %(asctime)s - %(levelname)s --- %(message)s') file_handler.setFormatter(formatter) # self._logger.addHandler(file_handler) # self._logger.disabled = True # self._logger.info("═" * 40) # self._logger.info("STARTING LOOP") # self._logger.info("═" * 40) def set_speedup(self: Any, speedup: float) -> None: # print(self._coro_intervals) self._start_time = self.time() self._speedup = speedup # # self._logger.info(f"STARTING EXPERIMENT WITH {self._speedup}") # # self._logger.info("═" * 30) # reset experiment counters self._coro_intervals.clear() self._completed_coros.clear() def get_completed_coros(self: Any) -> list[tuple[str, float]]: return copy(self._completed_coros) def get_run_time(self: Any) -> float: curr_time = self.time() if not self._coro_intervals: return 0 start_interval = self._coro_intervals[0][0] end_interval = curr_time if self._time_entered_coro else self._coro_intervals[-1][1] interval = (start_interval, end_interval) pause_time = self._get_pause_time(interval) return (curr_time - self._start_time) - pause_time def ping_enter_coro(self: Any) -> None: # # self._logger.debug(f"Recording coro ENTER.") self._time_entered_coro = self.time() def ping_exit_coro(self: Any) -> None: try: assert isinstance(self._time_entered_coro, float), "Tried to exit coro before recorded entry!" except AssertionError as e: # self._logger.critical(f"Tried to exit coro before recorded entry: {e}. Aborting.") sys.exit(1) # # self._logger.debug(f"Recording coro EXIT.") self._coro_intervals.add((self._time_entered_coro, self.time())) self._time_entered_coro = None def collect_ready_events(self: Any, timeout: int=0) -> None: event_list = self._selector.select(timeout) if event_list: self._ready_events.append((event_list, self.time())) @line_profiler.profile def update_ready(self: Any) -> None: ''' Polls the IO selector, schedules resulting callbacks, and schedules 'call_later' callbacks. ''' curr_time = self.time() sched_count = len(self._scheduled) # two methods to cleanup cancelled callbacks; # avoid the expensive one whenever possible if (sched_count > _MIN_SCHEDULED_TIMER_HANDLES and self._timer_cancelled_count / sched_count > _MIN_CANCELLED_TIMER_HANDLES_FRACTION): # Remove delayed calls that were cancelled if their number # is too high new_scheduled = [] for eff_guess, handle in self._scheduled: if handle._cancelled: handle._scheduled = False # self._logger.debug(f"\tSlow cleanup killed {_format_handle(handle)}") else: new_scheduled.append((eff_guess, handle)) heapq.heapify(new_scheduled) self._scheduled = new_scheduled self._timer_cancelled_count = 0 else: # Remove delayed calls that were cancelled from head of queue. while self._scheduled and self._scheduled[0][1]._cancelled: self._timer_cancelled_count -= 1 _, handle = heapq.heappop(self._scheduled) # self._logger.debug(f"\tLazy cleanup killed {_format_handle(handle)}") handle._scheduled = False timeout = None if self._ready or self._stopping: timeout = 0 elif self._scheduled: timeout = _get_next_effective_time() # self._logger.debug(f"HANDLING I/O.") # self._logger.info("-" * 20) self.collect_ready_events(timeout) # self._logger.debug(f"\tPolled events for {timeout} (waiting={len(self._ready_events)})\n{pformat(self._ready_events, indent=2)}") while len(self._ready_events): event_list = self._ready_events.pop(0) self._process_events(*event_list) # self._logger.debug(f"HANDLING SCHEDULED.") # self._logger.info("-" * 20) # Handle 'scheduled' callbacks that are ready. # note 'scheduled' callbacks include both I/O bound or call_later self._add_ready_handles() @line_profiler.profile def _run_once(self: Any) -> None: """ Run one full iteration of the event loop. This calls all currently ready callbacks. """ # self._logger.debug("STARTING ITERATION.") # self._logger.info("-" * 40) self.update_ready() # This is the only place where callbacks are actually *called*. # All other places just add them to ready. # Note: We run all currently scheduled callbacks, but not any # callbacks scheduled by callbacks run this time around -- # they will be run the next time (after another I/O poll). # Use an idiom that is thread-safe without using locks. # self._logger.info("-" * 40) ntodo = len(self._ready) # self._logger.debug(f"RUNNING {ntodo} CALLBACKS.") for i in range(ntodo): handle = self._ready.popleft() if handle._cancelled: continue try: self._current_handle = handle process_start_time = self.time() handle._run() process_end_time = self.time() dt = process_end_time - process_start_time if self._debug and dt >= self.slow_callback_duration: logger.warning('Executing %s took %.3f seconds', _format_handle(handle), dt) # self._logger.warning('Executing %s took %.3f seconds', # _format_handle(handle), dt) time_interval = (handle._when, process_start_time) pause_time = self._get_pause_time(time_interval) adjusted_start_time = process_start_time - pause_time wait_time = adjusted_start_time - handle._when try: assert wait_time >= -0.0001, f"wait time on {_format_handle(handle)} was found to be {wait_time:.4f}!" except AssertionError as e: # self._logger.critical(f'Negative latency on callback {_format_handle(handle)} ({dt}). Aborting.') sys.exit(1) self._completed_coros.append((_format_handle(handle), wait_time)) # self._logger.debug(f'\tCompleted {_format_handle(handle)} with latency {dt}') except Exception: traceback.print_exc() finally: self._current_handle = None handle = None # Needed to break cycles when an exception occurs. def _get_next_effective_time(self): """ Returns the delta time in which the next currently processing callback will be ready. The time this takes is O(log(n)), where n is the number of currently processing callbacks. We can get away with lazily updating the estimated effective run time of each item because `self._get_pause_for_io` never underestimates the true effective firing time. By recomputing the heap head's true effective run time, if our guess is correct, no other timer can have a smaller effective time. """ next_effective_time = None curr_time = self.time() while True: eff_guess, handle = self._scheduled[0] eff_true = handle._when + self._get_pause_for_io(handle, curr_time) if eff_true == eff_guess: next_effective_time = eff_true - curr_time break heapq.heappop(self._scheduled) heapq.heappush(self._scheduled, (eff_true, handle)) return min(next_effective_time, MAXIMUM_SELECT_TIMEOUT) def _add_ready_handles(self): curr_time = self.time() + self._clock_resolution while self._scheduled: eff_guess, handle = heapq.heappop(self._scheduled) eff_true = handle._when + self._get_pause_for_io(handle, curr_time) if curr_time >= eff_true: # set the true 'when' handle._when = eff_true self._ready.append(handle) # self._logger.debug(f"\tscheduled -> _ready for {_format_handle(handle)}") else: heapq.heappush(self._scheduled, (eff_true, handle)) break def _process_events(self, event_list, time): for key, mask in event_list: fileobj, (reader, writer) = key.fileobj, key.data if mask & selectors.EVENT_READ and reader is not None: if reader._cancelled: self._remove_reader(fileobj) # self._logger.info(f"\treader {reader} was cancelled.") else: time_to_buffer = time + \ self._get_pause_for_io(reader, time) reader._when = time_to_buffer self._add_callback(reader) if mask & selectors.EVENT_WRITE and writer is not None: if writer._cancelled: self._remove_writer(fileobj) # self._logger.info(f"\twriter {writer} was cancelled.") else: time_to_buffer = time + \ self._get_pause_for_io(writer, time) writer._when = time_to_buffer self._add_callback(writer) def call_at(self, when, callback, *args, context=None): """Like call_later(), but uses an absolute time. Absolute time corresponds to the event loop's time() method. """ if when is None: raise TypeError("when cannot be None") self._check_closed() if self._debug: self._check_thread() self._check_callback(callback, 'call_at') timer = events.TimerHandle(when, callback, args, self, context) if timer._source_traceback: del timer._source_traceback[-1] heapq.heappush(self._scheduled, timer) heapq.heappush(self._estimated, (timer._when, timer)) timer._scheduled = True return timer def _call_soon(self: Any, callback: Any, args: tuple, context: "contextvars.Context"): """Do not add 'callsoon' events to the pause buffer. Add them directly to ready.""" curr_time = self.time() handle = events.Handle(callback, args, self, context) if handle._source_traceback: del handle._source_traceback[-1] if not handle._when: handle._when = curr_time self._ready.append(handle) # self._logger.debug(f"\tio -> ready for {_format_handle(handle)}") return handle def call_soon_threadsafe(self, callback, *args, context=None): """Like call_soon(), but thread-safe.""" self._check_closed() if self._debug: self._check_callback(callback, 'call_soon_threadsafe') handle = events._ThreadSafeHandle(callback, args, self, context) if not handle._when: handle._when = self.time() self._ready.append(handle) # self._logger.debug(f"\tio -> ready for {_format_handle(handle)}") if handle._source_traceback: del handle._source_traceback[-1] if handle._source_traceback: del handle._source_traceback[-1] self._write_to_self() return handle def _add_callback(self, handle): curr_time = self.time() if not handle._when: handle._when = curr_time # required in cases where the event loop reuses the same handle if handle._when < curr_time: handle._when = curr_time if not handle._cancelled: if self._is_vital(handle): handle._when = curr_time self._ready.append(handle) # self._logger.debug(f"\tio -> ready for VITAL {_format_handle(handle)}") else: heapq.heappush(self._scheduled, (handle._when, handle)) # self._logger.debug(f"\tio -> _scheduled for {_format_handle(handle)} (delay={handle._when - curr_time})") else: # self._logger.warning(f"\t_add_callback called on cancelled handle {_format_handle(handle)}") pass def _get_pause_for_io(self: Any, handle, io_time: float) -> float: time_interval = (handle.register_time, io_time) p_time = self._get_pause_time(time_interval) try: assert p_time >= 0, f"calculated pause time on {_format_handle(handle)} was found to be {p_time:.4f}!" except AssertionError as e: print(f"Assertion failed: {e}") sys.exit(1) return p_time def _get_pause_time(self: Any, cb_interval: tuple[float, float]) -> float: time = 0 start, end = cb_interval for coro_start, coro_end in self._coro_intervals: if start < coro_end and coro_start < end: time += self._get_overlap(start, end, coro_start, coro_end) # coro_intervals are sorted, so by this time all overlap has passed if end < coro_start: break curr_time = self.time() if self._time_entered_coro and \ start < curr_time and self._time_entered_coro < end: time += self._get_overlap(start, end, self._time_entered_coro, curr_time) pause_time = time * self._speedup assert pause_time >= 0, f"Delay was found to be less than 0: {pause_time}" return pause_time def _get_overlap(self: Any, a_start: float, a_end: float, b_start: float, b_end: float) -> float: overlap_start = max(a_start, b_start) overlap_end = min(a_end, b_end) assert overlap_end >= overlap_start, f"Bad overlaps: {a_start} {a_end} : {b_start} {b_end} ({overlap_end - overlap_start})" return overlap_end - overlap_start def _is_vital(self, handle): """ Methods which cannot afford to be paused.""" blacklist = ['_read_from_self', '_read_ready', '_accept_connection'] cb = handle._callback if isinstance(getattr(cb, '__self__', None), asyncio.tasks.Task): if cb.__self__.get_coro().__name__ in blacklist: return True else: if getattr(cb, '__name__', None) in blacklist: return True return False # to profile your program, start the event loop with # asyncio.run(coro, loop_factory=causal_loop_factory) def causal_loop_factory() -> Any: return CausalEventLoop()