Dist

Contents

Dist#

class liesel.model.Dist(distribution, *inputs, _name='', _needs_seed=False, bijectors=None, convert_inputs='default', **kwinputs)[source]#

Bases: Node

A Node subclass that wraps a probability distribution.

Distribution nodes wrap distribution classes that follow the tensorflow_probability Distribution interface. They can be used to represent observation models and priors.

Distribution nodes will appear in the node graph created by viz.plot_nodes(), but they will not appear in the model graph created by viz.plot_vars().

Parameters:

See also

Var

A variable in a statistical model, typically with a probability distribution.

MultivariateNormalDegenerate

A custom distribution class that implements a degenerate multivariate normal distribution in the tensorflow_probability Distribution interface.

Examples

For the examples below, we import tensorflow_probability:

>>> import tensorflow_probability.substrates.jax.distributions as tfd

Creating an observation model for a normally distributed variable with fixed mean and scale. The log probability of the node y in the example below is None, until the variable is updated.

>>> dist = lsl.Dist(tfd.Normal, loc=0.0, scale=1.0)
>>> y = lsl.Var.new_obs(jnp.array([-0.5, 0.0, 0.5]), dist, name="y")
>>> print(y.log_prob)
None
>>> y.update()
Var(name="y")
>>> y.log_prob
Array([-1.0439385, -0.9189385, -1.0439385], dtype=float32)

Now we define the same observation model, but include the location and scale as parameters:

>>> loc = lsl.Var.new_param(0.0, name="loc")
>>> scale = lsl.Var.new_param(1.0, name="scale")
>>> dist = lsl.Dist(tfd.Normal, loc=loc, scale=scale)
>>> y = lsl.Var.new_obs(jnp.array([-0.5, 0.0, 0.5]), dist, name="y").update()
>>> y.log_prob
Array([-1.0439385, -0.9189385, -1.0439385], dtype=float32)

Summed-up log-probability

You can set the per_obs attribute of a distribution node to False to sum up the log-probability of the distribution over all observations.

>>> dist.per_obs = False
>>> y.update().log_prob
Array(-3.0068154, dtype=float32)

Methods

all_input_nodes()

Returns all non-keyword and keyword input nodes as a unique tuple.

biject_parameters([bijectors, inference])

Transforms distribution parameters using bijectors with eager evaluation.

find_default_parameter_bijectors()

Extracts default parameter bijectors from the wrapped distribution.

init_dist()

Initializes the distribution.

update()

Updates the value of the node.

Attributes

at

Where to evaluate the distribution.

distribution

The wrapped distribution.

log_prob

The log-probability of the distribution.

per_obs

Whether the log-probability is stored per observation or summed up.

monitor

Whether the node should be monitored by an inference algorithm.