Source code for latenpy.latent

"""A module for lazy evaluation and computation caching with dependency tracking.

This module provides tools for creating and managing lazy computations through the
Latent class and latent decorator. It supports nested data structures, dependency
tracking, and optional result caching.

Classes
------
Latent
   A class that implements lazy evaluation with dependency tracking and caching.

Functions
--------
compute_nested
   Recursively computes Latent objects found within nested data structures.
latent
   Decorator that creates lazy computations with optional caching.
"""

from warnings import warn
from typing import Any, Tuple, Dict, Callable, Set
from collections.abc import Mapping, Sequence
from functools import wraps
import numpy as np
from networkx import DiGraph

from .types import LatentData, T
from .graph import get_updated_nodes, correct_computed_status, validate_no_cycles


def compute_nested(obj: Any, force_recompute: bool = False, dont_cache: bool = False, depth: int = 0, maximum_depth: int = None) -> Any:
    """Recursively compute any Latent objects found within nested data structures.

    Parameters
    ----------
    obj : Any
        The object to process. Can be a Latent object, any nested structure
        containing Latent objects, or any other object.
    force_recompute : bool, optional
        If True, forces recomputation of cached results, by default False.
    dont_cache : bool, optional
        If True, prevents caching of computed results, by default False.
    depth : int, optional
        Current recursion depth, by default 0.
    maximum_depth : int, optional
        Maximum recursion depth before stopping, by default None.

    Returns
    -------
    Any
        The computed result with all Latent objects evaluated.

    Notes
    -----
    Supports the following nested structures:
    - NumPy arrays (i.e. if you have anumpy array of Latent objects, this will compute them all)
    - Mappings (dict-like objects)
    - Sequences (list-like objects)
    - Sets
    - Generators and iterators
    """
    if maximum_depth is not None and depth > maximum_depth:
        warn(f"Maximum depth ({maximum_depth}) reached in compute_nested at object: {type(obj).__name__}. Returning object as-is.")
        return obj

    depth += 1
    kwargs = dict(
        force_recompute=force_recompute,
        dont_cache=dont_cache,
        depth=depth,
        maximum_depth=maximum_depth,
    )

    # Handle Latent objects
    if isinstance(obj, Latent):
        maximum_depth = maximum_depth - 1 if maximum_depth is not None else None
        return obj.compute(force_recompute=force_recompute, dont_cache=dont_cache, maximum_depth=maximum_depth)

    # Handle numpy arrays
    if isinstance(obj, np.ndarray):
        return np.array([compute_nested(x, **kwargs) for x in obj.flat]).reshape(obj.shape)

    # Handle mappings (dict-like objects)
    if isinstance(obj, Mapping):
        return type(obj)({compute_nested(key, **kwargs): compute_nested(value, **kwargs) for key, value in obj.items()})

    # Handle sequences (list-like objects, excluding strings)
    if isinstance(obj, Sequence) and not isinstance(obj, (str, bytes)):
        return type(obj)(compute_nested(item, **kwargs) for item in obj)

    # Handle sets
    if isinstance(obj, Set):
        return type(obj)(compute_nested(item, **kwargs) for item in obj)

    # Handle generators and iterators
    if hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes, Mapping, Sequence, Set)):
        return (compute_nested(item, **kwargs) for item in obj)

    # Base case: return the object as-is
    return obj


