Defining a custom MCMC kernel#

Custom Metropolis-Hastings kernel#

The easiest way to use a custom MCMC kernel in liesel.goose is to provide a proposal function for a MHKernel. The function must accept a pseudo-random number key, a model state and a step size as arguments, and be compatible with just-in-time compilation via jax (i.e., pure, without side-effects). It returns a MHProposal, which simply wraps the proposed value and the Metropolis-Hastings log-correction factor. The MHKernel handles the acceptance/rejection logic and is fully equipped with dual averaging functionality for step size tuning, which can be switched on by passing da_tune_step_size as a keyword argument to the kernel. In this case, users should ensure that their settings for the initial step size (default: \(1\)) and the target acceptance probability (default: \(.234\)) are suitable.

As an example, a random walk kernel (like RWKernel) can be implemented with

>>> param_name = ... # name of the parameter variable to be sampled
>>> def rw_proposal(prng_key, model_state, step_size):
...     pos = model.extract_position([param_name], model_state)
...     current = pos[param_name]
...
...     proposal_dist = tfd.Normal(loc=current, scale=step_size)
...     proposed = proposal_dist.sample(seed=prng_key)
...
...     backward_dist = tfd.Normal(loc=proposed, scale=step_size)
...     backward_log_prob = backward_dist.log_prob(current)
...     forward_log_prob = proposal_dist.log_prob(proposed)
...     log_correction = (backward_log_prob - forward_log_prob).sum()
...     return gs.MHProposal({param_name: proposed}, log_correction)

It can then be attached to the coefficient variable with

>>> model.vars[param_name].coef.inference = gs.MCMCSpec(
...     gs.MHKernel,
...     kernel_kwargs={"proposal_fn": rw_proposal, "da_tune_step_size": True},
... )

In this case, the proposal distribution is symmetric, so the log correction factor is zero by definition. We still compute it here explicitly for the purpose of demonstration.

While a custom proposal function for a MHKernel can be written conveniently, it may not cover cases in which a custom MCMC kernel requires additional hyperparameters or specialized tuning. For such cases, liesel.goose provides tools for users to write their own classes, implementing the Kernel protocol.

The next section shows you how to write such a fully custom kernel class.

Fully customized MCMC kernel#

Any Python class that implements the Kernel protocol can be used as an MCMC kernel class in liesel.goose. The protocol requires the implementation of several attributes and methods, the most important of which are Kernel.transition() and Kernel.tune(). These methods are called by the engine and need to be pure and jittable.

Overview#

The transition method. The purpose of the transition method is to move the subset of the model state handled by the kernel using a valid MCMC step, e.g.~a Metropolis-Hastings algorithm. Its signature is:

>>> class Kernel:
...
...     def transition(
...         self,
...         prng_key: KeyArray,
...         kernel_state: KernelState,
...         model_state: ModelState,
...         epoch: EpochState,
...     ) -> TransitionOutcome[KernelState, TransitionInfo]:
...         ...

Since the Kernel.transition() method must be pure, and MCMC transitions generally involve the generation of random numbers, a key for pseudo-random number generation (PRNG) needs to be provided as an argument. In addition, the Kernel.transition() method receives the kernel state, the model state and the epoch state as arguments, and returns a TransitionOutcome object, which wraps the new kernel state, the new model state and some meta-information about the transition, e.g.~an error code or the acceptance probability (in a TransitionInfo object). An error code of zero indicates that the transition did not produce an error.

All inputs and outputs must be valid pytrees (i.e.~arrays or nested lists, tuples or dicts of arrays). The structure of these objects, e.g.~the shape of the arrays in the kernel state, must not change between transitions. This allows the kernels to have specialized KernelState and TransitionInfo classes. A kernel state can be any pytree.

The tune method. The Kernel.tune() method is updates the kernel hyperparameters at the end of an adaptation epoch. The method receives the PRNG key, the model state, the kernel state, the epoch state, and (optionally) the history, i.e.~the samples from the previous epoch, as arguments. It returns a TuningOutcome object that wraps the new kernel state and some meta-information about the tuning process, e.g.~an error code. As for the transition, the TuningInfo class can be kernel-specific but must be a valid pytree.

