Dist#
- class liesel.model.Dist(distribution, *inputs, _name='', _needs_seed=False, bijectors=None, convert_inputs='default', **kwinputs)[source]#
Bases:
NodeA
Nodesubclass that wraps a probability distribution.Distribution nodes wrap distribution classes that follow the
tensorflow_probabilityDistributioninterface. They can be used to represent observation models and priors.Distribution nodes will appear in the node graph created by
viz.plot_nodes(), but they will not appear in the model graph created byviz.plot_vars().- Parameters:
distribution (
Callable[...,TypeAliasType]) – The wrapped distribution class that follows thetensorflow_probabilityDistributioninterface.*inputs (
Any) – Non-keyword inputs. Any inputs that are not already nodes orVarwill be converted toValuenodes. The values of these inputs will be passed to the wrapped distribution in the same order they are entered here._name (
str, default:'') – The name of the node. If you do not specify a name, a unique name will be automatically generated upon initialization of aModel._needs_seed (
bool, default:False) – Whether the node needs a seed / PRNG key.bijectors (
None|Literal['auto'] |dict[str,TypeAliasType|Literal['auto'] |None] |Sequence[TypeAliasType|Literal['auto'] |None], default:None) – Optional parameter bijector specification for transforming distribution parameters. SeeDist.biject_parameters()for supported formats and behavior.convert_inputs (
Callable[[Any],Any] |Literal['default'], default:'default') – A function used to process the values of this node’s inputs. The default uses the function stored inconvert_value, which isjax.numpy.asarray.**kwinputs (
Any) – Keyword inputs. Any inputs that are not already nodes orVar`s will be converted to :class:.Value` nodes. The values of these inputs will be passed to the wrapped distribution as keyword arguments.
See also
VarA variable in a statistical model, typically with a probability distribution.
MultivariateNormalDegenerateA custom distribution class that implements a degenerate multivariate normal distribution in the
tensorflow_probabilityDistributioninterface.
Examples
For the examples below, we import
tensorflow_probability:>>> import tensorflow_probability.substrates.jax.distributions as tfd
Creating an observation model for a normally distributed variable with fixed mean and scale. The log probability of the node
yin the example below isNone, until the variable is updated.>>> dist = lsl.Dist(tfd.Normal, loc=0.0, scale=1.0) >>> y = lsl.Var.new_obs(jnp.array([-0.5, 0.0, 0.5]), dist, name="y") >>> print(y.log_prob) None
>>> y.update() Var(name="y") >>> y.log_prob Array([-1.0439385, -0.9189385, -1.0439385], dtype=float32)
Now we define the same observation model, but include the location and scale as parameters:
>>> loc = lsl.Var.new_param(0.0, name="loc") >>> scale = lsl.Var.new_param(1.0, name="scale") >>> dist = lsl.Dist(tfd.Normal, loc=loc, scale=scale) >>> y = lsl.Var.new_obs(jnp.array([-0.5, 0.0, 0.5]), dist, name="y").update() >>> y.log_prob Array([-1.0439385, -0.9189385, -1.0439385], dtype=float32)
Summed-up log-probability
You can set the
per_obsattribute of a distribution node toFalseto sum up the log-probability of the distribution over all observations.>>> dist.per_obs = False >>> y.update().log_prob Array(-3.0068154, dtype=float32)
Methods
Returns all non-keyword and keyword input nodes as a unique tuple.
biject_parameters([bijectors, inference])Transforms distribution parameters using bijectors with eager evaluation.
Extracts default parameter bijectors from the wrapped distribution.
Initializes the distribution.
update()Updates the value of the node.
Attributes
Where to evaluate the distribution.
The wrapped distribution.
The log-probability of the distribution.
Whether the log-probability is stored per observation or summed up.
Whether the node should be monitored by an inference algorithm.