DictInterface#
- class liesel.goose.DictInterface(log_prob_fn)[source]#
Bases:
object
A model interface for a model state represented by a
dict[str, Array]
and a corresponding log-probability function.- Parameters:
log_prob_fn (
Callable
[[Any
],float
]) – A function that takes a model state and returns the log-probability. The model state is expected to be adict[str, Array]
.
See also
DataclassInterface
A model interface for a model state represented by a
dataclass
and a corresponding log-probability function.LieselInterface
A model interface for a Liesel
Model
.
Examples
For this example, we import
tensorflow_probability
as follows:>>> import tensorflow_probability.substrates.jax.distributions as tfd
Now we define a very simple log_prob_fn for the sake of demonstration:
>>> def log_prob_fn(model_state): ... loc = model_state["loc"] ... scale = model_state["scale"] ... x = model_state["x"] ... return tfd.Normal(loc, scale).log_prob(x)
We initialize the interface by passing the log_prob_fn to the constructor:
>>> interface = gs.DictInterface(log_prob_fn)
We evaluate the log-probability of a model state by calling the log_prob method:
>>> state = {"x": jnp.array(0.0), "loc": jnp.array(0.0), "scale": jnp.array(1.0)} >>> interface.log_prob(state) Array(-0.9189385, dtype=float32)
We update the model state by passing a position to the update_state method:
>>> position = {"x": jnp.array(1.0)} >>> state = interface.update_state(position, state)
We can now evaluate the log-probability of the updated model state:
>>> interface.log_prob(state) Array(-1.4189385, dtype=float32)
Methods
extract_position
(position_keys, model_state)Extracts a position from a model state.
log_prob
(model_state)Returns the log-probability from a model state.
update_state
(position, model_state)Updates and returns a model state given a position.