Dist

Contents

Dist#

class liesel.model.nodes.Dist(distribution, *inputs, _name='', _needs_seed=False, **kwinputs)[source]#

Bases: Node

A Node subclass that wraps a probability distribution.

Distribution nodes wrap distribution classes that follow the tensorflow_probability Distribution interface. They can be used to represent observation models and priors.

Distribution nodes will appear in the node graph created by viz.plot_nodes(), but they will not appear in the model graph created by viz.plot_vars().

Parameters:
  • distribution (Callable[..., TypeAliasType]) – The wrapped distribution class that follows the tensorflow_probability Distribution interface.

  • *inputs (Any) – Non-keyword inputs. Any inputs that are not already nodes or Var will be converted to Value nodes. The values of these inputs will be passed to the wrapped distribution in the same order they are entered here.

  • _name (str) – The name of the node. If you do not specify a name, a unique name will be automatically generated upon initialization of a Model. (default: '')

  • _needs_seed (bool) – Whether the node needs a seed / PRNG key. (default: False)

  • **kwinputs (Any) – Keyword inputs. Any inputs that are not already nodes or Var`s will be converted to :class:.Value` nodes. The values of these inputs will be passed to the wrapped distribution as keyword arguments.

See also

Var

A variable in a statistical model, typically with a probability distribution.

param

A helper function to initialize a Var as a model parameter.

obs

A helper function to initialize a Var as an observed variable.

MultivariateNormalDegenerate

A custom distribution class that implements a degenerate multivariate normal distribution in the tensorflow_probability Distribution interface.

Examples

For the examples below, we import tensorflow_probability:

>>> import tensorflow_probability.substrates.jax.distributions as tfd

Creating an observation model for a normally distributed variable with fixed mean and scale. The log probability of the node y in the example below is None, until the variable is updated.

>>> dist = lsl.Dist(tfd.Normal, loc=0.0, scale=1.0)
>>> y = lsl.Var.new_obs(jnp.array([-0.5, 0.0, 0.5]), dist, name="y")
>>> print(y.log_prob)
None
>>> y.update()
Var(name="y")
>>> y.log_prob
Array([-1.0439385, -0.9189385, -1.0439385], dtype=float32)

Now we define the same observation model, but include the location and scale as parameters:

>>> loc = lsl.Var.new_param(0.0, name="loc")
>>> scale = lsl.Var.new_param(1.0, name="scale")
>>> dist = lsl.Dist(tfd.Normal, loc=loc, scale=scale)
>>> y = lsl.Var.new_obs(jnp.array([-0.5, 0.0, 0.5]), dist, name="y").update()
>>> y.log_prob
Array([-1.0439385, -0.9189385, -1.0439385], dtype=float32)

Summed-up log-probability

You can set the per_obs attribute of a distribution node to False to sum up the log-probability of the distribution over all observations.

>>> dist.per_obs = False
>>> y.update().log_prob
Array(-3.0068154, dtype=float32)

Methods

This section is empty if this class has only inherited attributes.

all_input_nodes()

Returns all non-keyword and keyword input nodes as a unique tuple.

init_dist()

Initializes the distribution.

update()

Updates the value of the node.

Attributes

This section is empty if this class has only inherited attributes.

at

Where to evaluate the distribution.

distribution

The wrapped distribution.

log_prob

The log-probability of the distribution.

per_obs

Whether the log-probability is stored per observation or summed up.

monitor

Whether the node should be monitored by an inference algorithm.