"""
The model and the graph builder.
"""
from __future__ import annotations
import logging
import math
import re
from collections import Counter
from collections.abc import Callable, Iterable, Sequence
from copy import deepcopy
from types import MappingProxyType
from typing import IO, Any, Literal, Self, TypeVar
import dill
import jax
import jax.numpy as jnp
import jax.random
import networkx as nx
import pandas as pd
from .nodes import Array, Calc, Dist, Group, Node, NodeState, Value, Var, VarValue
from .viz import plot_nodes, plot_vars
__all__ = ["GraphBuilder", "Model", "load_model", "save_model"]
logger = logging.getLogger(__name__)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Graph builder ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
NV = TypeVar("NV", Node, Var)
def _reduced_sum(*args: Array) -> Array:
"""Computes the sum after reducing arrays to scalars."""
reduced = (arg.sum() if hasattr(arg, "sum") else arg for arg in args)
return sum(reduced)
def _transform_back(var_transformed: Var) -> Calc:
"""
Creates a :class:`.Calc` mapping a transformed parameter back to
the original domain.
"""
if var_transformed.dist_node is None:
raise RuntimeError(
f"{repr(var_transformed)} must have a transformed distribution"
)
transformed_distribution = var_transformed.dist_node.distribution
def fn(at, *args, **kwargs):
bijector = transformed_distribution(*args, **kwargs).bijector
return bijector.inverse(at)
inputs = var_transformed.dist_node.inputs
kwinputs = var_transformed.dist_node.kwinputs
return Calc(fn, var_transformed.value_node, *inputs, **kwinputs) # type: ignore
[docs]
class GraphBuilder:
"""
A graph builder, used to set up a :class:`.Model`.
Constructs a model containing all nodes and variables that were added to the graph
builder and their recursive inputs.
.. important::
- In :meth:`.build_model` , the graph builder will automatically find all
**inputs** to its nodes - and the inputs to these inputs
(i.e. it finds inputs recursively).
- The **outputs** of the nodes, however, are not added to the model
automatically, so all **root nodes** need to be added explicitly.
- Root nodes are nodes that are not inputs to any other node in the graph.
The response in a regression model is an example of a root node.
The standard workflow is to create the nodes and variables, add the root var to a
graph builder, and construct a model from the graph builder. After the model has
been constructed, some methods of the graph builder are not available anymore,
because the graph is considered static.
Parameters
----------
to_float32
Whether to convert the dtype of the values of the added nodes \
from float64 to float32.
See Also
--------
:class:`.Model` : The liesel model class, representing a static graph.
:meth:`.GraphBuilder.add` : Method for adding variables and nodes to the
GraphBuilder.
:meth:`.GraphBuilder.build_model` : Method for building a model from the
GraphBuilder.
:meth:`.Var.transform` : Transforms a variable by adding a new transformed
variable as an input. This is useful for variables that are constrained to a
certain domain, e.g. positive values.
Examples
--------
We start by creating some variables:
>>> a = lsl.Var(1.0, name="a")
>>> b = lsl.Var(2.0, name="b")
>>> c = Var.new_calc(lambda x, y: x + y, a, b, name="c")
We now initialize a GraphBuilder and add the root node ``c`` to it:
>>> gb = lsl.GraphBuilder()
>>> gb.add(c)
GraphBuilder(0 nodes, 1 vars)
We are now ready to build the model:
>>> model = gb.build_model()
>>> model
Model(9 nodes, 3 vars)
Note that when :meth:`.build_model` is called, all :attr:`~.Var.weak` variables in
the graph will be updated. So the value of ``c`` is now available:
>>> c.value
Array(3., dtype=float32, weak_type=True)
The graph builder is now empty:
>>> gb.vars
[]
"""
def __init__(self, to_float32: bool = False):
self.nodes: list[Node] = []
"""The nodes that were explicitly added to the graph."""
self.vars: list[Var] = []
"""The variables that were explicitly added to the graph."""
self._log_lik_node: Node | None = None
self._log_prior_node: Node | None = None
self._log_prob_node: Node | None = None
self.to_float32 = to_float32
def _add_model_log_lik_node(self) -> GraphBuilder:
"""Adds the model log-likelihood node with the name ``_model_log_lik``."""
if self.log_lik_node:
self.add(
Calc(
lambda x: x,
self.log_lik_node,
_name="_model_log_lik",
_update_on_init=False,
)
)
return self
_, _vars = self._all_nodes_and_vars()
inputs = (v.dist_node for v in _vars if v.has_dist and v.observed)
node = Calc(
_reduced_sum, *inputs, _name="_model_log_lik", _update_on_init=False
)
self.add(node)
return self
def _add_model_log_prior_node(self) -> GraphBuilder:
"""Adds the model log-prior node with the name ``_model_log_prior``."""
if self.log_prior_node:
self.add(
Calc(
lambda x: x,
self.log_prior_node,
_name="_model_log_prior",
_update_on_init=False,
)
)
return self
_, _vars = self._all_nodes_and_vars()
inputs = (v.dist_node for v in _vars if v.has_dist and v.parameter)
node = Calc(
_reduced_sum, *inputs, _name="_model_log_prior", _update_on_init=False
)
self.add(node)
return self
def _add_model_log_prob_node(self) -> GraphBuilder:
"""Adds the model log-probability node with the name ``_model_log_prob``."""
if self.log_prob_node:
self.add(
Calc(
lambda x: x,
self.log_prob_node,
_name="_model_log_prob",
_update_on_init=False,
)
)
return self
nodes, _ = self._all_nodes_and_vars()
inputs = (n for n in nodes if isinstance(n, Dist))
node = Calc(
_reduced_sum, *inputs, _name="_model_log_prob", _update_on_init=False
)
self.add(node)
return self
def _add_model_seed_nodes(self) -> GraphBuilder:
"""Adds the model seed nodes with the names ``_model_*_seed``."""
nodes, _ = self._all_nodes_and_vars()
for node in nodes:
if node.needs_seed and not node.seed_node:
seed = Value(jax.random.PRNGKey(0), _name=f"_model_{node.name}_seed")
node.seed_node = seed
return self
def _all_nodes_and_vars(self) -> tuple[list[Node], list[Var]]:
"""
Returns all nodes and variables that were explicitly or implicitly
(as recursive inputs) added to the graph.
"""
nodes = self.nodes.copy()
nodes.extend(node for var in self.vars for node in var.nodes)
nodes = list(dict.fromkeys(nodes))
if self.log_lik_node:
nodes.append(self.log_lik_node)
if self.log_prior_node:
nodes.append(self.log_prior_node)
if self.log_prob_node:
nodes.append(self.log_prob_node)
all_nodes: list[Node] = []
all_vars: list[Var] = []
while nodes:
node = nodes.pop()
if node in all_nodes:
continue
nodes.extend(node.all_input_nodes())
all_nodes.append(node)
if node.var:
if node.var in all_vars:
continue
nodes.extend(node.var.nodes)
all_vars.append(node.var)
return all_nodes, all_vars
@staticmethod
def _do_set_missing_names(nodes_or_vars: Iterable[NV]) -> list[str]:
"""
Sets the missing names for the given nodes or variables.
Deprecated; use :meth:`.Var.ensure_name` instead.
"""
automatically_set_names = []
for nv in nodes_or_vars:
if not nv.name:
nv.ensure_name()
automatically_set_names.append(str(nv.name))
return automatically_set_names
def _set_missing_names(self) -> dict[str, list[str]]:
"""Sets the missing node and variable names."""
nodes, _vars = self._all_nodes_and_vars()
var_names_before = set([v.name for v in _vars])
for var in _vars:
var.ensure_name()
var_names_after = set([v.name for v in _vars])
auto_var_names = list(var_names_after - var_names_before)
node_names_before = set([v.name for v in nodes])
for node in nodes:
node.ensure_name()
node_names_after = set([v.name for v in nodes])
auto_node_names = list(node_names_after - node_names_before)
return {"vars": auto_var_names, "nodes": auto_node_names}
[docs]
def add(
self, *args: Node | Var | GraphBuilder, to_float32: bool | None = None
) -> GraphBuilder:
"""
Adds nodes, variables or other graph builders to the graph.
Parameters
----------
*args
The nodes, variables or graph builders to add to the graph. Note that \
the GraphBuilder will find input nodes recursively for all nodes and \
variables that are added to it, so you only need to add root nodes.
to_float32
Whether to convert the dtype of the values of the added nodes \
from float64 to float32. If ``None`` (default), the GraphBuilder's \
attribute ``GraphBuilder.to_float32``, which is set during initialization \
will be used instead.
See Also
--------
:meth:`.GraphBuilder.build_model` : Method for building a model from the \
GraphBuilder.
:meth:`.Var.transform` : Transforms a variable by adding a new
transformed variable as an input.
Examples
--------
We start by creating some variables:
>>> a = lsl.Var(1.0, name="a")
>>> b = lsl.Var(2.0, name="b")
>>> c = Var.new_calc(lambda x, y: x + y, a, b, name="c")
We now initialize a GraphBuilder and add the root node ``c`` to it:
>>> gb = lsl.GraphBuilder()
>>> gb.add(c)
GraphBuilder(0 nodes, 1 vars)
We are now ready to build the model:
>>> model = gb.build_model()
>>> model
Model(9 nodes, 3 vars)
"""
if to_float32 is None:
to_float32 = self.to_float32
for arg in args:
if isinstance(arg, Node):
self.nodes.append(arg)
elif isinstance(arg, Var):
self.vars.append(arg)
elif isinstance(arg, GraphBuilder):
self.nodes.extend(arg.nodes)
self.vars.extend(arg.vars)
else:
raise RuntimeError(f"Cannot add {type(arg).__name__} to graph builder")
if to_float32:
self.convert_dtype("float64", "float32")
return self
[docs]
def add_groups(
self, *groups: Group, to_float32: bool | None = None
) -> GraphBuilder:
"""
Adds groups to the graph.
Parameters
----------
*groups
The groups to add to the graph.
to_float32
Whether to convert the dtype of the values of the added nodes \
from float64 to float32. If ``None`` (default), the GraphBuilder's \
attribute ``GraphBuilder.to_float32``, which is set during initialization \
will be used instead.
Returns
-------
The graph builder.
"""
if to_float32 is None:
to_float32 = self.to_float32
for group in groups:
old = self.groups()
if group.name in old and group is not old[group.name]:
raise RuntimeError(
f"Group with name {repr(group.name)} already exists "
"in graph builder"
)
self.add(*group.nodes_and_vars.values())
if to_float32:
self.convert_dtype("float64", "float32")
return self
[docs]
def build_model(
self, copy: bool = False, validate_log_prob_decomposition: bool = True
) -> Model:
"""
Builds a model from the graph.
Constructs a model containing all nodes and variables that were added to the
graph builder and their recursive inputs. The outputs of the nodes are not added
to the model automatically, so the root nodes always need to be added
explicitly.
The standard workflow is to create the nodes and variables, add them to a graph
builder, and construct a model from the graph builder. After the model has been
constructed, some methods of the graph builder are not available anymore.
Parameters
----------
copy
Whether the nodes and variables should be copied when building the model.
Returns
-------
The liesel model, which is a static graph built from the GraphBuilder.
Notes
-----
If this method is called with the argument ``copy=False``, all nodes and
variables are removed from the graph builder, because most methods of the graph
builder do not work with nodes that are part of a model.
Examples
--------
We start by creating some variables:
>>> a = lsl.Var(1.0, name="a")
>>> b = lsl.Var(2.0, name="b")
>>> c = Var.new_calc(lambda x, y: x + y, a, b, name="c")
We now initialize a GraphBuilder and add the root node ``c`` to it:
>>> gb = lsl.GraphBuilder()
>>> gb.add(c)
GraphBuilder(0 nodes, 1 vars)
We are now ready to build the model:
>>> model = gb.build_model()
>>> model
Model(9 nodes, 3 vars)
Note that when :meth:`.build_model` is called, all :attr:`~.Var.weak` variables
in the graph will be updated. So the value of ``c`` is now available:
>>> c.value
Array(3., dtype=float32, weak_type=True)
The graph builder is now empty:
>>> gb.vars
[]
"""
nodes, _vars = self._all_nodes_and_vars()
if not nodes:
logger.warning("No nodes in graph builder, building an empty model")
for node in nodes:
if node.name.startswith("_model") and not node.name.endswith("_seed"):
raise RuntimeError(f"{repr(node)} has reserved name '_model*'")
gb = self.copy()
nodes, _vars = gb._all_nodes_and_vars()
for var in _vars:
if var.auto_transform:
if var.dist_node is None:
raise RuntimeError(
f"Auto-transform of {var} failed, because it has no"
" distribution, which means no default bijector can be found."
)
tname = f"{var.name}_transformed"
if tname in nodes or tname in _vars:
raise RuntimeError(
f"Auto-transform of {var} failed, because a variable of the "
f"name {tname} is already present in {gb}."
)
var.transform(bijector=None)
gb._set_missing_names()
gb._add_model_log_lik_node()
gb._add_model_log_prior_node()
gb._add_model_log_prob_node()
gb._add_model_seed_nodes()
nodes, _vars = gb._all_nodes_and_vars()
nodes_and_vars = nodes + _vars
model = Model(
nodes_and_vars,
grow=False,
copy=copy,
to_float32=self.to_float32,
validate_log_prob_decomposition=False,
)
if validate_log_prob_decomposition:
model._validate_log_prob_decomposition()
if not copy:
self.nodes.clear()
self.vars.clear()
self._log_lik_node = None
self._log_prior_node = None
self._log_prob_node = None
return model
def _discover_nodes_and_vars(self) -> list[Node | Var]:
nodes, _vars = self._all_nodes_and_vars()
if not nodes:
logger.warning("No nodes in graph builder, building an empty model")
for node in nodes:
if node.name.startswith("_model") and not node.name.endswith("_seed"):
raise RuntimeError(f"{repr(node)} has reserved name '_model*'")
gb = self.copy()
nodes, _vars = gb._all_nodes_and_vars()
for var in _vars:
if var.auto_transform:
if var.dist_node is None:
raise RuntimeError(
f"Auto-transform of {var} failed, because it has no"
" distribution, which means no default bijector can be found."
)
tname = f"{var.name}_transformed"
if tname in nodes or tname in _vars:
raise RuntimeError(
f"Auto-transform of {var} failed, because a variable of the "
f"name {tname} is already present in {gb}."
)
var.transform(bijector=None)
gb._set_missing_names()
gb._add_model_log_lik_node()
gb._add_model_log_prior_node()
gb._add_model_log_prob_node()
gb._add_model_seed_nodes()
nodes, _vars = gb._all_nodes_and_vars()
nodes_and_vars = nodes + _vars
return nodes_and_vars
[docs]
def convert_dtype(
self, from_dtype: str | jax.numpy.dtype, to_dtype: str | jax.numpy.dtype
) -> GraphBuilder:
"""
Tries to convert the node values in the graph to the specified data type.
Works for nodes whose value is an array or pytree_. Nodes whose value is of
another type are silently ignored.
.. _pytree: https://jax.readthedocs.io/en/latest/pytrees.html
Parameters
----------
from_dtype
The data type to convert from.
to_dtype
The data type to convert to.
Returns
-------
The graph builder.
"""
nodes, _ = self._all_nodes_and_vars()
class ConversionWrapper:
def __init__(self, value):
self.value = value
self.converted = False
try:
if value.dtype == from_dtype:
self.value = value.astype(to_dtype)
self.converted = True
except AttributeError:
pass
for node in nodes:
if node.model:
auto_update_before = node.model.auto_update
node.model.auto_update = False
try:
wrappers = jax.tree.map(ConversionWrapper, node.value)
value = jax.tree.map(lambda x: x.value, wrappers)
node.value = value # type: ignore # data node
converted = jax.tree.map(lambda x: x.converted, wrappers)
if any(jax.tree_util.tree_flatten(converted)[0]):
logger.info(f"Converted dtype of {repr(node)}.value")
except AttributeError:
pass
if node.model:
node.model.auto_update = auto_update_before
return self
[docs]
def copy(self) -> GraphBuilder:
"""Returns a shallow copy of the graph builder."""
gb = GraphBuilder(to_float32=self.to_float32)
gb.nodes = self.nodes.copy()
gb.vars = self.vars.copy()
gb.log_lik_node = self.log_lik_node
gb.log_prior_node = self.log_prior_node
gb.log_prob_node = self.log_prob_node
return gb
[docs]
def count_node_names(self) -> dict[str, int]:
"""Counts the number of times each node name occurs in the graph."""
nodes, _ = self._all_nodes_and_vars()
counter = Counter(node.name for node in nodes if node.name)
return dict(counter.most_common())
[docs]
def count_var_names(self) -> dict[str, int]:
"""Counts the number of times each variable name occurs in the graph."""
_, _vars = self._all_nodes_and_vars()
counter = Counter(var.name for var in _vars if var.name)
return dict(counter.most_common())
[docs]
def groups(self) -> dict[str, Group]:
"""Collects the groups from all nodes and variables."""
nodes, _vars = self._all_nodes_and_vars()
g1 = {g.name: g for n in nodes for g in n.groups.values()}
g2 = {g.name: g for v in _vars for g in v.groups.values()}
return g1 | g2
@property
def log_lik_node(self) -> Node | None:
"""User-defined log-likelihood node, if there is one."""
return self._log_lik_node
@log_lik_node.setter
def log_lik_node(self, log_lik_node: Node | None):
if log_lik_node and not isinstance(log_lik_node, Node):
raise RuntimeError("The log-likelihood node must be a node, not var")
self._log_lik_node = log_lik_node
@property
def log_prior_node(self) -> Node | None:
"""User-defined log-prior node, if there is one."""
return self._log_prior_node
@log_prior_node.setter
def log_prior_node(self, log_prior_node: Node | None):
if log_prior_node and not isinstance(log_prior_node, Node):
raise RuntimeError("The log-prior node must be a node, not var")
self._log_prior_node = log_prior_node
@property
def log_prob_node(self) -> Node | None:
"""User-defined log-probability node, if there is one."""
return self._log_prob_node
@log_prob_node.setter
def log_prob_node(self, log_prob_node: Node | None):
if log_prob_node and not isinstance(log_prob_node, Node):
raise RuntimeError("The log-probability node must be a node, not var")
self._log_prob_node = log_prob_node
[docs]
def plot_nodes(self) -> GraphBuilder:
"""
Plots all nodes in the graph.
See Also
--------
:meth:`.viz.plot_nodes` : The function used to plot the nodes.
"""
nodes, _vars = self._all_nodes_and_vars()
nodes_and_vars = nodes + _vars
self._set_missing_names()
model = Model(nodes_and_vars, grow=False)
plot_nodes(model)
model.pop_nodes_and_vars()
return self
[docs]
def plot_vars(self) -> GraphBuilder:
"""
Plots all variables in the graph.
Returns
-------
The graph builder.
See Also
--------
:meth:`.viz.plot_vars` : The function used to plot the variables.
"""
nodes, _vars = self._all_nodes_and_vars()
nodes_and_vars = nodes + _vars
self._set_missing_names()
model = Model(nodes_and_vars, grow=False)
plot_vars(model)
model.pop_nodes_and_vars()
return self
[docs]
def rename(self, pattern: str, replacement: str) -> GraphBuilder:
"""Renames all nodes and variables in the graph."""
self.rename_nodes(pattern, replacement)
self.rename_vars(pattern, replacement)
return self
[docs]
def rename_nodes(self, pattern: str, replacement: str) -> GraphBuilder:
"""Renames all nodes in the graph."""
nodes, _ = self._all_nodes_and_vars()
for node in nodes:
if node.name:
node.name = re.sub(pattern, replacement, node.name)
return self
[docs]
def rename_vars(self, pattern: str, replacement: str) -> GraphBuilder:
"""Renames all variables in the graph."""
_, _vars = self._all_nodes_and_vars()
for var in _vars:
if var.name:
var.name = re.sub(pattern, replacement, var.name)
return self
[docs]
def replace_node(self, old: Node, new: Node) -> GraphBuilder:
"""Replaces the ``old`` with the ``new`` node."""
self.nodes = [new if x is old else x for x in self.nodes]
nodes, _ = self._all_nodes_and_vars()
for node in nodes:
inputs = [new if x is old else x for x in node.inputs]
kwinputs = {k: new if v is old else v for k, v in node.kwinputs.items()}
node.set_inputs(*inputs, **kwinputs)
return self
[docs]
def replace_var(self, old: Var, new: Var) -> GraphBuilder:
"""Replaces the ``old`` with the ``new`` variable."""
self.vars = [new if x is old else x for x in self.vars]
self.replace_node(old.var_value_node, new.var_value_node)
self.replace_node(old.value_node, new.value_node)
if old.dist_node:
if not new.dist_node:
raise RuntimeError(
f"Cannot replace {repr(old)} with distribution "
f"with {repr(new)} without distribution"
)
self.replace_node(old.dist_node, new.dist_node)
return self
[docs]
def update(self) -> GraphBuilder:
"""
Updates all nodes in the graph.
Returns
-------
The graph builder.
"""
nodes, _vars = self._all_nodes_and_vars()
nodes_and_vars = nodes + _vars
self._set_missing_names()
model = Model(nodes_and_vars, grow=False)
model.pop_nodes_and_vars()
return self
def __repr__(self) -> str:
brackets = f"({len(self.nodes)} nodes, {len(self.vars)} vars)"
return type(self).__name__ + brackets
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Model ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]
class Model:
"""
A probabilistic graphical model.
Parameters
----------
nodes_and_vars
The nodes and variables to include in the model.
grow
Whether a :class:`.GraphBuilder` should be used to grow the model (finding \
the recursive inputs of the nodes and variables), and to add the model nodes.
copy
Whether the nodes and variables should be copied upon initialization.
to_float32
Whether to convert the dtype of the values of the added nodes \
from float64 to float32. Only takes effect if ``grow=True``.
See Also
--------
.Var.new_obs : Initializes a strong variable that holds observed data.
.Var.new_param : Initializes a strong variable that acts as a model parameter.
.Var.new_calc :
Initializes a weak variable that is a function of other variables.
.Var.new_value : Initializes a strong variable without a distribution.
:class:`.GraphBuilder` :
A graph builder, which can be used to set up and manipulate a model if you need
more control.
Examples
--------
Here, we set up a basic model based on three variables:
>>> a = lsl.Var.new_value(1.0, name="a")
>>> b = lsl.Var.new_value(2.0, name="b")
>>> c = lsl.Var.new_calc(lambda x, y: x + y, a, b, name="c")
We now build a model:
>>> model = lsl.Model([c])
>>> model
Model(9 nodes, 3 vars)
"""
def __init__(
self,
*nodes_and_vars: Node | Var | Iterable[Node | Var],
grow: bool = True,
copy: bool = False,
to_float32: bool = True,
validate_log_prob_decomposition: bool = True,
):
# this is for backwards compatibility: Old code, in which an iterable of
# nodes or vars is passed, still works.
_nodes_and_vars: list[Node | Var] = []
for nv in nodes_and_vars:
if isinstance(nv, Node | Var):
_nodes_and_vars.append(nv)
else:
try:
iter(nv)
_nodes_and_vars.extend(list(nv))
except TypeError:
pass
nodes_and_vars_list = _nodes_and_vars
# end of compatibility block
self._to_float32 = to_float32
self._auto_update = True
self._update_graph(nodes_and_vars_list, copy=copy, grow=grow)
self.graph_outdated = False
self.update_graph_lazily = False
self.locked = False
self.seed_nodes_and_vars = nodes_and_vars_list
if validate_log_prob_decomposition:
self._validate_log_prob_decomposition()
def _validate_log_prob_decomposition(self):
consistent = jnp.allclose(self.log_prob, self.log_prior + self.log_lik)
if not consistent:
logger.warning(
"Inconsistent log prob decomposition: "
f"Model.log_prob={self.log_prob:.2f} ≠ "
f"(Model.log_lik={self.log_lik:.2f} + "
f"Model.log_prior={self.log_prior:.2f}). "
)
for var in self.vars.values():
if var.dist_node is not None:
if not var.parameter and not var.observed:
logger.warning(
f"{var} has a distribution but "
"Var.parameter=False and Var.observed=False."
)
@property
def graph_outdated(self) -> bool:
"""
Whether the model graph is outdated.
The model graph can be updated with :meth:`.update_graph` or
:meth:`.rebuild_graph`.
"""
return self._graph_outdated
@graph_outdated.setter
def graph_outdated(self, value: bool) -> None:
if not isinstance(value, bool):
raise TypeError(f"Value must be bool, got {type(value)}.")
self._graph_outdated = value
@property
def locked(self) -> bool:
"""
Whether the model graph is locked.
If the model graph is locked, the in- and outputs, names, and distributions of
variables and nodes in the model graph cannot be changed.
"""
return self._locked
@locked.setter
def locked(self, value: bool) -> None:
if not isinstance(value, bool):
raise TypeError(f"Value must be bool, got {type(value)}.")
self._locked = value
@property
def seed_nodes_and_vars(self) -> list[Node | Var]:
"""
The seed nodes and variables passed to the model during initialization.
"""
return self._seed_nodes_and_vars
@seed_nodes_and_vars.setter
def seed_nodes_and_vars(self, value: list[Node | Var]) -> None:
self._seed_nodes_and_vars = value
def _replace_node(self, old: Node, new: Node) -> Self:
"""Replaces the ``old`` with the ``new`` node."""
if not isinstance(old, Node):
raise TypeError(f"'old' must be of type Node, got {type(old).__name__}.")
if not isinstance(new, Node):
raise TypeError(f"'new' must be of type Node, got {type(new).__name__}.")
nodes = [new if x is old else x for x in self.nodes.values()]
GraphBuilder._do_set_missing_names(nodes)
for node in nodes:
inputs = [new if x is old else x for x in node.inputs]
kwinputs = {k: new if v is old else v for k, v in node.kwinputs.items()}
node.set_inputs(*inputs, **kwinputs)
self._nodes = {nd.name: nd for nd in nodes}
self.graph_outdated = True
return self
def _replace_var_with_var(self, old: Var, new: Var) -> Self:
"""Replaces the ``old`` with the ``new`` variable."""
if not isinstance(old, Var):
raise TypeError(f"'old' must be of type Var, got {type(old).__name__}.")
if not isinstance(new, Var):
raise TypeError(f"'new' must be of type Var, got {type(new).__name__}.")
vars_ = [new if x is old else x for x in self.vars.values()]
GraphBuilder._do_set_missing_names(vars_)
self._vars = {v.name: v for v in vars_}
if old.dist_node:
if not new.dist_node:
lazy_before = self.update_graph_lazily
self.update_graph_lazily = True
auto_update_before = self.auto_update
self.auto_update = False
old.dist_node = None
self.auto_update = auto_update_before
self.update_graph_lazily = lazy_before
else:
self._replace_node(old.dist_node, new.dist_node)
self._replace_node(old.var_value_node, new.var_value_node)
self._replace_node(old.value_node, new.value_node)
self.graph_outdated = True
return self
def _replace_var_with_node(self, old: Var, new: Node) -> Self:
"""Replaces the ``old`` with the ``new`` variable."""
if not isinstance(old, Var):
raise TypeError(f"'old' must be of type Var, got {type(old).__name__}.")
if not isinstance(new, Node):
raise TypeError(f"'new' must be of type Node, got {type(new).__name__}.")
vars_ = [x for x in self.vars.values() if x is not old]
self._vars = {v.name: v for v in vars_}
self._replace_node(old.var_value_node, new)
self.graph_outdated = True
return self
[docs]
def replace(
self, old: str | Var, new: Node | Var | float | int | jax.Array
) -> Self:
"""
Replaces the ``old`` with the ``new`` node or variable.
Examples
--------
>>> import liesel.model as lsl
>>> x1 = lsl.Var.new_obs(1.0, name="x1")
>>> x2 = lsl.Var.new_obs(1.0, name="x2")
>>> m = lsl.Model(x1)
>>> list(m.vars)
['x1']
>>> m.replace("x1", x2)
Model(5 nodes, 1 vars)
>>> list(m.vars)
['x2']
"""
if old is new:
return self
if isinstance(old, str):
if old in self.nodes:
raise TypeError(f"{old=} must be of type Var, got a Node.")
elif old in self.vars:
old_nv = self.vars[old]
else:
raise KeyError(f"{old=} not found in the model.")
else:
old_nv = old
same_name = False
if isinstance(old_nv, Var):
if not isinstance(new, Var | Node):
new = Var.new_value(new)
new.name = old_nv.name
if isinstance(new, Var):
if new.model and new.model is not self:
raise RuntimeError(f"{new} can only be part of one model")
same_name = old_nv.name == new.name
if same_name:
new.name = new.name + "__tmp_new__"
self._replace_var_with_var(old_nv, new)
elif isinstance(new, Node):
if new.model and new.model is not self:
raise RuntimeError(f"{new} can only be part of one model")
same_name = old_nv.name == new.name
if same_name:
new.name = new.name + "__tmp_new__"
self._replace_var_with_node(old_nv, new)
else:
raise RuntimeError("Unexpected unknown problem in Model.replace().")
else:
raise TypeError(f"{old=} must be of type Var, got {type(old_nv)}.")
if not self.update_graph_lazily:
self.update_graph()
old_is_still_in_model = old_nv.name in self.nodes or old_nv.name in self.vars
if old_is_still_in_model:
self._remove_disconnected_parental_submodel(old)
if same_name and hasattr(new, "name"):
new.name = old_nv.name
if old_nv in self.seed_nodes_and_vars:
self.seed_nodes_and_vars.remove(old_nv)
if new not in self.seed_nodes_and_vars:
self.seed_nodes_and_vars.append(new)
return self
def _remove_disconnected_parental_submodel(self, of: str | Node | Var) -> Self:
"""
Removes the variable/node supplied as ``of`` and its inputs, then updates the
graph. If any of the removed variables/nodes is still an input to any of the
remaining variables/nodes in the model graph, they are re-added through the
update. Otherwise, they are removed from the model graph.
Note that, if any of the removed variables/nodes are in
:attr:`.seed_nodes_and_vars`, they remain in :attr:`.seed_nodes_and_vars` and
would be re-added to the graph if :meth:`.rebuild_graph` is called without
arguments.
"""
if isinstance(of, str):
if of in self.nodes:
of = self.nodes[of]
elif of in self.vars:
of = self.vars[of]
else:
raise KeyError(f"{of=} not found in the model.")
else:
of = of
p_nodes_and_vars = set()
if isinstance(of, Var):
p_nodes_and_vars.update(nx.ancestors(self.var_graph, of))
p_nodes_and_vars.update(nx.ancestors(self.node_graph, of.var_value_node))
p_nodes_and_vars.update(nx.ancestors(self.node_graph, of.value_node))
p_nodes_and_vars.add(of)
p_nodes_and_vars.add(of.var_value_node)
p_nodes_and_vars.add(of.value_node)
if of.dist_node:
p_nodes_and_vars.update(nx.ancestors(self.node_graph, of.dist_node))
p_nodes_and_vars.add(of.dist_node)
else:
p_nodes_and_vars.update(nx.ancestors(self.node_graph, of))
p_nodes_and_vars.add(of)
for nv in p_nodes_and_vars.copy():
if isinstance(nv, VarValue):
p_nodes_and_vars.add(nv.var)
self._nodes = {
n.name: n for n in self.nodes.values() if n not in p_nodes_and_vars
}
self._vars = {
v.name: v for v in self.vars.values() if v not in p_nodes_and_vars
}
self.graph_outdated = True
if not self.update_graph_lazily:
self.update_graph()
return self
def _get_singletons(self, graph: nx.DiGraph):
G = graph
singletons1 = [n for n, d in G.degree() if d == 0]
singletons2 = [
n for n in G.nodes() if G.in_degree(n) == 0 and G.out_degree(n) == 0
]
singletons = set(singletons1 + singletons2)
return [
nd
for nd in singletons
if isinstance(nd, Var) or not nd.name.startswith("_model")
]
def _ensure_unlocked(self):
if self.locked:
raise RuntimeError(
f"{self} is locked, cannot rebuild graph."
"To allow for changes to the model, you can set the Model.locked flag "
"to False. "
"ATTENTION: Note that, from v0.5, the default state for models will "
"be Model.locked = False, so do not rely on this error if you are "
"using the default."
)
[docs]
def modify_names(self, fn: Callable[[str], str]):
"""
Modifies the names of all variables and nodes in the model according to the
supplied function.
Examples
--------
>>> import liesel.model as lsl
>>> x1 = lsl.Var.new_obs(1.0, name="x1")
>>> x2 = lsl.Var.new_obs(1.0, name="x2")
>>> m = lsl.Model(x1, x2)
>>> list(m.vars)
['x2', 'x1']
>>> m.modify_names(lambda x: x.replace("x", "y"))
Model(7 nodes, 2 vars)
>>> list(m.vars)
['y1', 'y2']
"""
update_graph_lazily = self.update_graph_lazily
self.update_graph_lazily = True
nv_dict = self.nodes | self.vars
for nv in nv_dict.values():
nv.name = fn(nv.name)
self.update_graph()
self.update_graph_lazily = update_graph_lazily
return self
[docs]
def prefix_names(self, prefix: str) -> Self:
"""
Adds a prefix to the names of all variables and nodes in the model.
Examples
--------
>>> import liesel.model as lsl
>>> x1 = lsl.Var.new_obs(1.0, name="x1")
>>> x2 = lsl.Var.new_obs(1.0, name="x2")
>>> m = lsl.Model(x1, x2)
>>> list(m.vars)
['x2', 'x1']
>>> m.prefix_names("m.")
Model(10 nodes, 2 vars)
>>> list(m.vars)
['m.x1', 'm.x2']
"""
return self.modify_names(lambda name: prefix + name)
[docs]
def rebuild_graph(self, *vars_nodes_and_names: Var | Node | str) -> Self:
"""
Rebuilds the model graph by re-discovering the outputs of all supplied nodes and
variables. Also accepts strings, which must be the names of nodes or variables
currently in the model.
If no nodes or variables are supplied, uses the :attr:`.seed_nodes_and_vars`
supplied to the model during initialization.
Examples
--------
Here, a variable that was added manually but not added to the seed variables
is dropped when rebuilding the graph:
>>> import liesel.model as lsl
>>> x1 = lsl.Var.new_obs(1.0, name="x1")
>>> x2 = lsl.Var.new_obs(1.0, name="x2")
>>> m = lsl.Model(x1)
>>> m.add(x2, add_to_seeds=False)
Model(7 nodes, 2 vars)
>>> list(m.vars)
['x2', 'x1']
>>> m.rebuild_graph()
Model(5 nodes, 1 vars)
>>> list(m.vars)
['x1']
Here, the model is empty after dropping the singletons, but gets restored
when rebuilding, because the whole graph can be rediscovered from the seed
nodes.
>>> x1 = lsl.Var.new_obs(1.0, name="x1")
>>> x2 = lsl.Var.new_obs(1.0, name="x2")
>>> m = lsl.Model(x1, x2)
>>> m.drop_singletons()
Model(3 nodes, 0 vars)
>>> list(m.vars)
[]
>>> m.seed_nodes_and_vars
[Var(name="x1"), Var(name="x2")]
>>> m.rebuild_graph()
Model(7 nodes, 2 vars)
>>> list(m.vars)
['x2', 'x1']
"""
self._ensure_unlocked()
vars_nodes: list[Var | Node] = []
if vars_nodes_and_names:
for nvn in vars_nodes_and_names:
if isinstance(nvn, str):
if nvn in self.vars:
vars_nodes.append(self.vars[nvn])
elif nvn in self.nodes:
vars_nodes.append(self.nodes[nvn])
else:
raise KeyError(
f"No Node or Var with anme '{nvn}' found in model."
)
else:
vars_nodes.append(nvn)
else:
vars_nodes += self.seed_nodes_and_vars
self._update_graph(vars_nodes)
self.graph_outdated = False
return self
[docs]
def update_graph(self) -> Self:
"""
Updates the model graph by re-discovering the outputs of all nodes and variables
in the graph.
If the updated graph contains singleton nodes, i.e. nodes without inputs or
outputs, these nodes are dropped from the graph. Singleton variables are not
dropped, but can be dropped manually by calling :meth:`.drop_singletons`.
"""
return self.add() # adding with empty list means simply updating
[docs]
def add(
self, *args: Var | Node | Model, copy: bool = False, add_to_seeds: bool = True
) -> Self:
"""
Adds a variable number of variables or nodes to this model.
Parameters
-----------
*args
:class:`.Var` or :class:`.Node` objects to add. Other :class:`.Model`
instances are also accepted, in which case all nodes and variables from
the supplied models are added to this model. Duplicate names are not
allowed.
copy
If ``True``, the supplied nodes, variables, and models are copied before
adding them to this model.
add_to_seeds
If ``True``, the supplied nodes and variables, and the seed nodes and
variables of supplied models, are added to this model's seed nodes and
variables.
See Also
--------
.Model.seed_nodes_and_vars : Seed nodes and variables.
Notes
-----
If ``copy=False``, any supplied model will be empty after adding its contents
to the calling model.
Examples
--------
Adding a variable:
>>> import liesel.model as lsl
>>> x1 = lsl.Var.new_obs(1.0, name="x1")
>>> x2 = lsl.Var.new_obs(1.0, name="x2")
>>> m = lsl.Model(x1)
>>> m.add(x2)
Model(7 nodes, 2 vars)
>>> list(m.vars)
['x2', 'x1']
>>> m.seed_nodes_and_vars
[Var(name="x1"), Var(name="x2")]
Adding a model:
>>> import liesel.model as lsl
>>> x1 = lsl.Var.new_obs(1.0, name="x1")
>>> x2 = lsl.Var.new_obs(1.0, name="x2")
>>> m1 = lsl.Model(x1)
>>> m2 = lsl.Model(x2)
>>> m1.add(m2)
Model(7 nodes, 2 vars)
>>> list(m1.vars)
['x2', 'x1']
>>> list(m2.vars)
[]
>>> m1.seed_nodes_and_vars
[Var(name="x1"), Var(name="x2")]
"""
models = [m for m in args if isinstance(m, Model)]
nv = [nv for nv in args if isinstance(nv, Var | Node)]
if not (len(models) + len(nv)) == len(args):
unexpected = [x for x in args if x not in models and x not in nv]
raise TypeError(f"Received arguments of unexpected types: {unexpected}")
nodes = [nd for nd in self.nodes.values() if not nd.name.startswith("_model")]
vars_ = [nd for nd in self.vars.values() if not nd.name.startswith("_model")]
existing_nodes_and_vars = nodes + vars_
model_vars_and_nodes: list[Var | Node] = []
for _model in models:
# remove model nodes
_nodes_list = [
nd for nd in _model.nodes.values() if not nd.name.startswith("_model")
]
_vars_list = [
nd for nd in _model.vars.values() if not nd.name.startswith("_model")
]
model_vars_and_nodes += _nodes_list
model_vars_and_nodes += _vars_list
self._check_for_duplicates(existing_nodes_and_vars + nv + model_vars_and_nodes)
self._add_vars_and_nodes(*nv, copy=copy, add_to_seeds=add_to_seeds)
self._add_models(*models, copy=copy, add_to_seeds=add_to_seeds)
return self
def _add_vars_and_nodes(
self, *vars_and_nodes: Var | Node, copy: bool = False, add_to_seeds: bool = True
) -> Self:
"""
Adds a variable number of :class:`.Var`s and/or :class:`.Node`s to the model.
If ``add_to_seeds``, the nodes and variables are also added to the calling
model's seed nodes and variables.
"""
self._ensure_unlocked()
nodes = [nd for nd in self.nodes.values() if not nd.name.startswith("_model")]
vars_ = [nd for nd in self.vars.values() if not nd.name.startswith("_model")]
existing_nodes_and_vars = nodes + vars_
vn_list = list(vars_and_nodes)
if copy:
vn_list = deepcopy(vn_list)
for nv in vn_list:
if isinstance(nv, Node):
nv._unset_model()
existing_nodes_and_vars += vn_list
self._update_graph(existing_nodes_and_vars)
self.graph_outdated = False
if add_to_seeds:
self.seed_nodes_and_vars += vn_list
return self
def _add_models(
self, *models: Model, copy: bool = False, add_to_seeds: bool = True
) -> Self:
"""
Adds the seed variables and nodes from the supplied models to this model.
If ``copy=False``, the variables and nodes are removed from their original
models, leaving them empty. If ``copy=True``, variables and nodes are copied
instead.
If ``add_to_seeds``, the nodes and variables are also added to the calling
model's seed nodes and variables.
"""
for model in models:
if copy:
model = model.copy()
if add_to_seeds:
self.seed_nodes_and_vars += model.seed_nodes_and_vars
nodes_, vars_ = model.pop_nodes_and_vars()
nodes_and_vars = list(nodes_.values()) + list(vars_.values())
self._add_vars_and_nodes(*nodes_and_vars, add_to_seeds=False)
return self
[docs]
def join_by_all(self, model: Model, copy: bool = False) -> Self:
"""
Joins a second model into this one by all overlapping variable names.
See Also
--------
.Model.join : Join by no or a manually supplied sequence of overlapping names.
Examples
--------
>>> import liesel.model as lsl
>>> x1 = lsl.Var.new_obs(1.0, name="x")
>>> x2 = lsl.Var.new_obs(1.0, name="x")
>>> y = lsl.Var.new_calc(lambda x: x, x2, name="y")
>>> m1 = lsl.Model(x1)
>>> m2 = lsl.Model(x2, y)
>>> m1.join_by_all(m2)
Model(7 nodes, 2 vars)
>>> list(m1.vars)
['x', 'y']
>>> list(m2.vars)
[]
>>> y.value_node[0] is x1
True
>>> m1.seed_nodes_and_vars
[Var(name="x"), Var(name="y")]
"""
by = [name for name in self.vars if name in model.vars]
by = [name for name in by if not name.startswith("_model")]
return self.join(model, by=by, copy=copy)
[docs]
def join(
self,
model: Model,
by: Sequence[str] | None = None,
copy: bool = False,
suffix: tuple[str, str] = (".x", ".y"),
) -> Self:
"""
Joins a second model into this one.
Parameters
----------
model
The second model to join into this one.
by
Sequence of variable names to join on.
copy
Whether to copy the second model before joining.
suffix
Suffixes to use for renaming of variables with duplicate names.
See Also
--------
.Model.join_by_all : Automatically join by all overlapping names.
Notes
-----
If there are variables with duplicate names, the method's behavior depends on
``by``:
1. If the duplicate name is supplied in ``by``, then the variables from
``self`` (i.e., the model on which the method is called) are used. They
replace the respective variables in the second model.
2. If the duplicate name is not supplied in ``by``, then the duplicate names
are resolved by renaming the respective variables from both models using
``suffix``.
The seed nodes of the second model are added to the calling model's seed nodes.
Examples
--------
Nothing supplied in ``by``, duplicate names are resolved by renaming:
>>> import liesel.model as lsl
>>> x1 = lsl.Var.new_obs(1.0, name="x")
>>> x2 = lsl.Var.new_obs(1.0, name="x")
>>> y = lsl.Var.new_calc(lambda x: x, x2, name="y")
>>> m1 = lsl.Model(x1)
>>> m2 = lsl.Model(x2, y)
>>> m1.join(m2)
Model(9 nodes, 3 vars)
>>> list(m1.vars)
['x.y', 'x.x', 'y']
>>> list(m2.vars)
[]
>>> m1.seed_nodes_and_vars
[Var(name="x.x"), Var(name="x.y"), Var(name="y")]
Joining on 'x':
>>> import liesel.model as lsl
>>> x1 = lsl.Var.new_obs(1.0, name="x")
>>> x2 = lsl.Var.new_obs(1.0, name="x")
>>> y = lsl.Var.new_calc(lambda x: x, x2, name="y")
>>> m1 = lsl.Model(x1)
>>> m2 = lsl.Model(x2, y)
>>> m1.join(m2, by=["x"])
Model(7 nodes, 2 vars)
>>> list(m1.vars)
['x', 'y']
>>> list(m2.vars)
[]
>>> y.value_node[0] is x1
True
>>> m1.seed_nodes_and_vars
[Var(name="x"), Var(name="y")]
"""
_vars_and_nodes: list[Var | Node] = []
by = by or []
if copy:
model = model.copy()
nodes_model, vars_model = model.pop_nodes_and_vars()
_nodes_list = [
nd for nd in nodes_model.values() if not nd.name.startswith("_model")
]
_vars_list = [
nd for nd in vars_model.values() if not nd.name.startswith("_model")
]
_vars_and_nodes += _nodes_list
_vars_and_nodes += _vars_list
_vars_and_nodes_names = [nv.name for nv in _vars_and_nodes]
nodes = {
nd.name: nd
for nd in self.nodes.values()
if not nd.name.startswith("_model")
}
vars_ = {
nd.name: nd for nd in self.vars.values() if not nd.name.startswith("_model")
}
for name_ in by:
if not isinstance(name_, str):
raise TypeError(
"The argument 'by' must be a sequence of strings or empty."
)
if name_ not in vars_:
raise ValueError(f"No variable of name '{name_}' found in self.")
if name_ not in _vars_and_nodes_names:
raise ValueError(f"No variable of name '{name_}' found model.")
replacements = {}
for nv in _vars_list:
if nv.name in vars_:
if nv.name in by:
dup = vars_[nv.name]
nv.name = nv.name + suffix[1]
replacements[nv.name] = dup
else:
logger.info(
f"{nv.name} found in both models. Renaming "
f"to '{nv.name}{suffix[0]}' "
f"and '{nv.name}{suffix[1]}'."
)
dup = vars_[nv.name]
dup_name = dup.name
dup.name = dup.name + suffix[0]
nv.name = nv.name + suffix[1]
if dup.var_value_node.name == dup_name + "_var_value":
dup.var_value_node.name = dup.name + "_var_value"
if dup.value_node.name == dup_name + "_value":
dup.value_node.name = dup.name + "_value"
if nv.var_value_node.name == dup_name + "_var_value":
nv.var_value_node.name = nv.name + "_var_value"
if nv.value_node.name == dup_name + "_value":
nv.value_node.name = nv.name + "_value"
renamings_to_reverse = {} # key: value (temp: old)
for nd in _nodes_list:
if nd.name in nodes:
nd_name = nd.name
logger.debug(
f"{nd.name} found in both models. "
f"Renaming to '{nd.name}{suffix[0]}' and '{nd.name}{suffix[1]}'."
)
dup_nd = nodes[nd.name]
dup_nd.name = dup_nd.name + suffix[0]
nd.name = nd.name + suffix[1]
if nd_name.removesuffix("_value") in by:
renamings_to_reverse[dup_nd.name] = nd_name
if nd_name.removesuffix("_var_value") in by:
renamings_to_reverse[dup_nd.name] = nd_name
self._add_vars_and_nodes(*_vars_and_nodes, add_to_seeds=False)
self.seed_nodes_and_vars += model.seed_nodes_and_vars # manual update
replacement_names = list(
set([nv.name for nv in replacements.values() if isinstance(nv, Var)])
)
if replacements:
logger.info(f"Joining by: {', '.join(replacement_names)}")
for old, new in replacements.items():
if old in self.vars:
self.replace(old, new)
for temp, old in renamings_to_reverse.items():
self.nodes[temp].name = old
return self
[docs]
def drop_singletons(self) -> Self:
"""
Drops any singleton nodes and variables, i.e. nodes or variables that have
neither outputs nor inputs.
Notes
-----
While the :class:`.Var.value_node` and :class:`.Var.var_value_node` are no
singletons in the *node graph*, they are still dropped if they belong to a
singleton :class:`.Var`.
Examples
--------
>>> import liesel.model as lsl
>>> x1 = lsl.Var.new_obs(1.0, name="x1")
>>> x2 = lsl.Var.new_obs(1.0, name="x2")
>>> m = lsl.Model(x1, x2)
>>> m.drop_singletons()
Model(3 nodes, 0 vars)
>>> list(m.vars)
[]
"""
singleton_vars = self._get_singletons(self._var_graph)
for nv in singleton_vars:
if nv.name.startswith("_model"):
continue
self._vars.pop(nv.name, None)
self._nodes.pop(nv.var_value_node.name, None)
self._nodes.pop(nv.value_node.name, None)
singleton_nodes = self._get_singletons(self._node_graph)
for nv in singleton_nodes:
if nv.name.startswith("_model"):
continue
self._nodes.pop(nv.name, None)
self.update_graph()
return self
def _update_graph(
self,
nodes_and_vars: Iterable[Node | Var],
copy: bool = False,
grow: bool = True,
) -> Self:
if grow:
nodes_and_vars = self._discover_nodes_and_vars(nodes_and_vars)
nodes, _vars = self._check_for_duplicates(nodes_and_vars)
self._nodes = {n.name: n for n in nodes}
self._vars = {v.name: v for v in _vars}
if copy:
self._nodes, self._vars = deepcopy((self._nodes, self._vars))
for node in self._nodes.values():
node._clear_outputs()
if node.model is not self:
node._set_model(self)
for node in self._nodes.values():
for _input in node.all_input_nodes():
_input._add_output(node)
self._node_graph = self._build_node_graph(self._nodes.values())
self._var_graph = self._build_var_graph(self._vars.values())
self._sorted_nodes = list(nx.topological_sort(self._node_graph))
self._sorted_vars = list(nx.topological_sort(self._var_graph))
self._nodes = {n.name: n for n in self._sorted_nodes}
self._vars = {n.name: n for n in self._sorted_vars}
self._simulation_graph = self._build_simulation_graph(self._nodes.values())
self._simulation_nodes = list(nx.topological_sort(self._simulation_graph))
self._seed_nodes = []
for node in self._sorted_nodes:
if node.name.startswith("_model_") and node.name.endswith("_seed"):
self._seed_nodes.append(node)
if self.auto_update:
node.update()
return self
def _discover_nodes_and_vars(
self, nodes_and_vars: Iterable[Node | Var]
) -> list[Node | Var]:
gb = GraphBuilder(to_float32=self._to_float32).add(*nodes_and_vars)
nodes_and_vars = gb._discover_nodes_and_vars()
return nodes_and_vars
@staticmethod
def _check_for_duplicates(
nodes_and_vars: Iterable[Node | Var],
) -> tuple[list[Node], list[Var]]:
"""
Errors if there are two or more nodes/variables with the same name.
"""
nodes = [nv for nv in nodes_and_vars if isinstance(nv, Node)]
nodes = list(dict.fromkeys(nodes).keys())
counts = Counter(n.name for n in nodes)
dups = [k for k, v in counts.items() if v > 1]
if dups:
raise RuntimeError(f"Duplicate node names: {', '.join(dups)}")
_vars = [nv for nv in nodes_and_vars if isinstance(nv, Var)]
_vars = list(dict.fromkeys(_vars).keys())
counts = Counter(v.name for v in _vars)
dups = [k for k, v in counts.items() if v > 1]
if dups:
raise RuntimeError(f"Duplicate variable names: {', '.join(dups)}")
groups = [g for nv in nodes_and_vars for g in nv.groups.values()]
groups = list(dict.fromkeys(groups).keys())
counts = Counter(g.name for g in groups)
dups = [k for k, v in counts.items() if v > 1]
if dups:
raise RuntimeError(f"Duplicate group names: {', '.join(dups)}")
return nodes, _vars
@staticmethod
def _build_node_graph(nodes: Iterable[Node]) -> nx.DiGraph:
"""Builds the directed graph of the model nodes."""
edges: list[tuple[Node, Node]] = []
for node in nodes:
edges.extend((_input, node) for _input in node.all_input_nodes())
graph = nx.DiGraph(edges)
graph.add_nodes_from(nodes)
return graph
@staticmethod
def _build_simulation_graph(nodes: Iterable[Node]) -> nx.DiGraph:
"""Builds the simulation graph of the model nodes."""
edges: list[tuple[Node, Node]] = []
for node in nodes:
for _input in node.all_input_nodes():
if isinstance(node, Dist) and _input is node.at:
edges.append((node, _input))
else:
edges.append((_input, node))
graph = nx.DiGraph(edges)
graph.add_nodes_from(nodes)
return graph
@staticmethod
def _build_var_graph(_vars: Iterable[Var]) -> nx.DiGraph:
"""Builds the directed graph of the model variables."""
edges: list[tuple[Var, Var]] = []
for var in _vars:
edges.extend((_input, var) for _input in var.all_input_vars())
graph = nx.DiGraph(edges)
graph.add_nodes_from(_vars)
return graph
def _copy_computational_model(self) -> Model:
"""Returns a deep copy of the model with all node states cleared."""
backup = self.state
for node in self._nodes.values():
node.clear_state()
empty = deepcopy(self)
self.state = backup
return empty
def _recursive_inputs(self, name: str) -> list[Node]:
"""Returns the recursive inputs of a model node."""
nodes = [self._nodes[name]]
visited = []
while nodes:
node = nodes.pop()
if node in visited:
continue
nodes.extend(node.all_input_nodes())
visited.append(node)
return visited
@property
def model_nodes(self) -> dict[str, Node]:
"""
Dictionary of nodes with the ``"_model"`` prefix in their name.
Typically, this includes the nodes ``"_model_log_prob"``, ``"_model_log_lik"``,
and ``"_model_log_prior"``.
"""
model_nodes_ = {
name: node for name, node in self.nodes.items() if name.startswith("_model")
}
return model_nodes_
@property
def auto_update(self) -> bool:
"""
Whether to update the model automatically if the value of a node is modified.
The auto-update can be disabled to improve the performance if multiple model
parameters are updated at once.
"""
return self._auto_update
@auto_update.setter
def auto_update(self, auto_update: bool):
self._auto_update = auto_update
[docs]
def groups(self) -> dict[str, Group]:
"""Collects the groups from all nodes and variables."""
g1 = {g.name: g for n in self._nodes.values() for g in n.groups.values()}
g2 = {g.name: g for v in self._vars.values() for g in v.groups.values()}
return g1 | g2
[docs]
def copy(self, clear_state: bool = False) -> Model:
"""
Returns a new model filled with deep copies of all model nodes and variables.
Parameters
----------
clear_state
If ``True``, the model state will be cleared before constructing the
parental submodel, i.e., all values will be removed. This can be used to
save memory, if only the model structure is required.
"""
backup = self.state
if clear_state:
for node in self._nodes.values():
node.clear_state()
model = deepcopy(self)
self.state = backup
return model
[docs]
def copy_vars(self) -> dict[str, Var]:
"""Returns an unfrozen deep copy of the model variables."""
return self.copy_nodes_and_vars()[1]
[docs]
def copy_nodes_and_vars(self) -> tuple[dict[str, Node], dict[str, Var]]:
"""Returns an unfrozen deep copy of the model nodes and variables."""
nodes, _vars = deepcopy((self._nodes, self._vars))
for node in nodes.values():
node._unset_model()
nodes = {nm: nd for nm, nd in nodes.items() if not nm.startswith("_model")}
return nodes, _vars
[docs]
def node_parental_subgraph(self, *of: Node) -> nx.DiGraph:
"""
Returns a subgraph that consists of the input nodes and their parent nodes.
"""
nodes_to_include = set()
for node in of:
nodes_to_include.update(nx.ancestors(self.node_graph, node))
nodes_to_include.add(node)
subgraph = self.node_graph.subgraph(nodes_to_include)
return subgraph
[docs]
def var_parental_subgraph(self, *of: Var) -> nx.DiGraph:
"""
Returns a subgraph that consists of the input variables and their parent
variables.
"""
nodes_to_include = set()
for node in of:
nodes_to_include.update(nx.ancestors(self.var_graph, node))
nodes_to_include.add(node)
subgraph = self.var_graph.subgraph(nodes_to_include)
return subgraph
[docs]
def parental_submodel(
self, *of: Var | Node | str, clear_state: bool = False
) -> Model:
"""
Returns a new model that consists only of the given variables and nodes and \
their parent variables and nodes. The new model contains copies of these \
variables and nodes.
Parameters
----------
clear_state
If ``True``, the model state will be cleared before constructing the
parental submodel, i.e., all values will be removed. This can be used to
save memory, if only the model structure is required.
"""
of_nv: list[Var | Node] = []
for nvn in of:
if isinstance(nvn, str):
if nvn in self.vars:
of_nv.append(self.vars[nvn])
elif nvn in self.nodes:
of_nv.append(self.nodes[nvn])
else:
raise KeyError(f"No node or variable of name {nvn} found in model.")
else:
of_nv.append(nvn)
backup = self.state
if clear_state:
for node in self._nodes.values():
node.clear_state()
nodes, vars_ = self.copy_nodes_and_vars()
nodes_and_vars = nodes | vars_
copy_of_nodes_to_include = [nodes_and_vars[n.name] for n in of_nv]
stub = Value(0.0)
model = Model(stub) # adding a stub node to avoid printing a warning
# turning auto update off to avoid errors caused by the Nones
model.auto_update = False
model.add(*copy_of_nodes_to_include)
# removing the stub node
model._nodes.pop(stub.name)
model.seed_nodes_and_vars.remove(stub)
model.update_graph()
# turning auto update back on to restore default
model.auto_update = True
self.state = backup
return model
@property
def log_lik(self) -> Array:
"""
The log-likelihood of the model.
Defined as the sum of the log-probabilities of all observed variables
with a probability distribution.
"""
return self._nodes["_model_log_lik"].value
@property
def log_prior(self) -> Array:
"""
The log-prior of the model.
Defined as the sum of the log-probabilities of all parameter variables
with a probability distribution.
"""
return self._nodes["_model_log_prior"].value
@property
def log_prob(self) -> Array:
"""
The (unnormalized) log-probability / log-posterior of the model.
Defined as the sum of all distribution nodes.
"""
return self._nodes["_model_log_prob"].value
@property
def node_graph(self) -> nx.DiGraph:
"""The directed graph of the model nodes."""
return self._node_graph
@property
def nodes(self) -> MappingProxyType[str, Node]:
"""A mapping of the model nodes with their names as keys."""
return MappingProxyType(self._nodes)
[docs]
def pop_vars(self) -> dict[str, Var]:
"""
Pops the variables out of this model.
All nodes and variables are unfrozen and their reference to this model
is removed. This model becomes invalid and cannot be used anymore.
"""
return self.pop_nodes_and_vars()[1]
[docs]
def pop_nodes_and_vars(self) -> tuple[dict[str, Node], dict[str, Var]]:
"""
Pops the nodes and variables out of this model.
All nodes and variables are unfrozen and their reference to this model
is removed. This model becomes invalid and cannot be used anymore.
"""
nodes = self._nodes.copy()
_vars = self._vars.copy()
for node in nodes.values():
node._unset_model()
nodes = {nm: nd for nm, nd in nodes.items() if not nm.startswith("_model")}
# clear the model
self._nodes.clear()
self._vars.clear()
self._node_graph.clear()
self._var_graph.clear()
self._sorted_nodes.clear()
self._seed_nodes.clear()
return nodes, _vars
[docs]
def set_seed(self, seed: jax.Array) -> Model:
"""
Splits and sets the seed / PRNG key.
Parameters
----------
seed
The seed is split and distributed to the seed nodes of the model.
Must be a jax RNG key array that satisfies
``jnp.issubdtype(key.dtype, jax.dtypes.prng_key)``.
See :mod:`jax.random` and
https://docs.jax.dev/en/latest/jep/9263-typed-keys.html for more details.
"""
seeds = jax.random.split(seed, len(self._seed_nodes))
for node, seed in zip(self._seed_nodes, seeds):
node.value = seed # type: ignore # data node
return self
[docs]
def simulate(self, seed: jax.Array, skip: Iterable[str] = ()) -> Model:
"""
Updates the model state simulating from the probability distributions in the
model using a provided random seed, optionally skipping specified nodes.
Parameters
----------
seed
The seed is split and distributed to the seed nodes of the model. \
Must be a jax RNG key array that satisfies \
``jnp.issubdtype(key.dtype, jax.dtypes.prng_key)``. \
See :mod:`jax.random` and \
https://docs.jax.dev/en/latest/jep/9263-typed-keys.html for more details.
skip
The names of the nodes or variables to be excluded from the simulation. \
By default, no nodes or variables are skipped.
Returns
-------
The model instance itself after updating its state with the simulated values.
Raises
------
AttributeError
If the value of the :attr:`.Dist.at` node of a distribution node cannot be
set.
Notes
-----
The simulation is based on the shapes of the current values of the
:attr:`.Dist.at` nodes of the distribution nodes. If the :attr:`.Dist.at` node
of a distribution node is a :Class:`.VarValue` node, the value of its input is
updated.
"""
dists = [
node
for node in self._simulation_nodes
if isinstance(node, Dist)
and node.at is not None
and node.name not in skip
and node.at.name not in skip
and (node.var is not None and node.var.name not in skip)
]
seeds = jax.random.split(seed, len(dists))
for dist, seed in zip(dists, seeds):
tfp_dist = dist.init_dist()
event_shape = tfp_dist.event_shape
batch_shape = tfp_dist.batch_shape
value_shape = jnp.asarray(dist.at.value).shape # type: ignore
sample_index = len(value_shape) - len(batch_shape) - len(event_shape)
sample_shape = value_shape[:sample_index]
value = tfp_dist.sample(sample_shape, seed)
if isinstance(dist.at, VarValue):
try:
dist.at.inputs[0].value = value # type: ignore
except AttributeError:
raise AttributeError(f"Cannot set value of {dist.at.inputs[0]}")
else:
try:
dist.at.value = value # type: ignore
except AttributeError:
raise AttributeError(f"Cannot set value of {dist.at}")
return self
[docs]
def sample(
self,
shape: Sequence[int],
seed: jax.Array,
posterior_samples: dict[str, jax.typing.ArrayLike] | None = None,
fixed: Sequence[str] = (),
newdata: dict[str, jax.typing.ArrayLike] | None = None,
dists: dict[str, Dist] | None = None,
) -> dict[str, Array]:
"""
Draws samples from the model.
Parameters
----------
shape
Sample shape.
seed
The seed is split and distributed to the seed nodes of the model. \
Must be a jax RNG key array that satisfies \
``jnp.issubdtype(key.dtype, jax.dtypes.prng_key)``. \
See :mod:`jax.random` and \
https://docs.jax.dev/en/latest/jep/9263-typed-keys.html for more details.
posterior_samples
Dictionary of samples at which to evaluate predictions. All values of the \
dictionary are assumed to have two leading dimensions corresponding to \
``(nchains, niteration)``.
fixed
The names of the nodes or variables to be excluded from the simulation. \
By default, no nodes or variables are skipped.
newdata
Dictionary of new data at which to produce samples. The keys should \
correspond to variable or node names in the model whose values should be \
set to the given values before sampling. If ``None`` \
(default), the current variable values are used.
dists
Can be used to provide a dictionary of variable names and :class:`.Dist` \
instances to use in sampling. If ``None`` (default), samples are drawn for \
each variable using their :attr:`.Var.dist_node`.
Notes
-----
When compiling this function with ``jax.jit``, the arguments ``shape``,
``fixed``, and ``dists`` must be static.
Returns
-------
A dictionary of variable and node names and their sampled values. Includes
only sampled variables.
"""
posterior_samples = posterior_samples if posterior_samples is not None else {}
unique_sample_keys = set(list(posterior_samples))
unique_newdata_keys = set(list(newdata)) if newdata is not None else set()
intersection = unique_sample_keys & unique_newdata_keys
if len(intersection) > 0:
raise RuntimeError(
"The following keys are present in both 'samples' and 'newdata': "
f"{list(intersection)} "
"Any key should be present in only one of these arguments."
)
if posterior_samples is not None:
posterior_samples = jax.tree.map(jnp.asarray, posterior_samples)
if newdata is not None:
newdata = jax.tree.map(jnp.asarray, newdata)
# Pre-processing
# ------------------------------------------------------------------------------
state_for_sampling = (
self.update_state(newdata) if newdata is not None else self.state
)
dists = dists if dists is not None else {}
# Input validation
# ------------------------------------------------------------------------------
# validate values in dists
for var_name in dists:
vars_ = self.vars
if var_name not in vars_:
raise ValueError(f"No variable with name '{var_name}' found.")
if vars_[var_name].weak:
raise ValueError(f"Variable '{var_name}' is weak, cannot sample.")
# validate consistency of 'fixed' and 'posterior_samples' arguments
for name in fixed:
if name in posterior_samples:
raise ValueError(
f"Inconsistency: {name=} listed in 'fixed', but samples are"
" provided in 'posterior_samples'."
)
# Collect sampling information
# ------------------------------------------------------------------------------
# collect relevant distribution nodes in model
dists_list = [
node
for node in self._simulation_nodes
if isinstance(node, Dist)
and node.at is not None
and node.name not in fixed
and node.at.name not in fixed
and (node.var is not None and node.var.name not in fixed)
]
# collect information for sampling by processing dist nodes
sampling_specs = {}
for i, dist in enumerate(dists_list):
tfp_dist = dist.init_dist()
event_shape = tfp_dist.event_shape
batch_shape = tfp_dist.batch_shape
value_shape = jnp.asarray(dist.at.value).shape # type: ignore
sample_index = len(value_shape) - len(batch_shape) - len(event_shape)
sample_shape = value_shape[:sample_index]
if isinstance(dist.at, VarValue):
var_name = dist.at.var.name # type: ignore
value_var = dist.at.inputs[0]
else:
var_name = dist.at.name # type: ignore
value_var = dist.at # type: ignore
if var_name not in posterior_samples:
# pulls manually defined distribution from dists dict, returns current
# dist otherwise
dist = dists.get(var_name, dist)
sampling_specs[var_name] = {
"shape": sample_shape,
"dist": dist,
"i": i,
"value_var": value_var,
}
# add information for custom dists for variables that are not yet covered.
for var_name, dist in dists.items():
if var_name in sampling_specs:
# in this case, the variable has already been added to sampling specs,
# and it is also already using the custom dist
continue
i += 1
tfp_dist = dist.init_dist()
event_shape = tfp_dist.event_shape
batch_shape = tfp_dist.batch_shape
sample_index = len(value_shape) - len(batch_shape) - len(event_shape)
sample_shape = value_shape[:sample_index]
value_shape = jnp.asarray(self.vars[var_name].value).shape # type: ignore
value_var = self.vars[var_name].value_node
sampling_specs[var_name] = {
"shape": sample_shape,
"dist": dist,
"i": i,
"value_var": value_var,
}
# Shape handling
# ------------------------------------------------------------------------------
# set up shape of samples
samples_shape = (
next(iter(posterior_samples.values())).shape[:2]
if posterior_samples
else ()
)
nsamples = math.prod(
shape
) # total number of samples to draw (pure python so jit works)
# set up all seeds that will be needed
seeds = jax.random.split(
seed, (nsamples,) + samples_shape + (len(sampling_specs),)
)
def reshape(a):
# brings samples into the desired shape based on input argument.
# shape=(3,4)
# nsamples=12
# shape of drawn samples: (12,...)
# reshaped to (3,4, ...)
return jnp.reshape(a, shape=shape + a.shape[1:])
# Workhorse function
# ------------------------------------------------------------------------------
def one_draw(position, seeds):
# the position argument is for updating the state with posterior samples
previous_state = self.state
# update model state using the position (a single posterior sample, if any)
# and the state_for_sampling, which includes the observed values from
# newdata.
self.state = self.update_state(position, state_for_sampling)
# draw samples in order of the model graph
sampled_position = {}
for name, spec in sampling_specs.items():
# initializes the distribution node using the current model state,
# which may have been influenced by 'position', 'newdata', or sampled
# values from variables higher up the model hierarchy
tfp_dist = spec["dist"].init_dist()
# draw the actual sample
value = tfp_dist.sample(spec["shape"], seeds[spec["i"]])
# save the sampled value
sampled_position[name] = value
# update the variable's value with the sampled value so that the
# distributions of variables further down the model hierarchy will be
# correctly initialized based on the sampled values higher up
spec["value_var"].value = value
# to avoid tracer leakage we prevent side effects to persists
self.state = previous_state
return sampled_position
if not posterior_samples:
draw_chains = jax.vmap(one_draw, in_axes=(None, 0), out_axes=0)
# since we have no posterior samples, we use position={}
drawn_samples = draw_chains({}, seeds)
# return reshaped version of samples
return jax.tree.map(reshape, drawn_samples)
# this branch of the function continues only if posterior_samples is not None
# -----------------------------------------------------------------------------
draw_iter = jax.vmap(one_draw, in_axes=(0, 0), out_axes=0)
draw_chains = jax.vmap(draw_iter, in_axes=(0, 0), out_axes=0)
draw_samples = jax.vmap(draw_chains, in_axes=(None, 0), out_axes=0)
# filter samples to include only samples that belong to the model
vars_and_nodes = list(self.vars) + list(self.nodes)
filtered_samples = {
k: v for k, v in posterior_samples.items() if k in vars_and_nodes
}
try:
drawn_samples = draw_samples(filtered_samples, seeds)
except Exception as e:
msg = (
"Error during sampling. Make sure to check sample shapes! The values in"
" 'posterior_samples' must have two leading batching dimensions."
)
try:
error_to_raise = e.__class__(msg)
except Exception:
# fallback in case e has a custom error class that cannot simply
# be instantiated with a message.
error_to_raise = RuntimeError(msg)
raise error_to_raise from e
# return reshaped version of samples
return jax.tree.map(reshape, drawn_samples)
@property
def state(self) -> dict[str, NodeState]:
"""The state of the model as a dict of node names and states."""
return {name: node.state for name, node in self._nodes.items()}
@state.setter
def state(self, state: dict[str, NodeState]):
for name, node_state in state.items():
self._nodes[name].state = node_state
[docs]
def update(self, *names: str) -> Model:
"""
Updates the target nodes and their recursive inputs if they are outdated.
The update is performed in a topological order, restoring a consistent state
of the model. This method is called automatically by the nodes if their value
is modified (unless :attr:`.auto_update` is ``False``).
Parameters
----------
names
The names of the target nodes to be updated.
"""
if not names:
for node in self._sorted_nodes:
if node.outdated:
node.update()
else:
inputs = set().union(*(self._recursive_inputs(name) for name in names))
for node in self._sorted_nodes:
if node in inputs and node.outdated:
node.update()
return self
@property
def var_graph(self) -> nx.DiGraph:
"""The directed graph of the model variables."""
return self._var_graph
@property
def vars(self) -> MappingProxyType[str, Var]:
"""A mapping of the model variables with their names as keys."""
return MappingProxyType(self._vars)
@property
def parameters(self) -> MappingProxyType[str, Var]:
"""A mapping of the model parameters with their names as keys."""
params = {k: v for k, v in self._vars.items() if v.parameter}
return MappingProxyType(params)
@property
def observed(self) -> MappingProxyType[str, Var]:
"""A mapping of the observed model variables with their names as keys."""
observed = {k: v for k, v in self._vars.items() if v.observed}
return MappingProxyType(observed)
def __repr__(self) -> str:
brackets = f"({len(self._nodes)} nodes, {len(self._vars)} vars)"
return type(self).__name__ + brackets
[docs]
def plot(
self,
show: bool = True,
save_path: str | None | IO = None,
width: int = 14,
height: int = 10,
prog: Literal[
"dot", "circo", "fdp", "neato", "osage", "patchwork", "sfdp", "twopi"
] = "dot",
legend: bool = True,
):
"""
Plots the variables of this model.
Wraps :func:`~.viz.plot_vars`. Alias for :meth:`.Model.plot_vars`.
Parameters
----------
show
Whether to show the plot in a new window.
save_path
Path to save the plot. If not provided, the plot will not be saved.
width
Width of the plot in inches.
height
Height of the plot in inches.
prog
Layout parameter. Available layouts: circo, dot (the default), fdp, neato, \
osage, patchwork, sfdp, twopi.
legend
Whether to draw the legend.
See Also
--------
.Var.plot_vars : Plots the variables of the Liesel sub-model that terminates in
this variable.
.Var.plot_nodes : Plots the nodes of the Liesel sub-model that terminates in
this variable.
.Model.plot_vars : Plots the variables of a Liesel model.
.Model.plot_nodes : Plots the nodes of a Liesel model.
.viz.plot_vars : Plots the variables of a Liesel model.
.viz.plot_nodes : Plots the nodes of a Liesel model.
"""
return self.plot_vars(
show=show,
save_path=save_path,
width=width,
height=height,
prog=prog,
legend=legend,
)
[docs]
def plot_vars(
self,
show: bool = True,
save_path: str | None | IO = None,
width: int = 14,
height: int = 10,
prog: Literal[
"dot", "circo", "fdp", "neato", "osage", "patchwork", "sfdp", "twopi"
] = "dot",
legend: bool = True,
):
"""
Plots the variables of this model.
Wraps :func:`~.viz.plot_vars`.
Parameters
----------
show
Whether to show the plot in a new window.
save_path
Path to save the plot. If not provided, the plot will not be saved.
width
Width of the plot in inches.
height
Height of the plot in inches.
prog
Layout parameter. Available layouts: circo, dot (the default), fdp, neato, \
osage, patchwork, sfdp, twopi.
legend
Whether to draw the legend.
See Also
--------
.Var.plot_vars : Plots the variables of the Liesel sub-model that terminates in
this variable.
.Var.plot_nodes : Plots the nodes of the Liesel sub-model that terminates in
this variable.
.Model.plot_vars : Plots the variables of a Liesel model.
.Model.plot_nodes : Plots the nodes of a Liesel model.
.viz.plot_vars : Plots the variables of a Liesel model.
.viz.plot_nodes : Plots the nodes of a Liesel model.
"""
return plot_vars(
self,
show=show,
save_path=save_path,
width=width,
height=height,
prog=prog,
legend=legend,
)
[docs]
def plot_nodes(
self,
show: bool = True,
save_path: str | None | IO = None,
width: int = 14,
height: int = 10,
prog: Literal[
"dot", "circo", "fdp", "neato", "osage", "patchwork", "sfdp", "twopi"
] = "dot",
):
"""
Plots the nodes of this model.
Wraps :func:`~.viz.plot_nodes`.
Parameters
----------
show
Whether to show the plot in a new window.
save_path
Path to save the plot. If not provided, the plot will not be saved.
width
Width of the plot in inches.
height
Height of the plot in inches.
prog
Layout parameter. Available layouts: circo, dot (the default), fdp, neato, \
osage, patchwork, sfdp, twopi.
See Also
--------
.Var.plot_vars : Plots the variables of the Liesel sub-model that terminates in
this variable.
.Var.plot_nodes : Plots the nodes of the Liesel sub-model that terminates in
this variable.
.Model.plot_vars : Plots the variables of a Liesel model.
.Model.plot_nodes : Plots the nodes of a Liesel model.
.viz.plot_vars : Plots the variables of a Liesel model.
.viz.plot_nodes : Plots the nodes of a Liesel model.
"""
return plot_nodes(
self,
show=show,
save_path=save_path,
width=width,
height=height,
prog=prog,
)
[docs]
def update_state(
self,
position: dict[str, Array],
model_state: dict[str, NodeState] | None = None,
inplace: bool = False,
) -> dict[str, NodeState]:
"""
Updates and returns a model state given a position.
Parameters
----------
position
A dictionary of variable or node names and values.
model_state
A dictionary of node names and their corresponding :class:`.NodeState`. \
If ``None`` (default), the model's current state is used.
inplace
If ``False`` (default), a new model state is returned, while the current \
model's state is left unchanged. If ``True``, the current model's state is \
updated in place.
Warnings
--------
The ``model_state`` must be up-to-date, i.e. it must *not* contain any outdated
nodes. Updates can only be triggered through new variable or node values in the
``position``. If you supply a ``model_state`` with outdated nodes, these nodes
and their outputs will not be updated.
"""
model = self._copy_computational_model() if not inplace else self
# sets all outdated flags in the model state to false
# this is required to make the function jittable
model.state = model_state if model_state is not None else self.state
for node in model.nodes.values():
node._outdated = False
# temporarily disable auto_update to avoid shape incompatibilities
# when updating variables sequentially with new shapes
original_auto_update = model.auto_update
model.auto_update = False
try:
for key, value in position.items():
try:
model.nodes[key].value = value # type: ignore # data node
except KeyError:
model.vars[key].value = value
finally:
# restore original auto_update setting
model.auto_update = original_auto_update
model.update()
return model.state
[docs]
def predict(
self,
samples: dict[str, jax.typing.ArrayLike],
predict: Sequence[str] | None = None,
newdata: dict[str, jax.typing.ArrayLike] | None = None,
) -> dict[str, Array]:
"""
Returns a dictionary of predictions.
Parameters
----------
samples
Dictionary of samples at which to evaluate predictions. If ``samples``
contains entries for weak variables or for nodes in :attr:`.model_nodes`
they are ignored.
predict
Sequence of strings, which are the names of nodes or variables. \
Predictions will be returned only for the nodes or variables inlcuded \
here. If ``None`` (default), predictions will be returned for all \
*variables* in the model (but not for nodes).
newdata
Dictionary of new data at which to evaluate predictions. The keys should \
correspond to variable or node names in the model whose values should be \
set to the given values before evaluating predictions. If ``None`` \
(default), the current variable values are used.
"""
samples = samples.copy()
for name in self.model_nodes:
samples.pop(name, None)
for name, var in self.vars.items():
if var.weak:
if name in samples:
logger.debug(
f"Key '{name}' belongs to a weak var. "
"Removing it from samples dictionary."
)
samples.pop(name, None)
unique_sample_keys = set(list(samples))
unique_newdata_keys = set(list(newdata)) if newdata is not None else set()
intersection = unique_sample_keys & unique_newdata_keys
if len(intersection) > 0:
raise RuntimeError(
"The following keys are present in both 'samples' and 'newdata': "
f"{list(intersection)} "
"Any key should be present in only one of these arguments."
)
samples = jax.tree.map(jnp.asarray, samples)
if newdata is not None:
newdata = jax.tree.map(jnp.asarray, newdata)
# deduce batching dimensions
shapes = []
for name, value in samples.items():
if name in self.vars:
model_ndim = jnp.asarray(self.vars[name].value).ndim
elif name in self.nodes:
model_ndim = jnp.asarray(self.nodes[name].value).ndim
else:
continue
n_batching_dim = jnp.ndim(value) - model_ndim
batch_shape = jnp.shape(value)[:n_batching_dim]
shapes.append(batch_shape)
if not len(set(shapes)) == 1:
raise RuntimeError("Found inconsistent batch shapes.")
predict_names = predict if predict is not None else []
predicted_model_nodes = [
name for name in self.model_nodes if name in predict_names
]
predict_names_no_model_nodes = [
name for name in predict_names if name not in predicted_model_nodes
]
# extract nodes and vars for target nodes
if predicted_model_nodes or not predict_names:
# use full model without copying
# we want to always use the full model if a model node is among the ones
# to predict values for
submodel = self
if not predicted_model_nodes:
predict_names = list(self.vars) # output only vars
else:
predict_names = predicted_model_nodes
else:
predict_nodes_: list[Var | Node] = []
for name in predict_names_no_model_nodes:
try:
predict_nodes_.append(self.vars[name])
except KeyError:
predict_nodes_.append(self.nodes[name])
# construct submodel for target nodes
submodel = self.parental_submodel(*predict_nodes_)
newdata = newdata if newdata is not None else {}
# handle keys that are not needed
newdata = newdata.copy()
for key in list(newdata.keys()):
if key not in self.vars or (key in self.nodes):
raise KeyError(f"{key} is not part of the model.")
if key not in submodel.vars or (key in submodel.nodes):
newdata.pop(key, None)
# update submodel with new data, if any were given
submodel.state = submodel.update_state(newdata)
# filter samples to include only samples that belong to the submodel
vars_and_nodes = list(submodel.vars) + list(submodel.nodes)
filtered_samples = {k: v for k, v in samples.items() if k in vars_and_nodes}
if not filtered_samples:
raise ValueError(
"No samples provided for the variables or nodes in the submodel. "
f"Nodes in submodel: {vars_and_nodes}"
)
# single prediction function
def predict_one(samples):
updated_state = submodel.update_state(
samples, submodel.state, inplace=False
)
return submodel.extract_position(predict_names, updated_state)
# map over iterations
predict_batched = jax.vmap(predict_one, in_axes=0, out_axes=0)
def flatten_batch_dims(x):
new_shape = (-1,) + jnp.shape(x)[n_batching_dim:]
return jnp.reshape(x, new_shape)
flattened_samples = jax.tree.map(flatten_batch_dims, filtered_samples)
flat_predictions = predict_batched(flattened_samples)
def unflatten_batch_dims(x):
new_shape = batch_shape + jnp.shape(x)[1:]
return jnp.reshape(x, new_shape)
return jax.tree.map(unflatten_batch_dims, flat_predictions)
[docs]
def diagnose(self, verbose: bool = False) -> pd.DataFrame:
"""
Provides a dataframe with diagnostic information about the model.
"""
rows = []
for k, v in self.vars.items():
v.update()
row: dict[str, Any] = {}
row["name"] = k
row["has_dist"] = v.has_dist
if verbose:
row["n_input_vars"] = len(v.all_input_vars())
row["n_output_vars"] = len(v.all_output_vars())
row["parameter"] = v.parameter
row["observerd"] = v.observed
row["strong"] = v.strong
for name, target in (("value", v.value_node), ("log_prob", v.dist_node)):
if target is None:
continue
row[f"{name}_n_nan"] = jnp.isnan(target.value).sum()
row[f"{name}_n_inf"] = jnp.isinf(target.value).sum()
row[f"{name}_size"] = jnp.max(jnp.size(target.value))
row[f"{name}_dtype"] = jnp.asarray(target.value).dtype
if verbose:
row[f"{name}_mean"] = jnp.mean(target.value)
row[f"{name}_sd"] = jnp.std(target.value)
row[f"{name}_min"] = jnp.min(target.value)
row[f"{name}_max"] = jnp.max(target.value)
row[f"{name}_n_input_nodes"] = len(target.all_input_nodes())
row[f"{name}_n_output_nodes"] = len(target.all_output_nodes())
if verbose:
row["value_node_name"] = v.value_node.name
if v.dist_node is not None:
row["dist_node_name"] = v.dist_node.name
rows.append(row)
return pd.DataFrame(rows)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Save and load models ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def save_model(model: Any, file: str | IO[bytes]) -> None:
"""
Saves a model to a `dill <https://github.com/uqfoundation/dill>`_ file.
Parameters
----------
model
The model to be saved.
file
The file handler or path to save the model to.
"""
if isinstance(file, str):
with open(file, "wb") as handle:
dill.dump(model, handle)
else:
dill.dump(model, file)
def load_model(file: str | IO[bytes]) -> Any:
"""
Loads a model from a `dill <https://github.com/uqfoundation/dill>`_ file.
Parameters
----------
file
The file handler or path to load the model from.
"""
if isinstance(file, str):
with open(file, "rb") as handle:
model = dill.load(handle)
else:
model = dill.load(file)
return model
class TemporaryModel:
"""
A contextmanager for creating a temporary model.
Unnamed variables and nodes will be named temporarily. Their names will be reset
when the context closes.
Example use::
with TemporaryModel(variable) as model:
print(model.log_prob)
Parameters
----------
vars_and_nodes
Variable-length collection of variables and nodes that should be used to build \
the temporary model.
verbose
If ``verbose=True``, the detailed information about temporarily namend \
variables will be logged. If ``verbose=False``, this is reduced to a short
note.
silent
If ``silent=True``, all logging will be suppressed.
"""
def __init__(
self,
*vars_and_nodes,
verbose: bool = False,
silent: bool = False,
to_float32: bool = False,
):
self.vars_and_nodes = vars_and_nodes
self.verbose = verbose
self.silent = silent
self.to_float32 = to_float32
if verbose and silent:
raise ValueError(f"{verbose=} and {silent=} cannot both be True.")
self.gb = None
self.model = None
self.var_names = None
self.node_names = None
self.vars = None
self.nodes = None
def __enter__(self):
verbose = self.verbose
gb = GraphBuilder(to_float32=self.to_float32).add(*self.vars_and_nodes)
nodes, _vars = gb._all_nodes_and_vars()
automatically_set_names = gb._set_missing_names()
var_names = automatically_set_names["vars"]
node_names = automatically_set_names["nodes"]
if verbose and not self.silent:
if var_names:
names_ = f"The automatically assigned names are: {var_names}. "
logger.info(f"Unnamed variables were temporarily named. {names_}")
if node_names:
names_ = f"The automatically assigned names are: {node_names}. "
logger.info(f"Unnamed nodes were temporarily named. {names_}")
elif not self.silent:
if var_names:
names_ = f"The automatically assigned names are: {var_names}. "
logger.info(f"Unnamed variables were temporarily named. {names_}")
if node_names:
names_ = f"The automatically assigned names are: {node_names}. "
logger.debug(f"Unnamed nodes were temporarily named. {names_}")
else:
if var_names:
names_ = f"The automatically assigned names are: {var_names}. "
logger.debug(f"Unnamed variables were temporarily named. {names_}")
if node_names:
names_ = f"The automatically assigned names are: {node_names}. "
logger.debug(f"Unnamed nodes were temporarily named. {names_}")
model = gb.build_model()
self.gb = gb
self.model = model
self.var_names = var_names
self.node_names = node_names
self.vars = _vars
self.nodes = nodes
return model
def __exit__(self, exc_type, exc_value, traceback):
self.model.pop_nodes_and_vars()
vars_dict = {var_.name: var_ for var_ in self.vars}
nodes_dict = {node.name: node for node in self.nodes}
for name in self.var_names:
vars_dict[name].name = ""
for name in self.node_names:
nodes_dict[name].name = ""
self.gb.nodes.clear()
self.gb.vars.clear()
return False # Returning False means exceptions are not suppressed
[docs]
def log_prob_pointwise(
vars_: dict[str, Var],
samples: dict[str, jax.typing.ArrayLike],
newdata: dict[str, jax.typing.ArrayLike] | None = None,
) -> dict[str, jax.Array]:
"""
Returns a dictionary of pointwise log probabilities for the supplied variables.
Parameters
----------
vars_
Dictionary of variables for which to evaluate log probs.
samples
Dictionary of samples at which to evaluate log probs. If ``samples`` contains
entries for weak variables or for nodes in :attr:`.model_nodes` they are
ignored.
newdata
Dictionary of new data at which to evaluate log probs. The keys should
correspond to variable or node names in the model whose values should be set
to the given values before evaluating predictions. If ``None`` (default), the
current variable values are used.
Returns
-------
A dictionary with pointwise log probability evaluations as values and the
:class:`.Dist` node names of the supplied variables as keys.
"""
ll_names = []
models = []
for var in vars_.values():
if not var.model:
raise ValueError(f"{var} is not part of a model.")
models.append(var.model)
if var.dist_node is None:
continue
if not var.dist_node.per_obs:
raise ValueError(
f"{var} has Var.dist_node.per_obs=False. "
"For point log probability computation, "
"Var.dist_node.per_obs=True is required for "
"all variables contributing to the likelihood."
)
if not var.value.shape == var.log_prob.shape:
msg = (
f"{var}.value has shape {var.value.shape}, "
f"while {var}.log_prob has shape {var.log_prob.shape}. This "
f"suggests that the pointwise log prob for {var} may not be "
"available, or that you may be using a multivariate distribution. "
"Please double check."
)
logger.warning(msg)
ll_names.append(var.dist_node.name)
n_models = len(set(models))
if n_models > 1:
raise RuntimeError(
"The supplied variables must all belong to the same model. "
f"Found {n_models} different models."
)
model = models[0]
pointwise_ll_dict = model.predict(samples, predict=ll_names, newdata=newdata)
return pointwise_ll_dict