Source code for liesel.distributions.mvn_degen

"""
The degenerate, i.e. rank-deficient, multivariate normal distribution.
"""

from __future__ import annotations

from functools import cached_property
from typing import Any

import jax
import jax.numpy as jnp
import tensorflow_probability.substrates.jax.distributions as tfd
from tensorflow_probability.python.internal import reparameterization
from tensorflow_probability.substrates.jax import tf2jax as tf

Array = Any


def _rank(eigenvalues: Array, tol: float = 1e-6) -> Array | float:
    """
    Computes the rank of a matrix based on the provided eigenvalues. The rank is taken
    to be the number of non-zero eigenvalues.

    Can handle batches.
    """
    mask = eigenvalues > tol
    rank = jnp.sum(mask, axis=-1)
    return rank


def _log_pdet(
    eigenvalues: Array, rank: Array | float | None = None, tol: float = 1e-6
) -> Array | float:
    """
    Computes the log of the pseudo-determinant of a matrix based on the provided
    eigenvalues. If the rank is provided, it is used to select the non-zero eigenvalues.
    If the rank is not provided, it is computed by counting the non-zero eigenvalues. An
    eigenvalue is deemed to be non-zero if it is greater than the numerical tolerance
    ``tol``.

    Can handle batches.
    """
    if rank is None:
        mask = eigenvalues > tol
    else:
        max_index = eigenvalues.shape[-1] - rank

        def fn(i, x):
            return x.at[..., i].set(i >= max_index)

        mask = jax.lax.fori_loop(0, eigenvalues.shape[-1], fn, eigenvalues)

    selected = jnp.where(mask, eigenvalues, 1.0)
    log_pdet = jnp.sum(jnp.log(selected), axis=-1)
    return log_pdet


