NamedTupleInterface#
- class liesel.goose.NamedTupleInterface(log_prob_fn)[source]#
Bases:
object
A model interface for a model state represented by a
NamedTuple
and a corresponding log-probability function.- Parameters:
log_prob_fn (
Callable
[[Any
],float
]) – A function that takes a model state and returns the log-probability. The model state is expected to be aNamedTuple
.
See also
DictInterface
A model interface for a model state represented by a
dict[str, Array]
and a corresponding log-probability function.DataclassInterface
A model interface for a model state represented by a
dataclass
and a corresponding log-probability function.LieselInterface
A model interface for a Liesel
Model
.
Examples
For this example, we import
tensorflow_probability
as follows:>>> import tensorflow_probability.substrates.jax.distributions as tfd
We define a subclass of NamedTuple representing the model state:
>>> from typing import NamedTuple ... >>> class State(NamedTuple): ... x: jnp.ndarray ... loc: jnp.ndarray ... scale: jnp.ndarray
Now we define a very simple log_prob_fn for the sake of demonstration:
>>> def log_prob_fn(model_state): ... loc = model_state.loc ... scale = model_state.scale ... x = model_state.x ... return tfd.Normal(loc, scale).log_prob(x)
We initialize the interface by passing the log_prob_fn to the constructor:
>>> interface = gs.NamedTupleInterface(log_prob_fn)
We evaluate the log-probability of a model state by calling the log_prob method:
>>> state = State(x=jnp.array(0.0), loc=jnp.array(0.0), scale=jnp.array(1.0)) >>> interface.log_prob(state) Array(-0.9189385, dtype=float32)
We update the model state by passing a position to the update_state method:
>>> position = {"x": jnp.array(1.0)} >>> state = interface.update_state(position, state)
We can now evaluate the log-probability of the updated model state:
>>> interface.log_prob(state) Array(-1.4189385, dtype=float32)
Methods
extract_position
(position_keys, model_state)Extracts a position from a model state.
log_prob
(model_state)Returns the log-probability from a model state.
update_state
(position, model_state)Updates and returns a model state given a position.