Source code for liesel.model.nodes

"""
Nodes and variables.
"""

from __future__ import annotations

import logging
import weakref
from abc import ABC, abstractmethod
from collections.abc import Callable, Hashable, Iterable, Sequence
from functools import wraps
from itertools import chain
from types import MappingProxyType
from typing import IO, TYPE_CHECKING, Any, Literal, NamedTuple, Self, TypeGuard, TypeVar

import jax
import jax.numpy as jnp
import pandas as pd
import tensorflow_probability.substrates.jax.bijectors as jb
import tensorflow_probability.substrates.jax.distributions as jd
import tensorflow_probability.substrates.numpy.bijectors as nb
import tensorflow_probability.substrates.numpy.distributions as nd

from ..distributions.nodist import NoDistribution
from .names import random_name
from .viz import plot_nodes, plot_vars

if TYPE_CHECKING:
    from ..goose import MCMCSpec
    from .model import Model

    type InferenceTypes = None | MCMCSpec | dict[str, MCMCSpec] | Any

__all__ = [
    "Array",
    "Bijector",
    "Calc",
    "Value",
    "Dist",
    "Distribution",
    "Group",
    "InputGroup",
    "Node",
    "NodeState",
    "TransientCalc",
    "TransientDist",
    "TransientIdentity",
    "TransientNode",
    "Var",
]

type Array = Any
type Distribution = jd.Distribution | nd.Distribution
type Bijector = jb.Bijector | nb.Bijector

T = TypeVar("T", bound=Hashable)

logger = logging.getLogger(__name__)


def _unique_tuple(*args: Iterable[T]) -> tuple[T, ...]:
    return tuple(dict.fromkeys(chain(*args)))


def in_model_method(fn):
    @wraps(fn)
    def wrapped(self, *args, **kwargs):
        if not self.model:
            raise RuntimeError(
                f"{repr(self)} is not part of a model, cannot call {fn.__name__}()"
            )
        return fn(self, *args, **kwargs)

    return wrapped


def in_model_getter(fn):
    @wraps(fn)
    def wrapped(self, *args, **kwargs):
        if not self.model:
            raise RuntimeError(
                f"{repr(self)} is not part of a model, cannot call '{fn.__name__}'"
            )
        return fn(self, *args, **kwargs)

    return wrapped


def no_model_method(fn):
    @wraps(fn)
    def wrapped(self, *args, **kwargs):
        if self.model:
            raise RuntimeError(
                f"{repr(self)} is part of a model, cannot call {fn.__name__}()"
            )
        return fn(self, *args, **kwargs)

    return wrapped


def no_model_setter(fn):
    @wraps(fn)
    def wrapped(self, *args, **kwargs):
        if self.model:
            raise RuntimeError(
                f"{repr(self)} is part of a model, cannot set '{fn.__name__}'"
            )
        return fn(self, *args, **kwargs)

    return wrapped


def changes_model_graph(fn):
    @wraps(fn)
    def wrapped(self, *args, **kwargs):
        if self.model and self.model.locked:
            raise RuntimeError(
                f"{repr(self)} is part of a locked model, cannot call {fn.__name__}(). "
                "To allow for changes to the model, you can set the Model.locked flag "
                "to False. "
                "ATTENTION: Note that, from v0.5, the default state for models will "
                "be Model.locked = False, so do not rely on this error if you are "
                "using the default."
            )
        # pull out model here, because it may not be available on self below
        # in all cases, e.g. if a node's name is changed in fn (because Node.model
        # uses the name to check whether the node is still part of the model).
        model = self.model
        out = fn(self, *args, **kwargs)
        if model:
            model.outdated = True
            if not model.update_graph_lazily:
                model.update_graph()
        return out

    return wrapped


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Nodes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


class NodeState(NamedTuple):
    """The state of a node."""

    value: Any
    """The value of the node."""

    outdated: bool
    """Whether the node is outdated."""

    extra: Any = None
    """Optional extra information."""


