Source code for liesel.goose.mh

"""
Metropolis-Hastings.

This module uses the error codes 90-99.
"""

import jax
import jax.numpy as jnp

from .kernel import DefaultTransitionInfo
from .types import KeyArray, ModelInterface, ModelState, Position

__docformat__ = "numpy"

mh_error_book = {0: "no errors", 90: "nan acceptance prob"}
"""The error book of the :func:`.mh_step` function."""


[docs]def mh_step( prng_key: KeyArray, model: ModelInterface, proposal: Position, model_state: ModelState, log_correction: float = 0.0, ) -> tuple[DefaultTransitionInfo, ModelState]: r""" Decides if an MCMC proposal is accepted in a Metropolis-Hastings step. Parameters ---------- prng_key The key for JAX' pseudo-random number generator. model The model interface. proposal The proposal to be evaluated. model_state The current model state. log_correction The Metropolis-Hastings correction in the case of an asymmetric proposal distribution. Let ``q(x' | x)`` be the density of the proposal ``x'`` given the current state ``x``, then the ``log_correction`` is defined as ``log[q(x | x') / q(x' | x)]``. Returns ------- A tuple of a :class:`.TransitionInfo` and a :class:`.ModelState` (= a pytree). """ current_log_prob = model.log_prob(model_state) proposed_model_state = model.update_state(proposal, model_state) proposed_log_prob = model.log_prob(proposed_model_state) log_acc_prob = proposed_log_prob - current_log_prob + log_correction log_acc_prob, error_code = jax.lax.cond( jnp.isnan(log_acc_prob), lambda: (-jnp.inf, 90), lambda: (log_acc_prob, 0), ) acceptance_prob = jnp.clip(jnp.exp(log_acc_prob), a_max=1.0) do_accept = jax.random.uniform(prng_key) <= acceptance_prob info = DefaultTransitionInfo(error_code, acceptance_prob, do_accept) model_state = jax.lax.cond( do_accept, lambda: proposed_model_state, lambda: model_state, ) return info, model_state