GibbsKernel#
- class liesel.goose.GibbsKernel(position_keys, transition_fn, identifier='')[source]#
Bases:
ModelMixin,Kernel[Any,DefaultTransitionInfo,DefaultTuningInfo],ReprMixinA Gibbs kernel implementing the
Kernelprotocol.- Parameters:
position_keys (
Sequence[str]) – Sequence of position keys (variable names) handled by this kernel.transition_fn (
Callable[[Any,Any],Position(dict[str,Any])]) – Custom transition function that needs to be provided by the user.identifier (
str, default:'') – A string acting as a unique identifier for this kernel.
Examples
For this example, we import
tensorflow_probability,jaxandjax.numpyas follows:>>> import tensorflow_probability.substrates.jax.distributions as tfd >>> import jax >>> import jax.numpy as jnp
First, we set up a minimal model:
>>> mu = lsl.Var.new_param(0.0, name="mu") >>> dist = lsl.Dist(tfd.Normal, loc=mu, scale=1.0) >>> y = lsl.Var.new_obs(jnp.array([1.0, 2.0, 3.0]), dist, name="y") >>> model = lsl.Model([y])
Now we initialize the EngineBuilder and set the desired number of warmup and posterior samples:
>>> builder = gs.EngineBuilder(seed=1, num_chains=4) >>> builder.set_duration(warmup_duration=1000, posterior_duration=1000)
Next, we set the model interface and initial values:
>>> interface = gs.LieselInterface(model) >>> builder.set_model(interface) >>> builder.set_initial_values(model.state)
We define a function to sample from the full conditional for the parameter
"mu":>>> def sample_mu(prng_key, model_state): ... # extract relevant values from model state ... pos = interface.extract_position( ... position_keys=["y", "mu"], model_state=model_state ... ) ... # calculate relevant intermediate quantities ... n = len(pos["y"]) ... y_mean = pos["y"].mean() ... mu_new = (n * y_mean + pos["mu"]) / (n + 1) ... # draw new value from full conditional ... draw = mu_new + jax.random.normal(prng_key) ... # return key-value pair of variable name and new value ... return {"mu": draw}
>>> builder.add_kernel(gs.GibbsKernel(["mu"], sample_mu))
Finally, we build the engine:
>>> engine = builder.build()
From here, you can continue with
sample_all_epochs()to draw samples from your posterior distribution.See also
Methods
end_epoch(prng_key, kernel_state, ...)Currently does nothing.
end_warmup(prng_key, kernel_state, ...)Currently does nothing.
init_state(prng_key, model_state)Initializes an (empty) kernel state.
start_epoch(prng_key, kernel_state, ...)Currently does nothing.
transition(prng_key, kernel_state, ...)Performs an MCMC transition.
tune(prng_key, kernel_state, model_state, epoch)Currently does nothing.
Attributes
Maps error codes to error messages.
An identifier for the kernel object that is set by the EngineBuilder if it is an empty string.
Is set to true if the kernel expects the history for tuning.
Keys for which the kernel handles the transition.