The signature of the Kernel.tune() method is as follows:

>>> class Kernel:
...
...     def tune(
...         self,
...         prng_key: KeyArray,
...         kernel_state: KernelState,
...         model_state: ModelState,
...         epoch: EpochState,
...         history: Position | None,
...     ) -> TuningOutcome[KernelState, TuningInfo]:
...         ...

Step-by-step tutorial#

We will now go through the definition of the RWKernel step-by-step.

The kernel state#

First, we define the KernelState. Since we plan to use dual averaging for step size tuning in this kernel class, we define a kernel state that follows the DAKernelState protocol.

from dataclasses import dataclass, field  # general dataclass functionalty
from liesel.goose.pytree import (
    register_dataclass_as_pytree,  # dataclasses must be registered as pytrees with jax
)
from liesel.goose import da  # dual averaging functionality


@register_dataclass_as_pytree
@dataclass
class RWKernelState:
    """
    A dataclass for the state of a ``RWKernel``, implementing the
    :class:`.DAKernelState` protocol.
    """

    step_size: float
    error_sum: float = field(default=0.0, init=False)
    log_avg_step_size: float = field(default=0.0, init=False)
    mu: float = field(init=False)

    def __post_init__(self):
        da.da_init(self)

The kernel class#

We now define the actual kernel class. The class inherits from two mixins provided by liesel.goose.

The ModelMixin gives the kernel access to the model and provides convenience methods such as ModelMixin.position(), which extracts the part of the model state handled by this kernel.

The TransitionMixin provides the public TransitionMixin.transition() method. Internally, it dispatches to _standard_transition or _adaptive_transition, depending on the current epoch. This means that we only have to implement these two methods.

import jax
import liesel.goose as gs


class RWKernel(
    gs.ModelMixin, gs.TransitionMixin[RWKernelState, gs.DefaultTransitionInfo]
):
    error_book = {0: "no errors", 90: "nan acceptance prob"}
    """Dict of error codes and their meaning."""

    needs_history = False
    """Whether this kernel needs its history for tuning."""

    identifier: str = ""
    """Kernel identifier, set by :class:`~.goose.EngineBuilder`"""

    position_keys: tuple[str, ...]
    """Tuple of position keys handled by this kernel."""

At the beginning of the class, we define a few class attributes required by the kernel protocol.

The error_book maps error codes to human-readable messages. By convention, an error code of zero means that no error occurred.

The needs_history attribute tells the engine whether the kernel requires the samples from the previous epoch for tuning. This random walk kernel does not use the history, so we set it to False.

The identifier is set by the EngineBuilder and can be used to distinguish between kernels. Finally, position_keys stores the names of the model variables handled by this kernel.

The constructor stores the user-supplied settings. The most important argument is position_keys, which determines which model variables are updated by this kernel.

The remaining arguments configure the initial step size and the dual averaging algorithm. These values are stored on the kernel object, but they are not part of the kernel state. The mutable, chain-specific part of the kernel is stored separately in the RWKernelState.

    def __init__(
        self,
        position_keys: list[str] | tuple[str, ...],
        initial_step_size: float = 1.0,
        da_target_accept: float = 0.234,
        da_gamma: float = 0.05,
        da_kappa: float = 0.75,
        da_t0: int = 10,
        identifier: str = "",
    ):
        self._model = None
        self.position_keys = tuple(position_keys)
        self.initial_step_size = initial_step_size
        self.da_target_accept = da_target_accept
        self.da_gamma = da_gamma
        self.da_kappa = da_kappa
        self.da_t0 = da_t0
        self.identifier = identifier

Before sampling starts, the engine calls init_state. This method creates the initial kernel state for one chain. In our case, the only user-facing state variable is the current step size.

    def init_state(self, prng_key, model_state: gs.ModelState) -> RWKernelState:
        """
        Initializes the kernel state.
        """
        return RWKernelState(step_size=self.initial_step_size)