[docs] class Latent: """A class for lazy evaluation with dependency tracking and caching. This class enables delayed computation of functions and their arguments, with support for dependency tracking, result caching, and nested computations. Parameters ---------- func : Callable[..., T] The function to be computed lazily. *args : Tuple[Any, ...] Positional arguments for the function. disable_cache : bool, optional If True, disables result caching, by default False. **kwargs : Dict[str, Any] Keyword arguments for the function. Attributes ---------- latent_data : LatentData Storage for cached computation results. computed : bool Whether the computation has been performed and cached. Methods ------- compute(force_recompute=False, recompute_dependencies=False, dont_cache=False) Execute the computation and return the result. update_func(func) Update the function to be computed. update_args(*args) Update the positional arguments. update_kwargs(**kwargs) Update the keyword arguments. clear_cache(dependencies=False, dependents=False) Clear all cached results in the computation graph. """
[docs] def __init__(self, func: Callable[..., T], *args: Tuple[Any, ...], disable_cache: bool = False, **kwargs: Dict[str, Any]): self._func = func self._args = tuple(args) self._kwargs = kwargs self.disable_cache = disable_cache self.latent_data = LatentData() self._needs_recomputation = False # True if the underlying function or arguments have been updated self._computing = False # For tracking computation status self._dependents: Set[Latent] = set() # Set of nodes that depend on this node for arg in args: if isinstance(arg, Latent): arg._dependents.add(self) for value in kwargs.values(): if isinstance(value, Latent): value._dependents.add(self)
@property def func(self) -> Callable[..., T]: """Get the function to be computed. Returns ------- Callable[..., T] The function that will be executed when compute() is called. """ return self._func @property def args(self) -> Tuple[Any, ...]: """Get the positional arguments for the function. Returns ------- Tuple[Any, ...] The tuple of positional arguments. """ return self._args @property def kwargs(self) -> Mapping[str, Any]: """Get the keyword arguments for the function. Returns ------- Mapping[str, Any] The mapping of keyword arguments. """ return self._kwargs @property def computed(self) -> bool: """Check if the computation has been performed and cached. Returns ------- bool True if the result has been computed and cached, False otherwise. """ return bool(self.latent_data)
[docs] def update_func(self, func: Callable[..., T]) -> None: """Update the function to be computed. Parameters ---------- func : Callable[..., T] The new function to use for computation. Notes ----- This will clear the cached result and mark all dependent computations for recomputation. """ self._func = func self._needs_recomputation = True self._update_dependents() self.latent_data.clear()
[docs] def update_args(self, *args: Tuple[Any, ...]) -> None: """Update the positional arguments. Parameters ---------- *args : Tuple[Any, ...] The new positional arguments. Notes ----- This will clear the cached result and mark all dependent computations for recomputation. """ self._args = args self._needs_recomputation = True self._update_dependents() self.latent_data.clear()
[docs] def update_kwargs(self, full_reset: bool = False, **kwargs: Dict[str, Any]) -> None: """Update the keyword arguments. Parameters ---------- full_reset : bool, optional If True, replace all existing kwargs. If False, update only the provided kwargs, by default False. **kwargs : Dict[str, Any] The new keyword arguments. Notes ----- This will clear the cached result and mark all dependent computations for recomputation. """ if full_reset: self._kwargs = kwargs else: self._kwargs.update(kwargs) self._needs_recomputation = True self._update_dependents() self.latent_data.clear()
def _update_dependents(self, visited: Set = set()) -> None: """Recursively update all dependent nodes.""" for dependent in self._dependents: if dependent not in visited: visited.add(dependent) dependent._needs_recomputation = True dependent.latent_data.clear() dependent._update_dependents(visited)
[docs] def compute(self, force_recompute: bool = False, recompute_dependencies: bool = False, dont_cache: bool = False, maximum_depth: int = None) -> T: """Compute the result and cache if enabled. Will iteratively compute all dependencies in arguments and key-word arguments if they are also Latent objects. Parameters ---------- force_recompute : bool, optional If True, recompute even if cached, by default False. recompute_dependencies : bool, optional If True, recompute all dependencies, by default False. dont_cache : bool, optional If True, skip caching the result, by default False. maximum_depth : int, optional Maximum depth for nested computations, by default None. Returns ------- T The computed result. Raises ------ RecursionError If a circular dependency is detected. Exception If the computation fails, with details about the failure. Notes ----- This method will: 1. Check for circular dependencies 2. Validate the dependency graph 3. Compute any required dependencies 4. Execute the computation 5. Cache the result (unless disabled) """ if self._computing: raise RecursionError(f"Circular dependency detected in delayed computation of {self.func.__name__}") G = self.get_dependency_graph() validate_no_cycles(G) # Update dependents of any changed nodes correct_computed_status(G) # Get nodes that have been updated or have updated dependencies and need to be recomputed updated_nodes = get_updated_nodes(G) # Check if we can use a cached result # If force_recompute, never use cached result # If recompute_dependencies, never use cached result (also recompute all dependencies) # If updated_nodes, then a previously computed dependency needs to be recomputed # If there is not cache, we can't use the cache! can_use_cache = not force_recompute and not recompute_dependencies and not updated_nodes and self.latent_data if can_use_cache: return self.latent_data() try: # This Latent object is now trying to compute it's result self._computing = True kwargs = dict( force_recompute=recompute_dependencies, dont_cache=dont_cache, maximum_depth=maximum_depth, ) computed_args = [compute_nested(arg, **kwargs) for arg in self.args] computed_kwargs = {key: compute_nested(value, **kwargs) for key, value in self.kwargs.items()} result = self.func(*computed_args, **computed_kwargs) if not self.disable_cache and not dont_cache: self.latent_data.set(result) return result except Exception as e: raise type(e)(f"Error in delayed computation of {self.func.__name__}: {str(e)}") from e finally: # Reset the computation flag self._computing = False
def __bool__(self) -> bool: """Check if the computation has been performed. Returns ------- bool True if the result has been computed, False otherwise. """ return bool(self.latent_data) def __repr__(self) -> str: """Get the string representation of the delayed computation. Returns ------- str A string showing the function name and computation status. """ return f"Latent({self.func.__name__}):{'Computed' if self.latent_data else 'Not computed'}" def __len__(self): """Enable len() for delayed computations. Returns ------- Latent A new Latent object that will return the length. """ def get_len(obj): return len(obj) return Latent(get_len, self) def __iter__(self): """Enable iteration for delayed computations. Returns ------- Latent A new Latent object that will return an iterator. """ def get_iter(obj): return iter(obj) return Latent(get_iter, self)
[docs] def clear_cache(self, dependencies: bool = False, dependents: bool = False) -> None: """Clear all cached results in this computation graph. Parameters ---------- dependencies : bool, optional If True, clear caches of dependencies, by default False dependents : bool, optional If True, clear caches of dependents, by default False """ self.latent_data.clear() if dependencies: for arg in self.args: if isinstance(arg, Latent): arg.clear_cache(dependencies=True) for value in self.kwargs.values(): if isinstance(value, Latent): value.clear_cache(dependencies=True) if dependents: for dependent in self._dependents: dependent.clear_cache(dependents=True)
def _get_node_id(self) -> str: """Generate a compact node ID.""" name = self.func.__name__ instance_id = str(id(self))[-4:] # Use fewer digits # Create simplified arg representations arg_strs = [] for arg in self.args: if isinstance(arg, Latent): # Just use the function name and id of delayed args arg_strs.append(arg.func.__name__ + f"#{str(id(arg))[-4:]}") else: # For non-delayed args, use a short hash try: arg_str = str(hash(arg))[-4:] except TypeError: arg_str = type(arg).__name__[:4] arg_strs.append(arg_str) # Handle kwargs similarly but more concisely kwarg_strs = [] for k, v in sorted(self.kwargs.items()): if isinstance(v, Latent): kwarg_strs.append(f"{k[:4]}={v.func.__name__}") else: try: v_str = str(hash(v))[-4:] except TypeError: v_str = type(v).__name__[:4] kwarg_strs.append(f"{k[:4]}={v_str}") # Combine everything into a compact string content = f"{name}({','.join(arg_strs + kwarg_strs)})#{instance_id}" return content
[docs] def get_dependency_graph(self) -> DiGraph: """Build and return a directed graph of computation dependencies. Returns ------- DiGraph A directed graph where: - Nodes are computations - Edges represent dependencies - Node attributes include: - 'label': Description of the computation - 'computed': Boolean for cache status - 'func_name': Name of the function """ G = DiGraph() self._build_graph(G, set()) return G
def _build_graph(self, G: DiGraph, visited: Set[str]) -> None: """Recursively build the dependency graph. Parameters ---------- G : DiGraph The graph to build visited : Set[str] Set of node IDs already processed to avoid redundant traversal """ node_id = self._get_node_id() if node_id in visited: return visited.add(node_id) # Add this node to the graph G.add_node( node_id, label=self.func.__name__, computed=bool(self.latent_data), needs_recomputation=self._needs_recomputation, delayed_obj=self, ) # Process arguments for arg in self.args: if isinstance(arg, Latent): arg._build_graph(G, visited) G.add_edge(arg._get_node_id(), node_id) # Process keyword arguments for v in self.kwargs.values(): if isinstance(v, Latent): v._build_graph(G, visited) G.add_edge(v._get_node_id(), node_id)
[docs] def latent(func=None, *, disable_cache=False): """Decorator to create a latent computation that executes only when requested. Parameters ---------- func : callable or None The function to be delayed. Will be None if decorator is called with parameters. disable_cache : bool, optional If True, disables caching of computation results. Each call to compute() will re-execute the function, by default False. Returns ------- callable or Latent If used as @delayed: Returns a Latent object holding the function and arguments for later execution. If used as @delayed(no_cache=...): Returns a decorator function that will create a Latent object. See Also -------- Latent : The class that handles lazy evaluation of functions Examples -------- Basic usage with default caching: >>> @delayed ... def expensive_computation(x): ... return x * 2 ... >>> result = expensive_computation(10) # No computation yet >>> result.compute() # Now computes 20 Disable caching for always-fresh results: >>> @delayed(no_cache=True) ... def always_recompute(x): ... return x * 2 ... >>> result = always_recompute(10) >>> result.compute() # Computes without caching 20 Notes ----- The decorated function's computation is deferred until the .compute() method is called on the returned Latent object. By default, results are cached based on input arguments unless no_cache=True. """ # Called as @latent(disable_cache=...) if func is None: return lambda f: latent(f, disable_cache=disable_cache) # Called as @latent or latent(func, ...) @wraps(func) def wrapper(*args, **kwargs): return Latent(func, *args, **kwargs, disable_cache=disable_cache) return wrapper