Source code for chronos.timer_utils

# Copyright © 2026 UChicago Argonne, LLC All right reserved
# Full license accessible at https://github.com/AdvancedPhotonSource/chronos/blob/main/LICENSE.TXT

import time
from functools import wraps
from typing import Callable, TypeVar, Optional, Union
import numpy as np
try:
    import cupy as cp
    CUPY_AVAILABLE = True
except ImportError:
    CUPY_AVAILABLE = False
from collections import defaultdict

# Type variables to retain function signatures
T = TypeVar("T", bound=Callable)

ENABLE_TIMING = False
"Global flag to enable or disable timing."
ELAPSED_TIME_DICT: dict[str, np.ndarray] = defaultdict(lambda: np.array([]))
"""
A dictionary containing numpy arrays of the measured execution times of
each timed function.
"""
ADVANCED_TIME_DICT: dict[str, Union[np.ndarray, dict]] = defaultdict(lambda: {})
"""
A nested dictionary, where each level of the dictionary contains
1) a key, value pair ("time": np.ndarray) that contains all measured
execution times for that function and 2) zero or more key-value
pairs (function_name: dict) where function_name refers to the name
of each functions called in the function currently being times.

Note that only functions with the `timer` decorator will show up
in `ADVANCED_TIME_DICT`.
"""
CURRENT_DICT_REFERENCE = ADVANCED_TIME_DICT  # Initialized to the top level of ADVANCED_TIME_DICT
"""
A reference to the level of `ADVANCED_TIME_DICT` that corresponds to
the function currently being executed.
"""
TIMING_OVERHEAD_ARRAY = np.array([])
"""
A numpy array containing measurements of how long it takes to execute
functions in the `timer` decorator.
"""

list_of_all_gpus = tuple(range(cp.cuda.runtime.getDeviceCount())) if CUPY_AVAILABLE else ()


def toggle_timer(enable: bool):
    """
    Toggle the global ENABLE_TIMING flag.

    Parameters
    ----------
    enable : bool
        If True, enable timing. If False, disable timing.
    """
    global ENABLE_TIMING
    ENABLE_TIMING = enable


def timer(enabled: bool = True, override_with_name: Optional[str] = None):
    """
    Decorator to time a function's execution time and the execution time of the timed code
    within that function. This function is enabled or disabled depending on the state of the
    global ENABLE_TIMING flag.

    The results of the timer function will be recorded in `ELAPSED_TIME_DICT` and
    `ADVANCED_TIME_DICT`.

    Parameters
    ----------
    enabled : bool, optional
        Whether timing is enabled for the decorated function. Default is True.
    override_with_name : str, optional
        Custom name to use for the function in the timing dictionary. If not
        specified, the function name is automatically generated.

    Returns
    -------
    Callable
        The wrapped function.
    """

    def decorator(func: T) -> T:
        @wraps(func)
        def wrapper(*args, **kwargs):
            if enabled and globals().get("ENABLE_TIMING", False):
                # Measure the overhead from running the timer function
                measure_overhead_start_1 = time.time()
                if override_with_name is None:
                    function_name = func.__qualname__
                else:
                    function_name = override_with_name
                saved_dict_reference = update_current_dict_reference(function_name)
                overhead_time_1 = time.time() - measure_overhead_start_1

                # Measure function execution time
                wait_for_process_completion_on_all_gpus()
                start_time = time.time()
                result = func(*args, **kwargs)
                wait_for_process_completion_on_all_gpus()
                elapsed_time = time.time() - start_time

                # Measure the overhead from running the timer function
                measure_overhead_start_2 = time.time()
                update_elapsed_time_dict(function_name, elapsed_time)
                update_advanced_time_dict(elapsed_time)
                # Traverse back up the advanced timing dicts
                revert_current_dict_reference(saved_dict_reference)
                global TIMING_OVERHEAD_ARRAY
                overhead_time_2 = time.time() - measure_overhead_start_2
                TIMING_OVERHEAD_ARRAY = np.append(
                    TIMING_OVERHEAD_ARRAY,
                    overhead_time_1 + overhead_time_2,
                )
            else:
                # If timing is disabled, just call the function
                result = func(*args, **kwargs)
            return result

        # Ensure the wrapper function has the same type as the original
        return wrapper  # type: ignore

    return decorator


