Source code for liesel.goose.nuts

"""
No U-Turn Sampler (NUTS).
"""

from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from functools import partial
from typing import ClassVar

import jax.numpy as jnp
from blackjax import nuts as nuts_kernel
from blackjax.adaptation.step_size import find_reasonable_step_size
from blackjax.mcmc import hmc, nuts
from jax.flatten_util import ravel_pytree

from .da import da_finalize, da_init, da_step
from .epoch import EpochState
from .kernel import (
    DefaultTransitionInfo,
    DefaultTuningInfo,
    ModelMixin,
    TransitionMixin,
    TransitionOutcome,
    TuningMixin,
    TuningOutcome,
    WarmupOutcome,
)
from .mm import tune_inv_mm_diag, tune_inv_mm_full
from .pytree import register_dataclass_as_pytree
from .types import Array, KeyArray, ModelState, Position


@register_dataclass_as_pytree
@dataclass
class NUTSKernelState:
    """
    A dataclass for the state of a :class:`.NUTSKernel`, implementing the
    :class:`.DAKernelState` protocol.
    """

    step_size: float
    inverse_mass_matrix: Array
    error_sum: float = field(init=False)
    log_avg_step_size: float = field(init=False)
    mu: float = field(init=False)

    def __post_init__(self):
        da_init(self)


@register_dataclass_as_pytree
@dataclass
class NUTSTransitionInfo(DefaultTransitionInfo):
    error_code: int
    """Dict of error codes and their meaning."""
    acceptance_prob: float
    position_moved: int
    divergent: bool
    """
    Whether the difference in energy between the original and the new state exceeded
    the divergence threshold of 1000.
    """

    turning: bool
    """Whether the expansion was stopped because the trajectory started turning."""

    treedepth: int
    """The tree depth, that is, the number of times the trajectory was expanded."""

    leapfrog: int
    """The number of computed leapfrog steps."""


def _error_code(*args: bool) -> int:
    return jnp.array(args) @ (2 ** jnp.arange(len(args)))


def _goose_info(nuts_info: nuts.NUTSInfo, max_treedepth: int) -> NUTSTransitionInfo:
    error_code = _error_code(
        nuts_info.is_divergent, nuts_info.num_trajectory_expansions == max_treedepth
    )

    return NUTSTransitionInfo(
        error_code=error_code,
        acceptance_prob=nuts_info.acceptance_rate,
        position_moved=99,
        divergent=nuts_info.is_divergent,
        turning=nuts_info.is_turning,
        treedepth=nuts_info.num_trajectory_expansions,
        leapfrog=nuts_info.num_integration_steps,
    )


NUTSTuningInfo = DefaultTuningInfo


