Source code for liesel.goose.mm

"""
Inverse mass matrix tuner.

This module uses the error codes 70-79.
"""

import jax
import jax.numpy as jnp

from .types import Array, Position

_vravel = jax.vmap(jnp.ravel, in_axes=0, out_axes=0)


def _history_to_matrix(history: Position) -> Array:
    return jnp.column_stack([_vravel(x) for x in history.values()])


[docs]def tune_inv_mm_diag(history: Position) -> Array: """ Tunes an inverse mass vector with the sample variances of the history. Parameters ---------- history Holds the history of the position. It is to be understand as in :meth:`liesel.goose.types.Kernel.tune`. """ matrix = _history_to_matrix(history) var = jnp.var(matrix, axis=0, ddof=1) var = jnp.atleast_1d(var) var = var + 0.001 return var
[docs]def tune_inv_mm_full(history: Position) -> Array: """ Tunes an inverse mass matrix with the sample variance-covariance matrix of the history. Parameters ---------- history Holds the history of the position. It is to be understand as in :meth:`liesel.goose.types.Kernel.tune`. """ matrix = _history_to_matrix(history) cov = jnp.cov(matrix, rowvar=False) cov = jnp.atleast_2d(cov) # stan regularization, see: # https://github.com/stan-dev/stan/blob/v2.28.2/src/stan/mcmc/covar_adaptation.hpp#L27-L29 # n = matrix.shape[0] # cov = (n / (n + 5.0)) * cov # cov = cov.at[jnp.diag_indices_from(cov)].add(0.001 * (5.0 / (n + 5.0))) cov = cov.at[jnp.diag_indices_from(cov)].add(0.001) return cov