Source code for liesel.goose.hmc

"""
Hamiltonian/Hybrid Monte Carlo (HMC).
"""

from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from typing import ClassVar

import jax.numpy as jnp
from blackjax import hmc as hmc_kernel
from blackjax.adaptation.step_size import find_reasonable_step_size
from blackjax.mcmc import hmc
from jax.flatten_util import ravel_pytree

from .da import da_finalize, da_init, da_step
from .epoch import EpochState
from .kernel import (
    DefaultTransitionInfo,
    DefaultTuningInfo,
    ModelMixin,
    TransitionMixin,
    TransitionOutcome,
    TuningMixin,
    TuningOutcome,
    WarmupOutcome,
)
from .mm import tune_inv_mm_diag, tune_inv_mm_full
from .pytree import register_dataclass_as_pytree
from .types import Array, KeyArray, ModelState, Position


@register_dataclass_as_pytree
@dataclass
class HMCKernelState:
    """
    A dataclass for the state of a :class:`.HMCKernel`, implementing the
    :class:`.DAKernelState` protocol.
    """

    step_size: float
    inverse_mass_matrix: Array
    error_sum: float = field(init=False)
    log_avg_step_size: float = field(init=False)
    mu: float = field(init=False)

    def __post_init__(self):
        da_init(self)


@register_dataclass_as_pytree
@dataclass
class HMCTransitionInfo(DefaultTransitionInfo):
    error_code: int
    acceptance_prob: float
    position_moved: int
    divergent: bool
    """
    Whether the difference in energy between the original and the new state exceeded
    the divergence threshold of 1000.
    """


def _goose_info(hmc_info: hmc.HMCInfo) -> HMCTransitionInfo:
    error_code = 1 * hmc_info.is_divergent
    acceptance_prob = hmc_info.acceptance_rate
    position_moved = hmc_info.is_accepted

    return HMCTransitionInfo(
        error_code,
        acceptance_prob,
        position_moved,
        hmc_info.is_divergent,
    )


HMCTuningInfo = DefaultTuningInfo