class InlineTimer:
    """
    A timer class for inline timing of code blocks.

    Parameters
    ----------
    name : str
        The name associated with the timer that will be recorded
        in the timing dictionaries.
    enabled : bool, optional
        Whether the timer is enabled, by default True.
    """

    def __init__(self, name: str, enabled: bool = True):
        self.name = name
        self.enabled = enabled
        self.overhead_time = 0

    def start(self):
        """
        Starts the timer if timing is enabled.
        """
        if self.enabled and globals().get("ENABLE_TIMING", False):
            measure_overhead_start = time.time()
            saved_dict_reference = update_current_dict_reference(self.name)
            self.saved_dict_reference = saved_dict_reference
            self.overhead_time = time.time() - measure_overhead_start
            wait_for_process_completion_on_all_gpus()
            self.start_time = time.time()

    def end(self):
        """
        Stops the timer and records the elapsed time if timing is enabled.
        """
        if self.enabled and globals().get("ENABLE_TIMING", False):
            wait_for_process_completion_on_all_gpus()
            elapsed_time = time.time() - self.start_time

            measure_overhead_start = time.time()
            update_elapsed_time_dict(self.name, elapsed_time)
            update_advanced_time_dict(elapsed_time)
            revert_current_dict_reference(self.saved_dict_reference)
            global TIMING_OVERHEAD_ARRAY
            self.overhead_time += time.time() - measure_overhead_start
            TIMING_OVERHEAD_ARRAY = np.append(TIMING_OVERHEAD_ARRAY, self.overhead_time)


def update_elapsed_time_dict(function_name: str, elapsed_time: float):
    """
    Updates the global elapsed time dictionary with the elapsed time for a function.

    Parameters
    ----------
    function_name : str
        The name of the function being timed.
    elapsed_time : float
        The elapsed time for the function execution.
    """
    ELAPSED_TIME_DICT[function_name] = np.append(ELAPSED_TIME_DICT[function_name], elapsed_time)


def update_current_dict_reference(function_name: str) -> dict:
    """
    Updates the current reference in the advanced timing dictionary to a nested level.

    Parameters
    ----------
    function_name : str
        The name of the function being timed.

    Returns
    -------
    dict
        The previous dictionary reference.
    """
    global CURRENT_DICT_REFERENCE
    # Save the parent to traverse back to later
    saved_dict_reference = CURRENT_DICT_REFERENCE
    # Create new dict if necessary
    if function_name not in CURRENT_DICT_REFERENCE.keys():
        CURRENT_DICT_REFERENCE[function_name] = defaultdict(lambda: {})
        CURRENT_DICT_REFERENCE[function_name]["time"] = np.array([])
    # Update the pointer to the current dict
    CURRENT_DICT_REFERENCE = CURRENT_DICT_REFERENCE[function_name]
    return saved_dict_reference


def update_advanced_time_dict(elapsed_time: float):
    """
    Updates the advanced timing dictionary with the elapsed time.

    Parameters
    ----------
    elapsed_time : float
        The elapsed time for the function execution.
    """
    global CURRENT_DICT_REFERENCE
    CURRENT_DICT_REFERENCE["time"] = np.append(CURRENT_DICT_REFERENCE["time"], elapsed_time)


def revert_current_dict_reference(saved_dict_reference: dict):
    """
    Reverts the current dictionary reference in the advanced timing dictionary.

    Parameters
    ----------
    saved_dict_reference : dict
        The saved dictionary reference to revert to.
    """
    global CURRENT_DICT_REFERENCE
    CURRENT_DICT_REFERENCE = saved_dict_reference


# def clear_timer_globals():
def clear_timer_globals():
    """
    Clears the global timing dictionaries and resets the state.
    """
    global ELAPSED_TIME_DICT
    global ADVANCED_TIME_DICT
    global CURRENT_DICT_REFERENCE
    global TIMING_OVERHEAD_ARRAY
    ELAPSED_TIME_DICT = defaultdict(lambda: np.array([]))
    ADVANCED_TIME_DICT = defaultdict(lambda: {})
    CURRENT_DICT_REFERENCE = ADVANCED_TIME_DICT
    TIMING_OVERHEAD_ARRAY = np.array([])


def wait_for_process_completion_on_all_gpus():
    if not CUPY_AVAILABLE:
        return
    for i in list_of_all_gpus:
        with cp.cuda.Device(i):
            cp.cuda.Stream.null.synchronize()