Dist

Contents

Dist#

class liesel.model.nodes.Dist(distribution, *inputs, _name='', _needs_seed=False, **kwinputs)[source]#

Bases: Node

A Node subclass that wraps a probability distribution.

Distribution nodes wrap distribution classes that follow the tensorflow_probability Distribution interface. They can be used to represent observation models and priors.

Distribution nodes will appear in the node graph created by viz.plot_nodes(), but they will not appear in the model graph created by viz.plot_vars().

Parameters:
  • distribution (Callable[..., Union[Distribution, Distribution]]) – The wrapped distribution class that follows the tensorflow_probability Distribution interface.

  • *inputs (Any) – Non-keyword inputs. Any inputs that are not already nodes or Var will be converted to Data nodes. The values of these inputs will be passed to the wrapped distribution in the same order they are entered here.

  • _name (str) – The name of the node. If you do not specify a name, a unique name will be automatically generated upon initialization of a Model. (default: '')

  • _needs_seed (bool) – Whether the node needs a seed / PRNG key. (default: False)

  • **kwinputs (Any) – Keyword inputs. Any inputs that are not already nodes or Var`s will be converted to :class:.Data` nodes. The values of these inputs will be passed to the wrapped distribution as keyword arguments.

See also

Var

A variable in a statistical model, typically with a probability distribution.

param

A helper function to initialize a Var as a model parameter.

obs

A helper function to initialize a Var as an observed variable.

MultivariateNormalDegenerate

A custom distribution class that implements a degenerate multivariate normal distribution in the tensorflow_probability Distribution interface.

Examples

For the examples below, we import tensorflow_probability:

>>> import tensorflow_probability.substrates.jax.distributions as tfd

Creating an observation model for a normally distributed variable with fixed mean and scale. The log probability of the node y in the example below is None, until the variable is updated.

>>> dist = lsl.Dist(tfd.Normal, loc=0.0, scale=1.0)
>>> y = lsl.obs(jnp.array([-0.5, 0.0, 0.5]), dist, name="y")
>>> print(y.log_prob)
None
>>> y.update()
Var(name="y")
>>> y.log_prob
Array([-1.0439385, -0.9189385, -1.0439385], dtype=float32)

Now we define the same observation model, but include the location and scale as parameters:

>>> loc = lsl.param(0.0, name="loc")
>>> scale = lsl.param(1.0, name="scale")
>>> dist = lsl.Dist(tfd.Normal, loc=loc, scale=scale)
>>> y = lsl.obs(jnp.array([-0.5, 0.0, 0.5]), dist, name="y").update()
>>> y.log_prob
Array([-1.0439385, -0.9189385, -1.0439385], dtype=float32)

Summed-up log-probability

You can set the per_obs attribute of a distribution node to False to sum up the log-probability of the distribution over all observations.

>>> dist.per_obs = False
>>> y.update().log_prob
Array(-3.0068154, dtype=float32)

Methods

add_inputs(*inputs, **kwinputs)

Adds non-keyword and keyword input nodes to the existing ones.

all_input_nodes()

Returns all non-keyword and keyword input nodes as a unique tuple.

all_output_nodes()

Returns all output nodes as a unique tuple.

clear_state()

Clears the state of the node.

flag_outdated()

Flags the node and its recursive outputs as outdated.

init_dist()

Initializes the distribution.

set_inputs(*inputs, **kwinputs)

Sets the non-keyword and keyword input nodes.

update()

Updates the value of the node.

Attributes

at

Where to evaluate the distribution.

distribution

The wrapped distribution.

groups

The groups that this node is a part of.

inputs

The non-keyword input nodes.

kwinputs

The keyword input nodes.

log_prob

The log-probability of the distribution.

model

The model the node is part of.

name

The name of the node.

needs_seed

Whether the node needs a seed / PRNG key.

outdated

Whether the node is outdated.

outputs

The output nodes.

per_obs

Whether the log-probability is stored per observation or summed up.

state

The state of the node.

value

The value of the node.

var

The variable the node is part of.

monitor

Whether the node should be monitored by an inference algorithm.