Source code for liesel.goose.da

"""
Dual averaging.

This module uses the error codes 80-89.
"""

from typing import Protocol

import jax.numpy as jnp


[docs]class DAKernelState(Protocol): """ A protocol for a kernel state with dual averaging support. For an introduction to dual averaging, see the blog post by Colin Carroll [#carroll]_ and the Stan Reference Manual [#stan]_. .. [#carroll] `Colin Carroll, Step Size Adaptation in Hamiltonian Monte Carlo (2019) <https://colindcarroll.com/2019/04/21/step-size-adaptation-in-hamiltonian-monte-carlo>`_. .. [#stan] `Stan Development Team, Stan Reference Manual (2021), Chapter 15.2 <https://mc-stan.org/docs/2_28/reference-manual/hmc-algorithm-parameters.html>`_. """ step_size: float """The step size of the kernel.""" error_sum: float """The error sum of the acceptance probability. Should not be set by the user, but is used by the :func:`.da_step` function.""" log_avg_step_size: float """The logarithm of the average step size. Should not be set by the user, but is used by the :func:`.da_step` function.""" mu: float """The bias of the step size proposals. Should not be set by the user, but is used by the :func:`.da_step` function."""
[docs]def da_init(kernel_state: DAKernelState) -> None: """ Initializes (or resets) a :class:`.DAKernelState`. Returns ``None`` and should be called for the side effect on the ``kernel_state`` argument. """ kernel_state.error_sum = 0.0 kernel_state.log_avg_step_size = jnp.log(kernel_state.step_size) kernel_state.mu = jnp.log(10.0 * kernel_state.step_size)
[docs]def da_step( kernel_state: DAKernelState, acceptance_prob: float, time_in_epoch: int, target_accept: float = 0.8, gamma: float = 0.05, kappa: float = 0.75, t0: int = 10, ) -> None: """ Performs an dual averaging update on a :class:`.DAKernelState`. Returns ``None`` and should be called for the side effect on the ``kernel_state`` argument. Parameters ---------- kernel_state A kernel state implementing the :class:`.DAKernelState` protocol. acceptance_prob The acceptance probability of this MCMC iteration. time_in_epoch The number of completed MCMC iterations in this epoch. target_accept The target acceptance probability. gamma The adaptation regularization scale. kappa The adaptation relaxation exponent. t0 The adaptation iteration offset. """ ks = kernel_state t = time_in_epoch + 1 eta = t ** (-kappa) ks.error_sum += target_accept - acceptance_prob log_step_size = ks.mu - (ks.error_sum * jnp.sqrt(t)) / (gamma * (t0 + t)) ks.step_size = jnp.exp(log_step_size) ks.log_avg_step_size = (1 - eta) * ks.log_avg_step_size + eta * log_step_size
[docs]def da_finalize(kernel_state: DAKernelState) -> None: """ Sets the new step size in a :class:`.DAKernelState`. Returns ``None`` and should be called for the side effect on the ``kernel_state`` argument. """ kernel_state.step_size = jnp.exp(kernel_state.log_avg_step_size)