NodeDistribution#

class liesel.liesel.nodes.NodeDistribution(distribution, bijector=None, **inputs)[source]#

Bases: liesel.liesel.nodes.NodeComponent

A probability distribution of a node.

Computes the log-probability of a node from its inputs. Implemented as a thin wrapper around a TFP distribution.

Parameters
  • distribution (Union[str, Any]) –

    The name of a TFP distribution as a string, or alternatively, a user-defined TFP-compatible distribution class.

    If a class is provided instead of a string, the user needs to make sure it uses the right NumPy implementation.

  • bijector (Union[str, Any, None]) –

    The name of a TFP bijector as a string, or alternatively, a user-defined (default: None) TFP-compatible bijector class.

    If a class is provided instead of a string, the user needs to make sure it uses the right NumPy implementation. Defaults to None.

  • inputs (Node) – The inputs of the distribution. The keywords must match the arguments of the TFP distribution.

Methods

cdf(value, **kwargs)

The cumulative distribution function of the distribution.

distribution()

The TFP distribution initialized with the values of the inputs.

jaxify()

Enables JAX NumPy for the node component.

log_prob(value, **kwargs)

The log-probability (density) function of the distribution.

mean(**kwargs)

The mean function of the distribution.

sample([sample_shape, seed])

The sampling function of the distribution.

transform(bijector)

Transforms the distribution with a TFP bijector.

unjaxify()

Disables JAX NumPy for the node component.

Attributes

inputs

The inputs of the node component.

jaxified

Whether JAX NumPy is enabled for the node component.