Source code for liesel.goose.nuts

"""
No U-Turn Sampler (NUTS).
"""

from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from functools import partial
from typing import ClassVar

import jax.numpy as jnp
from blackjax import nuts as nuts_kernel
from blackjax.adaptation.step_size import find_reasonable_step_size
from blackjax.mcmc import hmc, nuts
from jax.flatten_util import ravel_pytree

from .da import da_finalize, da_init, da_step
from .epoch import EpochState
from .kernel import (
    DefaultTransitionInfo,
    DefaultTuningInfo,
    ModelMixin,
    ReprMixin,
    TransitionMixin,
    TransitionOutcome,
    TuningMixin,
    TuningOutcome,
    WarmupOutcome,
)
from .mm import tune_inv_mm_diag, tune_inv_mm_full
from .pytree import register_dataclass_as_pytree
from .types import Array, KeyArray, ModelState, Position


@register_dataclass_as_pytree
@dataclass
class NUTSKernelState:
    """
    A dataclass for the state of a :class:`.NUTSKernel`, implementing the
    :class:`.DAKernelState` protocol.
    """

    step_size: float
    inverse_mass_matrix: Array
    error_sum: float = field(init=False)
    log_avg_step_size: float = field(init=False)
    mu: float = field(init=False)

    def __post_init__(self):
        da_init(self)


@register_dataclass_as_pytree
@dataclass
class NUTSTransitionInfo(DefaultTransitionInfo):
    error_code: int
    """Dict of error codes and their meaning."""
    acceptance_prob: float
    position_moved: int
    divergent: bool
    """
    Whether the difference in energy between the original and the new state exceeded
    the divergence threshold of 1000.
    """

    turning: bool
    """Whether the expansion was stopped because the trajectory started turning."""

    treedepth: int
    """The tree depth, that is, the number of times the trajectory was expanded."""

    leapfrog: int
    """The number of computed leapfrog steps."""


def _error_code(*args: bool) -> int:
    return jnp.array(args) @ (2 ** jnp.arange(len(args)))


def _goose_info(nuts_info: nuts.NUTSInfo, max_treedepth: int) -> NUTSTransitionInfo:
    error_code = _error_code(
        nuts_info.is_divergent, nuts_info.num_trajectory_expansions == max_treedepth
    )

    return NUTSTransitionInfo(
        error_code=error_code,
        acceptance_prob=nuts_info.acceptance_rate,
        position_moved=99,
        divergent=nuts_info.is_divergent,
        turning=nuts_info.is_turning,
        treedepth=nuts_info.num_trajectory_expansions,
        leapfrog=nuts_info.num_integration_steps,
    )


NUTSTuningInfo = DefaultTuningInfo


