Dist#
- class liesel.model.nodes.Dist(distribution, *inputs, _name='', _needs_seed=False, **kwinputs)[source]#
Bases:
Node
A
Node
subclass that wraps a probability distribution.Distribution nodes wrap distribution classes that follow the
tensorflow_probability
Distribution
interface. They can be used to represent observation models and priors.Distribution nodes will appear in the node graph created by
viz.plot_nodes()
, but they will not appear in the model graph created byviz.plot_vars()
.- Parameters:
distribution (
Callable
[...
,Union
[Distribution
,Distribution
]]) – The wrapped distribution class that follows thetensorflow_probability
Distribution
interface.*inputs (
Any
) – Non-keyword inputs. Any inputs that are not already nodes orVar
will be converted toValue
nodes. The values of these inputs will be passed to the wrapped distribution in the same order they are entered here._name (
str
) – The name of the node. If you do not specify a name, a unique name will be automatically generated upon initialization of aModel
. (default:''
)_needs_seed (
bool
) – Whether the node needs a seed / PRNG key. (default:False
)**kwinputs (
Any
) – Keyword inputs. Any inputs that are not already nodes orVar`s will be converted to :class:
.Value` nodes. The values of these inputs will be passed to the wrapped distribution as keyword arguments.
See also
Var
A variable in a statistical model, typically with a probability distribution.
param
A helper function to initialize a
Var
as a model parameter.obs
A helper function to initialize a
Var
as an observed variable.MultivariateNormalDegenerate
A custom distribution class that implements a degenerate multivariate normal distribution in the
tensorflow_probability
Distribution
interface.
Examples
For the examples below, we import
tensorflow_probability
:>>> import tensorflow_probability.substrates.jax.distributions as tfd
Creating an observation model for a normally distributed variable with fixed mean and scale. The log probability of the node
y
in the example below isNone
, until the variable is updated.>>> dist = lsl.Dist(tfd.Normal, loc=0.0, scale=1.0) >>> y = lsl.Var.new_obs(jnp.array([-0.5, 0.0, 0.5]), dist, name="y") >>> print(y.log_prob) None
>>> y.update() Var(name="y") >>> y.log_prob Array([-1.0439385, -0.9189385, -1.0439385], dtype=float32)
Now we define the same observation model, but include the location and scale as parameters:
>>> loc = lsl.Var.new_param(0.0, name="loc") >>> scale = lsl.Var.new_param(1.0, name="scale") >>> dist = lsl.Dist(tfd.Normal, loc=loc, scale=scale) >>> y = lsl.Var.new_obs(jnp.array([-0.5, 0.0, 0.5]), dist, name="y").update() >>> y.log_prob Array([-1.0439385, -0.9189385, -1.0439385], dtype=float32)
Summed-up log-probability
You can set the
per_obs
attribute of a distribution node toFalse
to sum up the log-probability of the distribution over all observations.>>> dist.per_obs = False >>> y.update().log_prob Array(-3.0068154, dtype=float32)
Methods
add_inputs
(*inputs, **kwinputs)Adds non-keyword and keyword input nodes to the existing ones.
Returns all non-keyword and keyword input nodes as a unique tuple.
Returns all output nodes as a unique tuple.
Clears the state of the node.
Flags the node and its recursive outputs as outdated.
Initializes the distribution.
set_inputs
(*inputs, **kwinputs)Sets the non-keyword and keyword input nodes.
update
()Updates the value of the node.
Attributes
Where to evaluate the distribution.
The wrapped distribution.
The groups that this node is a part of.
The non-keyword input nodes.
The keyword input nodes.
The log-probability of the distribution.
The model the node is part of.
The name of the node.
Whether the node needs a seed / PRNG key.
Whether the node is outdated.
The output nodes.
Whether the log-probability is stored per observation or summed up.
The state of the node.
The value of the node.
The variable the node is part of.
Whether the node should be monitored by an inference algorithm.