[docs] class NUTSKernel( ModelMixin, TransitionMixin[NUTSKernelState, NUTSTransitionInfo], TuningMixin[NUTSKernelState, NUTSTuningInfo], ): """ A NUTS kernel with dual averaging and an inverse mass matrix tuner, implementing the :class:`.Kernel` protocol. """ error_book: ClassVar[dict[int, str]] = { 0: "no errors", 1: "divergent transition", 2: "maximum tree depth", 3: "divergent transition + maximum tree depth", } """Dict of error codes and their meaning.""" needs_history: ClassVar[bool] = True """Whether this kernel needs its history for tuning.""" identifier: str = "" """Kernel identifier, set by :class:`~.goose.EngineBuilder`""" position_keys: tuple[str, ...] """Tuple of position keys handled by this kernel.""" def __init__( self, position_keys: Sequence[str], initial_step_size: float | None = None, initial_inverse_mass_matrix: Array | None = None, max_treedepth: int = 10, da_target_accept: float = 0.8, da_gamma: float = 0.05, da_kappa: float = 0.75, da_t0: int = 10, mm_diag: bool = True, ): self.position_keys = tuple(position_keys) self._model = None self.initial_step_size = initial_step_size self.initial_inverse_mass_matrix = initial_inverse_mass_matrix self.max_treedepth = max_treedepth self.da_target_accept = da_target_accept self.da_gamma = da_gamma self.da_kappa = da_kappa self.da_t0 = da_t0 self.mm_diag = mm_diag def _blackjax_state(self, model_state: ModelState) -> hmc.HMCState: return nuts.init(self.position(model_state), self.log_prob_fn(model_state)) @property def _blackjax_kernel(self) -> Callable: return partial(nuts_kernel, max_num_doublings=self.max_treedepth)
[docs] def init_state(self, prng_key, model_state): """ Initializes the kernel state with an identity inverse mass matrix and a reasonable step size (unless explicit arguments were provided by the user). """ if self.initial_inverse_mass_matrix is None: flat_position, _ = ravel_pytree(self.position(model_state)) if self.mm_diag: inverse_mass_matrix = jnp.ones_like(flat_position) else: inverse_mass_matrix = jnp.eye(flat_position.size) else: inverse_mass_matrix = self.initial_inverse_mass_matrix if self.initial_step_size is None: blackjax_kernel = self._blackjax_kernel blackjax_state = self._blackjax_state(model_state) log_prob_fn = self.log_prob_fn(model_state) def kernel_generator(step_size: float) -> Callable: return blackjax_kernel( logdensity_fn=log_prob_fn, step_size=step_size, inverse_mass_matrix=inverse_mass_matrix, ).step step_size = find_reasonable_step_size( prng_key, kernel_generator, blackjax_state, initial_step_size=0.001, target_accept=self.da_target_accept, ) else: step_size = self.initial_step_size return NUTSKernelState(step_size, inverse_mass_matrix)
def _standard_transition( self, prng_key: KeyArray, kernel_state: NUTSKernelState, model_state: ModelState, epoch: EpochState, ) -> TransitionOutcome[NUTSKernelState, NUTSTransitionInfo]: """ Performs an MCMC transition *without* dual averaging. """ blackjax_state = self._blackjax_state(model_state) log_prob_fn = self.log_prob_fn(model_state) blackjax_kernel = self._blackjax_kernel( logdensity_fn=log_prob_fn, step_size=kernel_state.step_size, inverse_mass_matrix=kernel_state.inverse_mass_matrix, ) blackjax_state, blackjax_info = blackjax_kernel.step(prng_key, blackjax_state) info = _goose_info(blackjax_info, self.max_treedepth) model_state = self.model.update_state(blackjax_state.position, model_state) return TransitionOutcome(info, kernel_state, model_state) def _adaptive_transition( self, prng_key: KeyArray, kernel_state: NUTSKernelState, model_state: ModelState, epoch: EpochState, ) -> TransitionOutcome[NUTSKernelState, NUTSTransitionInfo]: """ Performs an MCMC transition *with* dual averaging. """ outcome = self._standard_transition(prng_key, kernel_state, model_state, epoch) da_step( outcome.kernel_state, outcome.info.acceptance_prob, epoch.time_in_epoch, self.da_target_accept, self.da_gamma, self.da_kappa, self.da_t0, ) return outcome def _tune_fast( self, prng_key: KeyArray, kernel_state: NUTSKernelState, model_state: ModelState, epoch: EpochState, history: Position | None = None, ) -> TuningOutcome[NUTSKernelState, NUTSTuningInfo]: """ Currently does nothing. """ info = NUTSTuningInfo(error_code=0, time=epoch.time) return TuningOutcome(info, kernel_state) def _tune_slow( self, prng_key: KeyArray, kernel_state: NUTSKernelState, model_state: ModelState, epoch: EpochState, history: Position | None = None, ) -> TuningOutcome[NUTSKernelState, NUTSTuningInfo]: """ Tunes the inverse mass vector or matrix using the samples from the last epoch. """ if history is not None: history = Position({k: history[k] for k in self.position_keys}) if self.mm_diag: new_inv_mm = tune_inv_mm_diag(history) trace_fn = jnp.sum # type: ignore else: new_inv_mm = tune_inv_mm_full(history) trace_fn = jnp.trace # type: ignore old_inv_mm = kernel_state.inverse_mass_matrix adjustment = jnp.sqrt(trace_fn(old_inv_mm) / trace_fn(new_inv_mm)) kernel_state.step_size = adjustment * kernel_state.step_size kernel_state.inverse_mass_matrix = new_inv_mm return self._tune_fast(prng_key, kernel_state, model_state, epoch, history)
[docs] def start_epoch( self, prng_key: KeyArray, kernel_state: NUTSKernelState, model_state: ModelState, epoch: EpochState, ) -> NUTSKernelState: """ Resets the state of the dual averaging algorithm. """ da_init(kernel_state) return kernel_state
[docs] def end_epoch( self, prng_key: KeyArray, kernel_state: NUTSKernelState, model_state: ModelState, epoch: EpochState, ) -> NUTSKernelState: """ Sets the step size as found by the dual averaging algorithm. """ da_finalize(kernel_state) return kernel_state
[docs] def end_warmup( self, prng_key: KeyArray, kernel_state: NUTSKernelState, model_state: ModelState, tuning_history: NUTSTuningInfo | None, ) -> WarmupOutcome[NUTSKernelState]: """ Currently does nothing. """ return WarmupOutcome(error_code=0, kernel_state=kernel_state)