NodeDistribution
NodeDistribution#
- class liesel.liesel.nodes.NodeDistribution(distribution, bijector=None, **inputs)[source]#
Bases:
liesel.liesel.nodes.NodeComponentA probability distribution of a node.
Computes the log-probability of a node from its inputs. Implemented as a thin wrapper around a TFP distribution.
- Parameters
distribution (
Union[str,Any]) –The name of a TFP distribution as a string, or alternatively, a user-defined TFP-compatible distribution class.
If a class is provided instead of a string, the user needs to make sure it uses the right NumPy implementation.
bijector (
Union[str,Any,None]) –The name of a TFP bijector as a string, or alternatively, a user-defined (default:
None) TFP-compatible bijector class.If a class is provided instead of a string, the user needs to make sure it uses the right NumPy implementation. Defaults to None.
inputs (
Node) – The inputs of the distribution. The keywords must match the arguments of the TFP distribution.
Methods
cdf(value, **kwargs)The cumulative distribution function of the distribution.
The TFP distribution initialized with the values of the inputs.
jaxify()Enables JAX NumPy for the node component.
log_prob(value, **kwargs)The log-probability (density) function of the distribution.
mean(**kwargs)The mean function of the distribution.
sample([sample_shape, seed])The sampling function of the distribution.
transform(bijector)Transforms the distribution with a TFP bijector.
unjaxify()Disables JAX NumPy for the node component.
Attributes
The inputs of the node component.
Whether JAX NumPy is enabled for the node component.