"""
No U-Turn Sampler (NUTS).
"""
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from functools import partial
from typing import ClassVar
import jax.numpy as jnp
from blackjax.adaptation.step_size import find_reasonable_step_size
from blackjax.mcmc import hmc, nuts
from jax.flatten_util import ravel_pytree
from .da import da_finalize, da_init, da_step
from .epoch import EpochState
from .kernel import (
DefaultTransitionInfo,
DefaultTuningInfo,
ModelMixin,
TransitionMixin,
TransitionOutcome,
TuningMixin,
TuningOutcome,
WarmupOutcome,
)
from .mm import tune_inv_mm_diag, tune_inv_mm_full
from .pytree import register_dataclass_as_pytree
from .types import Array, KeyArray, ModelState, Position
[docs]@register_dataclass_as_pytree
@dataclass
class NUTSKernelState:
"""
A dataclass for the state of a :class:`.NUTSKernel`, implementing the
:class:`.DAKernelState` protocol.
"""
step_size: float
inverse_mass_matrix: Array
error_sum: float = field(init=False)
log_avg_step_size: float = field(init=False)
mu: float = field(init=False)
def __post_init__(self):
da_init(self)
[docs]@register_dataclass_as_pytree
@dataclass
class NUTSTransitionInfo(DefaultTransitionInfo):
error_code: int
"""Dict of error codes and their meaning."""
acceptance_prob: float
position_moved: int
divergent: bool
"""
Whether the difference in energy between the original and the new state exceeded
the divergence threshold of 1000.
"""
turning: bool
"""Whether the expansion was stopped because the trajectory started turning."""
treedepth: int
"""The tree depth, that is, the number of times the trajectory was expanded."""
leapfrog: int
"""The number of computed leapfrog steps."""
def _error_code(*args: bool) -> int:
return jnp.array(args) @ (2 ** jnp.arange(len(args)))
def _goose_info(nuts_info: nuts.NUTSInfo, max_treedepth: int) -> NUTSTransitionInfo:
error_code = _error_code(
nuts_info.is_divergent, nuts_info.num_trajectory_expansions == max_treedepth
)
return NUTSTransitionInfo(
error_code=error_code,
acceptance_prob=nuts_info.acceptance_probability,
position_moved=99,
divergent=nuts_info.is_divergent,
turning=nuts_info.is_turning,
treedepth=nuts_info.num_trajectory_expansions,
leapfrog=nuts_info.num_integration_steps,
)
NUTSTuningInfo = DefaultTuningInfo
[docs]class NUTSKernel(
ModelMixin,
TransitionMixin[NUTSKernelState, NUTSTransitionInfo],
TuningMixin[NUTSKernelState, NUTSTuningInfo],
):
"""
A NUTS kernel with dual averaging and an inverse mass matrix tuner, implementing the
:class:`.Kernel` protocol.
"""
error_book: ClassVar[dict[int, str]] = {
0: "no errors",
1: "divergent transition",
2: "maximum tree depth",
3: "divergent transition + maximum tree depth",
}
"""Dict of error codes and their meaning."""
needs_history: ClassVar[bool] = True
"""Whether this kernel needs its history for tuning."""
identifier: str = ""
"""Kernel identifier, set by :class:`.EngineBuilder`"""
position_keys: tuple[str, ...]
"""Tuple of position keys handled by this kernel."""
def __init__(
self,
position_keys: Sequence[str],
initial_step_size: float | None = None,
initial_inverse_mass_matrix: Array | None = None,
max_treedepth: int = 10,
da_target_accept: float = 0.8,
da_gamma: float = 0.05,
da_kappa: float = 0.75,
da_t0: int = 10,
mm_diag: bool = True,
):
self.position_keys = tuple(position_keys)
self._model = None
self.initial_step_size = initial_step_size
self.initial_inverse_mass_matrix = initial_inverse_mass_matrix
self.max_treedepth = max_treedepth
self.da_target_accept = da_target_accept
self.da_gamma = da_gamma
self.da_kappa = da_kappa
self.da_t0 = da_t0
self.mm_diag = mm_diag
def _blackjax_state(self, model_state: ModelState) -> hmc.HMCState:
return nuts.init(self.position(model_state), self.log_prob_fn(model_state))
def _blackjax_kernel(self) -> Callable:
return nuts.kernel(max_num_doublings=self.max_treedepth)
[docs] def init_state(self, prng_key, model_state):
"""
Initializes the kernel state with an identity inverse mass matrix and a
reasonable step size (unless explicit arguments were provided by the user).
"""
if self.initial_inverse_mass_matrix is None:
flat_position, _ = ravel_pytree(self.position(model_state))
if self.mm_diag:
inverse_mass_matrix = jnp.ones_like(flat_position)
else:
inverse_mass_matrix = jnp.eye(flat_position.size)
else:
inverse_mass_matrix = self.initial_inverse_mass_matrix
if self.initial_step_size is None:
blackjax_kernel = self._blackjax_kernel()
blackjax_state = self._blackjax_state(model_state)
log_prob_fn = self.log_prob_fn(model_state)
def kernel_generator(step_size: float) -> Callable:
return partial(
blackjax_kernel,
logprob_fn=log_prob_fn,
step_size=step_size,
inverse_mass_matrix=inverse_mass_matrix,
)
step_size = find_reasonable_step_size(
prng_key,
kernel_generator,
blackjax_state,
initial_step_size=0.001,
target_accept=self.da_target_accept,
)
else:
step_size = self.initial_step_size
return NUTSKernelState(step_size, inverse_mass_matrix)
def _standard_transition(
self,
prng_key: KeyArray,
kernel_state: NUTSKernelState,
model_state: ModelState,
epoch: EpochState,
) -> TransitionOutcome[NUTSKernelState, NUTSTransitionInfo]:
"""
Performs an MCMC transition *without* dual averaging.
"""
blackjax_kernel = self._blackjax_kernel()
blackjax_state = self._blackjax_state(model_state)
log_prob_fn = self.log_prob_fn(model_state)
blackjax_state, blackjax_info = blackjax_kernel(
prng_key,
blackjax_state,
log_prob_fn,
kernel_state.step_size,
kernel_state.inverse_mass_matrix,
)
info = _goose_info(blackjax_info, self.max_treedepth)
model_state = self.model.update_state(blackjax_state.position, model_state)
return TransitionOutcome(info, kernel_state, model_state)
def _adaptive_transition(
self,
prng_key: KeyArray,
kernel_state: NUTSKernelState,
model_state: ModelState,
epoch: EpochState,
) -> TransitionOutcome[NUTSKernelState, NUTSTransitionInfo]:
"""
Performs an MCMC transition *with* dual averaging.
"""
outcome = self._standard_transition(prng_key, kernel_state, model_state, epoch)
da_step(
outcome.kernel_state,
outcome.info.acceptance_prob,
epoch.time_in_epoch,
self.da_target_accept,
self.da_gamma,
self.da_kappa,
self.da_t0,
)
return outcome
def _tune_fast(
self,
prng_key: KeyArray,
kernel_state: NUTSKernelState,
model_state: ModelState,
epoch: EpochState,
history: Position | None = None,
) -> TuningOutcome[NUTSKernelState, NUTSTuningInfo]:
"""
Currently does nothing.
"""
info = NUTSTuningInfo(error_code=0, time=epoch.time)
return TuningOutcome(info, kernel_state)
def _tune_slow(
self,
prng_key: KeyArray,
kernel_state: NUTSKernelState,
model_state: ModelState,
epoch: EpochState,
history: Position | None = None,
) -> TuningOutcome[NUTSKernelState, NUTSTuningInfo]:
"""
Tunes the inverse mass vector or matrix using the samples from the last epoch.
"""
if history is not None:
history = Position({k: history[k] for k in self.position_keys})
if self.mm_diag:
new_inv_mm = tune_inv_mm_diag(history)
trace_fn = jnp.sum
else:
new_inv_mm = tune_inv_mm_full(history)
trace_fn = jnp.trace
old_inv_mm = kernel_state.inverse_mass_matrix
adjustment = jnp.sqrt(trace_fn(old_inv_mm) / trace_fn(new_inv_mm))
kernel_state.step_size = adjustment * kernel_state.step_size
kernel_state.inverse_mass_matrix = new_inv_mm
return self._tune_fast(prng_key, kernel_state, model_state, epoch, history)
[docs] def start_epoch(
self,
prng_key: KeyArray,
kernel_state: NUTSKernelState,
model_state: ModelState,
epoch: EpochState,
) -> NUTSKernelState:
"""
Resets the state of the dual averaging algorithm.
"""
da_init(kernel_state)
return kernel_state
[docs] def end_epoch(
self,
prng_key: KeyArray,
kernel_state: NUTSKernelState,
model_state: ModelState,
epoch: EpochState,
) -> NUTSKernelState:
"""
Sets the step size as found by the dual averaging algorithm.
"""
da_finalize(kernel_state)
return kernel_state
[docs] def end_warmup(
self,
prng_key: KeyArray,
kernel_state: NUTSKernelState,
model_state: ModelState,
tuning_history: NUTSTuningInfo | None,
) -> WarmupOutcome[NUTSKernelState]:
"""
Currently does nothing.
"""
return WarmupOutcome(error_code=0, kernel_state=kernel_state)