Next, we implement the non-adaptive transition. This method performs one ordinary Metropolis-Hastings random walk step.

First, we split the pseudo-random number key. One key is used to generate the proposal, and the other key is used inside the Metropolis-Hastings accept/reject step.

    def _standard_transition(
        self,
        prng_key,
        kernel_state: RWKernelState,
        model_state: gs.ModelState,
        epoch: gs.EpochState,
    ) -> gs.TransitionOutcome[RWKernelState, gs.DefaultTransitionInfo]:
        """
        Performs an MCMC transition *without* dual averaging.
        """

        key, subkey = jax.random.split(prng_key)
        step_size = kernel_state.step_size
        ...

The current position is extracted from the model state. Since the position can be a pytree, we flatten it into a single vector before adding Gaussian noise. After the proposal has been generated, we transform it back into the original pytree structure.

This lets the same implementation work for scalar, vector-valued, or structured model positions.

def _standard_transition(...):
        # ... (continued)
        # random walk proposal
        position = self.position(model_state)
        flat_position, unravel_fn = jax.flatten_util.ravel_pytree(position)
        step = step_size * jax.random.normal(key, flat_position.shape)
        flat_proposal = flat_position + step
        proposal = unravel_fn(flat_proposal)

Finally, we pass the proposal to mh_step(). This function evaluates the proposed model state and performs the Metropolis-Hastings accept/reject step.

The result is returned as a TransitionOutcome, which contains the transition information, the kernel state, and the updated model state.

def _standard_transition(...):
        # ... (continued)
        # metropolis-hastings calibration
        info, model_state = gs.mh_step(subkey, self.model, proposal, model_state)
        return gs.TransitionOutcome(info, kernel_state, model_state)

The adaptive transition starts by performing the same Metropolis-Hastings step as above. It then updates the dual averaging state using the observed acceptance probability from the transition.

The dual averaging update modifies the kernel state in place. It uses the current acceptance probability, the time within the current epoch, and the dual averaging hyperparameters stored on the kernel object.

    def _adaptive_transition(
        self,
        prng_key,
        kernel_state: RWKernelState,
        model_state: gs.ModelState,
        epoch: gs.EpochState,
    ) -> gs.TransitionOutcome[RWKernelState, gs.DefaultTransitionInfo]:
        """
        Performs an MCMC transition *with* dual averaging.
        """

        outcome = self._standard_transition(prng_key, kernel_state, model_state, epoch)

        da.da_step(
            outcome.kernel_state,
            outcome.info.acceptance_prob,
            epoch.time_in_epoch,
            self.da_target_accept,
            self.da_gamma,
            self.da_kappa,
            self.da_t0,
        )

        return outcome

The tune method is called by the engine at the end of a tuning epoch. This particular kernel does not perform any additional tuning at the end of an epoch, because the step size adaptation already happens during the adaptive transitions.

Still, the method must be implemented to satisfy the kernel protocol. We therefore return a successful TuningOutcome with the unchanged kernel state.

    def tune(
        self,
        prng_key,
        kernel_state: RWKernelState,
        model_state: gs.ModelState,
        epoch: gs.EpochState,
        history: gs.Position | None = None,
    ) -> gs.TuningOutcome[RWKernelState, gs.DefaultTuningInfo]:
        """
        Currently does nothing.
        """

        info = gs.DefaultTuningInfo(error_code=0, time=epoch.time)
        return gs.TuningOutcome(info, kernel_state)

At the beginning of each adaptation epoch, we reset the dual averaging state. This is done in start_epoch.

This reset does not discard the current step size itself. Instead, it reinitializes the auxiliary quantities used internally by the dual averaging algorithm.

    def start_epoch(
        self,
        prng_key,
        kernel_state: RWKernelState,
        model_state: gs.ModelState,
        epoch: gs.EpochState,
    ) -> RWKernelState:
        """
        Resets the state of the dual averaging algorithm.
        """

        da.da_init(kernel_state)
        return kernel_state

