FlatLogProb

Contents

FlatLogProb#

class liesel.model.FlatLogProb(model, position_keys, component='log_prob', diff_mode='forward')[source]#

Bases: object

Interface for evaluating the unnormalized log probability of a Liesel model.

Also provides access to the first and second derivatives. The methods FlatLogProb.grad() and FlatLogProb.hessian() are flattened, which means the expect arrays as inputs and return arrays.

Parameters:
  • model (Model) – A Liesel model instance.

  • position_keys (Sequence[str]) – Names of the variables at which to evaluate the log probability. Other variables will be kept fixed at their current values in the model state.

  • component (Literal['log_prob', 'log_lik', 'log_prior'], default: 'log_prob') – Which component of the model’s log probability to evaluate.

  • diff_mode (Literal['forward', 'reverse'], default: 'forward') – Which auto-diff mode to use for the Hessian.

See also

LogProb

A similar class that returns gradients and hessians as dictionaries.

Examples

We initialize a very basic Liesel model:

>>> x = lsl.Var.new_param(
...     jnp.zeros(2),
...     dist=lsl.Dist(tfd.Normal, loc=0.0, scale=1.0),
...     name="x",
... )
>>> model = lsl.Model([x])

Now we initialize the log prob object:

>>> lp = lsl.FlatLogProb(model, ["x"])

And an array of new values to evaluate the log probability at:

>>> xnew = jnp.array([1.0, 2.0])
>>> lp(xnew)
Array(-4.3378773, dtype=float32)
>>> lp.grad(xnew)
Array([-1., -2.], dtype=float32)
>>> lp.hessian(xnew)
Array([[-1., -0.],
       [-0., -1.]], dtype=float32)

Methods

grad(flat_position)

Gradient of the log probability function with respect to the flat_position.

hessian(flat_position)

Hessian of the log probability function with respect to the flat_position.

log_prob(flat_position)

Log probability function evaluated at provided flat_position.