Source code for liesel.distributions.mvn_degen

"""
The degenerate, i.e. rank-deficient, multivariate normal distribution.
"""

from __future__ import annotations

from typing import Any

import jax
import jax.numpy as jnp
import tensorflow_probability.substrates.jax.distributions as tfd
from tensorflow_probability.python.internal import reparameterization
from tensorflow_probability.substrates.jax import tf2jax as tf

Array = Any


def _rank_and_log_pdet(
    prec: Array,
    rank: Array | int | None = None,
    log_pdet: Array | float | None = None,
    tol: float = 1e-6,
) -> tuple[Array | float, Array | int]:
    """
    Computes the rank and the log-pseudo-determinant of the positive semi-definite
    precision matrix ``prec``.

    Can handle batches.

    If both the rank and the determinant are provided, the function does nothing and
    just returns the provided arguments. If the rank is provided, it is used to select
    the non-zero eigenvalues. If the rank is not provided, it is computed by counting
    the non-zero eigenvalues. An eigenvalue is deemed to be non-zero if it is greater
    than the numerical tolerance ``tol``.
    """

    if log_pdet is not None and rank is not None:
        return rank, log_pdet

    eigenvals = jnp.linalg.eigvalsh(prec)

    if rank is None:
        mask = eigenvals > tol
        rank = jnp.sum(mask, axis=-1)
    else:
        max_index = eigenvals.shape[-1] - rank

        def fn(i, x):
            return x.at[..., i].set(i >= max_index)

        mask = jax.lax.fori_loop(0, eigenvals.shape[-1], fn, eigenvals)

    if log_pdet is None:
        selected = jnp.where(mask, eigenvals, 1.0)
        log_pdet = jnp.sum(jnp.log(selected), axis=-1)

    return rank, log_pdet


[docs]class MultivariateNormalDegenerate(tfd.Distribution): """ A potentially degenerate multivariate normal distribution. Provides the alternative constructor :meth:`.from_penalty`. Parameters ---------- loc The location (= mean) vector. prec The precision matrix (= a pseudo-inverse of the variance-covariance matrix). rank The rank of the precision matrix. Optional. log_pdet The log-pseudo-determinant of the precision matrix. Optional. validate_args Python ``bool``, default ``False``. When ``True``, distribution parameters \ are checked for validity despite possibly degrading runtime performance. \ When ``False``, invalid inputs may silently render incorrect outputs. allow_nan_stats Python ``bool``, default ``True``. When ``True``, statistics (e.g., mean, \ mode, variance) use the value ``NaN`` to indicate the result is undefined. \ When ``False``, an exception is raised if one or more of the statistic's \ batch members are undefined. name Python ``str``, name prefixed to ``Ops`` created by this class. Notes ----- If they are not provided as arguments, the constructor computes ``rank`` and ``log_pdet`` based on the eigenvalues of the precision matrix ``prec``. This is an expensive operation and can be avoided by specifying the corresponding arguments. """ def __init__( self, loc: Array, prec: Array, rank: Array | int | None = None, log_pdet: Array | float | None = None, validate_args: bool = False, allow_nan_stats: bool = True, name: str = "MultivariateNormalDegenerate", ): parameters = dict(locals()) self._prec = prec # necessary for correct broadcasting over event size self._loc = jnp.atleast_1d(loc) if not self._prec.shape[-2] == self._prec.shape[-1]: raise ValueError( "`prec` must be square (the last two dimensions must be equal)." ) try: jnp.broadcast_shapes(self._prec.shape[-1], self._loc.shape[-1]) except ValueError: raise ValueError( f"The event sizes of `prec` ({self._prec.shape[-1]}) and `loc` " f"({self._loc.shape[-1]}) cannot be broadcast together. If you " "are trying to use batches for `loc`, you may need to add a " "dimension for the event size." ) self._broadcast_batch_shape = jnp.broadcast_shapes( jnp.shape(self._prec)[:-2], jnp.shape(self._loc)[:-1] ) self._rank, self._log_pdet = _rank_and_log_pdet( self._prec, rank, log_pdet, tol=1e-6 ) super().__init__( dtype=prec.dtype, reparameterization_type=reparameterization.FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name, )
[docs] @classmethod def from_penalty( cls, loc: Array, var: Array, pen: Array, rank: Array | int | None = None, log_pdet: Array | float | None = None, validate_args: bool = False, allow_nan_stats: bool = True, name: str = "MultivariateNormalDegenerate", ) -> MultivariateNormalDegenerate: """ Alternative constructor based on a penalty matrix and an inverse smoothing parameter. Sometimes, the precision matrix of a degenerate multivariate normal distribution is decomposed into a penalty matrix ``pen`` and an inverse smoothing parameter ``var``. Using this constructor, a degenerate multivariate normal distribution can be initialized from such a decomposition. Parameters ---------- loc The location (= mean) vector. var The variance (= inverse smoothing) parameter. pen The (potentially rank-deficient) penalty matrix. rank The rank of the penalty matrix. Optional. log_pdet The log-pseudo-determinant of the penalty matrix. Optional. validate_args Python ``bool``, default ``False``. When ``True`` distribution parameters are checked for validity despite possibly degrading runtime performance. When ``False`` invalid inputs may silently render incorrect outputs. allow_nan_stats Python ``bool``, default ``True``. When ``True``, statistics (e.g., mean, mode, variance) use the value ``NaN`` to indicate the result is undefined. When ``False``, an exception is raised if one or more of the statistic's batch members are undefined. name Python ``str`` name prefixed to ``Ops`` created by this class. Warnings -------- If the log-pseudo-determinant is provided as an argument, it must be of the penalty matrix ``pen``, **not** of the precision matrix. Notes ----- If they are not provided as arguments, the constructor computes ``rank`` and ``log_pdet`` based on the eigenvalues of the penalty matrix ``pen``. This is an expensive operation and can be avoided by specifying the corresponding arguments. """ prec = pen / jnp.expand_dims(var, axis=(-2, -1)) rank, log_pdet = _rank_and_log_pdet(pen, rank, log_pdet, tol=1e-6) log_pdet_prec = log_pdet - rank * jnp.log(var) mvnd = cls( loc=loc, prec=prec, rank=rank, log_pdet=log_pdet_prec, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name, ) return mvnd
def _log_prob(self, x: Array) -> Array | float: x = x - self._loc # necessary for correct broadcasting in the quadratic form x = jnp.expand_dims(x, axis=-2) x_T = jnp.swapaxes(x, -2, -1) prob1 = -jnp.squeeze(x @ self._prec @ x_T, axis=(-2, -1)) prob2 = self._rank * jnp.log(2 * jnp.pi) - self._log_pdet return 0.5 * (prob1 - prob2) def _event_shape(self): return tf.TensorShape((jnp.shape(self._prec)[-1],)) def _event_shape_tensor(self): return jnp.array((jnp.shape(self._prec)[-1],), dtype=jnp.int32) def _batch_shape(self): return tf.TensorShape(self._broadcast_batch_shape) def _batch_shape_tensor(self): return jnp.array(self._broadcast_batch_shape, dtype=jnp.int32)