MHKernel#
- class liesel.goose.MHKernel(position_keys, proposal_fn, initial_step_size=1.0, da_tune_step_size=False, da_target_accept=0.234, da_gamma=0.05, da_kappa=0.75, da_t0=10, identifier='')[source]#
Bases:
ModelMixin,TransitionMixin[RWKernelState,DefaultTransitionInfo],ReprMixinA Metropolis-Hastings kernel implementing the
Kernelprotocol.- Parameters:
position_keys (
Sequence[str]) – Sequence of position keys (variable names) handled by this kernel.proposal_fn (
Callable[[Any,Any,float],MHProposal]) – Custom proposal function that proposes a new state. Needs to be provided by the user.initial_step_size (
float, default:1.0) – Value at which to start step size tuning.da_tune_step_size (default:
False) – IfTrue, the step size passed as an argument to the proposal function is tuned using the dual averaging algorithm. Step size is tuned on the fly during all adaptive epochs.da_target_accept (
float, default:0.234) – Target acceptance probability for dual averaging algorithm.da_gamma (
float, default:0.05) – The adaptation regularization scale.da_kappa (
float, default:0.75) – The adaptation relaxation exponent.da_t0 (
int, default:10) – The adaptation iteration offset.identifier (
str, default:'') – A string acting as a unique identifier for this kernel.
Examples
To begin, we import
tensorflow_probability,jaxandjax.numpyas follows:>>> import tensorflow_probability.substrates.jax.distributions as tfd >>> import jax >>> import jax.numpy as jnp
Then, we set up a minimal model:
>>> mu = lsl.Var.new_param(0.0, name="mu") >>> dist = lsl.Dist(tfd.Normal, loc=mu, scale=1.0) >>> y = lsl.Var.new_obs(jnp.array([1.0, 2.0, 3.0]), dist, name="y") >>> model = lsl.Model([y])
Now we initialize the EngineBuilder and set the desired number of warmup and posterior samples:
>>> builder = gs.EngineBuilder(seed=1, num_chains=4) >>> builder.set_duration(warmup_duration=1000, posterior_duration=1000)
Next, we set the model interface and initial values:
>>> interface = gs.LieselInterface(model) >>> builder.set_model(interface) >>> builder.set_initial_values(model.state)
We define a function to propose new values for the parameter
"mu":>>> def mu_proposal(key, model_state, step_size): ... # extract relevant values from model state ... pos = interface.extract_position( ... position_keys=["mu"], model_state=model_state ... ) ... mu_current = pos["mu"] ... # draw epsilon ... epsilon = jax.random.uniform(key, minval=-0.5, maxval=0.5) ... mu_proposed = mu_current + epsilon ... pos = {"mu": mu_proposed} ... return gs.MHProposal(pos, log_correction=0.0)
Note that in this case, the
"log correction"is 0, as the uniform distribution used to generate proposals is symmetric.>>> builder.add_kernel(gs.MHKernel(["mu"], mu_proposal))
Finally, we build the engine:
>>> engine = builder.build()
From here, you can continue with
sample_all_epochs()to draw samples from your posterior distribution.See also
Methods
end_epoch(prng_key, kernel_state, ...)Sets the step size as found by the dual averaging algorithm.
end_warmup(prng_key, kernel_state, ...)Currently does nothing.
init_state(prng_key, model_state)Initializes the kernel state.
start_epoch(prng_key, kernel_state, ...)Resets the state of the dual averaging algorithm.
tune(prng_key, kernel_state, model_state, epoch)Currently does nothing.
Attributes
Dict of error codes and their meaning.
Kernel identifier, set by
EngineBuilderWhether this kernel needs its history for tuning.
Tuple of position keys handled by this kernel.