Parameter transformations#
This tutorial builds on the linear regression tutorial. Here, we demonstrate how we can easily transform a parameter in our model to sample it with NUTS instead of a Gibbs Kernel.
First, let’s set up our model again. This is the same model as in the linear regression tutorial, so we will not go into the details here.
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
# We use distributions and bijectors from tensorflow probability
import tensorflow_probability.substrates.jax.distributions as tfd
import tensorflow_probability.substrates.jax.bijectors as tfb
import liesel.goose as gs
import liesel.model as lsl
rng = np.random.default_rng(42)
# data-generating process
n = 500
true_beta = np.array([1.0, 2.0])
true_sigma = 1.0
x0 = rng.uniform(size=n)
X_mat = np.c_[np.ones(n), x0]
y_vec = X_mat @ true_beta + rng.normal(scale=true_sigma, size=n)
# Model
# Part 1: Model for the mean
beta_prior = lsl.Dist(tfd.Normal, loc=0.0, scale=100.0)
beta = lsl.param(value=np.array([0.0, 0.0]), distribution=beta_prior,name="beta")
X = lsl.obs(X_mat, name="X")
mu = lsl.Var(lsl.Calc(jnp.dot, X, beta), name="mu")
# Part 2: Model for the standard deviation
a = lsl.Var(0.01, name="a")
b = lsl.Var(0.01, name="b")
sigma_sq_prior = lsl.Dist(tfd.InverseGamma, concentration=a, scale=b)
sigma_sq = lsl.param(value=10.0, distribution=sigma_sq_prior, name="sigma_sq")
sigma = lsl.Var(lsl.Calc(jnp.sqrt, sigma_sq), name="sigma")
# Observation model
y_dist = lsl.Dist(tfd.Normal, loc=mu, scale=sigma)
y = lsl.Var(y_vec, distribution=y_dist, name="y")
Now let’s try to sample the full parameter vector
\((\boldsymbol{\beta}', \sigma)'\) with a single NUTS kernel instead of
using a NUTS kernel for \(\boldsymbol{\beta}\) and a Gibbs kernel for
\(\sigma^2\). Since the standard deviation is a positive-valued parameter,
we need to log-transform it to sample it with a NUTS kernel. The
GraphBuilder
class provides the transform_parameter()
method for this purpose.
gb = lsl.GraphBuilder().add(y)
gb.transform(sigma_sq, tfb.Exp)
Var(name="sigma_sq_transformed")
model = gb.build_model()
lsl.plot_vars(model)
The response distribution still requires the standard deviation on the
original scale. The model graph shows that the back-transformation from
the logarithmic to the original scale is performed by a inserting the
sigma_sq_transformed
node and turning the sigma_sq
node into a weak
node. This weak node now deterministically depends on
sigma_sq_transformed
: its value is the back-transformed variance.
Now we can set up and run an MCMC algorithm with a NUTS kernel for all parameters.
builder = gs.EngineBuilder(seed=1339, num_chains=4)
builder.set_model(gs.LieselInterface(model))
builder.set_initial_values(model.state)
builder.add_kernel(gs.NUTSKernel(["beta", "sigma_sq_transformed"]))
builder.set_duration(warmup_duration=1000, posterior_duration=1000)
# by default, goose only stores the parameters specified in the kernels.
# let's also store the standard deviation on the original scale.
builder.positions_included = ["sigma_sq"]
engine = builder.build()
engine.sample_all_epochs()
Judging from the trace plots, it seems that all chains have converged.
results = engine.get_results()
g = gs.plot_trace(results)
We can also take a look at the summary table, which includes the original \(\sigma^2\) and the transformed \(\log(\sigma^2)\).
gs.Summary(results)
Parameter summary:
kernel | mean | sd | q_0.05 | q_0.5 | q_0.95 | sample_size | ess_bulk | ess_tail | rhat | ||
---|---|---|---|---|---|---|---|---|---|---|---|
parameter | index | ||||||||||
beta | (0,) | kernel_00 | 0.986 | 0.089 | 0.835 | 0.987 | 1.131 | 4000 | 1434.460 | 1666.438 | 1.001 |
(1,) | kernel_00 | 1.906 | 0.157 | 1.646 | 1.902 | 2.168 | 4000 | 1433.341 | 1896.020 | 1.001 | |
sigma_sq | () | \- | 1.045 | 0.068 | 0.939 | 1.045 | 1.159 | 4000 | 2210.900 | 1867.822 | 1.003 |
sigma_sq_transformed | () | kernel_00 | 0.042 | 0.065 | -0.062 | 0.044 | 0.148 | 4000 | 2210.906 | 1867.822 | 1.003 |
Error summary:
count | relative | ||||
---|---|---|---|---|---|
kernel | error_code | error_msg | phase | ||
kernel_00 | 1 | divergent transition | warmup | 70 | 0.018 |
posterior | 0 | 0.000 |
Finally, let’s check the autocorrelation of the samples.
g = gs.plot_cor(results)