At the end of an adaptation epoch, we finalize the dual averaging update. This replaces the current step size by the averaged step size found during the epoch.

    def end_epoch(
        self,
        prng_key,
        kernel_state: RWKernelState,
        model_state: gs.ModelState,
        epoch: gs.EpochState,
    ) -> RWKernelState:
        """
        Sets the step size as found by the dual averaging algorithm.
        """

        da.da_finalize(kernel_state)
        return kernel_state

Finally, the engine calls end_warmup after all warmup epochs have finished. This hook can be used for final warmup-specific adjustments. Our random walk kernel does not need any such adjustment, so we simply return the unchanged kernel state.

    def end_warmup(
        self,
        prng_key,
        kernel_state: RWKernelState,
        model_state: gs.ModelState,
        tuning_history: gs.TuningInfo | None,
    ) -> gs.WarmupOutcome[RWKernelState]:
        """
        Currently does nothing.
        """

        return gs.WarmupOutcome(error_code=0, kernel_state=kernel_state)

This completes the kernel class. The main logic is contained in _standard_transition, which constructs a random walk proposal and delegates the Metropolis-Hastings correction to mh_step(). The adaptive version adds one more step: it updates the step size using dual averaging based on the observed acceptance probability.

We now simply restate the full code-block here:

import jax
import liesel.goose as gs


class RWKernel(
    gs.ModelMixin, gs.TransitionMixin[RWKernelState, gs.DefaultTransitionInfo]
):
    error_book = {0: "no errors", 90: "nan acceptance prob"}
    """Dict of error codes and their meaning."""

    needs_history = False
    """Whether this kernel needs its history for tuning."""

    identifier: str = ""
    """Kernel identifier, set by :class:`~.goose.EngineBuilder`"""

    position_keys: tuple[str, ...]
    """Tuple of position keys handled by this kernel."""

    def __init__(
        self,
        position_keys: list[str] | tuple[str, ...],
        initial_step_size: float = 1.0,
        da_target_accept: float = 0.234,
        da_gamma: float = 0.05,
        da_kappa: float = 0.75,
        da_t0: int = 10,
        identifier: str = "",
    ):
        self._model = None
        self.position_keys = tuple(position_keys)
        self.initial_step_size = initial_step_size
        self.da_target_accept = da_target_accept
        self.da_gamma = da_gamma
        self.da_kappa = da_kappa
        self.da_t0 = da_t0
        self.identifier = identifier

    def init_state(self, prng_key, model_state: gs.ModelState) -> RWKernelState:
        """
        Initializes the kernel state.
        """
        return RWKernelState(step_size=self.initial_step_size)

    def _standard_transition(
        self,
        prng_key,
        kernel_state: RWKernelState,
        model_state: gs.ModelState,
        epoch: gs.EpochState,
    ) -> gs.TransitionOutcome[RWKernelState, gs.DefaultTransitionInfo]:
        """
        Performs an MCMC transition *without* dual averaging.
        """

        key, subkey = jax.random.split(prng_key)
        step_size = kernel_state.step_size

        # random walk proposal
        position = self.position(model_state)
        flat_position, unravel_fn = jax.flatten_util.ravel_pytree(position)
        step = step_size * jax.random.normal(key, flat_position.shape)
        flat_proposal = flat_position + step
        proposal = unravel_fn(flat_proposal)

        # metropolis-hastings calibration
        info, model_state = gs.mh_step(subkey, self.model, proposal, model_state)
        return gs.TransitionOutcome(info, kernel_state, model_state)

    def _adaptive_transition(
        self,
        prng_key,
        kernel_state: RWKernelState,
        model_state: gs.ModelState,
        epoch: gs.EpochState,
    ) -> gs.TransitionOutcome[RWKernelState, gs.DefaultTransitionInfo]:
        """
        Performs an MCMC transition *with* dual averaging.
        """

        outcome = self._standard_transition(prng_key, kernel_state, model_state, epoch)

        da.da_step(
            outcome.kernel_state,
            outcome.info.acceptance_prob,
            epoch.time_in_epoch,
            self.da_target_accept,
            self.da_gamma,
            self.da_kappa,
            self.da_t0,
        )

        return outcome

    def tune(
        self,
        prng_key,
        kernel_state: RWKernelState,
        model_state: gs.ModelState,
        epoch: gs.EpochState,
        history: gs.Position | None = None,
    ) -> gs.TuningOutcome[RWKernelState, gs.DefaultTuningInfo]:
        """
        Currently does nothing.
        """

        info = gs.DefaultTuningInfo(error_code=0, time=epoch.time)
        return gs.TuningOutcome(info, kernel_state)

    def start_epoch(
        self,
        prng_key,
        kernel_state: RWKernelState,
        model_state: gs.ModelState,
        epoch: gs.EpochState,
    ) -> RWKernelState:
        """
        Resets the state of the dual averaging algorithm.
        """

        da.da_init(kernel_state)
        return kernel_state

    def end_epoch(
        self,
        prng_key,
        kernel_state: RWKernelState,
        model_state: gs.ModelState,
        epoch: gs.EpochState,
    ) -> RWKernelState:
        """
        Sets the step size as found by the dual averaging algorithm.
        """

        da.da_finalize(kernel_state)
        return kernel_state

    def end_warmup(
        self,
        prng_key,
        kernel_state: RWKernelState,
        model_state: gs.ModelState,
        tuning_history: gs.TuningInfo | None,
    ) -> gs.WarmupOutcome[RWKernelState]:
        """
        Currently does nothing.
        """

        return gs.WarmupOutcome(error_code=0, kernel_state=kernel_state)