[docs] class HMCKernel( ModelMixin, TransitionMixin[HMCKernelState, HMCTransitionInfo], TuningMixin[HMCKernelState, HMCTuningInfo], ): """ A HMC kernel with dual averaging and an inverse mass matrix tuner, implementing the :class:`.Kernel` protocol. """ error_book: ClassVar[dict[int, str]] = {0: "no errors", 1: "divergent transition"} """Dict of error codes and their meaning.""" needs_history: ClassVar[bool] = True """Whether this kernel needs its history for tuning.""" identifier: str = "" """Kernel identifier, set by :class:`~.goose.EngineBuilder`""" position_keys: tuple[str, ...] """Tuple of position keys handled by this kernel.""" def __init__( self, position_keys: Sequence[str], initial_step_size: float | None = None, initial_inverse_mass_matrix: Array | None = None, num_integration_steps: int = 10, da_target_accept: float = 0.8, da_gamma: float = 0.05, da_kappa: float = 0.75, da_t0: int = 10, mm_diag: bool = True, ): self.position_keys = tuple(position_keys) self._model = None self.initial_step_size = initial_step_size self.initial_inverse_mass_matrix = initial_inverse_mass_matrix self.num_integration_steps = num_integration_steps self.da_target_accept = da_target_accept self.da_gamma = da_gamma self.da_kappa = da_kappa self.da_t0 = da_t0 self.mm_diag = mm_diag def _blackjax_state(self, model_state: ModelState) -> hmc.HMCState: return hmc.init(self.position(model_state), self.log_prob_fn(model_state)) @property def _blackjax_kernel(self) -> Callable: return hmc_kernel
[docs] def init_state(self, prng_key, model_state): """ Initializes the kernel state with an identity inverse mass matrix and a reasonable step size (unless explicit arguments were provided by the user). """ if self.initial_inverse_mass_matrix is None: flat_position, _ = ravel_pytree(self.position(model_state)) if self.mm_diag: inverse_mass_matrix = jnp.ones_like(flat_position) else: inverse_mass_matrix = jnp.eye(flat_position.size) else: inverse_mass_matrix = self.initial_inverse_mass_matrix if self.initial_step_size is None: blackjax_kernel = self._blackjax_kernel blackjax_state = self._blackjax_state(model_state) log_prob_fn = self.log_prob_fn(model_state) def kernel_generator(step_size: float) -> Callable: return blackjax_kernel( logdensity_fn=log_prob_fn, step_size=step_size, inverse_mass_matrix=inverse_mass_matrix, num_integration_steps=self.num_integration_steps, ).step step_size = find_reasonable_step_size( prng_key, kernel_generator, blackjax_state, initial_step_size=0.001, target_accept=self.da_target_accept, ) else: step_size = self.initial_step_size return HMCKernelState(step_size, inverse_mass_matrix)
def _standard_transition( self, prng_key: KeyArray, kernel_state: HMCKernelState, model_state: ModelState, epoch: EpochState, ) -> TransitionOutcome[HMCKernelState, HMCTransitionInfo]: """ Performs an MCMC transition *without* dual averaging. """ blackjax_state = self._blackjax_state(model_state) log_prob_fn = self.log_prob_fn(model_state) blackjax_kernel = self._blackjax_kernel( logdensity_fn=log_prob_fn, step_size=kernel_state.step_size, inverse_mass_matrix=kernel_state.inverse_mass_matrix, num_integration_steps=self.num_integration_steps, ) blackjax_state, blackjax_info = blackjax_kernel.step(prng_key, blackjax_state) info = _goose_info(blackjax_info) model_state = self.model.update_state(blackjax_state.position, model_state) return TransitionOutcome(info, kernel_state, model_state) def _adaptive_transition( self, prng_key: KeyArray, kernel_state: HMCKernelState, model_state: ModelState, epoch: EpochState, ) -> TransitionOutcome[HMCKernelState, HMCTransitionInfo]: """ Performs an MCMC transition *with* dual averaging. """ outcome = self._standard_transition(prng_key, kernel_state, model_state, epoch) da_step( outcome.kernel_state, outcome.info.acceptance_prob, epoch.time_in_epoch, self.da_target_accept, self.da_gamma, self.da_kappa, self.da_t0, ) return outcome def _tune_fast( self, prng_key: KeyArray, kernel_state: HMCKernelState, model_state: ModelState, epoch: EpochState, history: Position | None = None, ) -> TuningOutcome[HMCKernelState, HMCTuningInfo]: """ Currently does nothing. """ info = HMCTuningInfo(error_code=0, time=epoch.time) return TuningOutcome(info, kernel_state) def _tune_slow( self, prng_key: KeyArray, kernel_state: HMCKernelState, model_state: ModelState, epoch: EpochState, history: Position | None = None, ) -> TuningOutcome[HMCKernelState, HMCTuningInfo]: """ Tunes the inverse mass vector or matrix using the samples from the last epoch. """ if history is not None: history = Position({k: history[k] for k in self.position_keys}) if self.mm_diag: new_inv_mm = tune_inv_mm_diag(history) trace_fn = jnp.sum # type: ignore else: new_inv_mm = tune_inv_mm_full(history) trace_fn = jnp.trace # type: ignore old_inv_mm = kernel_state.inverse_mass_matrix adjustment = jnp.sqrt(trace_fn(old_inv_mm) / trace_fn(new_inv_mm)) kernel_state.step_size = adjustment * kernel_state.step_size kernel_state.inverse_mass_matrix = new_inv_mm return self._tune_fast(prng_key, kernel_state, model_state, epoch, history)
[docs] def start_epoch( self, prng_key: KeyArray, kernel_state: HMCKernelState, model_state: ModelState, epoch: EpochState, ) -> HMCKernelState: """ Resets the state of the dual averaging algorithm. """ da_init(kernel_state) return kernel_state
[docs] def end_epoch( self, prng_key: KeyArray, kernel_state: HMCKernelState, model_state: ModelState, epoch: EpochState, ) -> HMCKernelState: """ Sets the step size as found by the dual averaging algorithm. """ da_finalize(kernel_state) return kernel_state
[docs] def end_warmup( self, prng_key: KeyArray, kernel_state: HMCKernelState, model_state: ModelState, tuning_history: HMCTuningInfo | None, ) -> WarmupOutcome[HMCKernelState]: """ Currently does nothing. """ return WarmupOutcome(error_code=0, kernel_state=kernel_state)