Source code for latenpy.graph

from typing import Any, Dict, Set, Tuple
import numpy as np
import networkx as nx
from networkx import DiGraph
from matplotlib import pyplot as plt


[docs] def get_computed_nodes(G: DiGraph) -> Set[str]: """Return set of node IDs that have been computed.""" return {node for node in G.nodes if G.nodes[node].get("computed", False)}
[docs] def get_uncached_nodes(G: DiGraph) -> Set[str]: return {node for node in G.nodes if not G.nodes[node].get("computed", False)}
[docs] def get_updated_nodes(G: DiGraph) -> Set[str]: return {node for node in G.nodes if G.nodes[node].get("needs_recomputation", False)}
[docs] def correct_computed_status(G: DiGraph) -> None: """Clear data from nodes that depend on any node that needs recomputing.""" uncached = get_uncached_nodes(G) dependents = set() to_process = set(uncached) while to_process: node = to_process.pop() # remove and return an arbitrary element from the set successors = set(G.successors(node)) # get all nodes that depend on this node new_dependents = successors - dependents # get the nodes that haven't been processed yet dependents.update(new_dependents) # add them to list of dependents to_process.update(new_dependents) # add them to the list of nodes to process # Clear data from dependent nodes so they'll be recomputed for node in dependents: G.nodes[node]["delayed_obj"].latent_data.clear()
[docs] def analyze_dependencies(G: DiGraph) -> Dict[str, Any]: """Analyze the computation graph's dependencies and structure. Parameters ---------- G : DiGraph The directed graph to analyze. Returns ------- Dict[str, Any] A dictionary containing the following metrics: - depth : int Maximum depth of the computation graph - n_nodes : int Total number of computation nodes - n_edges : int Total number of dependencies - leaf_nodes : int Number of nodes with no dependencies - root_nodes : int Number of nodes with no dependents - is_cyclic : bool Whether the graph contains cycles - max_in_degree : int Maximum number of direct dependencies for any node - max_out_degree : int Maximum number of direct dependents for any node """ # Get root nodes (those with no predecessors) root_nodes = [node for node in G.nodes() if G.in_degree(node) == 0] # Get leaf nodes (those with no successors) leaf_nodes = [node for node in G.nodes() if G.out_degree(node) == 0] # Calculate maximum depth (longest path from any root to any leaf) max_depth = 0 for root in root_nodes: for leaf in leaf_nodes: try: path_length = len(nx.shortest_path(G, root, leaf)) - 1 max_depth = max(max_depth, path_length) except nx.NetworkXNoPath: continue return { "depth": max_depth, "n_nodes": G.number_of_nodes(), "n_edges": G.number_of_edges(), "leaf_nodes": len(leaf_nodes), "root_nodes": len(root_nodes), "is_cyclic": not nx.is_directed_acyclic_graph(G), "max_in_degree": max(dict(G.in_degree()).values(), default=0), "max_out_degree": max(dict(G.out_degree()).values(), default=0), }
[docs] def validate_no_cycles(G: DiGraph) -> None: """Validate that the computation graph has no cycles. Raises ------ ValueError If cycles are detected in the dependency graph. """ if not nx.is_directed_acyclic_graph(G): cycles = list(nx.simple_cycles(G)) cycle_str = " -> ".join(cycles[0]) # Show first cycle raise ValueError(f"Circular dependency detected in computation graph: {cycle_str}")
[docs] def get_optimized_pos(G: DiGraph, scale: float = 1.0): """Get optimized node positions with proper scaling.""" # Get base positions using hierarchical layout generations = list(nx.topological_generations(G)) pos = {} y_step = 1.0 / (len(generations) + 1) for i, gen in enumerate(generations): y = 1 - y_step * (i + 1) x_step = 1.0 / (len(gen) + 1) for j, node in enumerate(sorted(gen)): x = x_step * (j + 1) pos[node] = (x, y) # Fine-tune with spring layout, using hierarchical as starting point pos = nx.spring_layout(G, k=2, iterations=50, pos=pos, fixed=None if scale != 1.0 else pos.keys()) # Scale positions return {node: (x * scale, y * scale) for node, (x, y) in pos.items()}
[docs] def visualize(G: DiGraph, figsize: Tuple[int] = (8, 7), scale: float = 1.0, jitter: float = 0.0): plt.figure(figsize=figsize) # Get optimized positions pos = get_optimized_pos(G, scale=scale) pos = {node: (x + np.random.uniform(-jitter, jitter), y + np.random.uniform(-jitter, jitter)) for node, (x, y) in pos.items()} node_color = [] for node in G.nodes: if G.nodes[node].get("computed", False): node_color.append("forestgreen") elif G.nodes[node].get("needs_recomputation", False): node_color.append("indianred") else: node_color.append("silver") # Draw with more spacing nx.draw( G, pos, with_labels=True, node_color=node_color, node_size=1000, font_size=12, font_weight="bold", arrows=True, edge_color="gray", arrowsize=12, # Add minimum spacing between nodes min_target_margin=10, min_source_margin=10, )