Source code for liesel.model.viz

"""
Model visualization.
"""

import logging
from typing import IO, Literal

import matplotlib.pyplot as plt
import networkx as nx
from matplotlib.lines import Line2D

logger = logging.getLogger(__name__)


[docs] def plot_nodes( model, show: bool = True, save_path: str | None | IO = None, width: int = 14, height: int = 10, prog: Literal[ "dot", "circo", "fdp", "neato", "osage", "patchwork", "sfdp", "twopi" ] = "dot", ): """ Plots the nodes of a Liesel model. Parameters ---------- model The model to be plotted. show Whether to show the plot in a new window. save_path Path to save the plot. If not provided, the plot will not be saved. width Width of the plot in inches. height Height of the plot in inches. prog Layout parameter. Available layouts: circo, dot (the default), fdp, neato, \ osage, patchwork, sfdp, twopi. See Also -------- .Var.plot_vars : Plots the variables of the Liesel sub-model that terminates in this variable. .Var.plot_nodes : Plots the nodes of the Liesel sub-model that terminates in this variable. .Model.plot_vars : Plots the variables of a Liesel model. .Model.plot_nodes : Plots the nodes of a Liesel model. .viz.plot_vars : Plots the variables of a Liesel model. .viz.plot_nodes : Plots the nodes of a Liesel model. """ try: graph = model.node_graph except AttributeError: graph = model colors = ["#fc8d62" if node.outdated else "#8da0cb" for node in graph.nodes] _, axis, pos = _prepare_figure(graph, width, height, prog) nx.draw_networkx_nodes(graph, pos, node_color=colors, ax=axis) _add_labels(graph, axis, pos) _draw_edges(graph, axis, pos, False) if save_path: plt.savefig(save_path) if show: plt.show()
[docs] def plot_vars( model, show: bool = True, save_path: str | None | IO = None, width: int = 14, height: int = 10, prog: Literal[ "dot", "circo", "fdp", "neato", "osage", "patchwork", "sfdp", "twopi" ] = "dot", ): """ Plots the variables of a Liesel model. Parameters ---------- model The model to be plotted. show Whether to show the plot in a new window. save_path Path to save the plot. If not provided, the plot will not be saved. width Width of the plot in inches. height Height of the plot in inches. prog Layout parameter. Available layouts: circo, dot (the default), fdp, neato, osage, patchwork, sfdp, twopi. See Also -------- .Var.plot_vars : Plots the variables of the Liesel sub-model that terminates in this variable. .Var.plot_nodes : Plots the nodes of the Liesel sub-model that terminates in this variable. .Model.plot_vars : Plots the variables of a Liesel model. .Model.plot_nodes : Plots the nodes of a Liesel model. .viz.plot_vars : Plots the variables of a Liesel model. .viz.plot_nodes : Plots the nodes of a Liesel model. """ try: graph = model.var_graph except AttributeError: graph = model _, axis, pos = _prepare_figure(graph, width, height, prog) _add_nodes_with_distribution_to_plot(graph, axis, pos) _add_nodes_without_distribution_to_plot(graph, axis, pos) _add_labels(graph, axis, pos) _draw_edges(graph, axis, pos, True) _add_legend(axis) if save_path: plt.savefig(save_path) if show: plt.show()
def _prepare_figure(graph, width, height, prog): """Prepares the figure for plotting.""" fig, axis = plt.subplots() fig.set_size_inches(width, height) try: pos = nx.nx_pydot.pydot_layout(graph, prog=prog) except FileNotFoundError: logger.warning( "Graphviz not found in PATH. Using fallback graph layout. " "Consider installing Graphviz: https://graphviz.org/download" ) pos = nx.kamada_kawai_layout(graph.to_undirected()) except Exception as e: logger.warning( "Graphviz via pydot failed. Using fallback graph layout. " f"Raised exception: {e}" ) pos = nx.kamada_kawai_layout(graph.to_undirected()) return fig, axis, pos def _add_nodes_with_distribution_to_plot(graph, axis, pos): """Adds nodes with distribution to the figure.""" nodes_with_distribution = { node: "#fc8d62" if node.weak else "#8da0cb" for node in graph.nodes if node.has_dist } nx.draw_networkx_nodes( graph, pos, node_size=1200, node_color=nodes_with_distribution.values(), nodelist=nodes_with_distribution, node_shape="*", ax=axis, ) def _add_nodes_without_distribution_to_plot(graph, axis, pos): """Adds nodes without distribution to the figure.""" nodes_without_distribution = { node: "#fc8d62" if node.weak else "#8da0cb" for node in graph.nodes if not node.has_dist } nx.draw_networkx_nodes( graph, pos, node_size=500, node_color=nodes_without_distribution.values(), nodelist=nodes_without_distribution, node_shape="o", ax=axis, ) def _add_labels(graph, axis, pos): """Adds labels to the figure.""" labels = { node: ( f"{type(node).__name__}\n{node.name}" if node.name is not None else node.role.name ) for node in pos } nx.draw_networkx_labels(graph, pos, labels=labels, ax=axis, font_size=10) def _draw_edges(graph, axis, pos, is_var): """Adds edges to the figure.""" edges = list(graph.edges) if is_var: dist_edges = [] value_edges = [] for edge in edges: # find distribution edges if edge[1].has_dist: edge_0_output_nodes = set(edge[0].all_output_nodes()) edge_0_nodes = edge[0].nodes edge_1_input_nodes = set(edge[1].dist_node.all_input_nodes()) if bool(edge_0_output_nodes.union(edge_0_nodes) & edge_1_input_nodes): dist_edges.append(edge) # find value edges edge_0_output_nodes = set(edge[0].all_output_nodes()) edge_0_nodes = edge[0].nodes edge_1_input_nodes = set(edge[1].value_node.all_input_nodes()) if bool(edge_0_output_nodes.union(edge_0_nodes) & edge_1_input_nodes): value_edges.append(edge) edges_in_both = set(dist_edges) & set(value_edges) dist_edges = set(dist_edges) - edges_in_both value_edges = set(value_edges) - edges_in_both # assigns value_edges to edges to make it comparible with is_var=False edges = value_edges nx.draw_networkx_edges( graph, pos, edgelist=edges_in_both, edge_color="#FF0000", arrows=True, ax=axis, node_size=500, ) nx.draw_networkx_edges( graph, pos, edgelist=dist_edges, edge_color="#aaaaaa", arrows=True, ax=axis, node_size=500, ) nx.draw_networkx_edges( graph, pos, edgelist=edges, edge_color="#111111", arrows=True, ax=axis, node_size=500, ) def _add_legend(axis): """Adds a legend to the figure.""" legend_elements = [ Line2D([0], [0], color="#8da0cb", lw=4, label="Strong"), Line2D([0], [0], color="#fc8d62", lw=4, label="Weak"), Line2D( [0], [0], marker="*", color="w", label="With distribution", markerfacecolor="k", markersize=18, ), Line2D( [0], [0], marker="o", color="w", label="Without distribution", markerfacecolor="k", markersize=12, ), Line2D( [0], [0], marker=r"$\rightarrow$", color="#111111", label="Used in value", markerfacecolor="k", markersize=12, lw=0, ), Line2D( [0], [0], marker=r"$\rightarrow$", color="#AAAAAA", label="Used in distribution", markerfacecolor="k", markersize=12, lw=0, ), Line2D( [0], [0], marker=r"$\rightarrow$", color="#FF0000", label="Used in value and distribution", markerfacecolor="k", markersize=12, lw=0, ), ] axis.legend(handles=legend_elements, loc="best")