DataclassInterface

DataclassInterface#

class liesel.goose.DataclassInterface(log_prob_fn)[source]#

Bases: object

A model interface for a model state represented by a dataclass and a corresponding log-probability function.

Parameters:

log_prob_fn (Callable[[Any], float]) – A function that takes a model state and returns the log-probability. The model state is expected to be a dataclass.

See also

DictInterface

A model interface for a model state represented by a dict[str, Array] and a corresponding log-probability function.

LieselInterface

A model interface for a Liesel Model.

Examples

For this example, we import tensorflow_probability as follows:

>>> import tensorflow_probability.substrates.jax.distributions as tfd

We define a dataclass representing the model state:

>>> from dataclasses import dataclass
...
>>> @dataclass
... class State:
...     x: jnp.ndarray
...     loc: jnp.ndarray
...     scale: jnp.ndarray

Now we define a very simple log_prob_fn for the sake of demonstration:

>>> def log_prob_fn(model_state):
...     loc = model_state.loc
...     scale = model_state.scale
...     x = model_state.x
...     return tfd.Normal(loc, scale).log_prob(x)

We initialize the interface by passing the log_prob_fn to the constructor:

>>> interface = gs.DataclassInterface(log_prob_fn)

We evaluate the log-probability of a model state by calling the log_prob method:

>>> state = State(x=jnp.array(0.0), loc=jnp.array(0.0), scale=jnp.array(1.0))
>>> interface.log_prob(state)
Array(-0.9189385, dtype=float32)

We update the model state by passing a position to the update_state method:

>>> position = {"x": jnp.array(1.0)}
>>> state = interface.update_state(position, state)

We can now evaluate the log-probability of the updated model state:

>>> interface.log_prob(state)
Array(-1.4189385, dtype=float32)

Methods

extract_position(position_keys, model_state)

Extracts a position from a model state.

log_prob(model_state)

Returns the log-probability from a model state.

update_state(position, model_state)

Updates and returns a model state given a position.