[docs]class MultivariateNormalDegenerate(tfd.Distribution): """ A potentially degenerate multivariate normal distribution. Provides the alternative constructor :meth:`.from_penalty`. Parameters ---------- loc The location (= mean) vector. prec The precision matrix (= a pseudo-inverse of the variance-covariance matrix). rank The rank of the precision matrix. Optional. log_pdet The log-pseudo-determinant of the precision matrix. Optional. validate_args Python ``bool``, default ``False``. When ``True``, distribution parameters \ are checked for validity despite possibly degrading runtime performance. \ When ``False``, invalid inputs may silently render incorrect outputs. allow_nan_stats Python ``bool``, default ``True``. When ``True``, statistics (e.g., mean, \ mode, variance) use the value ``NaN`` to indicate the result is undefined. \ When ``False``, an exception is raised if one or more of the statistic's \ batch members are undefined. name Python ``str``, name prefixed to ``Ops`` created by this class. tol Numerical tolerance for determining which eigenvalues of the distribution's \ precision matrices should be treated as zeros. Used in :attr:`.rank` and \ :attr:`.log_pdet`, if they are computed by the class. Also used in \ :meth:`.sample`. Notes ----- * If they are not provided as arguments, ``rank`` and ``log_pdet`` are computed based on the eigenvalues of the precision matrix ``prec``. This is an expensive operation and can be avoided by specifying the corresponding arguments. * When you draw samples from the distribution via :meth:`.sample`, it is always necessary to compute the eigendecomposition of the distribution's precision matrices once and cache it, because sampling requires both the eigenvalues and eigenvectors. """ def __init__( self, loc: Array, prec: Array, rank: Array | int | None = None, log_pdet: Array | float | None = None, validate_args: bool = False, allow_nan_stats: bool = True, name: str = "MultivariateNormalDegenerate", tol: float = 1e-6, ): parameters = dict(locals()) self._tol = tol self._rank = rank self._log_pdet = log_pdet # necessary for correct broadcasting over event size loc = jnp.atleast_1d(loc) if not prec.shape[-2] == prec.shape[-1]: raise ValueError( "`prec` must be square (the last two dimensions must be equal)." ) try: jnp.broadcast_shapes(prec.shape[-1], loc.shape[-1]) except ValueError: raise ValueError( f"The event sizes of `prec` ({prec.shape[-1]}) and `loc` " f"({loc.shape[-1]}) cannot be broadcast together. If you " "are trying to use batches for `loc`, you may need to add a " "dimension for the event size." ) prec_batches = jnp.shape(prec)[:-2] loc_batches = jnp.shape(loc)[:-1] self._broadcast_batch_shape = jnp.broadcast_shapes(prec_batches, loc_batches) nbatch = len(self.batch_shape) self._prec = jnp.expand_dims(prec, tuple(range(nbatch - len(prec_batches)))) self._loc = jnp.expand_dims(loc, tuple(range(nbatch - len(loc_batches)))) super().__init__( dtype=prec.dtype, reparameterization_type=reparameterization.FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name, )
[docs] @classmethod def from_penalty( cls, loc: Array, var: Array, pen: Array, rank: Array | int | None = None, log_pdet: Array | float | None = None, validate_args: bool = False, allow_nan_stats: bool = True, name: str = "MultivariateNormalDegenerate", ) -> MultivariateNormalDegenerate: """ Alternative constructor based on a penalty matrix and an inverse smoothing parameter. Sometimes, the precision matrix of a degenerate multivariate normal distribution is decomposed into a penalty matrix ``pen`` and an inverse smoothing parameter ``var``. Using this constructor, a degenerate multivariate normal distribution can be initialized from such a decomposition. Parameters ---------- loc The location (= mean) vector. var The variance (= inverse smoothing) parameter. pen The (potentially rank-deficient) penalty matrix. rank The rank of the penalty matrix. Optional. log_pdet The log-pseudo-determinant of the penalty matrix. Optional. validate_args Python ``bool``, default ``False``. When ``True`` distribution parameters are checked for validity despite possibly degrading runtime performance. When ``False`` invalid inputs may silently render incorrect outputs. allow_nan_stats Python ``bool``, default ``True``. When ``True``, statistics (e.g., mean, mode, variance) use the value ``NaN`` to indicate the result is undefined. When ``False``, an exception is raised if one or more of the statistic's batch members are undefined. name Python ``str`` name prefixed to ``Ops`` created by this class. Warnings -------- If the log-pseudo-determinant is provided as an argument, it must be of the penalty matrix ``pen``, **not** of the precision matrix. Notes ----- If they are not provided as arguments, the constructor computes ``rank`` and ``log_pdet`` based on the eigenvalues of the penalty matrix ``pen``. This is an expensive operation and can be avoided by specifying the corresponding arguments. """ prec = pen / jnp.expand_dims(var, axis=(-2, -1)) evals = jax.numpy.linalg.eigvalsh(pen) rank = _rank(evals) if rank is None else rank log_pdet = _log_pdet(evals, rank=rank) if log_pdet is None else log_pdet log_pdet_prec = log_pdet - rank * jnp.log(var) mvnd = cls( loc=loc, prec=prec, rank=rank, log_pdet=log_pdet_prec, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name, ) return mvnd
@cached_property def eig(self) -> tuple[Array, Array]: """Eigenvalues and eigenvectors of the distribution's precision matrices.""" return jnp.linalg.eigh(self._prec) @cached_property def _sqrt_pcov(self) -> Array: """ Square roots of the distribution's pseudo-covariance matrices. Let ``prec = Q @ A @ Q.T`` be the eigendecomposition of the precision matrix. In essence, this property returns ``Q @ jnp.sqrt(1/A)``. """ eigenvalues, evecs = self.eig sqrt_eval = jnp.sqrt(1 / eigenvalues) sqrt_eval = jnp.where(eigenvalues < self._tol, 0.0, sqrt_eval) event_shape = sqrt_eval.shape[-1] shape = sqrt_eval.shape + (event_shape,) r = tuple(range(event_shape)) diags = jnp.zeros(shape).at[..., r, r].set(sqrt_eval) return evecs @ diags @cached_property def rank(self) -> Array | float: """Ranks of the distribution's precision matrices.""" if self._rank is not None: return self._rank evals, _ = self.eig return _rank(evals, tol=self._tol) @cached_property def log_pdet(self) -> Array | float: """Log-pseudo-determinants of the distribution's precision matrices.""" if self._log_pdet is not None: return self._log_pdet evals, _ = self.eig return _log_pdet(evals, self.rank, tol=self._tol) def _sample_n(self, n, seed=None) -> Array: shape = [n] + self.batch_shape + self.event_shape # The added dimension at the end here makes sure that matrix multiplication # with the "sqrt pcov" matrices works out correctly. z = jax.random.normal(key=seed, shape=shape + [1]) # Add a dimension at 0 for the sample size. sqrt_pcov = jnp.expand_dims(self._sqrt_pcov, 0) centered_samples = jnp.reshape(sqrt_pcov @ z, shape) # Add a dimension at 0 for the sample size. loc = jnp.expand_dims(self._loc, 0) return centered_samples + loc def _log_prob(self, x: Array) -> Array | float: x = x - self._loc # necessary for correct broadcasting in the quadratic form x = jnp.expand_dims(x, axis=-2) x_T = jnp.swapaxes(x, -2, -1) prob1 = -jnp.squeeze(x @ self._prec @ x_T, axis=(-2, -1)) prob2 = self.rank * jnp.log(2 * jnp.pi) - self.log_pdet return 0.5 * (prob1 - prob2) def _event_shape(self): return tf.TensorShape((jnp.shape(self._prec)[-1],)) def _event_shape_tensor(self): return jnp.array((jnp.shape(self._prec)[-1],), dtype=jnp.int32) def _batch_shape(self): return tf.TensorShape(self._broadcast_batch_shape) def _batch_shape_tensor(self): return jnp.array(self._broadcast_batch_shape, dtype=jnp.int32)