NodeDistribution
NodeDistribution#
- class liesel.liesel.nodes.NodeDistribution(distribution, bijector=None, **inputs)[source]#
Bases:
liesel.liesel.nodes.NodeComponent
A probability distribution of a node.
Computes the log-probability of a node from its inputs. Implemented as a thin wrapper around a TFP distribution.
- Parameters
distribution (
Union
[str
,Any
]) –The name of a TFP distribution as a string, or alternatively, a user-defined TFP-compatible distribution class.
If a class is provided instead of a string, the user needs to make sure it uses the right NumPy implementation.
bijector (
Union
[str
,Any
,None
]) –The name of a TFP bijector as a string, or alternatively, a user-defined (default:
None
) TFP-compatible bijector class.If a class is provided instead of a string, the user needs to make sure it uses the right NumPy implementation. Defaults to None.
inputs (
Node
) – The inputs of the distribution. The keywords must match the arguments of the TFP distribution.
Methods
cdf
(value, **kwargs)The cumulative distribution function of the distribution.
The TFP distribution initialized with the values of the inputs.
jaxify
()Enables JAX NumPy for the node component.
log_prob
(value, **kwargs)The log-probability (density) function of the distribution.
mean
(**kwargs)The mean function of the distribution.
sample
([sample_shape, seed])The sampling function of the distribution.
transform
(bijector)Transforms the distribution with a TFP bijector.
unjaxify
()Disables JAX NumPy for the node component.
Attributes
The inputs of the node component.
Whether JAX NumPy is enabled for the node component.