[docs] class NUTSKernel( ModelMixin, TransitionMixin[NUTSKernelState, NUTSTransitionInfo], TuningMixin[NUTSKernelState, NUTSTuningInfo], ReprMixin, ): """ A NUTS kernel with dual averaging and an inverse mass matrix tuner, implementing the :class:`.Kernel` protocol. Parameters ---------- position_keys Sequence of position keys (variable names) handled by this kernel. initial_step_size Value at which to start step size tuning. initial_inverse_mass_matrix Starting value for the inverse mass matrix (the precision matrix of the momentum). If ``None``, an identity matrix will be used here. max_treedepth The maximum number of times that the length of the trajectory is doubled before returning if no U-turn has been obserbed or no divergence has occured. See the Stan reference manual [#stan]_ for more details. da_target_accept Target acceptance probability for dual averaging algorithm. da_gamma The adaptation regularization scale. da_kappa The adaptation relaxation exponent. da_t0 The adaptation iteration offset. mm_diag Whether to use a diagonal mass matrix for drawing the momentum vector. If True, the inverse mass matrix will be tuned during adaptation using :func:`.tune_inv_mm_diag`. If set to False, the mass matrix will be tuned using :func:`.tune_inv_mm_full` instead. identifier A string acting as a unique identifier for this kernel. Notes ----- For more information on step size tuning via dual averaging, see :func:`.da_step` and :class:`.DAKernelState`. .. [#stan] `Stan Development Team, Stan Reference Manual (2021), Chapter 15.2 <https://mc-stan.org/docs/2_28/reference-manual/hmc-algorithm-parameters.html>`_. """ error_book: ClassVar[dict[int, str]] = { 0: "no errors", 1: "divergent transition", 2: "maximum tree depth", 3: "divergent transition + maximum tree depth", } """Dict of error codes and their meaning.""" needs_history: ClassVar[bool] = True """Whether this kernel needs its history for tuning.""" identifier: str = "" """Kernel identifier, set by :class:`~.goose.EngineBuilder`""" position_keys: tuple[str, ...] """Tuple of position keys handled by this kernel.""" def __init__( self, position_keys: Sequence[str], initial_step_size: float | None = None, initial_inverse_mass_matrix: Array | None = None, max_treedepth: int = 10, da_target_accept: float = 0.8, da_gamma: float = 0.05, da_kappa: float = 0.75, da_t0: int = 10, mm_diag: bool = True, identifier: str = "", ): self.position_keys = tuple(position_keys) self._model = None self.initial_step_size = initial_step_size self.initial_inverse_mass_matrix = initial_inverse_mass_matrix self.max_treedepth = max_treedepth self.da_target_accept = da_target_accept self.da_gamma = da_gamma self.da_kappa = da_kappa self.da_t0 = da_t0 self.mm_diag = mm_diag self.identifier = identifier def _blackjax_state(self, model_state: ModelState) -> hmc.HMCState: return nuts.init(self.position(model_state), self.log_prob_fn(model_state)) @property def _blackjax_kernel(self) -> Callable: return partial(nuts_kernel, max_num_doublings=self.max_treedepth)
[docs] def init_state(self, prng_key, model_state): """ Initializes the kernel state with an identity inverse mass matrix and a reasonable step size (unless explicit arguments were provided by the user). """ if self.initial_inverse_mass_matrix is None: flat_position, _ = ravel_pytree(self.position(model_state)) if self.mm_diag: inverse_mass_matrix = jnp.ones_like(flat_position) else: inverse_mass_matrix = jnp.eye(flat_position.size) else: inverse_mass_matrix = self.initial_inverse_mass_matrix if self.initial_step_size is None: blackjax_kernel = self._blackjax_kernel blackjax_state = self._blackjax_state(model_state) log_prob_fn = self.log_prob_fn(model_state) def kernel_generator(step_size: float) -> Callable: return blackjax_kernel( logdensity_fn=log_prob_fn, step_size=step_size, inverse_mass_matrix=inverse_mass_matrix, ).step step_size = find_reasonable_step_size( prng_key, kernel_generator, blackjax_state, initial_step_size=0.001, target_accept=self.da_target_accept, ) else: step_size = self.initial_step_size return NUTSKernelState(step_size, inverse_mass_matrix)
def _standard_transition( self, prng_key: KeyArray, kernel_state: NUTSKernelState, model_state: ModelState, epoch: EpochState, ) -> TransitionOutcome[NUTSKernelState, NUTSTransitionInfo]: """ Performs an MCMC transition *without* dual averaging. """ blackjax_state = self._blackjax_state(model_state) log_prob_fn = self.log_prob_fn(model_state) blackjax_kernel = self._blackjax_kernel( logdensity_fn=log_prob_fn, step_size=kernel_state.step_size, inverse_mass_matrix=kernel_state.inverse_mass_matrix, ) blackjax_state, blackjax_info = blackjax_kernel.step(prng_key, blackjax_state) info = _goose_info(blackjax_info, self.max_treedepth) model_state = self.model.update_state(blackjax_state.position, model_state) return TransitionOutcome(info, kernel_state, model_state) def _adaptive_transition( self, prng_key: KeyArray, kernel_state: NUTSKernelState, model_state: ModelState, epoch: EpochState, ) -> TransitionOutcome[NUTSKernelState, NUTSTransitionInfo]: """ Performs an MCMC transition *with* dual averaging. """ outcome = self._standard_transition(prng_key, kernel_state, model_state, epoch) da_step( outcome.kernel_state, outcome.info.acceptance_prob, epoch.time_in_epoch, self.da_target_accept, self.da_gamma, self.da_kappa, self.da_t0, ) return outcome def _tune_fast( self, prng_key: KeyArray, kernel_state: NUTSKernelState, model_state: ModelState, epoch: EpochState, history: Position | None = None, ) -> TuningOutcome[NUTSKernelState, NUTSTuningInfo]: """ Currently does nothing. """ info = NUTSTuningInfo(error_code=0, time=epoch.time) return TuningOutcome(info, kernel_state) def _tune_slow( self, prng_key: KeyArray, kernel_state: NUTSKernelState, model_state: ModelState, epoch: EpochState, history: Position | None = None, ) -> TuningOutcome[NUTSKernelState, NUTSTuningInfo]: """ Tunes the inverse mass vector or matrix using the samples from the last epoch. """ if history is not None: history = Position({k: history[k] for k in self.position_keys}) if self.mm_diag: new_inv_mm = tune_inv_mm_diag(history) trace_fn = jnp.sum # type: ignore else: new_inv_mm = tune_inv_mm_full(history) trace_fn = jnp.trace # type: ignore old_inv_mm = kernel_state.inverse_mass_matrix adjustment = jnp.sqrt(trace_fn(old_inv_mm) / trace_fn(new_inv_mm)) kernel_state.step_size = adjustment * kernel_state.step_size kernel_state.inverse_mass_matrix = new_inv_mm return self._tune_fast(prng_key, kernel_state, model_state, epoch, history)
[docs] def start_epoch( self, prng_key: KeyArray, kernel_state: NUTSKernelState, model_state: ModelState, epoch: EpochState, ) -> NUTSKernelState: """ Resets the state of the dual averaging algorithm. """ da_init(kernel_state) return kernel_state
[docs] def end_epoch( self, prng_key: KeyArray, kernel_state: NUTSKernelState, model_state: ModelState, epoch: EpochState, ) -> NUTSKernelState: """ Sets the step size as found by the dual averaging algorithm. """ da_finalize(kernel_state) return kernel_state
[docs] def end_warmup( self, prng_key: KeyArray, kernel_state: NUTSKernelState, model_state: ModelState, tuning_history: NUTSTuningInfo | None, ) -> WarmupOutcome[NUTSKernelState]: """ Currently does nothing. """ return WarmupOutcome(error_code=0, kernel_state=kernel_state)