"""
The degenerate, i.e. rank-deficient, multivariate normal distribution.
"""
from __future__ import annotations
from functools import cached_property
from typing import Any
import jax
import jax.numpy as jnp
import jax.numpy.linalg as jnpla
import tensorflow_probability.substrates.jax.distributions as tfd
from tensorflow_probability.python.internal import reparameterization
from tensorflow_probability.substrates.jax import tf2jax as tf
Array = Any
def _rank(eigenvalues: Array, tol: float = 1e-6) -> Array | float:
"""
Computes the rank of a matrix based on the provided eigenvalues. The rank is taken
to be the number of non-zero eigenvalues.
Can handle batches.
"""
mask = eigenvalues > tol
rank = jnp.sum(mask, axis=-1)
return rank
def _log_pdet(
eigenvalues: Array, rank: Array | float | None = None, tol: float = 1e-6
) -> Array | float:
"""
Computes the log of the pseudo-determinant of a matrix based on the provided
eigenvalues. If the rank is provided, it is used to select the non-zero eigenvalues.
If the rank is not provided, it is computed by counting the non-zero eigenvalues. An
eigenvalue is deemed to be non-zero if it is greater than the numerical tolerance
``tol``.
Can handle batches.
"""
if rank is None:
mask = eigenvalues > tol
else:
max_index = eigenvalues.shape[-1] - rank
def fn(i, x):
return x.at[..., i].set(i >= max_index)
mask = jax.lax.fori_loop(0, eigenvalues.shape[-1], fn, eigenvalues)
selected = jnp.where(mask, eigenvalues, 1.0)
log_pdet = jnp.sum(jnp.log(selected), axis=-1)
return log_pdet
[docs]
class MultivariateNormalDegenerate(tfd.Distribution):
"""
A potentially degenerate multivariate normal distribution.
Provides the alternative constructor :meth:`.from_penalty` and sampling
via :meth:`.sample`.
This is a simplified code-based illustration of how the log-probability for an array
``x`` is evaluated::
xc = x - loc
log_prob = -0.5 * (rank * np.log(2*np.pi) - log_pdet) -0.5 * (xc.T @ prec @ xc)
Parameters
----------
loc
The location (= mean) vector.
prec
The precision matrix (= a pseudo-inverse of the variance-covariance matrix).
rank
The rank of the precision matrix. Optional.
log_pdet
The log-pseudo-determinant of the precision matrix. Optional.
validate_args
Python ``bool``, default ``False``. When ``True``, distribution parameters \
are checked for validity despite possibly degrading runtime performance. \
When ``False``, invalid inputs may silently render incorrect outputs.
allow_nan_stats
Python ``bool``, default ``True``. When ``True``, statistics (e.g., mean, \
mode, variance) use the value ``NaN`` to indicate the result is undefined. \
When ``False``, an exception is raised if one or more of the statistic's \
batch members are undefined.
name
Python ``str``, name prefixed to ``Ops`` created by this class.
tol
Numerical tolerance for determining which eigenvalues of the distribution's \
precision matrices should be treated as zeros. Used in :attr:`.rank` and \
:attr:`.log_pdet`, if they are computed by the class. Also used in \
:meth:`.sample`.
Notes
-----
* If they are not provided as arguments, ``rank`` and ``log_pdet`` are computed
based on the eigenvalues of the precision matrix ``prec``. This is an expensive
operation and can be avoided by specifying the corresponding arguments.
* When you draw samples from the distribution via :meth:`.sample`, it is always
necessary to compute the eigendecomposition of the distribution's precision
matrices once and cache it, because sampling requires both the eigenvalues and
eigenvectors.
**Details on sampling**
To draw samples from a denegerate multivariate normal distribution, we 1) draw
standard normal samples with mean zero and variance one, 2) transform these samples
to have the desired covariance structure, and 3) add the desired mean.
The main problem is to find out how we have to transform the standard normal samples
in step 2. Say that we have a singular :math:`(m \\times m)` precision matrix
:math:`P`. We can view it as
the generalized inverse of a variance-covariance matrix :math:`\\Sigma`. We can
obtain :math:`\\Sigma` by finding the eigenvalue decomposition of the precision
matrix, i.e.
.. math::
P = QA^+Q^T,
where :math:`Q` is the orthogonal matrix of eigenvectors and :math:`A^+` is the
diagonal matrix of eigenvalues. Note that, if the precision matrix is singular, then
:math:`\\text{diag}(A^+)` contains zeroes. Now we take the inverse of the non-zero
entries of :math:`\\text{diag}(A^+)`, while the zero entries remain at zero,
resulting in a matrix :math:`A`. We can now write
.. math::
\\Sigma = Q A Q^T.
Now we can go through the three steps in detail. We first draw a vector of the
desired length :math:`z \\sim N(0, I)` from a standard normal distribution.
:math:`I` is the identity matrix of appropriate dimension. Next, we transform the
sample by applying :math:`x = Q A^{1/2}z`, such that :math:`\\text{Cov}(x) =
\\Sigma`:
.. math::
\\text{Cov}(x) & = Q A^{1/2} I (A^{1/2})^T Q^T \\
& = Q A Q^T \\
& = \\Sigma.
In the last step, we add the desired mean :math:`\\mu` to :math:`x`.
Note that the distribution is not a proper distribution on :math:`\\mathbb{R}^m`,
where :math:`m` refers to the number of columns and rows of :math:`P`.
Any vector in the null space of :math:`P` can be added to any
:math:`x \\in \\mathbb{R}^m` without changing the density. The samples generated
using the procedure described above are orthogonal to the null space of :math:`P`.
"""
def __init__(
self,
loc: Array,
prec: Array,
rank: Array | int | None = None,
log_pdet: Array | float | None = None,
validate_args: bool = False,
allow_nan_stats: bool = True,
name: str = "MultivariateNormalDegenerate",
tol: float = 1e-6,
):
parameters = dict(locals())
self._tol = tol
self._rank = rank
self._log_pdet = log_pdet
# necessary for correct broadcasting over event size
loc = jnp.atleast_1d(loc)
if not prec.shape[-2] == prec.shape[-1]:
raise ValueError(
"`prec` must be square (the last two dimensions must be equal)."
)
try:
jnp.broadcast_shapes(prec.shape[-1], loc.shape[-1])
except ValueError:
raise ValueError(
f"The event sizes of `prec` ({prec.shape[-1]}) and `loc` "
f"({loc.shape[-1]}) cannot be broadcast together. If you "
"are trying to use batches for `loc`, you may need to add a "
"dimension for the event size."
)
prec_batches = jnp.shape(prec)[:-2]
loc_batches = jnp.shape(loc)[:-1]
self._broadcast_batch_shape = jnp.broadcast_shapes(prec_batches, loc_batches)
nbatch = len(self.batch_shape)
self._prec = jnp.expand_dims(prec, tuple(range(nbatch - len(prec_batches))))
self._loc = jnp.expand_dims(loc, tuple(range(nbatch - len(loc_batches))))
super().__init__(
dtype=prec.dtype,
reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
name=name,
)
[docs]
@classmethod
def from_penalty(
cls,
loc: Array,
var: Array,
pen: Array,
rank: Array | int | None = None,
log_pdet: Array | float | None = None,
validate_args: bool = False,
allow_nan_stats: bool = True,
name: str = "MultivariateNormalDegenerate",
) -> MultivariateNormalDegenerate:
"""
Alternative constructor based on a penalty matrix and an inverse smoothing
parameter.
Sometimes, the precision matrix of a degenerate multivariate normal
distribution is decomposed into a penalty matrix ``pen`` and an inverse
smoothing parameter ``var``. Using this constructor, a degenerate multivariate
normal distribution can be initialized from such a decomposition.
Parameters
----------
loc
The location (= mean) vector.
var
The variance (= inverse smoothing) parameter.
pen
The (potentially rank-deficient) penalty matrix.
rank
The rank of the penalty matrix. Optional.
log_pdet
The log-pseudo-determinant of the penalty matrix. Optional.
validate_args
Python ``bool``, default ``False``. When ``True`` distribution parameters
are checked for validity despite possibly degrading runtime performance.
When ``False`` invalid inputs may silently render incorrect outputs.
allow_nan_stats
Python ``bool``, default ``True``. When ``True``, statistics (e.g., mean,
mode, variance) use the value ``NaN`` to indicate the result is undefined.
When ``False``, an exception is raised if one or more of the statistic's
batch members are undefined.
name
Python ``str`` name prefixed to ``Ops`` created by this class.
Warnings
--------
If the log-pseudo-determinant is provided as an argument, it must be
of the penalty matrix ``pen``, **not** of the precision matrix.
Notes
-----
If they are not provided as arguments, the constructor computes ``rank`` and
``log_pdet`` based on the eigenvalues of the penalty matrix ``pen``. This is
an expensive operation and can be avoided by specifying the corresponding
arguments.
"""
prec = pen / jnp.expand_dims(var, axis=(-2, -1))
if rank is None or log_pdet is None:
evals = jax.numpy.linalg.eigvalsh(pen)
rank = _rank(evals) if rank is None else rank
log_pdet = _log_pdet(evals, rank=rank) if log_pdet is None else log_pdet
log_pdet_prec = log_pdet - rank * jnp.log(var)
mvnd = cls(
loc=loc,
prec=prec,
rank=rank,
log_pdet=log_pdet_prec,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=name,
)
return mvnd
[docs]
@classmethod
def from_penalty_smooth(
cls,
loc: Array,
smooth: Array,
pen: Array,
rank: Array | int | None = None,
log_pdet: Array | float | None = None,
validate_args: bool = False,
allow_nan_stats: bool = True,
name: str = "MultivariateNormalDegenerate",
) -> MultivariateNormalDegenerate:
"""
Alternative constructor based on a penalty matrix and an inverse smoothing
parameter.
Sometimes, the precision matrix of a degenerate multivariate normal
distribution is decomposed into a penalty matrix ``pen`` and an inverse
smoothing parameter ``var``. Using this constructor, a degenerate multivariate
normal distribution can be initialized from such a decomposition.
Parameters
----------
loc
The location (= mean) vector.
smooth
The smoothing (= inverse variance) parameter.
pen
The (potentially rank-deficient) penalty matrix.
rank
The rank of the penalty matrix. Optional.
log_pdet
The log-pseudo-determinant of the penalty matrix. Optional.
validate_args
Python ``bool``, default ``False``. When ``True`` distribution parameters
are checked for validity despite possibly degrading runtime performance.
When ``False`` invalid inputs may silently render incorrect outputs.
allow_nan_stats
Python ``bool``, default ``True``. When ``True``, statistics (e.g., mean,
mode, variance) use the value ``NaN`` to indicate the result is undefined.
When ``False``, an exception is raised if one or more of the statistic's
batch members are undefined.
name
Python ``str`` name prefixed to ``Ops`` created by this class.
Warnings
--------
If the log-pseudo-determinant is provided as an argument, it must be
of the penalty matrix ``pen``, **not** of the precision matrix.
Notes
-----
If they are not provided as arguments, the constructor computes ``rank`` and
``log_pdet`` based on the eigenvalues of the penalty matrix ``pen``. This is
an expensive operation and can be avoided by specifying the corresponding
arguments.
"""
prec = pen * jnp.expand_dims(smooth, axis=(-2, -1))
if rank is None or log_pdet is None:
evals = jax.numpy.linalg.eigvalsh(pen)
rank = _rank(evals) if rank is None else rank
log_pdet = _log_pdet(evals, rank=rank) if log_pdet is None else log_pdet
log_pdet_prec = log_pdet + rank * jnp.log(smooth)
mvnd = cls(
loc=loc,
prec=prec,
rank=rank,
log_pdet=log_pdet_prec,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=name,
)
return mvnd
@cached_property
def eig(self) -> tuple[Array, Array]:
"""Eigenvalues and eigenvectors of the distribution's precision matrices."""
return jnpla.eigh(self._prec)
@cached_property
def _sqrt_pcov(self) -> Array:
"""
Square roots of the distribution's pseudo-covariance matrices.
Let ``prec = Q @ A @ Q.T`` be the eigendecomposition of the precision matrix.
In essence, this property returns ``Q @ jnp.sqrt(1/A)``.
"""
eigenvalues, evecs = self.eig
sqrt_eval = jnp.sqrt(1 / eigenvalues)
sqrt_eval = jnp.where(eigenvalues < self._tol, 0.0, sqrt_eval)
event_shape = sqrt_eval.shape[-1]
shape = sqrt_eval.shape + (event_shape,)
r = tuple(range(event_shape))
diags = jnp.zeros(shape).at[..., r, r].set(sqrt_eval)
return evecs @ diags
@cached_property
def rank(self) -> Array | float:
"""Ranks of the distribution's precision matrices."""
if self._rank is not None:
return self._rank
evals, _ = self.eig
return _rank(evals, tol=self._tol)
@cached_property
def log_pdet(self) -> Array | float:
"""Log-pseudo-determinants of the distribution's precision matrices."""
if self._log_pdet is not None:
return self._log_pdet
evals, _ = self.eig
return _log_pdet(evals, self.rank, tol=self._tol)
@property
def prec(self) -> Array:
"""Precision matrices."""
return self._prec
@property
def loc(self) -> Array:
"""Locations."""
return self._loc
def _sample_n(self, n, seed=None) -> Array:
shape = [n] + self.batch_shape + self.event_shape
# The added dimension at the end here makes sure that matrix multiplication
# with the "sqrt pcov" matrices works out correctly.
z = jax.random.normal(key=seed, shape=shape + [1])
# Add a dimension at 0 for the sample size.
sqrt_pcov = jnp.expand_dims(self._sqrt_pcov, 0)
centered_samples = jnp.reshape(sqrt_pcov @ z, shape)
# Add a dimension at 0 for the sample size.
loc = jnp.expand_dims(self._loc, 0)
return centered_samples + loc
def _log_prob(self, x: Array) -> Array | float:
x = x - self._loc
# necessary for correct broadcasting in the quadratic form
x = jnp.expand_dims(x, axis=-2)
x_T = jnp.swapaxes(x, -2, -1)
prob1 = -jnp.squeeze(x @ self._prec @ x_T, axis=(-2, -1))
prob2 = self.rank * jnp.log(2 * jnp.pi) - self.log_pdet
return 0.5 * (prob1 - prob2)
def _event_shape(self):
return tf.TensorShape((jnp.shape(self._prec)[-1],))
def _event_shape_tensor(self):
return jnp.array((jnp.shape(self._prec)[-1],), dtype=jnp.int32)
def _batch_shape(self):
return tf.TensorShape(self._broadcast_batch_shape)
def _batch_shape_tensor(self):
return jnp.array(self._broadcast_batch_shape, dtype=jnp.int32)