Inverse mass matrix tuner.

This module uses the error codes 70-79.

import jax
import jax.numpy as jnp

from .types import Array, Position

__docformat__ = "numpy"

_vravel = jax.vmap(jnp.ravel, in_axes=0, out_axes=0)

def _history_to_matrix(history: Position) -> Array:
    return jnp.column_stack([_vravel(x) for x in history.values()])

[docs]def tune_inv_mm_diag(history: Position) -> Array: """ Tunes an inverse mass vector with the sample variances of the history. Parameters ---------- history Holds the history of the position. It is to be understand as in :meth:`liesel.goose.types.Kernel.tune`. """ matrix = _history_to_matrix(history) var = jnp.var(matrix, axis=0, ddof=1) var = jnp.atleast_1d(var) var = var + 0.001 return var
[docs]def tune_inv_mm_full(history: Position) -> Array: """ Tunes an inverse mass matrix with the sample variance-covariance matrix of the history. Parameters ---------- history Holds the history of the position. It is to be understand as in :meth:`liesel.goose.types.Kernel.tune`. """ matrix = _history_to_matrix(history) cov = jnp.cov(matrix, rowvar=False) cov = jnp.atleast_2d(cov) # stan regularization, see: # # n = matrix.shape[0] # cov = (n / (n + 5.0)) * cov # cov =[jnp.diag_indices_from(cov)].add(0.001 * (5.0 / (n + 5.0))) cov =[jnp.diag_indices_from(cov)].add(0.001) return cov