Dist

Contents

Dist#

class liesel.model.Dist(distribution, *inputs, _name='', _needs_seed=False, bijectors=None, convert_inputs='default', **kwinputs)[source]#

Bases: Node

A Node subclass that wraps a probability distribution.

Distribution nodes wrap distribution classes that follow the tensorflow_probability Distribution interface. They can be used to represent observation models and priors.

Distribution nodes will appear in the node graph created by viz.plot_nodes(), but they will not appear in the model graph created by viz.plot_vars().

Parameters:
  • distribution (Callable[..., TypeAliasType]) – The wrapped distribution class that follows the tensorflow_probability Distribution interface.

  • *inputs (Any) – Non-keyword inputs. Any inputs that are not already nodes or Var will be converted to Value nodes. The values of these inputs will be passed to the wrapped distribution in the same order they are entered here.

  • _name (str, default: '') – The name of the node. If you do not specify a name, a unique name will be automatically generated upon initialization of a Model.

  • _needs_seed (bool, default: False) – Whether the node needs a seed / PRNG key.

  • bijectors (None | Literal['auto'] | dict[str, TypeAliasType | Literal['auto'] | None] | Sequence[TypeAliasType | Literal['auto'] | None], default: None) – Optional parameter bijector specification for transforming distribution parameters. See Dist.biject_parameters() for supported formats and behavior.

  • convert_inputs (Callable[[Any], Any] | Literal['default'], default: 'default') – A function used to process the values of this node’s inputs. The default uses the function stored in convert_value, which is jax.numpy.asarray.

  • **kwinputs (Any) – Keyword inputs. Any inputs that are not already nodes or Var`s will be converted to :class:.Value` nodes. The values of these inputs will be passed to the wrapped distribution as keyword arguments.

See also

Var

A variable in a statistical model, typically with a probability distribution.

MultivariateNormalDegenerate

A custom distribution class that implements a degenerate multivariate normal distribution in the tensorflow_probability Distribution interface.

Examples

For the examples below, we import tensorflow_probability:

>>> import tensorflow_probability.substrates.jax.distributions as tfd

Creating an observation model for a normally distributed variable with fixed mean and scale. The log probability of the node y in the example below is None, until the variable is updated.

>>> dist = lsl.Dist(tfd.Normal, loc=0.0, scale=1.0)
>>> y = lsl.Var.new_obs(jnp.array([-0.5, 0.0, 0.5]), dist, name="y")
>>> print(y.log_prob)
None
>>> y.update()
Var(name="y")
>>> y.log_prob
Array([-1.0439385, -0.9189385, -1.0439385], dtype=float32)

Now we define the same observation model, but include the location and scale as parameters:

>>> loc = lsl.Var.new_param(0.0, name="loc")
>>> scale = lsl.Var.new_param(1.0, name="scale")
>>> dist = lsl.Dist(tfd.Normal, loc=loc, scale=scale)
>>> y = lsl.Var.new_obs(jnp.array([-0.5, 0.0, 0.5]), dist, name="y").update()
>>> y.log_prob
Array([-1.0439385, -0.9189385, -1.0439385], dtype=float32)

Summed-up log-probability

You can set the per_obs attribute of a distribution node to False to sum up the log-probability of the distribution over all observations.

>>> dist.per_obs = False
>>> y.update().log_prob
Array(-3.0068154, dtype=float32)

Methods

all_input_nodes()

Returns all non-keyword and keyword input nodes as a unique tuple.

biject_parameters([bijectors, inference])

Transforms distribution parameters using bijectors with eager evaluation.

find_default_parameter_bijectors()

Extracts default parameter bijectors from the wrapped distribution.

init_dist()

Initializes the distribution.

update()

Updates the value of the node.

Attributes

at

Where to evaluate the distribution.

distribution

The wrapped distribution.

log_prob

The log-probability of the distribution.

per_obs

Whether the log-probability is stored per observation or summed up.

monitor

Whether the node should be monitored by an inference algorithm.