"""
No U-Turn Sampler (NUTS).
"""
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from functools import partial
from typing import ClassVar
import jax.numpy as jnp
from blackjax import nuts as nuts_kernel
from blackjax.adaptation.step_size import find_reasonable_step_size
from blackjax.mcmc import hmc, nuts
from jax.flatten_util import ravel_pytree
from .da import da_finalize, da_init, da_step
from .epoch import EpochState
from .kernel import (
DefaultTransitionInfo,
DefaultTuningInfo,
ModelMixin,
ReprMixin,
TransitionMixin,
TransitionOutcome,
TuningMixin,
TuningOutcome,
WarmupOutcome,
)
from .mm import tune_inv_mm_diag, tune_inv_mm_full
from .pytree import register_dataclass_as_pytree
from .types import Array, KeyArray, ModelState, Position
@register_dataclass_as_pytree
@dataclass
class NUTSKernelState:
"""
A dataclass for the state of a :class:`.NUTSKernel`, implementing the
:class:`.DAKernelState` protocol.
"""
step_size: float
inverse_mass_matrix: Array
error_sum: float = field(init=False)
log_avg_step_size: float = field(init=False)
mu: float = field(init=False)
def __post_init__(self):
da_init(self)
@register_dataclass_as_pytree
@dataclass
class NUTSTransitionInfo(DefaultTransitionInfo):
error_code: int
"""Dict of error codes and their meaning."""
acceptance_prob: float
position_moved: int
divergent: bool
"""
Whether the difference in energy between the original and the new state exceeded
the divergence threshold of 1000.
"""
turning: bool
"""Whether the expansion was stopped because the trajectory started turning."""
treedepth: int
"""The tree depth, that is, the number of times the trajectory was expanded."""
leapfrog: int
"""The number of computed leapfrog steps."""
def _error_code(*args: bool) -> int:
return jnp.array(args) @ (2 ** jnp.arange(len(args)))
def _goose_info(nuts_info: nuts.NUTSInfo, max_treedepth: int) -> NUTSTransitionInfo:
error_code = _error_code(
nuts_info.is_divergent, nuts_info.num_trajectory_expansions == max_treedepth
)
return NUTSTransitionInfo(
error_code=error_code,
acceptance_prob=nuts_info.acceptance_rate,
position_moved=99,
divergent=nuts_info.is_divergent,
turning=nuts_info.is_turning,
treedepth=nuts_info.num_trajectory_expansions,
leapfrog=nuts_info.num_integration_steps,
)
NUTSTuningInfo = DefaultTuningInfo
[docs]
class NUTSKernel(
ModelMixin,
TransitionMixin[NUTSKernelState, NUTSTransitionInfo],
TuningMixin[NUTSKernelState, NUTSTuningInfo],
ReprMixin,
):
"""
A NUTS kernel with dual averaging and an inverse mass matrix tuner, implementing the
:class:`.Kernel` protocol.
Parameters
----------
position_keys
Sequence of position keys (variable names) handled by this kernel.
initial_step_size
Value at which to start step size tuning.
initial_inverse_mass_matrix
Starting value for the inverse mass matrix (the precision matrix of the
momentum). If ``None``, an identity matrix will be used here.
max_treedepth
The maximum number of times that the length of the trajectory is doubled before
returning if no U-turn has been obserbed or no divergence has occured.
See the Stan reference manual [#stan]_ for more details.
da_target_accept
Target acceptance probability for dual averaging algorithm.
da_gamma
The adaptation regularization scale.
da_kappa
The adaptation relaxation exponent.
da_t0
The adaptation iteration offset.
mm_diag
Whether to use a diagonal mass matrix for drawing the momentum vector.
If True, the inverse mass matrix will be tuned during adaptation using
:func:`.tune_inv_mm_diag`. If set to False, the mass matrix will be tuned
using :func:`.tune_inv_mm_full` instead.
identifier
A string acting as a unique identifier for this kernel.
Notes
-----
For more information on step size tuning via dual averaging,
see :func:`.da_step` and :class:`.DAKernelState`.
.. [#stan] `Stan Development Team, Stan Reference Manual (2021), Chapter 15.2
<https://mc-stan.org/docs/2_28/reference-manual/hmc-algorithm-parameters.html>`_.
"""
error_book: ClassVar[dict[int, str]] = {
0: "no errors",
1: "divergent transition",
2: "maximum tree depth",
3: "divergent transition + maximum tree depth",
}
"""Dict of error codes and their meaning."""
needs_history: ClassVar[bool] = True
"""Whether this kernel needs its history for tuning."""
identifier: str = ""
"""Kernel identifier, set by :class:`~.goose.EngineBuilder`"""
position_keys: tuple[str, ...]
"""Tuple of position keys handled by this kernel."""
def __init__(
self,
position_keys: Sequence[str],
initial_step_size: float | None = None,
initial_inverse_mass_matrix: Array | None = None,
max_treedepth: int = 10,
da_target_accept: float = 0.8,
da_gamma: float = 0.05,
da_kappa: float = 0.75,
da_t0: int = 10,
mm_diag: bool = True,
identifier: str = "",
):
self.position_keys = tuple(position_keys)
self._model = None
self.initial_step_size = initial_step_size
self.initial_inverse_mass_matrix = initial_inverse_mass_matrix
self.max_treedepth = max_treedepth
self.da_target_accept = da_target_accept
self.da_gamma = da_gamma
self.da_kappa = da_kappa
self.da_t0 = da_t0
self.mm_diag = mm_diag
self.identifier = identifier
def _blackjax_state(self, model_state: ModelState) -> hmc.HMCState:
return nuts.init(self.position(model_state), self.log_prob_fn(model_state))
@property
def _blackjax_kernel(self) -> Callable:
return partial(nuts_kernel, max_num_doublings=self.max_treedepth)
[docs]
def init_state(self, prng_key, model_state):
"""
Initializes the kernel state with an identity inverse mass matrix and a
reasonable step size (unless explicit arguments were provided by the user).
"""
if self.initial_inverse_mass_matrix is None:
flat_position, _ = ravel_pytree(self.position(model_state))
if self.mm_diag:
inverse_mass_matrix = jnp.ones_like(flat_position)
else:
inverse_mass_matrix = jnp.eye(flat_position.size)
else:
inverse_mass_matrix = self.initial_inverse_mass_matrix
if self.initial_step_size is None:
blackjax_kernel = self._blackjax_kernel
blackjax_state = self._blackjax_state(model_state)
log_prob_fn = self.log_prob_fn(model_state)
def kernel_generator(step_size: float) -> Callable:
return blackjax_kernel(
logdensity_fn=log_prob_fn,
step_size=step_size,
inverse_mass_matrix=inverse_mass_matrix,
).step
step_size = find_reasonable_step_size(
prng_key,
kernel_generator,
blackjax_state,
initial_step_size=0.001,
target_accept=self.da_target_accept,
)
else:
step_size = self.initial_step_size
return NUTSKernelState(step_size, inverse_mass_matrix)
def _standard_transition(
self,
prng_key: KeyArray,
kernel_state: NUTSKernelState,
model_state: ModelState,
epoch: EpochState,
) -> TransitionOutcome[NUTSKernelState, NUTSTransitionInfo]:
"""
Performs an MCMC transition *without* dual averaging.
"""
blackjax_state = self._blackjax_state(model_state)
log_prob_fn = self.log_prob_fn(model_state)
blackjax_kernel = self._blackjax_kernel(
logdensity_fn=log_prob_fn,
step_size=kernel_state.step_size,
inverse_mass_matrix=kernel_state.inverse_mass_matrix,
)
blackjax_state, blackjax_info = blackjax_kernel.step(prng_key, blackjax_state)
info = _goose_info(blackjax_info, self.max_treedepth)
model_state = self.model.update_state(blackjax_state.position, model_state)
return TransitionOutcome(info, kernel_state, model_state)
def _adaptive_transition(
self,
prng_key: KeyArray,
kernel_state: NUTSKernelState,
model_state: ModelState,
epoch: EpochState,
) -> TransitionOutcome[NUTSKernelState, NUTSTransitionInfo]:
"""
Performs an MCMC transition *with* dual averaging.
"""
outcome = self._standard_transition(prng_key, kernel_state, model_state, epoch)
da_step(
outcome.kernel_state,
outcome.info.acceptance_prob,
epoch.time_in_epoch,
self.da_target_accept,
self.da_gamma,
self.da_kappa,
self.da_t0,
)
return outcome
def _tune_fast(
self,
prng_key: KeyArray,
kernel_state: NUTSKernelState,
model_state: ModelState,
epoch: EpochState,
history: Position | None = None,
) -> TuningOutcome[NUTSKernelState, NUTSTuningInfo]:
"""
Currently does nothing.
"""
info = NUTSTuningInfo(error_code=0, time=epoch.time)
return TuningOutcome(info, kernel_state)
def _tune_slow(
self,
prng_key: KeyArray,
kernel_state: NUTSKernelState,
model_state: ModelState,
epoch: EpochState,
history: Position | None = None,
) -> TuningOutcome[NUTSKernelState, NUTSTuningInfo]:
"""
Tunes the inverse mass vector or matrix using the samples from the last epoch.
"""
if history is not None:
history = Position({k: history[k] for k in self.position_keys})
if self.mm_diag:
new_inv_mm = tune_inv_mm_diag(history)
trace_fn = jnp.sum # type: ignore
else:
new_inv_mm = tune_inv_mm_full(history)
trace_fn = jnp.trace # type: ignore
old_inv_mm = kernel_state.inverse_mass_matrix
adjustment = jnp.sqrt(trace_fn(old_inv_mm) / trace_fn(new_inv_mm))
kernel_state.step_size = adjustment * kernel_state.step_size
kernel_state.inverse_mass_matrix = new_inv_mm
return self._tune_fast(prng_key, kernel_state, model_state, epoch, history)
[docs]
def start_epoch(
self,
prng_key: KeyArray,
kernel_state: NUTSKernelState,
model_state: ModelState,
epoch: EpochState,
) -> NUTSKernelState:
"""
Resets the state of the dual averaging algorithm.
"""
da_init(kernel_state)
return kernel_state
[docs]
def end_epoch(
self,
prng_key: KeyArray,
kernel_state: NUTSKernelState,
model_state: ModelState,
epoch: EpochState,
) -> NUTSKernelState:
"""
Sets the step size as found by the dual averaging algorithm.
"""
da_finalize(kernel_state)
return kernel_state
[docs]
def end_warmup(
self,
prng_key: KeyArray,
kernel_state: NUTSKernelState,
model_state: ModelState,
tuning_history: NUTSTuningInfo | None,
) -> WarmupOutcome[NUTSKernelState]:
"""
Currently does nothing.
"""
return WarmupOutcome(error_code=0, kernel_state=kernel_state)