Source code for liesel.goose.rw

"""
Random walk sampler.
"""

from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import ClassVar

import jax
import jax.flatten_util

from .da import da_finalize, da_init, da_step
from .epoch import EpochState
from .kernel import (
    DefaultTransitionInfo,
    DefaultTuningInfo,
    ModelMixin,
    TransitionMixin,
    TransitionOutcome,
    TuningOutcome,
    WarmupOutcome,
)
from .mh import mh_step
from .pytree import register_dataclass_as_pytree
from .types import KeyArray, ModelState, Position, TuningInfo


@register_dataclass_as_pytree
@dataclass
class RWKernelState:
    """
    A dataclass for the state of a ``RWKernel``, implementing the
    :class:`.DAKernelState` protocol.
    """

    step_size: float
    error_sum: float = field(default=0.0, init=False)
    log_avg_step_size: float = field(default=0.0, init=False)
    mu: float = field(init=False)

    def __post_init__(self):
        da_init(self)


RWTransitionInfo = DefaultTransitionInfo
RWTuningInfo = DefaultTuningInfo


[docs] class RWKernel(ModelMixin, TransitionMixin[RWKernelState, RWTransitionInfo]): """ A random walk kernel. Uses Gaussian proposals, Metropolis-Hastings correction and dual averaging. Implements the :class:`.Kernel` protocol. The kernel uses a default Metropolis-Hastings target acceptance probability of 0.234, which is optimal for a random walk sampler (in a certain sense). See Gelman et al. (1997) Weak convergence and optimal scaling of random walk Metropolis algorithms: https://doi.org/10.1214/aoap/1034625254. """ error_book: ClassVar[dict[int, str]] = {0: "no errors", 90: "nan acceptance prob"} """Dict of error codes and their meaning.""" needs_history: ClassVar[bool] = False """Whether this kernel needs its history for tuning.""" identifier: str = "" """Kernel identifier, set by :class:`~.goose.EngineBuilder`""" position_keys: tuple[str, ...] """Tuple of position keys handled by this kernel.""" def __init__( self, position_keys: Sequence[str], initial_step_size: float = 1.0, da_target_accept: float = 0.234, da_gamma: float = 0.05, da_kappa: float = 0.75, da_t0: int = 10, ): self._model = None self.position_keys = tuple(position_keys) self.initial_step_size = initial_step_size self.da_target_accept = da_target_accept self.da_gamma = da_gamma self.da_kappa = da_kappa self.da_t0 = da_t0
[docs] def init_state(self, prng_key, model_state): """ Initializes the kernel state. """ return RWKernelState(step_size=self.initial_step_size)
def _standard_transition( self, prng_key: KeyArray, kernel_state: RWKernelState, model_state: ModelState, epoch: EpochState, ) -> TransitionOutcome[RWKernelState, DefaultTransitionInfo]: """ Performs an MCMC transition *without* dual averaging. """ key, subkey = jax.random.split(prng_key) step_size = kernel_state.step_size # random walk proposal position = self.position(model_state) flat_position, unravel_fn = jax.flatten_util.ravel_pytree(position) step = step_size * jax.random.normal(key, flat_position.shape) flat_proposal = flat_position + step proposal = unravel_fn(flat_proposal) # metropolis-hastings calibration info, model_state = mh_step(subkey, self.model, proposal, model_state) return TransitionOutcome(info, kernel_state, model_state) def _adaptive_transition( self, prng_key: KeyArray, kernel_state: RWKernelState, model_state: ModelState, epoch: EpochState, ) -> TransitionOutcome[RWKernelState, DefaultTransitionInfo]: """ Performs an MCMC transition *with* dual averaging. """ outcome = self._standard_transition(prng_key, kernel_state, model_state, epoch) da_step( outcome.kernel_state, outcome.info.acceptance_prob, epoch.time_in_epoch, self.da_target_accept, self.da_gamma, self.da_kappa, self.da_t0, ) return outcome
[docs] def tune( self, prng_key: KeyArray, kernel_state: RWKernelState, model_state: ModelState, epoch: EpochState, history: Position | None = None, ) -> TuningOutcome[RWKernelState, DefaultTuningInfo]: """ Currently does nothing. """ info = RWTuningInfo(error_code=0, time=epoch.time) return TuningOutcome(info, kernel_state)
[docs] def start_epoch( self, prng_key: KeyArray, kernel_state: RWKernelState, model_state: ModelState, epoch: EpochState, ) -> RWKernelState: """ Resets the state of the dual averaging algorithm. """ da_init(kernel_state) return kernel_state
[docs] def end_epoch( self, prng_key: KeyArray, kernel_state: RWKernelState, model_state: ModelState, epoch: EpochState, ) -> RWKernelState: """ Sets the step size as found by the dual averaging algorithm. """ da_finalize(kernel_state) return kernel_state
[docs] def end_warmup( self, prng_key: KeyArray, kernel_state: RWKernelState, model_state: ModelState, tuning_history: TuningInfo | None, ) -> WarmupOutcome[RWKernelState]: """ Currently does nothing. """ return WarmupOutcome(error_code=0, kernel_state=kernel_state)