Trying out our new kernel#

Here, we just take a very simple model to confirm that our kernel runs.

import liesel.model as lsl
import tensorflow_probability.substrates.jax.distributions as tfd

mu = lsl.Var.new_param(0.0, name="mu", inference=gs.MCMCSpec(RWKernel))
y = lsl.Var.new_obs(
    value=jax.random.normal(jax.random.key(13), (100,)) + 0.5,
    dist=lsl.Dist(tfd.Normal, loc=mu, scale=1.0),
    name="y",
)
model = lsl.Model(y)

results = gs.LieselMCMC(model).run_for_epochs(
    seed=7, num_chains=4, adaptation=500, posterior=500
)
  0%|                                                  | 0/2 [00:00<?, ?chunk/s]
 50%|#####################                     | 1/2 [00:00<00:00,  2.38chunk/s]
100%|##########################################| 2/2 [00:00<00:00,  4.75chunk/s]

  0%|                                                  | 0/1 [00:00<?, ?chunk/s]
100%|########################################| 1/1 [00:00<00:00, 2258.65chunk/s]

  0%|                                                  | 0/2 [00:00<?, ?chunk/s]
100%|########################################| 2/2 [00:00<00:00, 4213.26chunk/s]

  0%|                                                 | 0/11 [00:00<?, ?chunk/s]
100%|#######################################| 11/11 [00:00<00:00, 863.01chunk/s]

  0%|                                                  | 0/4 [00:00<?, ?chunk/s]
100%|########################################| 4/4 [00:00<00:00, 5215.17chunk/s]

  0%|                                                 | 0/20 [00:00<?, ?chunk/s]
100%|######################################| 20/20 [00:00<00:00, 3343.27chunk/s]
gs.Summary(results)

Parameter summary:

kernel mean sd q_0.05 q_0.5 q_0.95 sample_size ess_bulk ess_tail rhat
parameter index
mu () kernel_00 0.461 0.095 0.298 0.456 0.634 2000 319.798 212.696 1.016

Acceptance probabilities:

acceptance_probability position_moved
kernel positions phase
kernel_00 mu posterior 0.204 0.204
warmup 0.222 0.224