Source code for liesel.goose.mh_kernel

"""
Metroplis Hastings kernel. This kernel allows for a user-defined proposal functions and
adds the MH step. Optional, the kernel supports a stepsize adaptation.
"""

from collections.abc import Callable, Sequence
from typing import ClassVar, NamedTuple

import jax

from .da import da_finalize, da_init, da_step
from .epoch import EpochState
from .kernel import (
    DefaultTransitionInfo,
    DefaultTuningInfo,
    ModelMixin,
    ReprMixin,
    TransitionMixin,
    TransitionOutcome,
    TuningOutcome,
    WarmupOutcome,
)
from .mh import mh_step
from .rw import RWKernelState
from .types import KeyArray, ModelState, Position, TuningInfo


[docs] class MHProposal(NamedTuple): r""" Encapsulates a proposed state and the log-correction for a Metropolis-Hastings transition. Parameters ---------- position A dictionary mapping parameter names to their newly proposed values. log_correction The Metropolis-Hastings correction in the case of an asymmetric proposal distribution. Let :math:`q(x' | x)` be the density of the proposal ``x'`` given the current state ``x``, then the ``log_correction`` is defined as :math:`\log[q(x | x') / q(x' | x)]`. See Also -------- :class:`.MHKernel` """ position: Position log_correction: float """ Let :math:`q(x' | x)` be the proposal density, then :math:`log(q(x | x') / q(x' | x))` is the log_mh_correction. """
MHTransitionInfo = DefaultTransitionInfo MHTuningInfo = DefaultTuningInfo MHProposalFn = Callable[[KeyArray, ModelState, float], MHProposal]
[docs] class MHKernel(ModelMixin, TransitionMixin[RWKernelState, MHTransitionInfo], ReprMixin): """ A Metropolis-Hastings kernel implementing the :class:`.Kernel` protocol. Parameters ---------- position_keys Sequence of position keys (variable names) handled by this kernel. proposal_fn Custom proposal function that proposes a new state. Needs to be provided by the user. initial_step_size Value at which to start step size tuning. da_tune_step_size If ``True``, the step size passed as an argument to the proposal function is tuned using the dual averaging algorithm. Step size is tuned on the fly during all adaptive epochs. da_target_accept Target acceptance probability for dual averaging algorithm. da_gamma The adaptation regularization scale. da_kappa The adaptation relaxation exponent. da_t0 The adaptation iteration offset. identifier A string acting as a unique identifier for this kernel. Examples -------- To begin, we import ``tensorflow_probability``, ``jax`` and ``jax.numpy`` as follows: >>> import tensorflow_probability.substrates.jax.distributions as tfd >>> import jax >>> import jax.numpy as jnp Then, we set up a minimal model: >>> mu = lsl.Var.new_param(0.0, name="mu") >>> dist = lsl.Dist(tfd.Normal, loc=mu, scale=1.0) >>> y = lsl.Var.new_obs(jnp.array([1.0, 2.0, 3.0]), dist, name="y") >>> model = lsl.Model([y]) Now we initialize the EngineBuilder and set the desired number of warmup and posterior samples: >>> builder = gs.EngineBuilder(seed=1, num_chains=4) >>> builder.set_duration(warmup_duration=1000, posterior_duration=1000) Next, we set the model interface and initial values: >>> interface = gs.LieselInterface(model) >>> builder.set_model(interface) >>> builder.set_initial_values(model.state) We define a function to propose new values for the parameter ``"mu"``: >>> def mu_proposal(key, model_state, step_size): ... # extract relevant values from model state ... pos = interface.extract_position( ... position_keys=["mu"], model_state=model_state ... ) ... mu_current = pos["mu"] ... # draw epsilon ... epsilon = jax.random.uniform(key, minval=-0.5, maxval=0.5) ... mu_proposed = mu_current + epsilon ... pos = {"mu": mu_proposed} ... return gs.MHProposal(pos, log_correction=0.0) Note that in this case, the ``"log correction"`` is 0, as the uniform distribution used to generate proposals is symmetric. >>> builder.add_kernel(gs.MHKernel(["mu"], mu_proposal)) Finally, we build the engine: >>> engine = builder.build() From here, you can continue with :meth:`~.goose.Engine.sample_all_epochs` to draw samples from your posterior distribution. See Also -------- :class:`.MHProposal` """ error_book: ClassVar[dict[int, str]] = {0: "no errors", 90: "nan acceptance prob"} """Dict of error codes and their meaning.""" needs_history: ClassVar[bool] = False """Whether this kernel needs its history for tuning.""" identifier: str = "" """Kernel identifier, set by :class:`.EngineBuilder`""" position_keys: tuple[str, ...] """Tuple of position keys handled by this kernel.""" def __init__( self, position_keys: Sequence[str], proposal_fn: MHProposalFn, initial_step_size: float = 1.0, da_tune_step_size=False, da_target_accept: float = 0.234, da_gamma: float = 0.05, da_kappa: float = 0.75, da_t0: int = 10, identifier: str = "", ): self._model = None self.position_keys = tuple(position_keys) self._proposal_fn = proposal_fn self.initial_step_size = initial_step_size self.da_tune_step_size = da_tune_step_size self.da_target_accept = da_target_accept self.da_gamma = da_gamma self.da_kappa = da_kappa self.da_t0 = da_t0 self.identifier = identifier
[docs] def init_state(self, prng_key, model_state): """Initializes the kernel state.""" return RWKernelState(step_size=self.initial_step_size)
def _standard_transition( self, prng_key: KeyArray, kernel_state: RWKernelState, model_state: ModelState, epoch: EpochState, ) -> TransitionOutcome[RWKernelState, DefaultTransitionInfo]: """Performs an MCMC transition *without* dual averaging.""" key, subkey = jax.random.split(prng_key) step_size = kernel_state.step_size # generate a proposal proposal = self._proposal_fn(key, model_state, step_size) # metropolis-hastings calibration info, model_state = mh_step( subkey, self.model, proposal.position, model_state, proposal.log_correction, ) return TransitionOutcome(info, kernel_state, model_state) def _adaptive_transition( self, prng_key: KeyArray, kernel_state: RWKernelState, model_state: ModelState, epoch: EpochState, ) -> TransitionOutcome[RWKernelState, DefaultTransitionInfo]: """Performs an MCMC transition *with* dual averaging.""" outcome = self._standard_transition(prng_key, kernel_state, model_state, epoch) if self.da_tune_step_size: da_step( outcome.kernel_state, outcome.info.acceptance_prob, epoch.time_in_epoch, self.da_target_accept, self.da_gamma, self.da_kappa, self.da_t0, ) return outcome
[docs] def tune( self, prng_key: KeyArray, kernel_state: RWKernelState, model_state: ModelState, epoch: EpochState, history: Position | None = None, ) -> TuningOutcome[RWKernelState, DefaultTuningInfo]: """Currently does nothing.""" info = MHTuningInfo(error_code=0, time=epoch.time) return TuningOutcome(info, kernel_state)
[docs] def start_epoch( self, prng_key: KeyArray, kernel_state: RWKernelState, model_state: ModelState, epoch: EpochState, ) -> RWKernelState: """Resets the state of the dual averaging algorithm.""" da_init(kernel_state) return kernel_state
[docs] def end_epoch( self, prng_key: KeyArray, kernel_state: RWKernelState, model_state: ModelState, epoch: EpochState, ) -> RWKernelState: """ Sets the step size as found by the dual averaging algorithm. """ da_finalize(kernel_state) return kernel_state
[docs] def end_warmup( self, prng_key: KeyArray, kernel_state: RWKernelState, model_state: ModelState, tuning_history: TuningInfo | None, ) -> WarmupOutcome[RWKernelState]: """Currently does nothing.""" return WarmupOutcome(error_code=0, kernel_state=kernel_state)