[docs] class Node(ABC): """ A node of a computational graph that can cache its value. Liesel represents statistical models as directed acyclic graphs (DAGs) of random variables (see :class:`.Var`) and computational nodes. The graph of random variables is built on top of the computational graph. The nodes of the computational graph will typically express computations in JAX returning arrays or pytrees_, but in general, they can represent arbitrary operations in Python. Nodes can cache the result of the operations they represent, improving the efficiency of the graph. The cached values are part of the model state (see :attr:`.Model.state`), and can be stored in a chain by Liesel's MCMC engine Goose. .. note:: This class is an abstract class that cannot be initialized without defining the :meth:`.update` method. See below for the most important concrete node classes. Parameters ---------- inputs Non-keyword inputs. Any inputs that are not already nodes or :class:`.Var` will be converted to :class:`.Value` nodes. _name The name of the node. If you do not specify a name, a unique name will be \ automatically generated upon initialization of a :class:`.Model`. _needs_seed Whether the node needs a seed / PRNG key. convert A function used to process the value of this node. The default uses the function stored in :meth:`.Node.convert_value`, which is ``jax.numpy.asarray``. See Also -------- .Calc : A node representing a general calculation/operation in JAX or Python. .Value : A node representing some static data. .Dist : A node representing a ``tensorflow_probability`` :class:`~tfp.distributions.Distribution`. .Var : A variable in a statistical model, typically with a probability distribution. .. _pytrees: https://jax.readthedocs.io/en/latest/pytrees.html .. _TensorFlow Probability: https://www.tensorflow.org/probability """ def __init__( self, *inputs: Any, _name: str = "", _needs_seed: bool = False, convert: Callable[[Any], Any] | Literal["default"] = "default", **kwinputs: Any, ): if convert == "default": self._convert: Callable[[Any], Any] = self.convert_value else: self._convert = convert self._groups: dict[str, Group] = {} self._inputs = tuple(self._to_node(_input) for _input in inputs) self._kwinputs = {kw: self._to_node(_input) for kw, _input in kwinputs.items()} self._model: weakref.ref[Model] | Callable[[], None] = lambda: None self._name = _name self._needs_seed = _needs_seed self._seed_node: Value | None = None self._outdated = True self._outputs: tuple[Node, ...] = () self._value: Any = None self._var: Var | None = None self.monitor = False """Whether the node should be monitored by an inference algorithm."""
[docs] @staticmethod def convert_value(x: Any) -> Any: """ The function used to process the value of this node, if ``convert="default"`` is supplied during init. Can be overwritten on subclasses to create node classes with different default conversion behavior. Make sure to overwrite it with a static method, for example (re-implementing the default behavior):: class MyNode(lsl.Node): @staticmethod def convert_value(x): return jnp.asarray(x) if x is not None else x """ return jnp.asarray(x) if x is not None else x
def _add_output(self, output: Node) -> Node: self._outputs = _unique_tuple(self._outputs, [output]) return self def _clear_outputs(self) -> Node: self._outputs = () return self def _set_model(self, model: Model) -> Node: if self.model: raise RuntimeError(f"{repr(self)} can only be part of one model") self._model = weakref.ref(model) return self def _set_var(self, var: Var) -> Node: if self.var: raise RuntimeError(f"{repr(self)} can only be part of one var") self._var = var return self def _to_node(self, x: Any) -> Node: if isinstance(x, Var): return x.var_value_node if not isinstance(x, Node): return Value(x, convert=self._convert) return x def _unset_model(self) -> Node: self._model = lambda: None return self def _unset_var(self) -> Node: self._var = None return self
[docs] @changes_model_graph def add_inputs(self, *inputs: Any, **kwinputs: Any) -> Node: """Adds non-keyword and keyword input nodes to the existing ones.""" inputs = self.inputs + inputs kwinputs = self.kwinputs | kwinputs self.set_inputs(*inputs, **kwinputs) return self
[docs] def all_input_nodes(self) -> tuple[Node, ...]: """Returns all non-keyword and keyword input nodes as a unique tuple.""" return _unique_tuple(self.inputs, self.kwinputs.values())
[docs] @in_model_method def all_output_nodes(self) -> tuple[Node, ...]: """Returns all output nodes as a unique tuple.""" return self.outputs
[docs] def clear_state(self) -> Node: """Clears the state of the node.""" self.state = NodeState(None, True) return self
[docs] def flag_outdated(self) -> Node: """Flags the node and its recursive outputs as outdated.""" self._outdated = True if self.model: for node in self._outputs: node.flag_outdated() return self
@property def groups(self) -> MappingProxyType[str, Group]: """The groups that this node is a part of.""" return MappingProxyType(self._groups) @property def inputs(self) -> tuple[Node, ...]: """The non-keyword input nodes.""" return self._inputs @property def kwinputs(self) -> MappingProxyType[str, Node]: """The keyword input nodes.""" return MappingProxyType(self._kwinputs) @property def model(self) -> Model | None: """The model the node is part of.""" model = self._model() if model is None: return None if self.name in model.nodes: return model self._unset_model() return None @property def name(self) -> str: """The name of the node.""" return self._name @name.setter @changes_model_graph def name(self, name: str): self._name = name
[docs] def ensure_name(self) -> Self: """ Ensures that the node has a name. If the node already has a name, nothing happens. Otherwise, a unique random name is generated with a leading underscore. """ if self.name: return self else: self.name = "_" + random_name() return self
@property def seed_node(self) -> Value | None: return self._seed_node @seed_node.setter @no_model_setter def seed_node(self, value: Value | None): if value is None: kwinputs_without_seed = self.kwinputs.copy() kwinputs_without_seed.pop("seed", None) self.set_inputs(*self.inputs, **kwinputs_without_seed) kwinputs_with_seed = dict(self.kwinputs) assert value is not None kwinputs_with_seed["seed"] = value self.set_inputs(*self.inputs, **kwinputs_with_seed) self._seed_node = value @property def needs_seed(self) -> bool: """Whether the node needs a seed / PRNG key.""" return self._needs_seed @needs_seed.setter @changes_model_graph def needs_seed(self, needs_seed: bool): self._needs_seed = needs_seed if not needs_seed: kwinputs_without_seed = self.kwinputs.copy() kwinputs_without_seed.pop("seed", None) self.set_inputs(*self.inputs, **kwinputs_without_seed) @property def outdated(self) -> bool: """Whether the node is outdated.""" if not self.model: return True return self._outdated @property @in_model_getter def outputs(self) -> tuple[Node, ...]: """The output nodes.""" return self._outputs
[docs] @changes_model_graph def set_inputs(self, *inputs: Any, **kwinputs: Any) -> Node: """Sets the non-keyword and keyword input nodes.""" self._inputs = tuple(self._to_node(_input) for _input in inputs) self._kwinputs.clear() kwinputs = {kw: self._to_node(_input) for kw, _input in kwinputs.items()} self._kwinputs.update(kwinputs) return self
@property def state(self) -> NodeState: """ The state of the node. For the default node, a :class:`.NodeState` with the value and the outdated flag, but subclasses can add extra information to the state. """ return NodeState(self.value, self.outdated) @state.setter def state(self, state: NodeState): self._value = state.value self._outdated = state.outdated
[docs] @abstractmethod def update(self) -> Node: """Updates the value of the node."""
@property def value(self) -> Any: """ The value of the node. Can only be set for a :class:`.Value` node, but not a :class:`.Calc` or :class:`.Dist` node. If the node is part of a :class:`.Model` ``m`` with ``m.auto_update == True``, setting the value of the node triggers an update of the model. The auto-update can be disabled to improve the performance if multiple model parameters are updated at once. """ return self._value @property def var(self) -> Var | None: """The variable the node is part of.""" return self._var @property def _iloc(self) -> tuple[Node | Var, ...]: input_list: list[Node | Var] = [] for input_ in self.inputs: if isinstance(input_, VarValue): # This should not happen in practice, but the check makes mypy happy. if input_.var is None: raise RuntimeError(f"{input_}.var is None.") input_list.append(input_.var) else: input_list.append(input_) return tuple(input_list) @property def _loc(self) -> dict[str, Node | Var]: input_dict: dict[str, Node | Var] = {} for key, input_ in self.kwinputs.items(): if isinstance(input_, VarValue): # This should not happen in practice, but the check makes mypy happy. if input_.var is None: raise RuntimeError(f"{input_}.var is None.") input_dict[key] = input_.var else: input_dict[key] = input_ return input_dict def __getitem__(self, key: int | str) -> Node | Var: if isinstance(key, int): try: return self._iloc[key] except IndexError as error: available_indices = { idx: self._iloc[idx] for idx in range(len(self._iloc)) } available_keywords = str(self._loc).replace("'", '"') msg = ( f"{key} is out of bounds. Available index-variable pairs:" f" {available_indices}. Available keyword-variable pairs:" f" {available_keywords}." ) raise IndexError(msg) from error elif isinstance(key, str): try: return self._loc[key] except KeyError as error: available_indices = { idx: self._iloc[idx] for idx in range(len(self._iloc)) } available_keywords = str(self._loc).replace("'", '"') msg = ( f"{key} not found. Available index-variable pairs:" f" {available_indices}. Available keyword-variable pairs:" f" {available_keywords}." ) raise KeyError(msg) from error else: raise ValueError(f"Key must be str or int, not {type(key)}.") def _iloc_replace(self, key: int, value: Node | Var | Any) -> None: inputs = list(self.inputs) inputs[key] = self._to_node(value) return self.set_inputs(*inputs, **self.kwinputs) def _loc_replace(self, key: str, value: Node | Var | Any) -> None: kwinputs = dict(self.kwinputs) if key not in kwinputs: raise KeyError(f"'{key}' is not the key of an existing keyword input.") kwinputs[key] = self._to_node(value) return self.set_inputs(*self.inputs, **kwinputs) @changes_model_graph def __setitem__(self, key: int | str, value: Node | Var | Any) -> None: if isinstance(key, int): all_inputs = self.all_input_nodes() node_to_replace = all_inputs[key] for kwinputs_key, kwinputs_node in self.kwinputs.items(): if node_to_replace is kwinputs_node: return self._loc_replace(kwinputs_key, value) return self._iloc_replace(key, value) elif isinstance(key, str): return self._loc_replace(key, value) else: raise ValueError(f"Key must be str or int, not {type(key)}.") def __getstate__(self): state = self.__dict__.copy() state["_model"] = self._model() return state def __setstate__(self, state): self.__dict__.update(state) if self._model is not None: self._model = weakref.ref(self._model) else: self._model = lambda: None def __repr__(self) -> str: return f'{type(self).__name__}(name="{self.name}")'
[docs] class TransientNode(Node): """ A node that does not cache its value. A transient node is outdated if and only if at least one of its input nodes is outdated. The :attr:`.outdated` property checks this condition on-the-fly. """ @property def outdated(self) -> bool: """ Whether the node is outdated. A transient node is outdated if and only if at least one of its input nodes is outdated. This condition is checked on-the-fly. """ if not self.model: return True return any(_input.outdated for _input in self.all_input_nodes()) @property def state(self) -> NodeState: """The state of the node with the value ``None``.""" return NodeState(None, self.outdated) @state.setter def state(self, state: NodeState): self._value = state.value self._outdated = state.outdated
[docs] def update(self): """Does nothing.""" return self
@property @abstractmethod def value(self) -> Any: """ The value of the node. Computed on-the-fly. """
class ArgGroup(NamedTuple): """A group of arguments as a named tuple of ``args`` and ``kwargs``.""" args: list[Any] """The non-keyword arguments.""" kwargs: dict[str, Any] """The keyword arguments."""
[docs] class InputGroup(TransientNode): """ A node that groups its inputs for another node. Essentially, this node "forwards" the val ues of its inputs to its outputs as an :class:`.ArgGroup`. """ @property def value(self) -> ArgGroup: args = [_input.value for _input in self.inputs] kwargs = {kw: _input.value for kw, _input in self.kwinputs.items()} return ArgGroup(args, kwargs)
[docs] class Value(Node): r""" A :class:`.Node` subclass that holds constant values. Since the information represented by a value node does not change, it is always up-to-date. A common usecase for value nodes is to cache computed values. - By default, value nodes *will* appear in the node graph created by :func:`.viz.plot_nodes`, but they will *not* appear in the model graph created by :func:`.viz.plot_vars`. - You can wrap a value node in a :class:`.Var` to make it appear in the model graph. Parameters ---------- value The value of the node. _name The name of the node. If you do not specify a name, a unique name will be automatically generated upon initialization of a :class:`.Model`. convert A function used to process the value of this node. The default uses the function stored in :meth:`.Node.convert_value`, which is ``jax.numpy.asarray``. See Also -------- .Calc : A node representing a general calculation/operation in JAX or Python. .Dist : A node representing a ``tensorflow_probability`` :class:`~tfp.distributions.Distribution`. .Var : A variable in a statistical model, typically with a probability distribution. .param : A helper function to initialize a :class:`.Var` as a model parameter. .obs : A helper function to initialize a :class:`.Var` as an observed variable. Examples -------- A simple constant node representing a constant value without a name: >>> nameless_node = lsl.Value(1.0) >>> nameless_node Value(name="") Adding this node to a model leads to an automatically generated name: >>> model = lsl.Model([nameless_node]) >>> nameless_node.name.startswith("_") True A constant node with a name: >>> node = lsl.Value(1.0, _name="my_name") >>> node Value(name="my_name") """ def __init__( self, value: Any, _name: str = "", convert: Callable[[Any], Any] | Literal["default"] = "default", ): super().__init__(_name=_name, convert=convert) self.value = value
[docs] def flag_outdated(self) -> Value: """Stops the recursion setting outdated flags.""" return self
@property def outdated(self) -> bool: return False
[docs] def update(self) -> Value: """Does nothing.""" return self
@property def value(self) -> Any: return self._value @value.setter def value(self, value: Any): try: self._value = self._convert(value) except TypeError as e: msg = ( "Error during value conversion. If you updated the " "`.convert_value` method, please make sure to define " "it as a staticmethod." ) raise TypeError(msg) from e if self.model: for node in self.outputs: node.flag_outdated() if self.model.auto_update: self.model.update()
class Data(Value): """ A :class:`.Node` subclass that holds constant data. This is an alias for :class:`.Value`. See Also -------- .Value : Alias for :class:`.Value`. For full documentation, please consult :class:`.Value`. """ pass
[docs] class Calc(Node): """ A :class:`.Node` subclass that calculates its value based on its inputs nodes. Calculator nodes are a central element of the Liesel graph building toolkit. They wrap arbitrary calculations in pure JAX functions. - By default, calculator nodes *will* appear in the node graph created by :func:`.viz.plot_nodes`, but they will *not* appear in the model graph created by :func:`.viz.plot_vars`. - You can use :meth:`~.Var.new_calc` if you want your calculation to be treated as a model variable and thus be shown in :func:`.viz.plot_vars`. .. tip:: The wrapped function must be jit-compilable by JAX. This mainly means that it must be a pure function, i.e. it must not have any side effects and, given the same input, it must always return the same output. Some special consideration is also required for loops and conditionals. Please consult the JAX docs_ for details. Parameters ---------- function The function to be wrapped. Must be jit-compilable by JAX. *inputs Non-keyword inputs. Any inputs that are not already nodes or :class:`.Var` will be converted to :class:`.Value` nodes. The values of these inputs will be passed to the wrapped function in the same order they are entered here. _name The name of the node. If you do not specify a name, a unique name will be \ automatically generated upon initialization of a :class:`.Model`. _needs_seed Whether the node needs a seed / PRNG key. _update_on_init If ``True``, the calculator will try to evaluate its function upon \ initialization. convert_inputs A function used to process the values of this node's inputs. The default uses the function stored in :meth:`.Node.convert_value`, which is ``jax.numpy.asarray``. **kwinputs Keyword inputs. Any inputs that are not already nodes or :class:`.Var`s will be converted to :class:`.Value` nodes. The values of these inputs will be passed to the wrapped function as keyword arguments. See Also -------- .Var.new_calc : Initializes a weak variable that is a function of other variables. .Var : A variable in a statistical model, typically with a probability distribution. .Var.new_param : Initializes a strong variable that acts as a model parameter. .Var.new_obs : Initializes a strong variable that holds observed data. .Var.new_value : Initializes a strong variable without a distribution. .Value : A node representing some static data. .Dist : A node representing a ``tensorflow_probability`` :class:`~tfp.distributions.Distribution`. Examples -------- A simple calculator node, taking the exponential value of an input parameter. >>> log_scale = lsl.Var.new_param(0.0, name="log_scale") >>> scale = lsl.Calc(jnp.exp, log_scale) >>> print(scale.value) 1.0 The value of the calculator node is updated when :meth:`.Calc.update` is called. >>> scale.update() Calc(name="") >>> print(scale.value) 1.0 You can also use your own functions as long as they are jit-compilable by JAX. >>> def compute_variance(x): ... return jnp.exp(x)**2 >>> log_scale = lsl.Var.new_param(0.0, name="log_scale") >>> variance = lsl.Calc(compute_variance, log_scale).update() >>> print(variance.value) 1.0 .. _docs: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html """ def __init__( self, function: Callable[..., Any], *inputs: Any, _name: str = "", _needs_seed: bool = False, _update_on_init: bool = True, convert_inputs: Callable[[Any], Any] | Literal["default"] = "default", **kwinputs: Any, ): super().__init__( *inputs, **kwinputs, _name=_name, _needs_seed=_needs_seed, convert=convert_inputs, ) self._function = function self._update_on_init = _update_on_init if self._update_on_init: try: self.update() except Exception as e: logger.warning( f"{self} was not updated during initialization, because the" f" following exception occured: {repr(e)}. See debug log for the" " full traceback." ) logger.debug( f"{self} was not updated during initialization, because the" " following exception occured:", exc_info=e, ) @property def function(self) -> Callable[..., Any]: """The wrapped function.""" return self._function @function.setter @changes_model_graph def function(self, function: Callable[..., Any]): self._function = function
[docs] def update(self) -> Calc: args = [_input.value for _input in self.inputs] kwargs = {kw: _input.value for kw, _input in self.kwinputs.items()} try: self._value = self.function(*args, **kwargs) except Exception as e: raise RuntimeError(f"Error while updating {self}.") from e self._outdated = False return self
[docs] class TransientCalc(TransientNode, Calc): """A transient calculator node that does not cache its value.""" @property def value(self) -> Any: args = [_input.value for _input in self.inputs] kwargs = {kw: _input.value for kw, _input in self.kwinputs.items()} try: value = self.function(*args, **kwargs) except Exception as e: raise RuntimeError(f"Error while updating {self}.") from e return value
[docs] class TransientIdentity(TransientCalc): """ A transient identity node that does not cache its value. Essentially, this node "forwards" the value of its input to its outputs. """ def __init__(self, _input: Any, _name: str = ""): super().__init__(lambda x: x, _input, _name=_name)
class VarValue(TransientIdentity): """ A proxy node for the value of a :class:`.Var`. This node type is used to keep the references to a variable intact, even if the underlying value node is replaced. """
[docs] class Dist(Node): """ A :class:`.Node` subclass that wraps a probability distribution. Distribution nodes wrap distribution classes that follow the ``tensorflow_probability`` :class:`~tfp.distributions.Distribution` interface. They can be used to represent observation models and priors. Distribution nodes *will* appear in the node graph created by :func:`.viz.plot_nodes`, but they will *not* appear in the model graph created by :func:`.viz.plot_vars`. Parameters ---------- distribution The wrapped distribution class that follows the ``tensorflow_probability`` :class:`~tfp.distributions.Distribution` interface. *inputs Non-keyword inputs. Any inputs that are not already nodes or :class:`.Var` will be converted to :class:`.Value` nodes. The values of these inputs will be passed to the wrapped distribution in the same order they are entered here. _name The name of the node. If you do not specify a name, a unique name will be \ automatically generated upon initialization of a :class:`.Model`. _needs_seed Whether the node needs a seed / PRNG key. bijectors Optional parameter bijector specification for transforming distribution parameters. See :meth:`.Dist.biject_parameters` for supported formats and behavior. convert_inputs A function used to process the values of this node's inputs. The default uses the function stored in :meth:`.Node.convert_value`, which is ``jax.numpy.asarray``. **kwinputs Keyword inputs. Any inputs that are not already nodes or :class:`.Var`s will be converted to :class:`.Value` nodes. The values of these inputs will be passed to the wrapped distribution as keyword arguments. See Also -------- .Var : A variable in a statistical model, typically with a probability distribution. .MultivariateNormalDegenerate : A custom distribution class that implements a degenerate multivariate normal distribution in the ``tensorflow_probability`` :class:`~tfp.distributions.Distribution` interface. Examples -------- For the examples below, we import ``tensorflow_probability``: >>> import tensorflow_probability.substrates.jax.distributions as tfd Creating an observation model for a normally distributed variable with fixed mean and scale. The log probability of the node ``y`` in the example below is ``None``, until the variable is updated. >>> dist = lsl.Dist(tfd.Normal, loc=0.0, scale=1.0) >>> y = lsl.Var.new_obs(jnp.array([-0.5, 0.0, 0.5]), dist, name="y") >>> print(y.log_prob) None >>> y.update() Var(name="y") >>> y.log_prob Array([-1.0439385, -0.9189385, -1.0439385], dtype=float32) Now we define the same observation model, but include the location and scale as parameters: >>> loc = lsl.Var.new_param(0.0, name="loc") >>> scale = lsl.Var.new_param(1.0, name="scale") >>> dist = lsl.Dist(tfd.Normal, loc=loc, scale=scale) >>> y = lsl.Var.new_obs(jnp.array([-0.5, 0.0, 0.5]), dist, name="y").update() >>> y.log_prob Array([-1.0439385, -0.9189385, -1.0439385], dtype=float32) .. rubric:: Summed-up log-probability You can set the ``per_obs`` attribute of a distribution node to ``False`` to sum up the log-probability of the distribution over all observations. >>> dist.per_obs = False >>> y.update().log_prob Array(-3.0068154, dtype=float32) """ def __init__( self, distribution: Callable[..., Distribution], *inputs: Any, _name: str = "", _needs_seed: bool = False, bijectors: None | Literal["auto"] | dict[str, Bijector | Literal["auto"] | None] | Sequence[Bijector | Literal["auto"] | None] = None, convert_inputs: Callable[[Any], Any] | Literal["default"] = "default", **kwinputs: Any, ): super().__init__( *inputs, **kwinputs, _name=_name, _needs_seed=_needs_seed, convert=convert_inputs, ) self._at: Node | None = None self._distribution = distribution self._per_obs = True # Apply bijectors eagerly if provided if bijectors is not None: self.biject_parameters(bijectors=bijectors)
[docs] def all_input_nodes(self) -> tuple[Node, ...]: inputs = super().all_input_nodes() if self.at: inputs = _unique_tuple(inputs, [self.at]) return inputs
@property def at(self) -> Node | None: """Where to evaluate the distribution.""" return self._at @at.setter @changes_model_graph def at(self, at: Node | None): if self.var and at is not self.var.var_value_node: raise RuntimeError( f"{repr(self)} is part of a var, cannot set property `at`" ) self._at = at @property def distribution(self) -> Callable[..., Distribution]: """The wrapped distribution.""" return self._distribution @distribution.setter @changes_model_graph def distribution(self, distribution: Callable[..., Distribution]): self._distribution = distribution
[docs] def init_dist(self) -> Distribution: """Initializes the distribution.""" args = [_input.value for _input in self.inputs] kwargs = {kw: _input.value for kw, _input in self.kwinputs.items()} dist = self.distribution(*args, **kwargs) return dist
@property def log_prob(self) -> Array: """The log-probability of the distribution.""" return self.value @property def per_obs(self) -> bool: """Whether the log-probability is stored per observation or summed up.""" return self._per_obs @per_obs.setter @changes_model_graph def per_obs(self, per_obs: bool): self._per_obs = per_obs
[docs] def update(self) -> Dist: if not self.at: raise RuntimeError( f"{repr(self)} cannot evaluate log-prob, property `at` not set" ) log_prob = self.init_dist().log_prob(self.at.value) if not self.per_obs and hasattr(log_prob, "sum"): log_prob = log_prob.sum() self._value = log_prob self._outdated = False return self
[docs] def biject_parameters( self, bijectors: Literal["auto"] | dict[str, Bijector | Literal["auto"] | None] | Sequence[Bijector | Literal["auto"] | None] = "auto", inference: Literal["drop"] | None = None, ) -> Self: """ Transforms distribution parameters using bijectors with eager evaluation. This method applies bijectors to the distribution's parameters immediately. Only strong Var parameters (with parameter=True and not weak) will be transformed. Variable-level bijectors always take precedence over distribution-level bijectors. Parameters ---------- bijectors Bijector specification. Options: \ - ``"auto"``: Use default parameter bijectors from the distribution's \ parameter_properties(). Parameters without a default bijector \ (e.g., Wishart's df parameter) are skipped. \ - dict: Map parameter names to bijectors. Use ``"auto"`` for default, \ ``None`` to skip, or provide an explicit bijector instance/class. \ - Sequence: Bijectors for positional parameters in the order they \ appear in the distribution's __init__ signature. \ inference Inference information for transformed variables. If ``"drop"``, \ inference information is removed from original parameters. If ``None``, \ an error will be raised if inference is present for any parameter that \ is selected for bijection. Returns ------- Self Notes ----- A default bijector's forward is expected to map from unconstrained to constrained space. Consequently, its inverse is expected to map from constrained to unconstrained space. For custom TensorFlow Probability distributions, parameter_properties() must return a dictionary with parameter names as keys and ParameterProperties instances as values. The dict order must match the distribution's __init__ signature order. See Also -------- .Var.biject : Method for transforming individual variables. Examples -------- >>> scale = lsl.Var.new_param(1.0, name="scale") >>> rate = lsl.Var.new_param(1.0, name="rate") Auto-transform all parameters: >>> dist = lsl.Dist(tfd.Gamma, concentration=scale, rate=rate, bijectors="auto") >>> scale.weak # Now weak, transformed True Transform only specific parameters: >>> scale = lsl.Var.new_param(1.0, name="scale") >>> rate = lsl.Var.new_param(1.0, name="rate") >>> dist = lsl.Dist( ... tfd.Normal, ... loc=0.0, ... scale=scale, ... bijectors={"scale": "auto", "loc": None}, ... ) """ if inference is not None and inference != "drop": raise ValueError(f"Value {inference=} is not supported.") # Validate no mixing of positional and keyword inputs for auto bijectors if bijectors == "auto" and (self.inputs and self.kwinputs): raise ValueError( "Cannot use auto bijectors with mixed positional and keyword inputs. " "Please use either all positional or all keyword arguments for the " "distribution parameters." ) resolved = self._resolve_bijectors(bijectors) for param_name, (param_var, bijector) in resolved.items(): if param_var.weak: if bijector == "auto": logger.debug(f"Parameter '{param_name}' already weak. Skipping.") continue else: raise ValueError( f"Parameter '{param_name}' is weak, but explicit " f"bijector {bijector} provided." ) if param_var.auto_transform: if bijector == "auto": logger.debug( f"Parameter '{param_name}' has auto_transform=True. Skipping." ) continue else: raise ValueError( f"Parameter '{param_name}' has auto_transform=True, but " f"explicit bijector provided. Please resolve the conflict." ) if is_bijector_class(bijector): raise TypeError( f" For parameter {param_name} of {self}, you passed a bijector " f"class ({bijector}) instead of an instance." "This is currently not supported by Dist.biject_parameters." ) param_var.biject(bijector=bijector, inference=inference) return self
def _resolve_bijectors( self, bijectors: Literal["auto"] | dict[str, Bijector | Literal["auto"] | None] | Sequence[Bijector | Literal["auto"] | None], ) -> dict[str, tuple[Var, Bijector | Literal["auto"]]]: """Resolves bijector specs to parameter->(Var, Bijector) mappings.""" result: dict[str, tuple[Var, Bijector | Literal["auto"]]] = {} # Get default bijectors once - validates parameter_properties() default_bijectors = self.find_default_parameter_bijectors() param_names = list(default_bijectors) if self.inputs: if not len(self.inputs) <= len(param_names): raise ValueError( f"{self.distribution} has the parameters {param_names}, but got " f"{len(self.inputs)} inputs: {self.inputs}." ) if self.kwinputs: if not set(self.kwinputs) <= set(param_names): raise ValueError( f"{self.distribution} has the parameters {param_names}, but got " f"inputs: {list(self.kwinputs)} with values {self.kwinputs}." ) # Resolve bijector specification to dict if bijectors == "auto": bijector_dict = default_bijectors elif isinstance(bijectors, dict): if self.inputs: raise ValueError( "When dist inputs are supplied as positional arguments, " "bijectors have to be supplied positionally, too. Got inputs " f"{self.inputs} and bijectors {bijectors} for dist {self}." ) # Validate that all keys are valid parameter names invalid_keys = set(bijectors.keys()) - set(param_names) if invalid_keys: raise ValueError( f"Invalid parameter name(s) in bijectors dict: {invalid_keys}. " f"Valid parameter names are: {', '.join(param_names)}." ) bijector_dict = {} for param_name, bijector in bijectors.items(): if bijector == "auto": bijector_dict[param_name] = default_bijectors.get(param_name) else: bijector_dict[param_name] = bijector elif isinstance(bijectors, Sequence): if len(bijectors) > len(param_names): raise ValueError( f"Too many bijectors provided: got {len(bijectors)} bijectors " f"but distribution has only {len(param_names)} parameters " f"({', '.join(param_names)})." ) # now the mapping of bijectors to dist parameters is being resolved bijector_dict = {} # this makes sure that positional bijectors refer to the order of # kwinputs given by the user, not the order of kwinputs as expected # by the distribution kwinput_names = list(self.kwinputs) if self.kwinputs else param_names for i, bijector in enumerate(bijectors): param_name = kwinput_names[i] if bijector == "auto": bijector_dict[param_name] = default_bijectors.get(param_name) else: bijector_dict[param_name] = bijector else: raise TypeError(f"Invalid bijectors type: {type(bijectors)}.") for i, input_node in enumerate(self.inputs): param_name = param_names[i] bijector = bijector_dict.get(param_name) if bijector is None: continue if isinstance(input_node, VarValue): param_var = input_node.var if param_var: if not param_var.parameter: logger.warning( f"{param_var} has parameter=False " "but a bijector is being applied.", ) result[param_name] = (param_var, bijector) else: raise ValueError( f"Got bijector {bijector} for parameter '{param_name}', given by " f"{input_node}, but only lsl.Var " "objects can be bijected. You can supply 'None' for this parameter " "if you do not want to biject, or supply a lsl.Var object." ) for param_name, input_node in self.kwinputs.items(): bijector = bijector_dict.get(param_name) if bijector is None: continue if isinstance(input_node, VarValue): param_var = input_node.var if param_var: if not param_var.parameter: logger.warning( f"{param_var} has parameter=False " "but a bijector is being applied.", ) result[param_name] = (param_var, bijector) else: raise ValueError( f"Got bijector {bijector} for parameter '{param_name}', given by " f"{input_node}, but only lsl.Var " "objects can be bijected. You can supply 'None' for this parameter " "if you do not want to biject, or supply a lsl.Var object." ) return result @property def _dtype(self): dtypes_input = [jnp.asarray(v.value).dtype for v in self.inputs] dtypes_kwinput = [jnp.asarray(v.value).dtype for v in self.kwinputs.values()] dtypes = dtypes_input + dtypes_kwinput if len(set(dtypes)) > 1: raise TypeError(f"Found more than one dtype in inputs: {dtypes}") return dtypes[0]
[docs] def find_default_parameter_bijectors(self) -> dict[str, Bijector | None]: """Extracts default parameter bijectors from the wrapped distribution.""" try: param_props = self.init_dist().parameter_properties(dtype=self._dtype) except (AttributeError, TypeError) as e: # raise the same error type raise type(e)( "Error when accessing " f"parameter_properties() method on {self.distribution.__name__}. " "Cannot auto-transform parameters. " "This may indicate an issue with the TFP distribution or version. " "Either use a distribution that supports parameter_properties() or " "manually transform parameters with .transform()." ) from e if not isinstance(param_props, dict): raise TypeError( f"Distribution {self.distribution.__name__}'s " "parameter_properties() must return a dictionary, but returned " f"{type(param_props).__name__}. This may indicate an issue with " "a custom distribution implementation." ) bijectors: dict[str, Bijector | None] = {} for param_name, prop in param_props.items(): if not hasattr(prop, "default_constraining_bijector_fn"): raise AttributeError( f"Parameter property for '{param_name}' of " f"{self.distribution.__name__} does not have " "'default_constraining_bijector_fn' attribute. This may " "indicate an issue with the TFP distribution or version. " "Either use a distribution that supports " "parameter_properties() or manually transform parameters " "with .transform()." ) try: bijector = prop.default_constraining_bijector_fn() # type: ignore except Exception as e: raise type(e)( f"Error getting bijector for parameter '{param_name}' of " f"{self.distribution.__name__}: {e}" ) from e if bijector is None: raise ValueError( f"Expected a bijector or BIJECTOR_NOT_IMPLEMENTED for " f"parameter '{param_name}' of {self.distribution.__name__}, " f"but got None. If no default bijector is provided for " f"this parameter, the return value should be the " f"BIJECTOR_NOT_IMPLEMENTED method instead." ) bijector_type_name = type(bijector).__name__ # TFP's way of indicating no default bijector exists if bijector_type_name == "BIJECTOR_NOT_IMPLEMENTED": bijectors[param_name] = None else: bijectors[param_name] = bijector return bijectors
[docs] class TransientDist(TransientNode, Dist): """A transient distribution node that does not cache its value.""" @property def value(self) -> Any: if not self.at: raise RuntimeError( f"{repr(self)} cannot evaluate log-prob, property `at` not set" ) log_prob = self.init_dist().log_prob(self.at.value) if not self.per_obs and hasattr(log_prob, "sum"): log_prob = log_prob.sum() return log_prob
class NoDist(Dist): def __init__(self): super().__init__(NoDistribution) def all_input_nodes(self) -> tuple[Node, ...]: return () def all_output_nodes(self) -> tuple[Node, ...]: return () @property def outputs(self) -> tuple[Node, ...]: return () def update(self) -> NoDist: return self @property def value(self) -> float: return 0.0 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Variable ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ def is_bijector_class(obj) -> TypeGuard[type[Any]]: return isinstance(obj, type) and issubclass(obj, jb.Bijector)
[docs] class Var: """ A variable in a statistical model. A variable in Liesel is often a random variable, e.g. an observed or latent variable with a probability distribution (see :meth:`~.Var.new_obs`), or a model parameter with a prior distribution (see :meth:`~.Var.new_param`). Other quantities can also be declared as variables, e.g. fixed data like hyperparameters or design matrices (see :meth:`~.Var.new_value`), or quantities that are computed from other nodes, e.g. structured additive predictors in semi-parametric regression models (see :meth:`~.Var.new_calc`). .. tip:: You should initialize variables through one of the four constructors: :meth:`.new_param`, :meth:`.new_obs`, :meth:`.new_calc`, and :meth:`.new_value`. .. rubric:: Accessing inputs :class:`.Calc` and :class:`.Dist` objects support access to their inputs via square-bracket syntax. Thus, with a :class:`.Var` object, you can use square bracket indexing on its attributes :attr:`.Var.value_node` and :attr:`.Var.dist_node`. You can access both keyword and positional arguments this way. >>> import tensorflow_probability.substrates.jax.distributions as tfd Access keyword inputs to a calculator :attr:`.Var.value_node`: >>> a = lsl.Var.new_value(2.0, name="a") >>> b = lsl.Var.new_calc(lambda x: x + 1.0, x=a) >>> b.value_node["x"] Var(name="a") Access positional inputs to a calculator :attr:`.Var.value_node`: >>> a = lsl.Var.new_value(2.0, name="a") >>> b = lsl.Var.new_calc(lambda x: x + 1.0, a) >>> b.value_node[0] Var(name="a") Access keyword inputs to a distribution :attr:`.Var.dist_node`: >>> a = lsl.Var.new_value(2.0, name="a") >>> b = lsl.Var.new_obs(1.0, lsl.Dist(tfd.Normal, loc=a, scale=1.0)) >>> b.dist_node["loc"] Var(name="a") Access positional inputs to a distribution :attr:`.Var.dist_node`: >>> a = lsl.Var.new_value(2.0, name="a") >>> b = lsl.Var.new_obs(1.0, lsl.Dist(tfd.Normal, a, scale=1.0)) >>> b.dist_node[0] Var(name="a") .. note:: Note that, for accessing keyword arguments, you do *not* use the :attr:`.Var.name` attribute of the looked-for input variable or node, but the *argument name*. Consider this case from above:: a = lsl.Var.new_value(2.0, name="a") b = lsl.Var.new_obs(1.0, lsl.Dist(tfd.Normal, loc=a, scale=1.0)) b.dist_node["loc"] Here, we retrieve the variable ``a`` with the name ``"a"``. But for the indexing, we use the *argument name* ``"loc"`` from the call to ``lsl.Dist``. .. rubric:: Swapping out inputs You can also use square-bracket indexing on :attr:`.Var.value_node` and :attr:`.Var.dist_node` to swap out existing inputs. This allows you to easily make changes to your model. Swap out inputs to a calculator via :attr:`.Var.value_node`: >>> a = lsl.Var.new_value(2.0, name="a") >>> b = lsl.Var.new_calc(lambda x: x + 1.0, x=a) >>> c = lsl.Var.new_value(3.0, name="c") >>> b.value_node["x"] = c >>> b.value_node["x"] Var(name="c") Swap out inputs to a distribution via :attr:`.Var.dist_node`: >>> a = lsl.Var.new_value(2.0, name="a") >>> b = lsl.Var.new_obs(1.0, lsl.Dist(tfd.Normal, loc=a, scale=1.0)) >>> c = lsl.Var.new_value(3.0, name="c") >>> b.dist_node["loc"] = c >>> b.dist_node["loc"] Var(name="c") Parameters ---------- value The value of the variable. dist The probability distribution of the variable. name The name of the variable. If you do not specify a name, a unique name will be \ automatically generated upon initialization of a :class:`.Model`. inference Additional information that can be used to set up inference algorithms. bijector Bijector for variable transformation. If ``"auto"``, uses the default event \ space bijector defined by the variable's distribution. \ If ``None``, no transformation takes place. If not ``None``, the variable will \ call :meth:`.biject` with this bijector upon initialization. \ Any supplied inference information will be passed to the bijected \ variable. convert A function used to process the value of this variable. The default uses the function stored in :meth:`.Var.convert_value`, which is ``jax.numpy.asarray``. distribution Deprecated argument name for the probability distribution of the variable, kept for backwards-compatibility. Please use the new name ``dist``. See Also -------- .Var.new_obs : Initializes a strong variable that holds observed data. .Var.new_param : Initializes a strong variable that acts as a model parameter. .Var.new_calc : Initializes a weak variable that is a function of other variables. .Var.new_value : Initializes a strong variable without a distribution. :meth:`.Var.transform` : Transforms a variable by adding a new transformed variable as an input. This is useful for variables that are constrained to a certain domain, e.g. positive values. .Calc : A node representing a general calculation/operation in JAX or Python. Use this instead of :meth:`~.Var.new_calc` if you want to hide your calculation in the model graph produced by :func:`.plot_vars`. .Value : A node representing a static value. Use this instead of :meth:`~.Var.new_value` if you want to hide your value in the model graph produced by :func:`.plot_vars`. .Dist : A node representing a ``tensorflow_probability`` :class:`~tfp.distributions.Distribution`. """ __slots__ = ( "info", "inference", "_auto_transform", "_bijected_var", "_dist_node", "_groups", "_name", "_observed", "_parameter", "_role", "_value_node", "_var_value_node", "_convert", ) def __init__( self, value: Any, dist: Dist | None = None, name: str = "", inference: InferenceTypes = None, bijector: None | Bijector | Literal["auto"] = None, convert: Callable[[Any], Any] | Literal["default"] = "default", distribution: Dist | None = None, ): if dist is not None and distribution is not None: raise ValueError( "Values for 'dist' and 'distribution' provided. " "Please provide the distribution only via 'dist'; the name " "'distribution' is deprecated." ) if dist is None: dist = distribution self._name = name if convert == "default": self._convert: Callable[[Any], Any] = self.convert_value else: self._convert = convert self._value_node: Node = Value(None, convert=lambda x: x) self._dist_node: Dist = NoDist() self._var_value_node: VarValue = VarValue( self._value_node, _name=f"{self._name}_var_value" ) self._var_value_node._set_var(self) # use setters self.value_node = value # type: ignore # unfrozen self.dist_node = dist # type: ignore # unfrozen self._auto_transform = False self._bijected_var: Var | None = None self._groups: dict[str, Group] = {} self._observed = False self._parameter = False self._role = "" self.info: dict[str, Any] = {} """Additional meta-information about the variable as a dict.""" self.inference = inference # Apply bijector eagerly if provided if bijector is not None: self.biject(bijector=bijector, inference=inference)
[docs] @staticmethod def convert_value(x: Any) -> Any: """ The function used to process the value of this variable, if ``convert="default"`` is supplied during init. Can be overwritten on subclasses to create variable classes with different default conversion behavior. Make sure to overwrite it with a static method, for example (re-implementing the default behavior):: class MyVar(lsl.Var): @staticmethod def convert_value(x): return jnp.asarray(x) if x is not None else x """ return jnp.asarray(x) if x is not None else x
[docs] @classmethod def new_param( cls, value: Any, dist: Dist | None = None, name: str = "", inference: InferenceTypes = None, bijector: None | Bijector | Literal["auto"] = None, convert: Callable[[Any], Any] | Literal["default"] = "default", distribution: Dist | None = None, ) -> Var: """ Initializes a strong variable that acts as a model parameter. A parameter is a strong variable that can have a distribution. If it does have a distribution, its :attr:`~.Var.log_prob` is counted in a model's log prior, i.e. :attr:`~.Model.log_prior`. Parameters ---------- value The value of the variable. dist The probability distribution of the variable. name The name of the variable. If you do not specify a name, a unique name will \ be automatically generated upon initialization of a :class:`.Model`. inference Additional information that can be used to set up inference algorithms. bijector Bijector for variable transformation. If ``"auto"``, uses the default \ event space bijector defined by the variable's distribution. \ If ``None``, no transformation takes place. If not ``None``, the variable \ will call :meth:`.biject` with this bijector upon initialization. \ Any supplied inference information will be passed to the bijected \ variable. convert A function used to process the value of this variable. The default uses the function stored in :meth:`.Var.convert_value`, which is ``jax.numpy.asarray``. distribution Deprecated argument name for the probability distribution of the variable, kept for backwards-compatibility. Please use the new name ``dist``. See Also -------- .Var.new_obs : Initializes a strong variable that holds observed data. .Var.new_calc : Initializes a weak variable that is a function of other variables. .Var.new_value : Initializes a strong variable without a distribution. Examples -------- A simple parameter without a distribution and without a name: >>> x = lsl.Var.new_param(1.0) >>> x Var(name="") A simple parameter with a normal prior: >>> prior = lsl.Dist(tfd.Normal, loc=0.0, scale=1.0) >>> x = lsl.Var.new_param(1.0, dist=prior) >>> x Var(name="") """ var = cls( value, dist, name, inference=inference, bijector=bijector, convert=convert, distribution=distribution, ) var.value_node.monitor = True if var.bijected_var is not None: var.bijected_var.parameter = True else: var.parameter = True return var
[docs] @classmethod def new_obs( cls, value: Any, dist: Dist | None = None, name: str = "", distribution: Dist | None = None, convert: Callable[[Any], Any] | Literal["default"] = "default", ) -> Var: """ Initializes a strong variable that holds observed data. An observed variables is a strong variable that can have a distribution. If it does have a distribution, its :attr:`~.Var.log_prob` is counted in a model's log likelihood, i.e. :attr:`~.Model.log_lik`. Parameters ---------- value The value of the variable. dist The probability distribution of the variable. name The name of the variable. If you do not specify a name, a unique name will \ be automatically generated upon initialization of a :class:`.Model`. convert A function used to process the value of this variable. The default uses the function stored in :meth:`.Var.convert_value`, which is ``jax.numpy.asarray``. distribution Deprecated argument name for the probability distribution of the variable, kept for backwards-compatibility. Please use the new name ``dist``. See Also -------- .Var.new_param : Initializes a strong variable that acts as a model parameter. .Var.new_calc : Initializes a weak variable that is a function of other variables. .Var.new_value : Initializes a strong variable without a distribution. Examples -------- A simple observed variable without a distribution and without a name: >>> x = lsl.Var.new_obs(1.0) >>> x Var(name="") A simple observed variable with a normal distribution: >>> prior = lsl.Dist(tfd.Normal, loc=0.0, scale=1.0) >>> x = lsl.Var.new_param(1.0, dist=prior) >>> x Var(name="") """ var = cls(value, dist, name, distribution=distribution, convert=convert) var.observed = True return var
[docs] @classmethod def new_calc( cls, function: Callable[..., Any], *inputs: Any, dist: Dist | None = None, name: str = "", _needs_seed: bool = False, _update_on_init: bool = True, convert_inputs: Callable[[Any], Any] | Literal["default"] = "default", cache: bool = True, distribution: Dist | None = None, **kwinputs: Any, ) -> Var: """ Initializes a weak variable that is a function of other variables. A calculating variable can wrap arbitrary calculations in pure JAX functions. .. tip:: The wrapped function must be jit-compilable by JAX. This mainly means that it must be a pure function, i.e. it must not have any side effects and, given the same input, it must always return the same output. Some special consideration is also required for loops and conditionals. Please consult the JAX docs_ for details. Parameters ---------- function The function to be wrapped. Must be jit-compilable by JAX. *inputs Non-keyword inputs. Any inputs that are not already nodes or :class:`.Var` \ will be converted to :class:`.Value` nodes. The values of these inputs \ will be passed to the wrapped function in the same order they are entered \ here. dist The probability distribution of the variable. name The name of the node. If you do not specify a name, a unique name will be \ automatically generated upon initialization of a :class:`.Model`. _needs_seed Whether the node needs a seed / PRNG key. _update_on_init If ``True``, the calculator will try to evaluate its function upon \ initialization. convert_inputs A function used to process the values of this variable's inputs. The default uses the function stored in :meth:`.Var.convert_value`, which is ``jax.numpy.asarray``. cache If ``False``, this variable will not store a cache of its value. This means, the ``function`` is evaluated every single time that the value of this variable is requested. This can save memory, if the computations are trivial (such as prepending an axis to an array), but it can greatly slow down computations otherwise (such as when the function performs a matrix inversion). Internally, if ``cache=True``, this variable wraps a :class:`.Calc`, and if ``cache=False``, it wraps a :class:`.TransientCalc`. distribution Deprecated argument name for the probability distribution of the variable, kept for backwards-compatibility. Please use the new name ``dist``. **kwinputs Keyword inputs. Any inputs that are not already nodes or :class:`.Var`s will be converted to :class:`.Data` nodes. The values of these inputs will \ be passed to the wrapped function as keyword arguments. Notes ----- Internally, this constructor initializes and wraps a :class:`.Calc` node. See Also -------- .Var.new_param : Initializes a strong variable that acts as a model parameter. .Var.new_obs : Initializes a strong variable that holds observed data. .Var.new_value : Initializes a strong variable without a distribution. .Calc : The calculator node class. Examples -------- A simple calculator node, taking the exponential value of an input parameter. >>> log_scale = lsl.Var.new_param(0.0, name="log_scale") >>> scale = lsl.Var.new_calc(jnp.exp, log_scale, name="scale") >>> print(scale.value) 1.0 You can also use your own functions as long as they are jit-compilable by JAX. >>> def compute_variance(x): ... return jnp.exp(x)**2 >>> log_scale = lsl.Var.new_param(0.0, name="log_scale") >>> variance = lsl.Var.new_calc(compute_variance, log_scale, name="scale") >>> print(variance.value) 1.0 The value of the calculating variable is updated when :meth:`~.Var.update` is called. >>> log_scale = lsl.Var.new_param(0.0, name="log_scale") >>> scale = lsl.Var.new_calc(jnp.exp, log_scale, name="scale") >>> print(scale.value) 1.0 >>> log_scale.value = 1.0 >>> print(scale.value) 1.0 >>> print(scale.update().value) 2.7182817 .. _docs: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html """ # noqa: E501 if convert_inputs == "default": convert_inputs = cls.convert_value if cache: calc_class = Calc else: calc_class = TransientCalc calc = calc_class( function, *inputs, _name=f"{name}_calc", _needs_seed=_needs_seed, _update_on_init=_update_on_init, convert_inputs=convert_inputs, **kwinputs, ) var = cls( calc, dist=dist, distribution=distribution, name=name, convert=lambda x: x ) return var
[docs] @classmethod def new_value( cls, value: Any, name: str = "", inference: InferenceTypes = None, convert: Callable[[Any], Any] | Literal["default"] = "default", ) -> Var: """ Initializes a strong variable without a distribution. Parameters ---------- value The value of the variable. distribution The probability distribution of the variable. name The name of the variable. If you do not specify a name, a unique name will \ be automatically generated upon initialization of a :class:`.Model`. inference Additional information that can be used to set up inference algorithms. convert A function used to process the value of this variable. The default uses the function stored in :meth:`.Var.convert_value`, which is ``jax.numpy.asarray``. See Also -------- .Var.new_param : Initializes a strong variable that acts as a model parameter. .Var.new_obs : Initializes a strong variable that acts as an observed variable. .Var.new_calc : Initializes a weak variable that is a function of other variables. Examples -------- A simple value variable without a name: >>> x = lsl.Var.new_value(1.0) >>> x Var(name="") """ var = cls(value, name=name, inference=inference, convert=convert) return var
[docs] def get_inference(self, key: str | None) -> InferenceTypes: if isinstance(self.inference, dict): if key is None: raise ValueError( f"{key=} is invalid. Possible keys: {list(self.inference)}." ) return self.inference[key] return self.inference
[docs] def all_input_nodes( self, to: Literal["value_node", "dist_node", "both"] = "both" ) -> tuple[Node, ...]: """Returns all input *nodes* as a unique tuple.""" inputs1 = self.value_node.all_input_nodes() inputs2 = self._dist_node.all_input_nodes() match to: case "value_node": return inputs1 case "dist_node": return inputs2 case "both": return _unique_tuple(inputs1, inputs2)
[docs] def all_input_vars( self, to: Literal["value_node", "dist_node", "both"] = "both" ) -> tuple[Var, ...]: """ Returns all input *variables* as a unique tuple. The returned tuple also contains input variables that are indirect inputs of this variable through nodes without variables. """ nodes = list(self.all_input_nodes(to=to)) visited = [] _vars = [] while nodes: node = nodes.pop() if node not in visited: if node.var and node.var is not self: _vars.append(node.var) else: nodes.extend(node.all_input_nodes()) visited.append(node) return _unique_tuple(_vars)
[docs] @changes_model_graph def transform( self, bijector: type[jb.Bijector] | jb.Bijector | None = None, *bijector_args, inference: InferenceTypes | Literal["drop"] = None, name: str | None = None, **bijector_kwargs, ) -> Var: """ Transforms the variable, making it a function of a new variable. Creates a new variable on the unconstrained space ``R**n`` with the appropriate transformed distribution, turning the original variable into a weak variable without an associated distribution. The transformation is performed using TFP's bijector classes; see the `TFP bijectors documentation \ <https://www.tensorflow.org/probability/api_docs/python/tfp/bijectors>`_. - **Stored transformed variable**: After the transformation, you can access the transformed variable via :attr:`.bijected_var`. Parameters ---------- bijector The bijector used to map the new transformed variable to this variable \ (forward transformation). If ``None``, the experimental default event \ space bijector (see tensorflow probability documentation) is used. \ If a bijector class is \ passed, it is instantiated with the arguments ``bijector_args`` and \ ``bijector_kwargs``. If a bijector instance is passed, it is used \ directly. bijector_args The arguments passed on to the init function of the bijector. inference Additional information that can be used to set up inference algorithms for \ the new, transformed variable. If ``"drop"``, the inference \ information will be dropped from the original variable. \ The new variable will have no inference information. \ If ``None`` (default), the new variable will likewise have no inference \ information, but an error will be raised if there is inference information \ on the original variable. name Name for the new, transformed variable. If ``None`` (default), the new \ name will be ``<old_name>_transformed``, where ``<old_name>`` is \ a placeholder for the current variable's name. bijector_kwargs The keyword arguments passed on to the init function of the bijector. Returns ------- The new transformed variable which acts as an input to this variable. Raises ------ RuntimeError If the variable is weak or if the variable has no distribution. ValueError If the argument ``bijector`` is ``None``, but the distribution does not have a default event space bijector. Also, if in the arguments to :meth:`.transform` is ``inference=None`` but the variable attribute :attr:`.inference` is not ``None``. See Also -------- .biject : Similar method, but with a slightly different API and returns self \ instead of the transformed variable. Notes ----- This is a simplified pseudo-code illustration of what this method does: .. code-block:: python import tensorflow_probability.substrates.jax.bijectors as tfb import tensorflow_probability.substrates.jax.distributions as tfd def transform(original_var: lsl.Var, bijector: tfb.Bijector): original_dist = original_var.dist_node.distribution dist_inputs = original_var.dist_node.inputs # transform the distribution new_dist = tfd.TransformedDistribution( original_dist, tfb.Invert(bijector) ) # transform initial value new_value = bijector.inverse(original_var.value) # initialise the new variable new_var = lsl.Var( new_value, lsl.Dist(new_dist, *dist_inputs), name=f"{original_var.name}_transformed" ) new_var.parameter = original_var.parameter # define the original variable as a function of the new variable original_var.value_node = lsl.Calc(bijector.forward, new_var) original_var.parameter = False # return the new variable return new_var The value of the attribute :attr:`~liesel.model.nodes.Var.parameter` is transferred to the transformed variable and set to ``False`` on the original variable. The attributes :attr:`~liesel.model.nodes.Var.observed` and :attr:`~liesel.model.nodes.Var.role` have the default values for the transformed variable and remain unchanged on the original variable. Examples -------- >>> import tensorflow_probability.substrates.jax.distributions as tfd >>> import tensorflow_probability.substrates.jax.bijectors as tfb Assume we have a variable ``scale`` that is constrained to be positive, and we want to include the log-transformation of this variable in the model. We first set up the parameter var with its distribution: >>> prior = lsl.Dist(tfd.HalfCauchy, loc=0.0, scale=25.0) >>> scale = lsl.Var.new_param(1.0, prior, name="scale") The we transform the variable to the log-scale: >>> log_scale = scale.transform(tfb.Exp()) >>> log_scale Var(name="scale_transformed") Now the ``log_scale`` has a log probability, and the ``scale`` variable does not: >>> log_scale.update().log_prob Array(-3.6720574, dtype=float32) >>> scale.update().log_prob 0.0 """ if inference is None and self.inference: raise ValueError( f"{self} has inference information in the .inference attribute. " "To proceed with transformation, the .inference information needs to " "be explicitly removed. You can transform with ``inference='drop'``." ) # if self.weak: # raise RuntimeError(f"{repr(self)} is weak") if is_bijector_class(bijector) and not (bijector_args or bijector_kwargs): raise ValueError( "You passed a bijector class instead of an instance, but did not " "provide any arguments for the bijector. You should either provide " "arguments or pass an instance of the bijector class instead." ) # use default event space bijector if bijector is None use_default_bijector = bijector is None default_bijector = None if use_default_bijector and self.dist_node is not None: dist_inst = self.dist_node.init_dist() _args = [] for arg in bijector_args: if isinstance(arg, Var | Node): _args.append(arg.value) else: _args.append(arg) _kwargs = {} for key, val in bijector_kwargs.items(): if isinstance(val, Var | Node): _kwargs[key] = val.value else: _kwargs[key] = val default_bijector = dist_inst.experimental_default_event_space_bijector( *_args, **_kwargs ) if use_default_bijector and default_bijector is None: if self.dist_node is not None: msg = ( f"{self} has distribution without default event space bijector. " "No bijector was given." ) else: msg = ( f"{self} has no distribution, so there is no default event space " "bijector to be found. No bijector was given." ) raise RuntimeError(msg) if isinstance(bijector, jb.Bijector) and (bijector_args or bijector_kwargs): raise RuntimeError( "You passed a bijector instance and nonempty bijector" " arguments. You should either initialise your bijector" " directly with the arguments, or pass a bijector class" " instead. The first option is preferred, if the bijector" " argumentsare constant." ) if self.dist_node is None and is_bijector_class(bijector): tvar = _transform_var_without_dist_with_bijector_class( self, bijector, *bijector_args, **bijector_kwargs ) elif self.dist_node is None and isinstance(bijector, jb.Bijector): tvar = _transform_var_without_dist_with_bijector_instance(self, bijector) elif is_bijector_class(bijector) or use_default_bijector: # avoid infinite recursion self.auto_transform = False tvar = _transform_var_with_bijector_class( self, bijector, *bijector_args, **bijector_kwargs ) self.dist_node = None elif isinstance(bijector, jb.Bijector): # avoid infinite recursion self.auto_transform = False tvar = _transform_var_with_bijector_instance(self, bijector) self.dist_node = None else: raise TypeError( f"Argument {bijector=} is of invalid type {type(bijector)}." ) tvar.parameter = self.parameter # type: ignore self.parameter = False if not self.name: tvar.name = "" tvar.value_node.name = "" if tvar.dist_node is not None: tvar.dist_node.name = "" if name is not None: tvar.name = name if inference == "drop": self.inference = None else: self.inference = None tvar.inference = inference self.bijected_var = tvar return tvar
[docs] def biject( self, bijector: Bijector | Literal["auto"] = "auto", *bijector_args, inference: InferenceTypes | Literal["drop"] = None, name: str | None = None, **bijector_kwargs, ) -> Var: """ Transforms the variable using a bijector. - **Eager evaluation**: The transformation is applied immediately. - **Stored transformed variable**: Access via :attr:`.bijected_var`. This method is similar to :meth:`.transform`, but with key differences: - **Returns self**: Returns the original variable (now weakened) for chaining. - **Default "auto"**: Default uses the distribution's event space bijector. - **None means no transformation**: ``bijector=None`` means skip transformation. Parameters ---------- bijector The bijector for transformation. If ``"auto"``, uses the default event space bijector. If ``None``, no transformation. If a bijector class, instantiated with args/kwargs. If instance, used directly. bijector_args Arguments for bijector init if a class is provided. inference Inference information for the transformed variable. name Name for transformed variable. Default: ``h(<old_name>)``. bijector_kwargs Keyword arguments for bijector init if a class is provided. Returns ------- The original variable (self), now weak and depending on the transformed variable. Access the transformed variable via :attr:`.bijected_var`. See Also -------- .transform : Similar method with lazy evaluation and different return. .bijected_var : Property to access the transformed variable. """ # If bijector is None, no transformation is applied if bijector is None: return self # Validate bijector type if not ( bijector == "auto" or isinstance(bijector, jb.Bijector) or is_bijector_class(bijector) ): raise TypeError( f"Argument {bijector=} is of invalid type {type(bijector)}. " f"Expected Bijector, type[Bijector], 'auto', or None." ) # Delegate to transform(), which handles class vs instance uniformly target_bijector = None if bijector == "auto" else bijector self.transform( target_bijector, *bijector_args, inference=inference, name=name if name is not None else f"h({self.name})", **bijector_kwargs, ) return self
@property def bijected_var(self) -> Var | None: """ Transformed variable. Either supplied manually or automatically created by :meth:`.biject`. """ return self._bijected_var @bijected_var.setter def bijected_var(self, value: Var): if not isinstance(value, Var): raise TypeError(f"Bijected var must be a lsl.Var, got type {type(value)}.") in_inputs = value.var_value_node in self.value_node.inputs in_kwinputs = value.var_value_node in list(self.value_node.kwinputs.values()) if not (in_inputs or in_kwinputs): raise ValueError(f"{value} is on in the inputs or kwinputs of {self}") self._bijected_var = value
[docs] @in_model_method def all_output_nodes( self, of: Literal["value_node", "dist_node", "both"] = "both" ) -> tuple[Node, ...]: """Returns all output *nodes* as a unique tuple.""" match of: case "value_node": nodes = list(self.value_node.all_output_nodes()) nodes.extend(self.var_value_node.all_output_nodes()) case "dist_node": nodes = list(self._dist_node.all_output_nodes()) case "both": nodes = list(self.value_node.all_output_nodes()) nodes.extend(self.var_value_node.all_output_nodes()) nodes.extend(self._dist_node.all_output_nodes()) nodes = [node for node in nodes if node is not self.var_value_node] return _unique_tuple(nodes)
[docs] @in_model_method def all_output_vars( self, of: Literal["value_node", "dist_node", "both"] = "both" ) -> tuple[Var, ...]: """ Returns all output *variables* as a unique tuple. The returned tuple also contains output variables that are indirect outputs of this variable through nodes without variables. """ nodes = list(self.all_output_nodes(of=of)) visited = [] _vars = [] while nodes: node = nodes.pop() if node not in visited: if node.var and node.var is not self: _vars.append(node.var) else: nodes.extend(node.all_output_nodes()) visited.append(node) return _unique_tuple(_vars)
@property def auto_transform(self) -> bool: """ Whether the variable should automatically be transformed to the unconstrained space ``R**n`` upon model initialization. """ return self._auto_transform @auto_transform.setter def auto_transform(self, auto_transform: bool): self._auto_transform = auto_transform @property def dist_node(self) -> Dist | None: """The distribution node of the variable.""" return self._dist_node if self.has_dist else None @dist_node.setter @changes_model_graph def dist_node(self, dist_node: Dist | None): if not dist_node: dist_node = NoDist() if self.name and not dist_node.name: dist_node.name = f"{self.name}_log_prob" # type: ignore # unfrozen if self._dist_node.model: model = self._dist_node.model auto_update_before = model.auto_update model.auto_update = False lazy_before = model.update_graph_lazily model.update_graph_lazily = True inputs = self._dist_node.inputs kwinputs = self._dist_node.kwinputs inputs = inputs + tuple(kwinputs.values()) self._dist_node._unset_var() self._dist_node.set_inputs() model._nodes = { n.name: n for n in model._nodes.values() if n is not self._dist_node } self._dist_node.at = None for nv in inputs: model._remove_disconnected_parental_submodel(nv) if isinstance(nv, VarValue): assert nv.var model._remove_disconnected_parental_submodel(nv.var) model.update_graph_lazily = lazy_before model.auto_update = auto_update_before else: self._dist_node._unset_var() self._dist_node.at = None dist_node._set_var(self) dist_node.at = self.var_value_node # type: ignore # unfrozen self._dist_node = dist_node @property def groups(self) -> MappingProxyType[str, Group]: """The groups that this variable is a part of.""" return MappingProxyType(self._groups) @property def has_dist(self) -> bool: """Whether the variable has a probability distribution.""" return not isinstance(self._dist_node, NoDist) @property def log_prob(self) -> Array: """ The log-probability of the variable. A variable without a probability distribution has a log-probability of 0.0. """ return self._dist_node.value @property def model(self) -> Model | None: """The model the variable is part of.""" return self.value_node.model @property def name(self) -> str: """The name of the variable.""" return self._name @name.setter @changes_model_graph def name(self, name: str): if name and self.value_node.name in ("", f"{self.name}_value"): self.value_node.name = f"{name}_value" # type: ignore # unfrozen self.var_value_node.name = f"{name}_var_value" # type: ignore # unfrozen if name and self._dist_node.name in ("", f"{self.name}_log_prob"): self._dist_node.name = f"{name}_log_prob" # type: ignore # unfrozen self._name = name @property def nodes(self) -> list[Node]: """The nodes of the variable as a list.""" nodes = [self.value_node, self.var_value_node] if self.dist_node: nodes.append(self.dist_node) return nodes @property def observed(self) -> bool: """ Whether the variable is observed. If a variable is observed and has an associated probability distribution, its log-probability is automatically added to the model log-likelihood (see :attr:`.Model.log_lik`). See Also -------- .obs : Helper function to declare a variable as a parameter. .Model.log_prior : The log-prior of a Liesel model. .Var.parameter : Whether the variable is a parameter. If a variable is \ a parameter, it is not observed. Notes ----- We recommend to use the :func:`.obs` helper function to declare an observed variable. """ return self._observed @observed.setter @changes_model_graph def observed(self, observed: bool): if self.parameter and observed is True: raise ValueError("Cannot set observed flag to True if parameter=True") self._observed = observed @property def parameter(self) -> bool: """ Whether the variable is a parameter. If the variable is a parameter and has an associated probability distribution, its log-probability is added to the model's :attr:`~.Model.log_prior`. See Also -------- .param : Helper function to declare a variable as a parameter. .Model.log_prior : The log-prior of a Liesel model. .Var.observed : Whether the variable is observed. If a variable is \ a parameter, it is not observed. Notes ----- We recommend to use the :func:`.param` helper function to declare a variable as a parameter. """ return self._parameter @parameter.setter @changes_model_graph def parameter(self, parameter: bool): if self.observed and parameter is True: raise ValueError("Cannot set parameter flag to True if observed=True") self._parameter = parameter @property def role(self) -> str: """The role of the variable.""" return self._role @role.setter def role(self, role: str): self._role = role @property def strong(self) -> bool: """ Whether the variable is strong. A strong node is a node whose value is defined outside of the model, for example, if the node represents some observed data or a parameter (parameters are usually set by an inference algorithm such as an optimizer or sampler). In contrast, a weak node is a node whose value is defined within the model, that is, it is a deterministic function of some other nodes. An exp-transformation mapping a real-valued parameter to a positive number, for example, would be a weak node. See Also -------- .weak : Whether the variable is weak. In general, ``strong = not weak``. """ return isinstance(self.value_node, Value)
[docs] def update(self) -> Var: """Updates the variable.""" self.value_node.update() self._dist_node.update() return self
@property def value(self) -> Any: """ The value of the variable. Can only be set if the variable is strong. If the variable is part of a :class:`.Model` ``m`` with ``m.auto_update == True``, setting the value of the variable triggers an update of the model. The auto-update can be disabled to improve the performance if multiple model parameters are updated at once. """ return self.value_node.value @value.setter def value(self, value: Any): if self.weak: raise RuntimeError(f"{repr(self)} is weak, cannot set value") self.value_node.value = value # type: ignore # data node @property def value_node(self) -> Node: """The value node of the variable.""" return self._value_node @value_node.setter @changes_model_graph def value_node(self, value_node: Any): if isinstance(value_node, Var): value_node = Calc(lambda x: x, value_node, convert_inputs=self._convert) if not isinstance(value_node, Node): value_node = Value(value_node, convert=self._convert) if value_node.model: if value_node.model is not self.model: raise RuntimeError( f"{repr(value_node)} and {self} must be part of no " "model, or the same model." ) if self.name and not value_node.name: value_node.name = f"{self.name}_value" self.value_node._unset_var() if self.model: self.model._nodes = { n.name: n for n in self.model._nodes.values() if n is not self.value_node } value_node._set_var(self) self._value_node = value_node self._var_value_node.set_inputs(self._value_node) @property def var_value_node(self) -> VarValue: """The proxy node for the value of the variable.""" return self._var_value_node @property def weak(self) -> bool: """ Whether the variable is weak. A weak variable is a variable whose value is defined within the model, that is, it is a deterministic function of some other nodes. See Also -------- .strong : Whether the variable is strong. In general, ``weak = not strong``. """ return not self.strong def __repr__(self) -> str: return f'{type(self).__name__}(name="{self.name}")' def _plot( self, which: Literal["vars", "nodes"] = "vars", verbose: bool = False, **kwargs ) -> None: if self.model is not None: match which: case "vars": subgraph = self.model.var_parental_subgraph(self) return plot_vars(subgraph, **kwargs) case "nodes": self_nodes = [self.value_node, self.dist_node, self.var_value_node] filtered_nodes = [nd for nd in self_nodes if nd is not None] subgraph = self.model.node_parental_subgraph(*filtered_nodes) return plot_nodes(subgraph, **kwargs) from liesel.model.model import TemporaryModel try: to_float32 = not jax.config.jax_enable_x64 # type: ignore except Exception: # just to be really sure in case anything changes # this is an implicit test of whether x64 flag is enabled import jax.numpy as jnp to_float32 = jnp.array(1.0).dtype == jnp.dtype("float32") with TemporaryModel(self, verbose=verbose, to_float32=to_float32) as model: match which: case "vars": subgraph = model.var_parental_subgraph(self) plot_vars(subgraph, **kwargs) case "nodes": self_nodes = [self.value_node, self.dist_node, self.var_value_node] filtered_nodes = [nd for nd in self_nodes if nd is not None] subgraph = model.node_parental_subgraph(*filtered_nodes) plot_nodes(subgraph, **kwargs)
[docs] def plot( self, show: bool = True, save_path: str | None | IO = None, width: int = 14, height: int = 10, prog: Literal[ "dot", "circo", "fdp", "neato", "osage", "patchwork", "sfdp", "twopi" ] = "dot", verbose: bool = False, legend: bool = True, ) -> None: """ Plots the variables of the Liesel sub-model that terminates in this variable. Wraps :func:`~.viz.plot_vars`. Alias for :meth:`.Var.plot_vars`. Parameters ---------- verbose If ``True``, logs a message if unnamed variables or nodes are temporarily \ named for plotting. show Whether to show the plot in a new window. save_path Path to save the plot. If not provided, the plot will not be saved. width Width of the plot in inches. height Height of the plot in inches. prog Layout parameter. Available layouts: circo, dot (the default), fdp, neato, \ osage, patchwork, sfdp, twopi. verbose If ``True``, the message that will be logged if unnamed nodes are \ automatically named for plotting contains a list of the automatically \ assigned names. legend Whether to draw the legend. See Also -------- .Var.plot_vars : Plots the variables of the Liesel sub-model that terminates in this variable. .Var.plot_nodes : Plots the nodes of the Liesel sub-model that terminates in this variable. .Model.plot_vars : Plots the variables of a Liesel model. .Model.plot_nodes : Plots the nodes of a Liesel model. .viz.plot_vars : Plots the variables of a Liesel model. .viz.plot_nodes : Plots the nodes of a Liesel model. """ return self.plot_vars( verbose=verbose, show=show, save_path=save_path, width=width, height=height, prog=prog, legend=legend, )
[docs] def plot_vars( self, show: bool = True, save_path: str | None | IO = None, width: int = 14, height: int = 10, prog: Literal[ "dot", "circo", "fdp", "neato", "osage", "patchwork", "sfdp", "twopi" ] = "dot", verbose: bool = False, legend: bool = True, ) -> None: """ Plots the variables of the Liesel sub-model that terminates in this variable. Wraps :func:`~.viz.plot_vars`. Parameters ---------- verbose If ``True``, logs a message if unnamed variables or nodes are temporarily \ named for plotting. show Whether to show the plot in a new window. save_path Path to save the plot. If not provided, the plot will not be saved. width Width of the plot in inches. height Height of the plot in inches. prog Layout parameter. Available layouts: circo, dot (the default), fdp, neato, \ osage, patchwork, sfdp, twopi. verbose If ``True``, the message that will be logged if unnamed nodes are \ automatically named for plotting contains a list of the automatically \ assigned names. legend Whether to draw the legend. See Also -------- .Var.plot_vars : Plots the variables of the Liesel sub-model that terminates in this variable. .Var.plot_nodes : Plots the nodes of the Liesel sub-model that terminates in this variable. .Model.plot_vars : Plots the variables of a Liesel model. .Model.plot_nodes : Plots the nodes of a Liesel model. .viz.plot_vars : Plots the variables of a Liesel model. .viz.plot_nodes : Plots the nodes of a Liesel model. """ return self._plot( which="vars", verbose=verbose, show=show, save_path=save_path, width=width, height=height, prog=prog, legend=legend, )
[docs] def predict( self, samples: dict[str, jax.typing.ArrayLike], newdata: dict[str, jax.typing.ArrayLike] | None = None, ) -> Array: """ Returns an array of predictions for this variable. Parameters ---------- samples Dictionary of samples at which to evaluate predictions. newdata Dictionary of new data at which to evaluate predictions. The keys should \ correspond to variable or node names in the model whose values should be \ set to the given values before evaluating predictions. """ if self.model is not None: submodel = self.model.parental_submodel(self) else: from liesel.model.model import TemporaryModel try: to_float32 = not jax.config.jax_enable_x64 # type: ignore except Exception: # just to be really sure in case anything changes # this is an implicit test of whether x64 flag is enabled import jax.numpy as jnp to_float32 = jnp.array(1.0).dtype == jnp.dtype("float32") with TemporaryModel(self, silent=True, to_float32=to_float32) as model: submodel = model.parental_submodel(self) if self.model is None: model = submodel else: model = self.model newdata = newdata if newdata is not None else {} newdata = newdata.copy() for key in list(newdata.keys()): if key not in model.vars or (key in model.nodes): msg = f"{key} is not part of the model." if self.model is None: msg += ( f" Note that {self} is not part of a model, so this check " f"can only use the inputs to {self}, which is more strict." ) raise KeyError(msg) if key not in submodel.vars or (key in submodel.nodes): newdata.pop(key, None) pred = submodel.predict(samples=samples, predict=[self.name], newdata=newdata) return pred[self.name]
[docs] def diagnose(self, verbose: bool = False) -> pd.DataFrame: """ Provides a dataframe with diagnostic information about this variable's submodel. """ if self.model is not None: submodel = self.model.parental_submodel(self) else: from liesel.model.model import TemporaryModel try: to_float32 = not jax.config.jax_enable_x64 # type: ignore except Exception: # just to be really sure in case anything changes # this is an implicit test of whether x64 flag is enabled import jax.numpy as jnp to_float32 = jnp.array(1.0).dtype == jnp.dtype("float32") with TemporaryModel(self, silent=True, to_float32=to_float32) as model: submodel = model.parental_submodel(self) if self.model is None: model = submodel else: model = self.model return model.diagnose(verbose=verbose)
[docs] def sample( self, shape: Sequence[int], seed: jax.Array, posterior_samples: dict[str, jax.typing.ArrayLike] | None = None, fixed: Sequence[str] = (), newdata: dict[str, jax.typing.ArrayLike] | None = None, dists: dict[str, Dist] | None = None, ) -> dict[str, Array]: """ Draws samples from the parental model for this variable. Parameters ---------- shape Sample shape. seed The seed is split and distributed to the seed nodes of the model. \ Must be a jax RNG key array that satisfies \ ``jnp.issubdtype(key.dtype, jax.dtypes.prng_key)``. \ See :mod:`jax.random` and \ https://docs.jax.dev/en/latest/jep/9263-typed-keys.html for more details. posterior_samples Dictionary of samples at which to evaluate predictions. All values of the \ dictionary are assumed to have two leading dimensions corresponding to \ ``(nchains, niteration)``. fixed The names of the nodes or variables to be excluded from the simulation. \ By default, no nodes or variables are skipped. newdata Dictionary of new data at which to produce samples. The keys should \ correspond to variable or node names in the model whose values should be \ set to the given values before sampling. If ``None`` \ (default), the current variable values are used. dists Can be used to provide a dictionary of variable names and :class:`.Dist` \ instances to use in sampling. If ``None`` (default), samples are drawn for \ each variable using their :attr:`.Var.dist_node`. Notes ----- When compiling this function with ``jax.jit``, the arguments ``shape``, ``fixed``, and ``dists`` must be static. Returns ------- A dictionary of variable and node names and their sampled values. Includes only sampled variables. """ if self.model: submodel = self.model.parental_submodel(self) drawn_samples = submodel.sample( shape=shape, seed=seed, posterior_samples=posterior_samples, fixed=fixed, newdata=newdata, dists=dists, ) return drawn_samples from .model import TemporaryModel try: to_float32 = not jax.config.jax_enable_x64 # type: ignore except Exception: # just to be really sure in case anything changes # this is an implicit test of whether x64 flag is enabled import jax.numpy as jnp to_float32 = jnp.array(1.0).dtype == jnp.dtype("float32") with TemporaryModel(self, silent=True, to_float32=to_float32) as model: drawn_samples = model.sample( shape=shape, seed=seed, posterior_samples=posterior_samples, fixed=fixed, newdata=newdata, dists=dists, ) return drawn_samples
[docs] def plot_nodes( self, show: bool = True, save_path: str | None | IO = None, width: int = 14, height: int = 10, prog: Literal[ "dot", "circo", "fdp", "neato", "osage", "patchwork", "sfdp", "twopi" ] = "dot", verbose: bool = False, ) -> None: """ Plots the nodes of the Liesel sub-model that terminates in this variable. Wraps :func:`~.viz.plot_nodes`. Parameters ---------- verbose If ``True``, logs a message if unnamed variables or nodes are temporarily \ named for plotting. show Whether to show the plot in a new window. save_path Path to save the plot. If not provided, the plot will not be saved. width Width of the plot in inches. height Height of the plot in inches. prog Layout parameter. Available layouts: circo, dot (the default), fdp, neato, \ osage, patchwork, sfdp, twopi. verbose If ``True``, the message that will be logged if unnamed nodes are \ automatically named for plotting contains a list of the automatically \ assigned names. See Also -------- .Var.plot_vars : Plots the variables of the Liesel sub-model that terminates in this variable. .Var.plot_nodes : Plots the nodes of the Liesel sub-model that terminates in this variable. .Model.plot_vars : Plots the variables of a Liesel model. .Model.plot_nodes : Plots the nodes of a Liesel model. .viz.plot_vars : Plots the variables of a Liesel model. .viz.plot_nodes : Plots the nodes of a Liesel model. """ return self._plot( which="nodes", verbose=verbose, show=show, save_path=save_path, width=width, height=height, prog=prog, )
[docs] def ensure_name(self) -> Self: """ Ensures that the variable has a name. If the variable already has a name, nothing happens. Otherwise, a unique random name is generated with a leading underscore. """ if self.name: return self else: self.name = "_" + random_name() return self
def _transform_var_with_bijector_instance(var: Var, bijector_inst: jb.Bijector) -> Var: if var.dist_node is None: # type: ignore raise RuntimeError(f"{var} has no distribution") InputDist = var.dist_node.distribution inputs = var.dist_node.inputs kwinputs = var.dist_node.kwinputs bijector_inv = jb.Invert(bijector_inst) def transform_dist(*args, **kwargs): return jd.TransformedDistribution(InputDist(*args, **kwargs), bijector_inv) transformed_dist = Dist( transform_dist, *inputs, _name="", _needs_seed=var.dist_node.needs_seed, bijectors=None, convert_inputs=jnp.asarray, **kwinputs, ) transformed_dist.per_obs = var.dist_node.per_obs if var.weak: try: value_function = var.value_node.function # type: ignore except AttributeError as e: raise AttributeError( "Trying to transform a weak variable without calculator node." ) from e def forward(*args, **kwargs): return bijector_inv.forward(value_function(*args, **kwargs)) value_inputs = var.value_node.inputs value_kwinputs = var.value_node.kwinputs value_node_needs_seed = var.value_node.needs_seed try: value_node_update_on_init = var.value_node._update_on_init # type: ignore except AttributeError as e: raise e transformed_var = Var( Calc( forward, *value_inputs, _name="", _needs_seed=value_node_needs_seed, _update_on_init=value_node_update_on_init, convert_inputs=jnp.asarray, **value_kwinputs, ), transformed_dist, name=f"{var.name}_transformed", ) else: transformed_var = Var( bijector_inv.forward(var.value), transformed_dist, name=f"{var.name}_transformed", ) var.value_node = Calc(bijector_inst.forward, transformed_var) return transformed_var def _transform_var_with_bijector_class( var: Var, bijector_cls: type[jb.Bijector] | None, *args, **kwargs ) -> Var: if var.dist_node is None: # type: ignore raise RuntimeError(f"{var} has no distribution") InputDist = var.dist_node.distribution dist_inputs = InputGroup( *var.dist_node.inputs, **var.dist_node.kwinputs, # type: ignore ) bijector_inputs = InputGroup(*args, **kwargs) # define distribution "class" for the transformed var def transform_dist(dist_args: ArgGroup, bijector_args: ArgGroup): tfp_dist = InputDist(*dist_args.args, **dist_args.kwargs) bjargs, bjkwargs = bijector_args.args, bijector_args.kwargs if bijector_cls is None: default_bijector_cls = tfp_dist.experimental_default_event_space_bijector bijector_inst = default_bijector_cls(*bjargs, **bjkwargs) else: bijector_inst = bijector_cls(*bjargs, **bjkwargs) bijector_inv = jb.Invert(bijector_inst) transformed_dist = jd.TransformedDistribution( tfp_dist, bijector_inv, validate_args=tfp_dist.validate_args ) return transformed_dist dist_node_transformed = Dist( transform_dist, dist_inputs, bijector_inputs, _name="", _needs_seed=var.dist_node.needs_seed, bijectors=None, ) dist_node_transformed.per_obs = var.dist_node.per_obs bijector_inv = dist_node_transformed.init_dist().bijector if var.weak: try: value_function = var.value_node.function # type: ignore except AttributeError as e: raise AttributeError( "Trying to transform a weak variable without calculator node." ) from e def forward(*args, **kwargs): return bijector_inv.forward(value_function(*args, **kwargs)) value_inputs = var.value_node.inputs value_kwinputs = var.value_node.kwinputs value_node_needs_seed = var.value_node.needs_seed try: value_node_upadte_on_init = var.value_node._update_on_init # type: ignore except AttributeError as e: raise e transformed_var = Var( Calc( forward, *value_inputs, _name="", _needs_seed=value_node_needs_seed, _update_on_init=value_node_upadte_on_init, convert_inputs=jnp.asarray, **value_kwinputs, ), dist_node_transformed, name=f"{var.name}_transformed", ) else: transformed_var = Var( bijector_inv.forward(var.value), dist_node_transformed, name=f"{var.name}_transformed", ) def bijector_fn(value, dist_inputs, bijector_inputs): bijector = transform_dist(dist_inputs, bijector_inputs).bijector return bijector.inverse(value) var.value_node = Calc(bijector_fn, transformed_var, dist_inputs, bijector_inputs) return transformed_var def _transform_var_without_dist_with_bijector_instance( var: Var, bijector_inst: jb.Bijector ) -> Var: if var.strong: transformed_var = Var( bijector_inst.inverse(var.value), name=f"{var.name}_transformed", ) else: transformed_var = Var.new_calc( bijector_inst.inverse, var.value_node, name=f"{var.name}_transformed", ) var.value_node = Calc(bijector_inst.forward, transformed_var) return transformed_var def _transform_var_without_dist_with_bijector_class( var: Var, bijector_cls: type[jb.Bijector] | None, *args, **kwargs ) -> Var: def bijection_inverse(x, *bjargs, **bjkwargs): # this somewhat over-complicated functionality accounts for bijector # arguments being passed directly as values, or as Liesel Vars and Nodes. # This inverse is executed only once in the initialization of the transformed # variable. arg_values = [] for arg in bjargs: try: arg_values.append(arg.value) except AttributeError: arg_values.append(arg) kwarg_values = {} for key, val in bjkwargs.items(): try: kwarg_values[key] = val.value except AttributeError: kwarg_values[key] = val bijector_inst = bijector_cls(*arg_values, **kwarg_values) bijector_inv = jb.Invert(bijector_inst) return bijector_inv(x) def bijection_forward(x, *bjargs, **bjkwargs): bijector_inst = bijector_cls(*bjargs, **bjkwargs) return bijector_inst(x) if var.strong: transformed_var = Var( bijection_inverse(var.value, *args, **kwargs), name=f"{var.name}_transformed", ) else: transformed_var = Var.new_calc( bijection_inverse, var.value_node, *args, **kwargs, name=f"{var.name}_transformed", ) var.value_node = Calc(bijection_forward, transformed_var, *args, **kwargs) return transformed_var # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Group ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class Group: """ A group holds a collection of related :class:`.Var` and/or :class:`.Node` objects. They allow you to do three basic things: 1. Store related nodes together for easier access. 2. Access their member nodes and variables via ``group["name"]``, where ``"name"`` is the group-specific name, which can be different from the :attr:`.Var.name` / :attr:`.Node.name`. 3. Easily retrieve a variable's or a node's value from a :attr:`.Model.state` based on their group-specific name via :meth:`.value_from`. Parameters ---------- name The group's name. Must be unique among the groups of its members, and within a model. **nodes_and_vars An arbitrary number of nodes or variables. The keywords will be used as the group-specific names of the respective objects. See Also -------- * :attr:`.Node.groups` and :attr:`.Var.groups` are :obj:`MappingProxyType` objects (basically read-only dictionaries) of the groups whose member the respective object is. * :meth:`.GraphBuilder.groups` and :meth:`.Model.groups` are methods that collect and return all groups within the graph/model. Notes ----- Note the following: - Groups can only be filled upon initialization. - After initialization, variables and nodes cannot be removed from a group. Examples -------- Add a variable to a group: >>> my_var = lsl.Var(0.0, name="long_unique_variable_name") >>> grp = lsl.Group(name="demo_group", short_name=my_var) >>> grp Group(name="demo_group") Access the variable by its group-specific name: >>> grp["short_name"] Var(name="long_unique_variable_name") Retrieve the value of a variable from a model state: >>> model_state = {my_var.value_node.name: lsl.NodeState(10.0, False)} >>> grp.value_from(model_state, "short_name") 10.0 """ def __init__(self, name: str, **nodes_and_vars: Node | Var) -> None: self._name = name self._nodes_and_vars = nodes_and_vars for member in self._nodes_and_vars.values(): if name in member.groups: raise RuntimeError( f"{repr(member)} is already a member of a group " f"with the name {repr(name)}" ) member._groups[name] = self self._nodes = { name: obj for name, obj in self._nodes_and_vars.items() if isinstance(obj, Node) } self._vars = { name: obj for name, obj in self._nodes_and_vars.items() if isinstance(obj, Var) } @property def name(self) -> str: """The group's name.""" return self._name def value_from(self, model_state: dict[str, NodeState], name: str) -> Array: """ Retrieves the value of a node or variable that is a member of the group from a model state. Parameters ---------- model_state The state of a Liesel model, i.e. a :class:`~.Model.state`. name The name of the node or variable within this group. Returns ------- The value of the node or variable. """ member = self[name] if isinstance(member, Var): value_name = member.value_node.name else: value_name = member.name return model_state[value_name].value @property def vars(self) -> MappingProxyType[str, Var]: """A mapping of the variables in the group with their names as keys.""" return MappingProxyType(self._vars) @property def nodes(self) -> MappingProxyType[str, Node]: """A mapping of the nodes in the group with their names as keys.""" return MappingProxyType(self._nodes) @property def nodes_and_vars(self) -> MappingProxyType[str, Node | Var]: """A mapping of all group members with their names as keys.""" return MappingProxyType(self._nodes_and_vars) def __contains__(self, key) -> bool: return key in self._nodes_and_vars def __getitem__(self, key) -> Var | Node: return self._nodes_and_vars[key] def __repr__(self) -> str: return f'{type(self).__name__}(name="{self.name}")'