Source code for liesel.goose.da

"""
Dual averaging.

This module uses the error codes 80-89.
"""

from dataclasses import dataclass
from typing import Protocol

import jax.numpy as jnp

from .pytree import register_dataclass_as_pytree


[docs] @register_dataclass_as_pytree @dataclass class DualAvgState: """ The state of the dual averaging algorithm. """ error_sum: float """The error sum of the acceptance probability.""" log_avg_step_size: float """The logarithm of the average step size.""" mu: float """The bias of the step size proposals."""
[docs] @classmethod def from_step_size(cls, step_size: float) -> "DualAvgState": """Initializes a dual averaging state for ``step_size``.""" return cls( error_sum=0.0, log_avg_step_size=jnp.log(step_size), mu=jnp.log(10.0 * step_size), )
[docs] class DAKernelState(Protocol): """ A protocol for a kernel state with dual averaging support. For an introduction to dual averaging, see the blog post by Colin Carroll [#carroll]_ and the Stan Reference Manual [#stan]_. .. [#carroll] `Colin Carroll, Step Size Adaptation in Hamiltonian Monte Carlo (2019) <https://colindcarroll.com/blog/step_size_adapt_hmc.html>`_. .. [#stan] `Stan Development Team, Stan Reference Manual (2021), Chapter 15.2 <https://mc-stan.org/docs/2_28/reference-manual/hmc-algorithm-parameters.html>`_. """ step_size: float """The step size of the kernel.""" da_state: DualAvgState | None """The internal state of the dual averaging algorithm."""
[docs] def da_init(kernel_state: DAKernelState) -> None: """ Initializes (or resets) a :class:`.DAKernelState`. Returns ``None`` and should be called for the side effect on the ``kernel_state`` argument. """ kernel_state.da_state = DualAvgState.from_step_size(kernel_state.step_size)
[docs] def da_step( kernel_state: DAKernelState, acceptance_prob: float, time_in_epoch: int, target_accept: float = 0.8, gamma: float = 0.05, kappa: float = 0.75, t0: int = 10, ) -> None: """ Performs an dual averaging update on a :class:`.DAKernelState`. Returns ``None`` and should be called for the side effect on the ``kernel_state`` argument. Parameters ---------- kernel_state A kernel state implementing the :class:`.DAKernelState` protocol. acceptance_prob The acceptance probability of this MCMC iteration. time_in_epoch The number of completed MCMC iterations in this epoch. target_accept The target acceptance probability. gamma The adaptation regularization scale. kappa The adaptation relaxation exponent. t0 The adaptation iteration offset. Notes ----- For an introduction to dual averaging, see the blog post by Colin Carroll [#carroll]_ and the Stan Reference Manual [#stan]_. .. [#carroll] `Colin Carroll, Step Size Adaptation in Hamiltonian Monte Carlo (2019) <https://colindcarroll.com/blog/step_size_adapt_hmc.html>`_. .. [#stan] `Stan Development Team, Stan Reference Manual (2021), Chapter 15.2 <https://mc-stan.org/docs/2_28/reference-manual/hmc-algorithm-parameters.html>`_. """ ks = kernel_state t = time_in_epoch + 1 eta = t ** (-kappa) da_state = ks.da_state if da_state is None: raise RuntimeError("Dual averaging state has not been initialized.") da_state.error_sum += target_accept - acceptance_prob log_step_size = da_state.mu - (da_state.error_sum * jnp.sqrt(t)) / ( gamma * (t0 + t) ) ks.step_size = jnp.exp(log_step_size) log_avg_step_size = (1 - eta) * da_state.log_avg_step_size + eta * log_step_size da_state.log_avg_step_size = log_avg_step_size
[docs] def da_finalize(kernel_state: DAKernelState) -> None: """ Sets the new step size in a :class:`.DAKernelState`. Returns ``None`` and should be called for the side effect on the ``kernel_state`` argument. """ da_state = kernel_state.da_state if da_state is None: raise RuntimeError("Dual averaging state has not been initialized.") kernel_state.step_size = jnp.exp(da_state.log_avg_step_size)