Source code for liesel.goose.iwls_utils
"""
Utilities for the IWLS sampler.
"""
import jax
import jax.numpy as jnp
import jax.scipy
from .types import Array, KeyArray
triangular_solve = jax.lax.linalg.triangular_solve
[docs]def solve(chol_lhs: Array, rhs: Array) -> Array:
"""
Solves a system of linear equations `chol_lhs @ x = rhs` for x by applying
forward and backward substitution. Returns x.
Parameters
----------
chol_lhs
The lower triangular matrix of the Cholesky decomposition.
rhs
The right-hand side of the system.
"""
tmp = triangular_solve(chol_lhs, rhs, left_side=True, lower=True)
return triangular_solve(chol_lhs, tmp, lower=True)
[docs]def mvn_log_prob(x: Array, mean: Array, chol_inv_cov: Array) -> Array:
"""
Returns the log-density of a multivariate normal distribution.
Parameters
----------
x
The vector of observations.
mean
The mean vector.
chol_inv_cov
The lower triangular matrix of the Cholesky decomposition of the inverse
variance.
"""
standardized = (x - mean) @ chol_inv_cov
log_prob = jnp.sum(jax.scipy.stats.norm.logpdf(standardized))
adjustment = jnp.sum(jnp.log(jnp.diag(chol_inv_cov)))
return log_prob + adjustment
[docs]def mvn_sample(prng_key: KeyArray, mean: Array, chol_inv_cov: Array) -> Array:
"""
Samples from the normal distribution based on the Cholesky decomposition of the
inverse covariance matrix.
Parameters
----------
prng_key
The key for JAX' pseudo-random number generator.
mean
The mean vector.
chol_inv_cov
The lower triangular matrix of the Cholesky decomposition of the inverse
variance.
"""
standardized = jax.random.normal(prng_key, mean.shape)
centered = triangular_solve(chol_inv_cov, standardized, lower=True)
return centered + mean