GibbsKernel

Contents

GibbsKernel#

class liesel.goose.GibbsKernel(position_keys, transition_fn, identifier='')[source]#

Bases: ModelMixin, Kernel[Any, DefaultTransitionInfo, DefaultTuningInfo], ReprMixin

A Gibbs kernel implementing the Kernel protocol.

Parameters:
  • position_keys (Sequence[str]) – Sequence of position keys (variable names) handled by this kernel.

  • transition_fn (Callable[[Any, Any], Position (dict[str, Any])]) – Custom transition function that needs to be provided by the user.

  • identifier (str, default: '') – A string acting as a unique identifier for this kernel.

Examples

For this example, we import tensorflow_probability, jax and jax.numpy as follows:

>>> import tensorflow_probability.substrates.jax.distributions as tfd
>>> import jax
>>> import jax.numpy as jnp

First, we set up a minimal model:

>>> mu = lsl.Var.new_param(0.0, name="mu")
>>> dist = lsl.Dist(tfd.Normal, loc=mu, scale=1.0)
>>> y = lsl.Var.new_obs(jnp.array([1.0, 2.0, 3.0]), dist, name="y")
>>> model = lsl.Model([y])

Now we initialize the EngineBuilder and set the desired number of warmup and posterior samples:

>>> builder = gs.EngineBuilder(seed=1, num_chains=4)
>>> builder.set_duration(warmup_duration=1000, posterior_duration=1000)

Next, we set the model interface and initial values:

>>> interface = gs.LieselInterface(model)
>>> builder.set_model(interface)
>>> builder.set_initial_values(model.state)

We define a function to sample from the full conditional for the parameter "mu":

>>> def sample_mu(prng_key, model_state):
...     # extract relevant values from model state
...     pos = interface.extract_position(
...         position_keys=["y", "mu"], model_state=model_state
...     )
...     # calculate relevant intermediate quantities
...     n = len(pos["y"])
...     y_mean = pos["y"].mean()
...     mu_new = (n * y_mean + pos["mu"]) / (n + 1)
...     # draw new value from full conditional
...     draw = mu_new + jax.random.normal(prng_key)
...     # return key-value pair of variable name and new value
...     return {"mu": draw}
>>> builder.add_kernel(gs.GibbsKernel(["mu"], sample_mu))

Finally, we build the engine:

>>> engine = builder.build()

From here, you can continue with sample_all_epochs() to draw samples from your posterior distribution.

See also

Gibbs Sampling

Methods

end_epoch(prng_key, kernel_state, ...)

Currently does nothing.

end_warmup(prng_key, kernel_state, ...)

Currently does nothing.

init_state(prng_key, model_state)

Initializes an (empty) kernel state.

start_epoch(prng_key, kernel_state, ...)

Currently does nothing.

transition(prng_key, kernel_state, ...)

Performs an MCMC transition.

tune(prng_key, kernel_state, model_state, epoch)

Currently does nothing.

Attributes

error_book

Maps error codes to error messages.

identifier

An identifier for the kernel object that is set by the EngineBuilder if it is an empty string.

needs_history

Is set to true if the kernel expects the history for tuning.

position_keys

Keys for which the kernel handles the transition.